add subreddit trending posts watcher

This commit is contained in:
Andre Medeiros 2021-10-10 11:51:42 -04:00
parent 009d60dc2f
commit 65792abd94
12 changed files with 577 additions and 69 deletions

View file

@ -94,8 +94,9 @@ func (a *api) Routes() *mux.Router {
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE") r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.editWatcherHandler).Methods("PATCH")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET")
r.HandleFunc("/v1/receipt", a.checkReceiptHandler).Methods("POST") r.HandleFunc("/v1/receipt", a.checkReceiptHandler).Methods("POST")
r.HandleFunc("/v1/receipt/{apns}", a.checkReceiptHandler).Methods("POST") r.HandleFunc("/v1/receipt/{apns}", a.checkReceiptHandler).Methods("POST")

View file

@ -12,6 +12,7 @@ import (
) )
type watcherCriteria struct { type watcherCriteria struct {
Author string
Upvotes int64 Upvotes int64
Keyword string Keyword string
Flair string Flair string
@ -22,6 +23,7 @@ type createWatcherRequest struct {
Type string Type string
User string User string
Subreddit string Subreddit string
Label string
Criteria watcherCriteria Criteria watcherCriteria
} }
@ -74,15 +76,17 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
watcher := domain.Watcher{ watcher := domain.Watcher{
Label: cwr.Label,
DeviceID: dev.ID, DeviceID: dev.ID,
AccountID: account.ID, AccountID: account.ID,
Author: strings.ToLower(cwr.Criteria.Author),
Upvotes: cwr.Criteria.Upvotes, Upvotes: cwr.Criteria.Upvotes,
Keyword: strings.ToLower(cwr.Criteria.Keyword), Keyword: strings.ToLower(cwr.Criteria.Keyword),
Flair: strings.ToLower(cwr.Criteria.Flair), Flair: strings.ToLower(cwr.Criteria.Flair),
Domain: strings.ToLower(cwr.Criteria.Domain), Domain: strings.ToLower(cwr.Criteria.Domain),
} }
if cwr.Type == "subreddit" { if cwr.Type == "subreddit" || cwr.Type == "trending" {
srr, err := ac.SubredditAbout(cwr.Subreddit) srr, err := ac.SubredditAbout(cwr.Subreddit)
if err != nil { if err != nil {
a.errorResponse(w, r, 422, err.Error()) a.errorResponse(w, r, 422, err.Error())
@ -102,7 +106,13 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
switch cwr.Type {
case "subreddit":
watcher.Type = domain.SubredditWatcher watcher.Type = domain.SubredditWatcher
case "trending":
watcher.Type = domain.TrendingWatcher
}
watcher.WatcheeID = sr.ID watcher.WatcheeID = sr.ID
} else if cwr.Type == "user" { } else if cwr.Type == "user" {
urr, err := ac.UserAbout(cwr.User) urr, err := ac.UserAbout(cwr.User)
@ -139,6 +149,50 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
vars := mux.Vars(r)
id, err := strconv.ParseInt(vars["watcherID"], 10, 64)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
watcher, err := a.watcherRepo.GetByID(ctx, id)
if err != nil || watcher.Device.APNSToken != vars["apns"] {
a.errorResponse(w, r, 422, "nice try")
return
}
ewr := &createWatcherRequest{
Criteria: watcherCriteria{
Upvotes: 0,
Keyword: "",
Flair: "",
Domain: "",
},
}
if err := json.NewDecoder(r.Body).Decode(ewr); err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
watcher.Label = ewr.Label
watcher.Upvotes = ewr.Criteria.Upvotes
watcher.Keyword = ewr.Criteria.Keyword
watcher.Flair = ewr.Criteria.Flair
watcher.Domain = ewr.Criteria.Domain
if err := a.watcherRepo.Update(ctx, &watcher); err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
w.WriteHeader(http.StatusOK)
}
func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) { func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
ctx := context.Background() ctx := context.Background()
@ -149,16 +203,8 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
apns := vars["apns"]
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
watcher, err := a.watcherRepo.GetByID(ctx, id) watcher, err := a.watcherRepo.GetByID(ctx, id)
if err != nil || watcher.DeviceID != dev.ID { if err != nil || watcher.Device.APNSToken != vars["apns"] {
a.errorResponse(w, r, 422, "nice try") a.errorResponse(w, r, 422, "nice try")
return return
} }
@ -169,6 +215,9 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
type watcherItem struct { type watcherItem struct {
ID int64 ID int64
CreatedAt float64
Type string
Label string
Upvotes int64 Upvotes int64
Keyword string Keyword string
Flair string Flair string
@ -193,6 +242,9 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
for i, watcher := range watchers { for i, watcher := range watchers {
wi := watcherItem{ wi := watcherItem{
ID: watcher.ID, ID: watcher.ID,
CreatedAt: watcher.CreatedAt,
Type: watcher.Type.String(),
Label: watcher.Label,
Upvotes: watcher.Upvotes, Upvotes: watcher.Upvotes,
Keyword: watcher.Keyword, Keyword: watcher.Keyword,
Flair: watcher.Flair, Flair: watcher.Flair,

View file

@ -77,6 +77,11 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
return err return err
} }
trendingQueue, err := queue.OpenQueue("trending")
if err != nil {
return err
}
userQueue, err := queue.OpenQueue("users") userQueue, err := queue.OpenQueue("users")
if err != nil { if err != nil {
return err return err
@ -84,7 +89,7 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
s := gocron.NewScheduler(time.UTC) s := gocron.NewScheduler(time.UTC)
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) }) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) }) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, []rmq.Queue{subredditQueue, trendingQueue}) })
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) }) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) })
_, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) }) _, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) })
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) }) _, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) })
@ -273,7 +278,7 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
} }
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) {
now := time.Now() now := time.Now()
ready := now.Unix() - subredditEnqueueTimeout ready := now.Unix() - subredditEnqueueTimeout
@ -326,15 +331,17 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
batchIds[i] = strconv.FormatInt(id, 10) batchIds[i] = strconv.FormatInt(id, 10)
} }
for _, queue := range queues {
if err = queue.Publish(batchIds...); err != nil { if err = queue.Publish(batchIds...); err != nil {
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"queue": queue,
"err": err, "err": err,
}).Error("failed to enqueue subreddit") }).Error("failed to enqueue subreddit")
} }
}
_ = statsd.Histogram("apollo.queue.subreddits.enqueued", float64(len(ids)), []string{}, 1) _ = statsd.Histogram("apollo.queue.subreddits.enqueued", float64(len(ids)), []string{}, 1)
_ = statsd.Histogram("apollo.queue.subreddits.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1) _ = statsd.Histogram("apollo.queue.subreddits.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1)
} }
func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) { func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) {

View file

@ -15,6 +15,7 @@ var (
queues = map[string]worker.NewWorkerFn{ queues = map[string]worker.NewWorkerFn{
"notifications": worker.NewNotificationsWorker, "notifications": worker.NewNotificationsWorker,
"subreddits": worker.NewSubredditsWorker, "subreddits": worker.NewSubredditsWorker,
"trending": worker.NewTrendingWorker,
"users": worker.NewUsersWorker, "users": worker.NewUsersWorker,
} }
) )

View file

@ -7,29 +7,50 @@ type WatcherType int64
const ( const (
SubredditWatcher WatcherType = iota SubredditWatcher WatcherType = iota
UserWatcher UserWatcher
TrendingWatcher
) )
func (wt WatcherType) String() string {
switch wt {
case SubredditWatcher:
return "subreddit"
case UserWatcher:
return "user"
case TrendingWatcher:
return "trending"
}
return "unknown"
}
type Watcher struct { type Watcher struct {
ID int64 ID int64
CreatedAt float64 CreatedAt float64
LastNotifiedAt float64 LastNotifiedAt float64
Label string
DeviceID int64 DeviceID int64
AccountID int64 AccountID int64
Type WatcherType Type WatcherType
WatcheeID int64 WatcheeID int64
Author string
Upvotes int64 Upvotes int64
Keyword string Keyword string
Flair string Flair string
Domain string Domain string
Hits int64 Hits int64
// Related models
Device Device
Account Account
} }
type WatcherRepository interface { type WatcherRepository interface {
GetByID(ctx context.Context, id int64) (Watcher, error) GetByID(ctx context.Context, id int64) (Watcher, error)
GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error) GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error)
GetByUserID(ctx context.Context, id int64) ([]Watcher, error) GetByUserID(ctx context.Context, id int64) ([]Watcher, error)
GetByTrendingSubredditID(ctx context.Context, id int64) ([]Watcher, error)
GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error)
Create(ctx context.Context, watcher *Watcher) error Create(ctx context.Context, watcher *Watcher) error

View file

@ -244,6 +244,10 @@ func (rac *AuthenticatedClient) SubredditHot(subreddit string, opts ...RequestOp
return rac.subredditPosts(subreddit, "hot", opts...) return rac.subredditPosts(subreddit, "hot", opts...)
} }
func (rac *AuthenticatedClient) SubredditTop(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(subreddit, "top", opts...)
}
func (rac *AuthenticatedClient) SubredditNew(subreddit string, opts ...RequestOption) (*ListingResponse, error) { func (rac *AuthenticatedClient) SubredditNew(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(subreddit, "new", opts...) return rac.subredditPosts(subreddit, "new", opts...)
} }

View file

@ -79,8 +79,7 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
query := ` query := `
INSERT INTO subreddits (subreddit_id, name) INSERT INTO subreddits (subreddit_id, name)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT(subreddit_id) DO ON CONFLICT(subreddit_id) DO NOTHING
UPDATE SET last_checked_at = $3
RETURNING id` RETURNING id`
return p.pool.QueryRow( return p.pool.QueryRow(
@ -88,6 +87,5 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
query, query,
sr.SubredditID, sr.SubredditID,
sr.NormalizedName(), sr.NormalizedName(),
sr.LastCheckedAt,
).Scan(&sr.ID) ).Scan(&sr.ID)
} }

View file

@ -32,15 +32,23 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
&watcher.ID, &watcher.ID,
&watcher.CreatedAt, &watcher.CreatedAt,
&watcher.LastNotifiedAt, &watcher.LastNotifiedAt,
&watcher.Label,
&watcher.DeviceID, &watcher.DeviceID,
&watcher.AccountID, &watcher.AccountID,
&watcher.Type, &watcher.Type,
&watcher.WatcheeID, &watcher.WatcheeID,
&watcher.Author,
&watcher.Upvotes, &watcher.Upvotes,
&watcher.Keyword, &watcher.Keyword,
&watcher.Flair, &watcher.Flair,
&watcher.Domain, &watcher.Domain,
&watcher.Hits, &watcher.Hits,
&watcher.Device.ID,
&watcher.Device.APNSToken,
&watcher.Device.Sandbox,
&watcher.Account.ID,
&watcher.Account.AccessToken,
&watcher.Account.RefreshToken,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -51,9 +59,31 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) { func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) {
query := ` query := `
SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits SELECT
watchers.id,
watchers.created_at,
watchers.last_notified_at,
watchers.label,
watchers.device_id,
watchers.account_id,
watchers.type,
watchers.watchee_id,
watchers.author,
watchers.upvotes,
watchers.keyword,
watchers.flair,
watchers.domain,
watchers.hits,
devices.id,
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.access_token,
accounts.refresh_token
FROM watchers FROM watchers
WHERE id = $1` INNER JOIN devices ON watchers.device_id = devices.id
INNER JOIN accounts ON watchers.account_id = accounts.id
WHERE watchers.id = $1`
watchers, err := p.fetch(ctx, query, id) watchers, err := p.fetch(ctx, query, id)
@ -68,13 +98,39 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) ([]domain.Watcher, error) { func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) ([]domain.Watcher, error) {
query := ` query := `
SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits SELECT
watchers.id,
watchers.created_at,
watchers.last_notified_at,
watchers.label,
watchers.device_id,
watchers.account_id,
watchers.type,
watchers.watchee_id,
watchers.author,
watchers.upvotes,
watchers.keyword,
watchers.flair,
watchers.domain,
watchers.hits,
devices.id,
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.access_token,
accounts.refresh_token
FROM watchers FROM watchers
WHERE type = $1 AND watchee_id = $2` INNER JOIN devices ON watchers.device_id = devices.id
INNER JOIN accounts ON watchers.account_id = accounts.id
WHERE watchers.type = $1 AND watchers.watchee_id = $2`
return p.fetch(ctx, query, typ, id) return p.fetch(ctx, query, typ, id)
} }
func (p *postgresWatcherRepository) GetByTrendingSubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
return p.GetByTypeAndWatcheeID(ctx, domain.TrendingWatcher, id)
}
func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) { func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
return p.GetByTypeAndWatcheeID(ctx, domain.SubredditWatcher, id) return p.GetByTypeAndWatcheeID(ctx, domain.SubredditWatcher, id)
} }
@ -88,16 +144,24 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
SELECT SELECT
watchers.id, watchers.id,
watchers.created_at, watchers.created_at,
watchers.last_notified_at watchers.last_notified_at,
watchers.label,
watchers.device_id, watchers.device_id,
watchers.account_id, watchers.account_id,
watchers.type, watchers.type,
watchers.watchee_id, watchers.watchee_id,
watchers.author,
watchers.upvotes, watchers.upvotes,
watchers.keyword, watchers.keyword,
watchers.flair, watchers.flair,
watchers.domain, watchers.domain,
watchers.hits watchers.hits,
devices.id,
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.access_token,
accounts.refresh_token
FROM watchers FROM watchers
INNER JOIN accounts ON watchers.account_id = accounts.id INNER JOIN accounts ON watchers.account_id = accounts.id
INNER JOIN devices ON watchers.device_id = devices.id INNER JOIN devices ON watchers.device_id = devices.id
@ -113,18 +177,20 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
query := ` query := `
INSERT INTO watchers INSERT INTO watchers
(created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain) (created_at, last_notified_at, label, device_id, account_id, type, watchee_id, author, upvotes, keyword, flair, domain)
VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9) VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id` RETURNING id`
return p.pool.QueryRow( return p.pool.QueryRow(
ctx, ctx,
query, query,
now, now,
watcher.Label,
watcher.DeviceID, watcher.DeviceID,
watcher.AccountID, watcher.AccountID,
watcher.Type, watcher.Type,
watcher.WatcheeID, watcher.WatcheeID,
watcher.Author,
watcher.Upvotes, watcher.Upvotes,
watcher.Keyword, watcher.Keyword,
watcher.Flair, watcher.Flair,
@ -135,20 +201,24 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.Watcher) error { func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.Watcher) error {
query := ` query := `
UPDATE watchers UPDATE watchers
SET upvotes = $2, SET author = $2,
keyword = $3, upvotes = $3,
flair = $4, keyword = $4,
domain = $5, flair = $5,
domain = $6,
label = $7
WHERE id = $1` WHERE id = $1`
res, err := p.pool.Exec( res, err := p.pool.Exec(
ctx, ctx,
query, query,
watcher.ID, watcher.ID,
watcher.Author,
watcher.Upvotes, watcher.Upvotes,
watcher.Keyword, watcher.Keyword,
watcher.Flair, watcher.Flair,
watcher.Domain, watcher.Domain,
watcher.Label,
) )
if res.RowsAffected() != 1 { if res.RowsAffected() != 1 {

View file

@ -330,6 +330,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
"status": res.StatusCode, "status": res.StatusCode,
"reason": res.Reason, "reason": res.Reason,
}).Error("failed to send notification") }).Error("failed to send notification")
// Delete device as notifications might have been disabled here
_ = nc.deviceRepo.Delete(ctx, device.APNSToken)
} else { } else {
_ = nc.statsd.Incr("apns.notification.sent", []string{}, 1) _ = nc.statsd.Incr("apns.notification.sent", []string{}, 1)
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{

View file

@ -40,6 +40,8 @@ type subredditsWorker struct {
watcherRepo domain.WatcherRepository watcherRepo domain.WatcherRepository
} }
const subredditNotificationTitleFormat = "📣 %s"
func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker { func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient( reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_ID"),
@ -170,7 +172,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
if len(watchers) == 0 { if len(watchers) == 0 {
sc.logger.WithFields(logrus.Fields{ sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID, "subreddit#id": subreddit.ID,
}).Info("no watchers for subreddit, skipping") }).Debug("no watchers for subreddit, finishing job")
return return
} }
@ -298,11 +300,12 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
"count": len(posts), "count": len(posts),
}).Debug("checking posts for hits") }).Debug("checking posts for hits")
for _, post := range posts { for _, post := range posts {
lowcaseAuthor := strings.ToLower(post.Author)
lowcaseTitle := strings.ToLower(post.Title) lowcaseTitle := strings.ToLower(post.Title)
lowcaseFlair := strings.ToLower(post.Flair) lowcaseFlair := strings.ToLower(post.Flair)
lowcaseDomain := strings.ToLower(post.URL) lowcaseDomain := strings.ToLower(post.URL)
ids := []int64{} notifs := []domain.Watcher{}
for _, watcher := range watchers { for _, watcher := range watchers {
// Make sure we only alert on posts created after the search // Make sure we only alert on posts created after the search
@ -312,6 +315,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
matched := true matched := true
if watcher.Author != "" && lowcaseAuthor != watcher.Author {
matched = false
}
if watcher.Upvotes > 0 && post.Score < watcher.Upvotes { if watcher.Upvotes > 0 && post.Score < watcher.Upvotes {
matched = false matched = false
} }
@ -363,10 +370,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
}).Debug("got a hit") }).Debug("got a hit")
sc.redis.SetEX(ctx, lockKey, true, 24*time.Hour) sc.redis.SetEX(ctx, lockKey, true, 24*time.Hour)
ids = append(ids, watcher.DeviceID) notifs = append(notifs, watcher)
} }
if len(ids) == 0 { if len(notifs) == 0 {
continue continue
} }
@ -374,19 +381,22 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
"subreddit#id": subreddit.ID, "subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name, "subreddit#name": subreddit.Name,
"post#id": post.ID, "post#id": post.ID,
"count": len(ids), "count": len(notifs),
}).Debug("got hits for post") }).Debug("got hits for post")
payload := payloadFromPost(post)
for _, watcher := range notifs {
title := fmt.Sprintf(subredditNotificationTitleFormat, watcher.Label)
payload.AlertTitle(title)
notification := &apns2.Notification{} notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo" notification.Topic = "com.christianselig.Apollo"
notification.Payload = payloadFromPost(post) notification.DeviceToken = watcher.Device.APNSToken
notification.Payload = payload
for _, id := range ids {
device, _ := sc.deviceRepo.GetByID(ctx, id)
notification.DeviceToken = device.APNSToken
client := sc.apnsProduction client := sc.apnsProduction
if device.Sandbox { if watcher.Device.Sandbox {
client = sc.apnsSandbox client = sc.apnsSandbox
} }
@ -395,7 +405,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
_ = sc.statsd.Incr("apns.notification.errors", []string{}, 1) _ = sc.statsd.Incr("apns.notification.errors", []string{}, 1)
sc.logger.WithFields(logrus.Fields{ sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID, "subreddit#id": subreddit.ID,
"device#id": device.ID, "device#id": watcher.Device.ID,
"err": err, "err": err,
"status": res.StatusCode, "status": res.StatusCode,
"reason": res.Reason, "reason": res.Reason,
@ -404,8 +414,8 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
_ = sc.statsd.Incr("apns.notification.sent", []string{}, 1) _ = sc.statsd.Incr("apns.notification.sent", []string{}, 1)
sc.logger.WithFields(logrus.Fields{ sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID, "subreddit#id": subreddit.ID,
"device#id": device.ID, "device#id": watcher.Device.ID,
"device#token": device.APNSToken, "device#token": watcher.Device.APNSToken,
}).Info("sent notification") }).Info("sent notification")
} }
} }
@ -418,11 +428,11 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
} }
func payloadFromPost(post *reddit.Thing) *payload.Payload { func payloadFromPost(post *reddit.Thing) *payload.Payload {
title := fmt.Sprintf("📣 Subreddit Watch (r/%s)", post.Subreddit) subtitle := fmt.Sprintf("r/%s", post.Subreddit)
payload := payload. payload := payload.
NewPayload(). NewPayload().
AlertTitle(title). AlertSubtitle(subtitle).
AlertBody(post.Title). AlertBody(post.Title).
AlertSummaryArg(post.Subreddit). AlertSummaryArg(post.Subreddit).
Category("post-watch"). Category("post-watch").

336
internal/worker/trending.go Normal file
View file

@ -0,0 +1,336 @@
package worker
import (
"context"
"fmt"
"math/rand"
"os"
"strconv"
"time"
"github.com/DataDog/datadog-go/statsd"
"github.com/adjust/rmq/v4"
"github.com/go-redis/redis/v8"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/sideshow/apns2"
"github.com/sideshow/apns2/payload"
"github.com/sideshow/apns2/token"
"github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit"
"github.com/christianselig/apollo-backend/internal/repository"
)
type trendingWorker struct {
logger *logrus.Logger
statsd *statsd.Client
redis *redis.Client
queue rmq.Connection
reddit *reddit.Client
apns *token.Token
consumers int
accountRepo domain.AccountRepository
deviceRepo domain.DeviceRepository
subredditRepo domain.SubredditRepository
watcherRepo domain.WatcherRepository
}
const trendingNotificationTitleFormat = "🔥 Trending in r/%s"
func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
consumers,
)
var apns *token.Token
{
authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH"))
if err != nil {
panic(err)
}
apns = &token.Token{
AuthKey: authKey,
KeyID: os.Getenv("APPLE_KEY_ID"),
TeamID: os.Getenv("APPLE_TEAM_ID"),
}
}
return &trendingWorker{
logger,
statsd,
redis,
queue,
reddit,
apns,
consumers,
repository.NewPostgresAccount(db),
repository.NewPostgresDevice(db),
repository.NewPostgresSubreddit(db),
repository.NewPostgresWatcher(db),
}
}
func (tw *trendingWorker) Start() error {
queue, err := tw.queue.OpenQueue("trending")
if err != nil {
return err
}
tw.logger.WithFields(logrus.Fields{
"numConsumers": tw.consumers,
}).Info("starting up trending worker")
prefetchLimit := int64(tw.consumers * 2)
if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil {
return err
}
host, _ := os.Hostname()
for i := 0; i < tw.consumers; i++ {
name := fmt.Sprintf("consumer %s-%d", host, i)
consumer := NewTrendingConsumer(tw, i)
if _, err := queue.AddConsumer(name, consumer); err != nil {
return err
}
}
return nil
}
func (tw *trendingWorker) Stop() {
<-tw.queue.StopAllConsuming() // wait for all Consume() calls to finish
}
type trendingConsumer struct {
*trendingWorker
tag int
apnsSandbox *apns2.Client
apnsProduction *apns2.Client
}
func NewTrendingConsumer(tw *trendingWorker, tag int) *trendingConsumer {
return &trendingConsumer{
tw,
tag,
apns2.NewTokenClient(tw.apns),
apns2.NewTokenClient(tw.apns).Production(),
}
}
func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
ctx := context.Background()
tc.logger.WithFields(logrus.Fields{
"subreddit#id": delivery.Payload(),
}).Debug("starting job")
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": delivery.Payload(),
"err": err,
}).Error("failed to parse subreddit ID")
_ = delivery.Reject()
return
}
defer func() { _ = delivery.Ack() }()
subreddit, err := tc.subredditRepo.GetByID(ctx, id)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"err": err,
}).Error("failed to fetch subreddit from database")
return
}
watchers, err := tc.watcherRepo.GetByTrendingSubredditID(ctx, subreddit.ID)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"err": err,
}).Error("failed to fetch watchers from database")
return
}
if len(watchers) == 0 {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
}).Debug("no watchers for trending, finishing job")
return
}
// Grab last month's top posts so we calculate a trending average
i := rand.Intn(len(watchers))
watcher := watchers[i]
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week"))
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"err": err,
}).Error("failed to fetch month's top posts")
return
}
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"count": tps.Count,
}).Debug("loaded month's hot posts")
if tps.Count == 0 {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
}).Debug("no top posts for subreddit, returning")
return
}
if tps.Count < 20 {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
}).Debug("not enough posts, returning")
return
}
middlePost := tps.Count / 2
medianScore := tps.Children[middlePost].Score
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"score": medianScore,
}).Debug("calculated median score")
// Grab hot posts and filter out anything that's > 2 days old
i = rand.Intn(len(watchers))
watcher = watchers[i]
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
hps, err := rac.SubredditHot(subreddit.Name)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"err": err,
}).Error("failed to fetch hot posts")
return
}
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"count": hps.Count,
}).Debug("loaded hot posts")
// Trending only counts for posts less than 2 days old
threshold := float64(time.Now().AddDate(0, 0, -2).UTC().Unix())
for _, post := range hps.Children {
if post.Score < medianScore {
continue
}
if post.CreatedAt < threshold {
break
}
notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo"
notification.Payload = payloadFromTrendingPost(post)
for _, watcher := range watchers {
if watcher.CreatedAt > post.CreatedAt {
continue
}
lockKey := fmt.Sprintf("watcher:trending:%d:%s", watcher.DeviceID, post.ID)
notified, _ := tc.redis.Get(ctx, lockKey).Bool()
if notified {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"watcher#id": watcher.ID,
"post#id": post.ID,
}).Debug("already notified, skipping")
continue
}
tc.redis.SetEX(ctx, lockKey, true, 48*time.Hour)
if err := tc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"watcher#id": watcher.ID,
"err": err,
}).Error("could not increment hits")
return
}
notification.DeviceToken = watcher.Device.APNSToken
client := tc.apnsProduction
if watcher.Device.Sandbox {
client = tc.apnsSandbox
}
res, err := client.Push(notification)
if err != nil {
_ = tc.statsd.Incr("apns.notification.errors", []string{}, 1)
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"post#id": post.ID,
"device#id": watcher.Device.ID,
"err": err,
"status": res.StatusCode,
"reason": res.Reason,
}).Error("failed to send notification")
} else {
_ = tc.statsd.Incr("apns.notification.sent", []string{}, 1)
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"post#id": post.ID,
"device#id": watcher.Device.ID,
"device#token": watcher.Device.APNSToken,
}).Info("sent notification")
}
}
}
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
}).Debug("finishing job")
}
func payloadFromTrendingPost(post *reddit.Thing) *payload.Payload {
title := fmt.Sprintf(trendingNotificationTitleFormat, post.Subreddit)
payload := payload.
NewPayload().
AlertTitle(title).
AlertSubtitle(post.Title).
AlertBody(post.Title).
AlertSummaryArg(post.Subreddit).
Category("post-watch").
Custom("post_title", post.Title).
Custom("post_id", post.ID).
Custom("subreddit", post.Subreddit).
Custom("author", post.Author).
Custom("post_age", post.CreatedAt).
MutableContent().
Sound("traloop.wav")
return payload
}

View file

@ -38,6 +38,8 @@ type usersWorker struct {
watcherRepo domain.WatcherRepository watcherRepo domain.WatcherRepository
} }
const userNotificationTitleFormat = "👨\u200d🚀 %s"
func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker { func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient( reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_ID"),
@ -224,9 +226,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
continue continue
} }
notification := &apns2.Notification{} payload := payloadFromUserPost(post)
notification.Topic = "com.christianselig.Apollo"
notification.Payload = payloadFromUserPost(post)
for _, watcher := range watchers { for _, watcher := range watchers {
// Make sure we only alert on activities created after the search // Make sure we only alert on activities created after the search
@ -248,6 +248,13 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
} }
device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID) device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID)
title := fmt.Sprintf(userNotificationTitleFormat, watcher.Label)
payload.AlertTitle(title)
notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo"
notification.Payload = payload
notification.DeviceToken = device.APNSToken notification.DeviceToken = device.APNSToken
client := uc.apnsProduction client := uc.apnsProduction
@ -283,12 +290,10 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
} }
func payloadFromUserPost(post *reddit.Thing) *payload.Payload { func payloadFromUserPost(post *reddit.Thing) *payload.Payload {
title := fmt.Sprintf("👨\u200d🚀 User post! (u/%s)", post.Author)
payload := payload. payload := payload.
NewPayload(). NewPayload().
AlertTitle(title).
AlertBody(post.Title). AlertBody(post.Title).
AlertSubtitle(post.Author).
AlertSummaryArg(post.Author). AlertSummaryArg(post.Author).
Category("user-watch"). Category("user-watch").
Custom("post_title", post.Title). Custom("post_title", post.Title).