diff --git a/internal/api/api.go b/internal/api/api.go index 26125d9..f7a7ec3 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -94,8 +94,9 @@ func (a *api) Routes() *mux.Router { r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST") - r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE") + r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.editWatcherHandler).Methods("PATCH") + r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET") r.HandleFunc("/v1/receipt", a.checkReceiptHandler).Methods("POST") r.HandleFunc("/v1/receipt/{apns}", a.checkReceiptHandler).Methods("POST") diff --git a/internal/api/watcher.go b/internal/api/watcher.go index 46f8c08..1065f42 100644 --- a/internal/api/watcher.go +++ b/internal/api/watcher.go @@ -12,6 +12,7 @@ import ( ) type watcherCriteria struct { + Author string Upvotes int64 Keyword string Flair string @@ -22,6 +23,7 @@ type createWatcherRequest struct { Type string User string Subreddit string + Label string Criteria watcherCriteria } @@ -74,15 +76,17 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) watcher := domain.Watcher{ + Label: cwr.Label, DeviceID: dev.ID, AccountID: account.ID, + Author: strings.ToLower(cwr.Criteria.Author), Upvotes: cwr.Criteria.Upvotes, Keyword: strings.ToLower(cwr.Criteria.Keyword), Flair: strings.ToLower(cwr.Criteria.Flair), Domain: strings.ToLower(cwr.Criteria.Domain), } - if cwr.Type == "subreddit" { + if cwr.Type == "subreddit" || cwr.Type == "trending" { srr, err := ac.SubredditAbout(cwr.Subreddit) if err != nil { a.errorResponse(w, r, 422, err.Error()) @@ -102,7 +106,13 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { } } - watcher.Type = domain.SubredditWatcher + switch cwr.Type { + case "subreddit": + watcher.Type = domain.SubredditWatcher + case "trending": + watcher.Type = domain.TrendingWatcher + } + watcher.WatcheeID = sr.ID } else if cwr.Type == "user" { urr, err := ac.UserAbout(cwr.User) @@ -139,6 +149,50 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } +func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) { + ctx := context.Background() + + vars := mux.Vars(r) + id, err := strconv.ParseInt(vars["watcherID"], 10, 64) + if err != nil { + a.errorResponse(w, r, 422, err.Error()) + return + } + + watcher, err := a.watcherRepo.GetByID(ctx, id) + if err != nil || watcher.Device.APNSToken != vars["apns"] { + a.errorResponse(w, r, 422, "nice try") + return + } + + ewr := &createWatcherRequest{ + Criteria: watcherCriteria{ + Upvotes: 0, + Keyword: "", + Flair: "", + Domain: "", + }, + } + + if err := json.NewDecoder(r.Body).Decode(ewr); err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + watcher.Label = ewr.Label + watcher.Upvotes = ewr.Criteria.Upvotes + watcher.Keyword = ewr.Criteria.Keyword + watcher.Flair = ewr.Criteria.Flair + watcher.Domain = ewr.Criteria.Domain + + if err := a.watcherRepo.Update(ctx, &watcher); err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + w.WriteHeader(http.StatusOK) +} + func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) { ctx := context.Background() @@ -149,16 +203,8 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) { return } - apns := vars["apns"] - - dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) - if err != nil { - a.errorResponse(w, r, 422, err.Error()) - return - } - watcher, err := a.watcherRepo.GetByID(ctx, id) - if err != nil || watcher.DeviceID != dev.ID { + if err != nil || watcher.Device.APNSToken != vars["apns"] { a.errorResponse(w, r, 422, "nice try") return } @@ -168,12 +214,15 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) { } type watcherItem struct { - ID int64 - Upvotes int64 - Keyword string - Flair string - Domain string - Hits int64 + ID int64 + CreatedAt float64 + Type string + Label string + Upvotes int64 + Keyword string + Flair string + Domain string + Hits int64 } func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) { @@ -192,12 +241,15 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) { wis := make([]watcherItem, len(watchers)) for i, watcher := range watchers { wi := watcherItem{ - ID: watcher.ID, - Upvotes: watcher.Upvotes, - Keyword: watcher.Keyword, - Flair: watcher.Flair, - Domain: watcher.Domain, - Hits: watcher.Hits, + ID: watcher.ID, + CreatedAt: watcher.CreatedAt, + Type: watcher.Type.String(), + Label: watcher.Label, + Upvotes: watcher.Upvotes, + Keyword: watcher.Keyword, + Flair: watcher.Flair, + Domain: watcher.Domain, + Hits: watcher.Hits, } wis[i] = wi diff --git a/internal/cmd/scheduler.go b/internal/cmd/scheduler.go index 7836aa4..e68f577 100644 --- a/internal/cmd/scheduler.go +++ b/internal/cmd/scheduler.go @@ -77,6 +77,11 @@ func SchedulerCmd(ctx context.Context) *cobra.Command { return err } + trendingQueue, err := queue.OpenQueue("trending") + if err != nil { + return err + } + userQueue, err := queue.OpenQueue("users") if err != nil { return err @@ -84,7 +89,7 @@ func SchedulerCmd(ctx context.Context) *cobra.Command { s := gocron.NewScheduler(time.UTC) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) }) - _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) }) + _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, []rmq.Queue{subredditQueue, trendingQueue}) }) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) }) _, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) }) _, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) }) @@ -273,7 +278,7 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli } -func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { +func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) { now := time.Now() ready := now.Unix() - subredditEnqueueTimeout @@ -326,15 +331,17 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats batchIds[i] = strconv.FormatInt(id, 10) } - if err = queue.Publish(batchIds...); err != nil { - logger.WithFields(logrus.Fields{ - "err": err, - }).Error("failed to enqueue subreddit") + for _, queue := range queues { + if err = queue.Publish(batchIds...); err != nil { + logger.WithFields(logrus.Fields{ + "queue": queue, + "err": err, + }).Error("failed to enqueue subreddit") + } } _ = statsd.Histogram("apollo.queue.subreddits.enqueued", float64(len(ids)), []string{}, 1) _ = statsd.Histogram("apollo.queue.subreddits.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1) - } func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) { diff --git a/internal/cmd/worker.go b/internal/cmd/worker.go index dd6f36c..beb1b55 100644 --- a/internal/cmd/worker.go +++ b/internal/cmd/worker.go @@ -15,6 +15,7 @@ var ( queues = map[string]worker.NewWorkerFn{ "notifications": worker.NewNotificationsWorker, "subreddits": worker.NewSubredditsWorker, + "trending": worker.NewTrendingWorker, "users": worker.NewUsersWorker, } ) diff --git a/internal/domain/watcher.go b/internal/domain/watcher.go index 0655e4f..a141473 100644 --- a/internal/domain/watcher.go +++ b/internal/domain/watcher.go @@ -7,29 +7,50 @@ type WatcherType int64 const ( SubredditWatcher WatcherType = iota UserWatcher + TrendingWatcher ) +func (wt WatcherType) String() string { + switch wt { + case SubredditWatcher: + return "subreddit" + case UserWatcher: + return "user" + case TrendingWatcher: + return "trending" + } + + return "unknown" +} + type Watcher struct { ID int64 CreatedAt float64 LastNotifiedAt float64 + Label string DeviceID int64 AccountID int64 Type WatcherType WatcheeID int64 + Author string Upvotes int64 Keyword string Flair string Domain string Hits int64 + + // Related models + Device Device + Account Account } type WatcherRepository interface { GetByID(ctx context.Context, id int64) (Watcher, error) GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error) GetByUserID(ctx context.Context, id int64) ([]Watcher, error) + GetByTrendingSubredditID(ctx context.Context, id int64) ([]Watcher, error) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error) Create(ctx context.Context, watcher *Watcher) error diff --git a/internal/reddit/client.go b/internal/reddit/client.go index 73de11b..e2a061d 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -244,6 +244,10 @@ func (rac *AuthenticatedClient) SubredditHot(subreddit string, opts ...RequestOp return rac.subredditPosts(subreddit, "hot", opts...) } +func (rac *AuthenticatedClient) SubredditTop(subreddit string, opts ...RequestOption) (*ListingResponse, error) { + return rac.subredditPosts(subreddit, "top", opts...) +} + func (rac *AuthenticatedClient) SubredditNew(subreddit string, opts ...RequestOption) (*ListingResponse, error) { return rac.subredditPosts(subreddit, "new", opts...) } diff --git a/internal/repository/postgres_subreddit.go b/internal/repository/postgres_subreddit.go index 7ed2cf2..3a7e88e 100644 --- a/internal/repository/postgres_subreddit.go +++ b/internal/repository/postgres_subreddit.go @@ -79,8 +79,7 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do query := ` INSERT INTO subreddits (subreddit_id, name) VALUES ($1, $2) - ON CONFLICT(subreddit_id) DO - UPDATE SET last_checked_at = $3 + ON CONFLICT(subreddit_id) DO NOTHING RETURNING id` return p.pool.QueryRow( @@ -88,6 +87,5 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do query, sr.SubredditID, sr.NormalizedName(), - sr.LastCheckedAt, ).Scan(&sr.ID) } diff --git a/internal/repository/postgres_watcher.go b/internal/repository/postgres_watcher.go index 0dc47c4..efde594 100644 --- a/internal/repository/postgres_watcher.go +++ b/internal/repository/postgres_watcher.go @@ -32,15 +32,23 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg &watcher.ID, &watcher.CreatedAt, &watcher.LastNotifiedAt, + &watcher.Label, &watcher.DeviceID, &watcher.AccountID, &watcher.Type, &watcher.WatcheeID, + &watcher.Author, &watcher.Upvotes, &watcher.Keyword, &watcher.Flair, &watcher.Domain, &watcher.Hits, + &watcher.Device.ID, + &watcher.Device.APNSToken, + &watcher.Device.Sandbox, + &watcher.Account.ID, + &watcher.Account.AccessToken, + &watcher.Account.RefreshToken, ); err != nil { return nil, err } @@ -51,9 +59,31 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) { query := ` - SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits + SELECT + watchers.id, + watchers.created_at, + watchers.last_notified_at, + watchers.label, + watchers.device_id, + watchers.account_id, + watchers.type, + watchers.watchee_id, + watchers.author, + watchers.upvotes, + watchers.keyword, + watchers.flair, + watchers.domain, + watchers.hits, + devices.id, + devices.apns_token, + devices.sandbox, + accounts.id, + accounts.access_token, + accounts.refresh_token FROM watchers - WHERE id = $1` + INNER JOIN devices ON watchers.device_id = devices.id + INNER JOIN accounts ON watchers.account_id = accounts.id + WHERE watchers.id = $1` watchers, err := p.fetch(ctx, query, id) @@ -68,13 +98,39 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) ([]domain.Watcher, error) { query := ` - SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits + SELECT + watchers.id, + watchers.created_at, + watchers.last_notified_at, + watchers.label, + watchers.device_id, + watchers.account_id, + watchers.type, + watchers.watchee_id, + watchers.author, + watchers.upvotes, + watchers.keyword, + watchers.flair, + watchers.domain, + watchers.hits, + devices.id, + devices.apns_token, + devices.sandbox, + accounts.id, + accounts.access_token, + accounts.refresh_token FROM watchers - WHERE type = $1 AND watchee_id = $2` + INNER JOIN devices ON watchers.device_id = devices.id + INNER JOIN accounts ON watchers.account_id = accounts.id + WHERE watchers.type = $1 AND watchers.watchee_id = $2` return p.fetch(ctx, query, typ, id) } +func (p *postgresWatcherRepository) GetByTrendingSubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) { + return p.GetByTypeAndWatcheeID(ctx, domain.TrendingWatcher, id) +} + func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) { return p.GetByTypeAndWatcheeID(ctx, domain.SubredditWatcher, id) } @@ -88,16 +144,24 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c SELECT watchers.id, watchers.created_at, - watchers.last_notified_at + watchers.last_notified_at, + watchers.label, watchers.device_id, watchers.account_id, watchers.type, watchers.watchee_id, + watchers.author, watchers.upvotes, watchers.keyword, watchers.flair, watchers.domain, - watchers.hits + watchers.hits, + devices.id, + devices.apns_token, + devices.sandbox, + accounts.id, + accounts.access_token, + accounts.refresh_token FROM watchers INNER JOIN accounts ON watchers.account_id = accounts.id INNER JOIN devices ON watchers.device_id = devices.id @@ -113,18 +177,20 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain. query := ` INSERT INTO watchers - (created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain) - VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9) + (created_at, last_notified_at, label, device_id, account_id, type, watchee_id, author, upvotes, keyword, flair, domain) + VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id` return p.pool.QueryRow( ctx, query, now, + watcher.Label, watcher.DeviceID, watcher.AccountID, watcher.Type, watcher.WatcheeID, + watcher.Author, watcher.Upvotes, watcher.Keyword, watcher.Flair, @@ -135,20 +201,24 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain. func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.Watcher) error { query := ` UPDATE watchers - SET upvotes = $2, - keyword = $3, - flair = $4, - domain = $5, + SET author = $2, + upvotes = $3, + keyword = $4, + flair = $5, + domain = $6, + label = $7 WHERE id = $1` res, err := p.pool.Exec( ctx, query, watcher.ID, + watcher.Author, watcher.Upvotes, watcher.Keyword, watcher.Flair, watcher.Domain, + watcher.Label, ) if res.RowsAffected() != 1 { diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index cb4c3af..ae78013 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -330,6 +330,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { "status": res.StatusCode, "reason": res.Reason, }).Error("failed to send notification") + + // Delete device as notifications might have been disabled here + _ = nc.deviceRepo.Delete(ctx, device.APNSToken) } else { _ = nc.statsd.Incr("apns.notification.sent", []string{}, 1) nc.logger.WithFields(logrus.Fields{ diff --git a/internal/worker/subreddits.go b/internal/worker/subreddits.go index 4ab9cbd..2f14ab8 100644 --- a/internal/worker/subreddits.go +++ b/internal/worker/subreddits.go @@ -40,6 +40,8 @@ type subredditsWorker struct { watcherRepo domain.WatcherRepository } +const subredditNotificationTitleFormat = "📣 %s" + func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker { reddit := reddit.NewClient( os.Getenv("REDDIT_CLIENT_ID"), @@ -170,7 +172,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { if len(watchers) == 0 { sc.logger.WithFields(logrus.Fields{ "subreddit#id": subreddit.ID, - }).Info("no watchers for subreddit, skipping") + }).Debug("no watchers for subreddit, finishing job") return } @@ -298,11 +300,12 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { "count": len(posts), }).Debug("checking posts for hits") for _, post := range posts { + lowcaseAuthor := strings.ToLower(post.Author) lowcaseTitle := strings.ToLower(post.Title) lowcaseFlair := strings.ToLower(post.Flair) lowcaseDomain := strings.ToLower(post.URL) - ids := []int64{} + notifs := []domain.Watcher{} for _, watcher := range watchers { // Make sure we only alert on posts created after the search @@ -312,6 +315,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { matched := true + if watcher.Author != "" && lowcaseAuthor != watcher.Author { + matched = false + } + if watcher.Upvotes > 0 && post.Score < watcher.Upvotes { matched = false } @@ -363,10 +370,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { }).Debug("got a hit") sc.redis.SetEX(ctx, lockKey, true, 24*time.Hour) - ids = append(ids, watcher.DeviceID) + notifs = append(notifs, watcher) } - if len(ids) == 0 { + if len(notifs) == 0 { continue } @@ -374,19 +381,22 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { "subreddit#id": subreddit.ID, "subreddit#name": subreddit.Name, "post#id": post.ID, - "count": len(ids), + "count": len(notifs), }).Debug("got hits for post") - notification := &apns2.Notification{} - notification.Topic = "com.christianselig.Apollo" - notification.Payload = payloadFromPost(post) + payload := payloadFromPost(post) - for _, id := range ids { - device, _ := sc.deviceRepo.GetByID(ctx, id) - notification.DeviceToken = device.APNSToken + for _, watcher := range notifs { + title := fmt.Sprintf(subredditNotificationTitleFormat, watcher.Label) + payload.AlertTitle(title) + + notification := &apns2.Notification{} + notification.Topic = "com.christianselig.Apollo" + notification.DeviceToken = watcher.Device.APNSToken + notification.Payload = payload client := sc.apnsProduction - if device.Sandbox { + if watcher.Device.Sandbox { client = sc.apnsSandbox } @@ -395,7 +405,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { _ = sc.statsd.Incr("apns.notification.errors", []string{}, 1) sc.logger.WithFields(logrus.Fields{ "subreddit#id": subreddit.ID, - "device#id": device.ID, + "device#id": watcher.Device.ID, "err": err, "status": res.StatusCode, "reason": res.Reason, @@ -404,8 +414,8 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { _ = sc.statsd.Incr("apns.notification.sent", []string{}, 1) sc.logger.WithFields(logrus.Fields{ "subreddit#id": subreddit.ID, - "device#id": device.ID, - "device#token": device.APNSToken, + "device#id": watcher.Device.ID, + "device#token": watcher.Device.APNSToken, }).Info("sent notification") } } @@ -418,11 +428,11 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { } func payloadFromPost(post *reddit.Thing) *payload.Payload { - title := fmt.Sprintf("📣 Subreddit Watch (r/%s)", post.Subreddit) + subtitle := fmt.Sprintf("r/%s", post.Subreddit) payload := payload. NewPayload(). - AlertTitle(title). + AlertSubtitle(subtitle). AlertBody(post.Title). AlertSummaryArg(post.Subreddit). Category("post-watch"). diff --git a/internal/worker/trending.go b/internal/worker/trending.go new file mode 100644 index 0000000..6712338 --- /dev/null +++ b/internal/worker/trending.go @@ -0,0 +1,336 @@ +package worker + +import ( + "context" + "fmt" + "math/rand" + "os" + "strconv" + "time" + + "github.com/DataDog/datadog-go/statsd" + "github.com/adjust/rmq/v4" + "github.com/go-redis/redis/v8" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/sideshow/apns2" + "github.com/sideshow/apns2/payload" + "github.com/sideshow/apns2/token" + "github.com/sirupsen/logrus" + + "github.com/christianselig/apollo-backend/internal/domain" + "github.com/christianselig/apollo-backend/internal/reddit" + "github.com/christianselig/apollo-backend/internal/repository" +) + +type trendingWorker struct { + logger *logrus.Logger + statsd *statsd.Client + redis *redis.Client + queue rmq.Connection + reddit *reddit.Client + apns *token.Token + + consumers int + + accountRepo domain.AccountRepository + deviceRepo domain.DeviceRepository + subredditRepo domain.SubredditRepository + watcherRepo domain.WatcherRepository +} + +const trendingNotificationTitleFormat = "🔥 Trending in r/%s" + +func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker { + reddit := reddit.NewClient( + os.Getenv("REDDIT_CLIENT_ID"), + os.Getenv("REDDIT_CLIENT_SECRET"), + statsd, + consumers, + ) + + var apns *token.Token + { + authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH")) + if err != nil { + panic(err) + } + + apns = &token.Token{ + AuthKey: authKey, + KeyID: os.Getenv("APPLE_KEY_ID"), + TeamID: os.Getenv("APPLE_TEAM_ID"), + } + } + + return &trendingWorker{ + logger, + statsd, + redis, + queue, + reddit, + apns, + consumers, + + repository.NewPostgresAccount(db), + repository.NewPostgresDevice(db), + repository.NewPostgresSubreddit(db), + repository.NewPostgresWatcher(db), + } +} + +func (tw *trendingWorker) Start() error { + queue, err := tw.queue.OpenQueue("trending") + if err != nil { + return err + } + + tw.logger.WithFields(logrus.Fields{ + "numConsumers": tw.consumers, + }).Info("starting up trending worker") + + prefetchLimit := int64(tw.consumers * 2) + + if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil { + return err + } + + host, _ := os.Hostname() + + for i := 0; i < tw.consumers; i++ { + name := fmt.Sprintf("consumer %s-%d", host, i) + + consumer := NewTrendingConsumer(tw, i) + if _, err := queue.AddConsumer(name, consumer); err != nil { + return err + } + } + + return nil +} + +func (tw *trendingWorker) Stop() { + <-tw.queue.StopAllConsuming() // wait for all Consume() calls to finish +} + +type trendingConsumer struct { + *trendingWorker + tag int + + apnsSandbox *apns2.Client + apnsProduction *apns2.Client +} + +func NewTrendingConsumer(tw *trendingWorker, tag int) *trendingConsumer { + return &trendingConsumer{ + tw, + tag, + apns2.NewTokenClient(tw.apns), + apns2.NewTokenClient(tw.apns).Production(), + } +} + +func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { + ctx := context.Background() + + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": delivery.Payload(), + }).Debug("starting job") + + id, err := strconv.ParseInt(delivery.Payload(), 10, 64) + if err != nil { + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": delivery.Payload(), + "err": err, + }).Error("failed to parse subreddit ID") + + _ = delivery.Reject() + return + } + + defer func() { _ = delivery.Ack() }() + + subreddit, err := tc.subredditRepo.GetByID(ctx, id) + if err != nil { + tc.logger.WithFields(logrus.Fields{ + "err": err, + }).Error("failed to fetch subreddit from database") + return + } + + watchers, err := tc.watcherRepo.GetByTrendingSubredditID(ctx, subreddit.ID) + if err != nil { + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "err": err, + }).Error("failed to fetch watchers from database") + return + } + + if len(watchers) == 0 { + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + }).Debug("no watchers for trending, finishing job") + return + } + + // Grab last month's top posts so we calculate a trending average + i := rand.Intn(len(watchers)) + watcher := watchers[i] + rac := tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken) + + tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week")) + if err != nil { + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "subreddit#name": subreddit.Name, + "err": err, + }).Error("failed to fetch month's top posts") + return + } + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "subreddit#name": subreddit.Name, + "count": tps.Count, + }).Debug("loaded month's hot posts") + + if tps.Count == 0 { + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + }).Debug("no top posts for subreddit, returning") + return + } + + if tps.Count < 20 { + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + }).Debug("not enough posts, returning") + return + } + + middlePost := tps.Count / 2 + medianScore := tps.Children[middlePost].Score + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "score": medianScore, + }).Debug("calculated median score") + + // Grab hot posts and filter out anything that's > 2 days old + i = rand.Intn(len(watchers)) + watcher = watchers[i] + rac = tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken) + hps, err := rac.SubredditHot(subreddit.Name) + if err != nil { + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "subreddit#name": subreddit.Name, + "err": err, + }).Error("failed to fetch hot posts") + return + } + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "subreddit#name": subreddit.Name, + "count": hps.Count, + }).Debug("loaded hot posts") + + // Trending only counts for posts less than 2 days old + threshold := float64(time.Now().AddDate(0, 0, -2).UTC().Unix()) + + for _, post := range hps.Children { + if post.Score < medianScore { + continue + } + + if post.CreatedAt < threshold { + break + } + + notification := &apns2.Notification{} + notification.Topic = "com.christianselig.Apollo" + notification.Payload = payloadFromTrendingPost(post) + + for _, watcher := range watchers { + if watcher.CreatedAt > post.CreatedAt { + continue + } + + lockKey := fmt.Sprintf("watcher:trending:%d:%s", watcher.DeviceID, post.ID) + notified, _ := tc.redis.Get(ctx, lockKey).Bool() + + if notified { + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "subreddit#name": subreddit.Name, + "watcher#id": watcher.ID, + "post#id": post.ID, + }).Debug("already notified, skipping") + continue + } + + tc.redis.SetEX(ctx, lockKey, true, 48*time.Hour) + + if err := tc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil { + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "watcher#id": watcher.ID, + "err": err, + }).Error("could not increment hits") + return + } + + notification.DeviceToken = watcher.Device.APNSToken + + client := tc.apnsProduction + if watcher.Device.Sandbox { + client = tc.apnsSandbox + } + + res, err := client.Push(notification) + if err != nil { + _ = tc.statsd.Incr("apns.notification.errors", []string{}, 1) + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "post#id": post.ID, + "device#id": watcher.Device.ID, + "err": err, + "status": res.StatusCode, + "reason": res.Reason, + }).Error("failed to send notification") + } else { + _ = tc.statsd.Incr("apns.notification.sent", []string{}, 1) + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "post#id": post.ID, + "device#id": watcher.Device.ID, + "device#token": watcher.Device.APNSToken, + }).Info("sent notification") + } + } + } + + tc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "subreddit#name": subreddit.Name, + }).Debug("finishing job") +} + +func payloadFromTrendingPost(post *reddit.Thing) *payload.Payload { + title := fmt.Sprintf(trendingNotificationTitleFormat, post.Subreddit) + + payload := payload. + NewPayload(). + AlertTitle(title). + AlertSubtitle(post.Title). + AlertBody(post.Title). + AlertSummaryArg(post.Subreddit). + Category("post-watch"). + Custom("post_title", post.Title). + Custom("post_id", post.ID). + Custom("subreddit", post.Subreddit). + Custom("author", post.Author). + Custom("post_age", post.CreatedAt). + MutableContent(). + Sound("traloop.wav") + + return payload +} diff --git a/internal/worker/users.go b/internal/worker/users.go index 0675c2b..50c34ab 100644 --- a/internal/worker/users.go +++ b/internal/worker/users.go @@ -38,6 +38,8 @@ type usersWorker struct { watcherRepo domain.WatcherRepository } +const userNotificationTitleFormat = "👨\u200d🚀 %s" + func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker { reddit := reddit.NewClient( os.Getenv("REDDIT_CLIENT_ID"), @@ -224,9 +226,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { continue } - notification := &apns2.Notification{} - notification.Topic = "com.christianselig.Apollo" - notification.Payload = payloadFromUserPost(post) + payload := payloadFromUserPost(post) for _, watcher := range watchers { // Make sure we only alert on activities created after the search @@ -248,6 +248,13 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { } device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID) + + title := fmt.Sprintf(userNotificationTitleFormat, watcher.Label) + payload.AlertTitle(title) + + notification := &apns2.Notification{} + notification.Topic = "com.christianselig.Apollo" + notification.Payload = payload notification.DeviceToken = device.APNSToken client := uc.apnsProduction @@ -283,12 +290,10 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { } func payloadFromUserPost(post *reddit.Thing) *payload.Payload { - title := fmt.Sprintf("👨\u200d🚀 User post! (u/%s)", post.Author) - payload := payload. NewPayload(). - AlertTitle(title). AlertBody(post.Title). + AlertSubtitle(post.Author). AlertSummaryArg(post.Author). Category("user-watch"). Custom("post_title", post.Title).