add subreddit trending posts watcher

This commit is contained in:
Andre Medeiros 2021-10-10 11:51:42 -04:00
parent 009d60dc2f
commit 65792abd94
12 changed files with 577 additions and 69 deletions

View file

@ -94,8 +94,9 @@ func (a *api) Routes() *mux.Router {
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.editWatcherHandler).Methods("PATCH")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET")
r.HandleFunc("/v1/receipt", a.checkReceiptHandler).Methods("POST")
r.HandleFunc("/v1/receipt/{apns}", a.checkReceiptHandler).Methods("POST")

View file

@ -12,6 +12,7 @@ import (
)
type watcherCriteria struct {
Author string
Upvotes int64
Keyword string
Flair string
@ -22,6 +23,7 @@ type createWatcherRequest struct {
Type string
User string
Subreddit string
Label string
Criteria watcherCriteria
}
@ -74,15 +76,17 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
watcher := domain.Watcher{
Label: cwr.Label,
DeviceID: dev.ID,
AccountID: account.ID,
Author: strings.ToLower(cwr.Criteria.Author),
Upvotes: cwr.Criteria.Upvotes,
Keyword: strings.ToLower(cwr.Criteria.Keyword),
Flair: strings.ToLower(cwr.Criteria.Flair),
Domain: strings.ToLower(cwr.Criteria.Domain),
}
if cwr.Type == "subreddit" {
if cwr.Type == "subreddit" || cwr.Type == "trending" {
srr, err := ac.SubredditAbout(cwr.Subreddit)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
@ -102,7 +106,13 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
}
}
switch cwr.Type {
case "subreddit":
watcher.Type = domain.SubredditWatcher
case "trending":
watcher.Type = domain.TrendingWatcher
}
watcher.WatcheeID = sr.ID
} else if cwr.Type == "user" {
urr, err := ac.UserAbout(cwr.User)
@ -139,6 +149,50 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
vars := mux.Vars(r)
id, err := strconv.ParseInt(vars["watcherID"], 10, 64)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
watcher, err := a.watcherRepo.GetByID(ctx, id)
if err != nil || watcher.Device.APNSToken != vars["apns"] {
a.errorResponse(w, r, 422, "nice try")
return
}
ewr := &createWatcherRequest{
Criteria: watcherCriteria{
Upvotes: 0,
Keyword: "",
Flair: "",
Domain: "",
},
}
if err := json.NewDecoder(r.Body).Decode(ewr); err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
watcher.Label = ewr.Label
watcher.Upvotes = ewr.Criteria.Upvotes
watcher.Keyword = ewr.Criteria.Keyword
watcher.Flair = ewr.Criteria.Flair
watcher.Domain = ewr.Criteria.Domain
if err := a.watcherRepo.Update(ctx, &watcher); err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
w.WriteHeader(http.StatusOK)
}
func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
@ -149,16 +203,8 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
return
}
apns := vars["apns"]
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
watcher, err := a.watcherRepo.GetByID(ctx, id)
if err != nil || watcher.DeviceID != dev.ID {
if err != nil || watcher.Device.APNSToken != vars["apns"] {
a.errorResponse(w, r, 422, "nice try")
return
}
@ -169,6 +215,9 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
type watcherItem struct {
ID int64
CreatedAt float64
Type string
Label string
Upvotes int64
Keyword string
Flair string
@ -193,6 +242,9 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
for i, watcher := range watchers {
wi := watcherItem{
ID: watcher.ID,
CreatedAt: watcher.CreatedAt,
Type: watcher.Type.String(),
Label: watcher.Label,
Upvotes: watcher.Upvotes,
Keyword: watcher.Keyword,
Flair: watcher.Flair,

View file

@ -77,6 +77,11 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
return err
}
trendingQueue, err := queue.OpenQueue("trending")
if err != nil {
return err
}
userQueue, err := queue.OpenQueue("users")
if err != nil {
return err
@ -84,7 +89,7 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
s := gocron.NewScheduler(time.UTC)
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) })
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, []rmq.Queue{subredditQueue, trendingQueue}) })
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) })
_, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) })
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) })
@ -273,7 +278,7 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
}
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) {
now := time.Now()
ready := now.Unix() - subredditEnqueueTimeout
@ -326,15 +331,17 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
batchIds[i] = strconv.FormatInt(id, 10)
}
for _, queue := range queues {
if err = queue.Publish(batchIds...); err != nil {
logger.WithFields(logrus.Fields{
"queue": queue,
"err": err,
}).Error("failed to enqueue subreddit")
}
}
_ = statsd.Histogram("apollo.queue.subreddits.enqueued", float64(len(ids)), []string{}, 1)
_ = statsd.Histogram("apollo.queue.subreddits.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1)
}
func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) {

View file

@ -15,6 +15,7 @@ var (
queues = map[string]worker.NewWorkerFn{
"notifications": worker.NewNotificationsWorker,
"subreddits": worker.NewSubredditsWorker,
"trending": worker.NewTrendingWorker,
"users": worker.NewUsersWorker,
}
)

View file

@ -7,29 +7,50 @@ type WatcherType int64
const (
SubredditWatcher WatcherType = iota
UserWatcher
TrendingWatcher
)
func (wt WatcherType) String() string {
switch wt {
case SubredditWatcher:
return "subreddit"
case UserWatcher:
return "user"
case TrendingWatcher:
return "trending"
}
return "unknown"
}
type Watcher struct {
ID int64
CreatedAt float64
LastNotifiedAt float64
Label string
DeviceID int64
AccountID int64
Type WatcherType
WatcheeID int64
Author string
Upvotes int64
Keyword string
Flair string
Domain string
Hits int64
// Related models
Device Device
Account Account
}
type WatcherRepository interface {
GetByID(ctx context.Context, id int64) (Watcher, error)
GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error)
GetByUserID(ctx context.Context, id int64) ([]Watcher, error)
GetByTrendingSubredditID(ctx context.Context, id int64) ([]Watcher, error)
GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error)
Create(ctx context.Context, watcher *Watcher) error

View file

@ -244,6 +244,10 @@ func (rac *AuthenticatedClient) SubredditHot(subreddit string, opts ...RequestOp
return rac.subredditPosts(subreddit, "hot", opts...)
}
func (rac *AuthenticatedClient) SubredditTop(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(subreddit, "top", opts...)
}
func (rac *AuthenticatedClient) SubredditNew(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(subreddit, "new", opts...)
}

View file

@ -79,8 +79,7 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
query := `
INSERT INTO subreddits (subreddit_id, name)
VALUES ($1, $2)
ON CONFLICT(subreddit_id) DO
UPDATE SET last_checked_at = $3
ON CONFLICT(subreddit_id) DO NOTHING
RETURNING id`
return p.pool.QueryRow(
@ -88,6 +87,5 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
query,
sr.SubredditID,
sr.NormalizedName(),
sr.LastCheckedAt,
).Scan(&sr.ID)
}

View file

@ -32,15 +32,23 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
&watcher.ID,
&watcher.CreatedAt,
&watcher.LastNotifiedAt,
&watcher.Label,
&watcher.DeviceID,
&watcher.AccountID,
&watcher.Type,
&watcher.WatcheeID,
&watcher.Author,
&watcher.Upvotes,
&watcher.Keyword,
&watcher.Flair,
&watcher.Domain,
&watcher.Hits,
&watcher.Device.ID,
&watcher.Device.APNSToken,
&watcher.Device.Sandbox,
&watcher.Account.ID,
&watcher.Account.AccessToken,
&watcher.Account.RefreshToken,
); err != nil {
return nil, err
}
@ -51,9 +59,31 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) {
query := `
SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits
SELECT
watchers.id,
watchers.created_at,
watchers.last_notified_at,
watchers.label,
watchers.device_id,
watchers.account_id,
watchers.type,
watchers.watchee_id,
watchers.author,
watchers.upvotes,
watchers.keyword,
watchers.flair,
watchers.domain,
watchers.hits,
devices.id,
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.access_token,
accounts.refresh_token
FROM watchers
WHERE id = $1`
INNER JOIN devices ON watchers.device_id = devices.id
INNER JOIN accounts ON watchers.account_id = accounts.id
WHERE watchers.id = $1`
watchers, err := p.fetch(ctx, query, id)
@ -68,13 +98,39 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) ([]domain.Watcher, error) {
query := `
SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits
SELECT
watchers.id,
watchers.created_at,
watchers.last_notified_at,
watchers.label,
watchers.device_id,
watchers.account_id,
watchers.type,
watchers.watchee_id,
watchers.author,
watchers.upvotes,
watchers.keyword,
watchers.flair,
watchers.domain,
watchers.hits,
devices.id,
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.access_token,
accounts.refresh_token
FROM watchers
WHERE type = $1 AND watchee_id = $2`
INNER JOIN devices ON watchers.device_id = devices.id
INNER JOIN accounts ON watchers.account_id = accounts.id
WHERE watchers.type = $1 AND watchers.watchee_id = $2`
return p.fetch(ctx, query, typ, id)
}
func (p *postgresWatcherRepository) GetByTrendingSubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
return p.GetByTypeAndWatcheeID(ctx, domain.TrendingWatcher, id)
}
func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
return p.GetByTypeAndWatcheeID(ctx, domain.SubredditWatcher, id)
}
@ -88,16 +144,24 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
SELECT
watchers.id,
watchers.created_at,
watchers.last_notified_at
watchers.last_notified_at,
watchers.label,
watchers.device_id,
watchers.account_id,
watchers.type,
watchers.watchee_id,
watchers.author,
watchers.upvotes,
watchers.keyword,
watchers.flair,
watchers.domain,
watchers.hits
watchers.hits,
devices.id,
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.access_token,
accounts.refresh_token
FROM watchers
INNER JOIN accounts ON watchers.account_id = accounts.id
INNER JOIN devices ON watchers.device_id = devices.id
@ -113,18 +177,20 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
query := `
INSERT INTO watchers
(created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain)
VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9)
(created_at, last_notified_at, label, device_id, account_id, type, watchee_id, author, upvotes, keyword, flair, domain)
VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id`
return p.pool.QueryRow(
ctx,
query,
now,
watcher.Label,
watcher.DeviceID,
watcher.AccountID,
watcher.Type,
watcher.WatcheeID,
watcher.Author,
watcher.Upvotes,
watcher.Keyword,
watcher.Flair,
@ -135,20 +201,24 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.Watcher) error {
query := `
UPDATE watchers
SET upvotes = $2,
keyword = $3,
flair = $4,
domain = $5,
SET author = $2,
upvotes = $3,
keyword = $4,
flair = $5,
domain = $6,
label = $7
WHERE id = $1`
res, err := p.pool.Exec(
ctx,
query,
watcher.ID,
watcher.Author,
watcher.Upvotes,
watcher.Keyword,
watcher.Flair,
watcher.Domain,
watcher.Label,
)
if res.RowsAffected() != 1 {

View file

@ -330,6 +330,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
"status": res.StatusCode,
"reason": res.Reason,
}).Error("failed to send notification")
// Delete device as notifications might have been disabled here
_ = nc.deviceRepo.Delete(ctx, device.APNSToken)
} else {
_ = nc.statsd.Incr("apns.notification.sent", []string{}, 1)
nc.logger.WithFields(logrus.Fields{

View file

@ -40,6 +40,8 @@ type subredditsWorker struct {
watcherRepo domain.WatcherRepository
}
const subredditNotificationTitleFormat = "📣 %s"
func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
@ -170,7 +172,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
if len(watchers) == 0 {
sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
}).Info("no watchers for subreddit, skipping")
}).Debug("no watchers for subreddit, finishing job")
return
}
@ -298,11 +300,12 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
"count": len(posts),
}).Debug("checking posts for hits")
for _, post := range posts {
lowcaseAuthor := strings.ToLower(post.Author)
lowcaseTitle := strings.ToLower(post.Title)
lowcaseFlair := strings.ToLower(post.Flair)
lowcaseDomain := strings.ToLower(post.URL)
ids := []int64{}
notifs := []domain.Watcher{}
for _, watcher := range watchers {
// Make sure we only alert on posts created after the search
@ -312,6 +315,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
matched := true
if watcher.Author != "" && lowcaseAuthor != watcher.Author {
matched = false
}
if watcher.Upvotes > 0 && post.Score < watcher.Upvotes {
matched = false
}
@ -363,10 +370,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
}).Debug("got a hit")
sc.redis.SetEX(ctx, lockKey, true, 24*time.Hour)
ids = append(ids, watcher.DeviceID)
notifs = append(notifs, watcher)
}
if len(ids) == 0 {
if len(notifs) == 0 {
continue
}
@ -374,19 +381,22 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"post#id": post.ID,
"count": len(ids),
"count": len(notifs),
}).Debug("got hits for post")
payload := payloadFromPost(post)
for _, watcher := range notifs {
title := fmt.Sprintf(subredditNotificationTitleFormat, watcher.Label)
payload.AlertTitle(title)
notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo"
notification.Payload = payloadFromPost(post)
for _, id := range ids {
device, _ := sc.deviceRepo.GetByID(ctx, id)
notification.DeviceToken = device.APNSToken
notification.DeviceToken = watcher.Device.APNSToken
notification.Payload = payload
client := sc.apnsProduction
if device.Sandbox {
if watcher.Device.Sandbox {
client = sc.apnsSandbox
}
@ -395,7 +405,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
_ = sc.statsd.Incr("apns.notification.errors", []string{}, 1)
sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"device#id": device.ID,
"device#id": watcher.Device.ID,
"err": err,
"status": res.StatusCode,
"reason": res.Reason,
@ -404,8 +414,8 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
_ = sc.statsd.Incr("apns.notification.sent", []string{}, 1)
sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"device#id": device.ID,
"device#token": device.APNSToken,
"device#id": watcher.Device.ID,
"device#token": watcher.Device.APNSToken,
}).Info("sent notification")
}
}
@ -418,11 +428,11 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
}
func payloadFromPost(post *reddit.Thing) *payload.Payload {
title := fmt.Sprintf("📣 Subreddit Watch (r/%s)", post.Subreddit)
subtitle := fmt.Sprintf("r/%s", post.Subreddit)
payload := payload.
NewPayload().
AlertTitle(title).
AlertSubtitle(subtitle).
AlertBody(post.Title).
AlertSummaryArg(post.Subreddit).
Category("post-watch").

336
internal/worker/trending.go Normal file
View file

@ -0,0 +1,336 @@
package worker
import (
"context"
"fmt"
"math/rand"
"os"
"strconv"
"time"
"github.com/DataDog/datadog-go/statsd"
"github.com/adjust/rmq/v4"
"github.com/go-redis/redis/v8"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/sideshow/apns2"
"github.com/sideshow/apns2/payload"
"github.com/sideshow/apns2/token"
"github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit"
"github.com/christianselig/apollo-backend/internal/repository"
)
type trendingWorker struct {
logger *logrus.Logger
statsd *statsd.Client
redis *redis.Client
queue rmq.Connection
reddit *reddit.Client
apns *token.Token
consumers int
accountRepo domain.AccountRepository
deviceRepo domain.DeviceRepository
subredditRepo domain.SubredditRepository
watcherRepo domain.WatcherRepository
}
const trendingNotificationTitleFormat = "🔥 Trending in r/%s"
func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
consumers,
)
var apns *token.Token
{
authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH"))
if err != nil {
panic(err)
}
apns = &token.Token{
AuthKey: authKey,
KeyID: os.Getenv("APPLE_KEY_ID"),
TeamID: os.Getenv("APPLE_TEAM_ID"),
}
}
return &trendingWorker{
logger,
statsd,
redis,
queue,
reddit,
apns,
consumers,
repository.NewPostgresAccount(db),
repository.NewPostgresDevice(db),
repository.NewPostgresSubreddit(db),
repository.NewPostgresWatcher(db),
}
}
func (tw *trendingWorker) Start() error {
queue, err := tw.queue.OpenQueue("trending")
if err != nil {
return err
}
tw.logger.WithFields(logrus.Fields{
"numConsumers": tw.consumers,
}).Info("starting up trending worker")
prefetchLimit := int64(tw.consumers * 2)
if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil {
return err
}
host, _ := os.Hostname()
for i := 0; i < tw.consumers; i++ {
name := fmt.Sprintf("consumer %s-%d", host, i)
consumer := NewTrendingConsumer(tw, i)
if _, err := queue.AddConsumer(name, consumer); err != nil {
return err
}
}
return nil
}
func (tw *trendingWorker) Stop() {
<-tw.queue.StopAllConsuming() // wait for all Consume() calls to finish
}
type trendingConsumer struct {
*trendingWorker
tag int
apnsSandbox *apns2.Client
apnsProduction *apns2.Client
}
func NewTrendingConsumer(tw *trendingWorker, tag int) *trendingConsumer {
return &trendingConsumer{
tw,
tag,
apns2.NewTokenClient(tw.apns),
apns2.NewTokenClient(tw.apns).Production(),
}
}
func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
ctx := context.Background()
tc.logger.WithFields(logrus.Fields{
"subreddit#id": delivery.Payload(),
}).Debug("starting job")
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": delivery.Payload(),
"err": err,
}).Error("failed to parse subreddit ID")
_ = delivery.Reject()
return
}
defer func() { _ = delivery.Ack() }()
subreddit, err := tc.subredditRepo.GetByID(ctx, id)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"err": err,
}).Error("failed to fetch subreddit from database")
return
}
watchers, err := tc.watcherRepo.GetByTrendingSubredditID(ctx, subreddit.ID)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"err": err,
}).Error("failed to fetch watchers from database")
return
}
if len(watchers) == 0 {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
}).Debug("no watchers for trending, finishing job")
return
}
// Grab last month's top posts so we calculate a trending average
i := rand.Intn(len(watchers))
watcher := watchers[i]
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week"))
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"err": err,
}).Error("failed to fetch month's top posts")
return
}
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"count": tps.Count,
}).Debug("loaded month's hot posts")
if tps.Count == 0 {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
}).Debug("no top posts for subreddit, returning")
return
}
if tps.Count < 20 {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
}).Debug("not enough posts, returning")
return
}
middlePost := tps.Count / 2
medianScore := tps.Children[middlePost].Score
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"score": medianScore,
}).Debug("calculated median score")
// Grab hot posts and filter out anything that's > 2 days old
i = rand.Intn(len(watchers))
watcher = watchers[i]
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
hps, err := rac.SubredditHot(subreddit.Name)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"err": err,
}).Error("failed to fetch hot posts")
return
}
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"count": hps.Count,
}).Debug("loaded hot posts")
// Trending only counts for posts less than 2 days old
threshold := float64(time.Now().AddDate(0, 0, -2).UTC().Unix())
for _, post := range hps.Children {
if post.Score < medianScore {
continue
}
if post.CreatedAt < threshold {
break
}
notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo"
notification.Payload = payloadFromTrendingPost(post)
for _, watcher := range watchers {
if watcher.CreatedAt > post.CreatedAt {
continue
}
lockKey := fmt.Sprintf("watcher:trending:%d:%s", watcher.DeviceID, post.ID)
notified, _ := tc.redis.Get(ctx, lockKey).Bool()
if notified {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
"watcher#id": watcher.ID,
"post#id": post.ID,
}).Debug("already notified, skipping")
continue
}
tc.redis.SetEX(ctx, lockKey, true, 48*time.Hour)
if err := tc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"watcher#id": watcher.ID,
"err": err,
}).Error("could not increment hits")
return
}
notification.DeviceToken = watcher.Device.APNSToken
client := tc.apnsProduction
if watcher.Device.Sandbox {
client = tc.apnsSandbox
}
res, err := client.Push(notification)
if err != nil {
_ = tc.statsd.Incr("apns.notification.errors", []string{}, 1)
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"post#id": post.ID,
"device#id": watcher.Device.ID,
"err": err,
"status": res.StatusCode,
"reason": res.Reason,
}).Error("failed to send notification")
} else {
_ = tc.statsd.Incr("apns.notification.sent", []string{}, 1)
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"post#id": post.ID,
"device#id": watcher.Device.ID,
"device#token": watcher.Device.APNSToken,
}).Info("sent notification")
}
}
}
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name,
}).Debug("finishing job")
}
func payloadFromTrendingPost(post *reddit.Thing) *payload.Payload {
title := fmt.Sprintf(trendingNotificationTitleFormat, post.Subreddit)
payload := payload.
NewPayload().
AlertTitle(title).
AlertSubtitle(post.Title).
AlertBody(post.Title).
AlertSummaryArg(post.Subreddit).
Category("post-watch").
Custom("post_title", post.Title).
Custom("post_id", post.ID).
Custom("subreddit", post.Subreddit).
Custom("author", post.Author).
Custom("post_age", post.CreatedAt).
MutableContent().
Sound("traloop.wav")
return payload
}

View file

@ -38,6 +38,8 @@ type usersWorker struct {
watcherRepo domain.WatcherRepository
}
const userNotificationTitleFormat = "👨\u200d🚀 %s"
func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
@ -224,9 +226,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
continue
}
notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo"
notification.Payload = payloadFromUserPost(post)
payload := payloadFromUserPost(post)
for _, watcher := range watchers {
// Make sure we only alert on activities created after the search
@ -248,6 +248,13 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
}
device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID)
title := fmt.Sprintf(userNotificationTitleFormat, watcher.Label)
payload.AlertTitle(title)
notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo"
notification.Payload = payload
notification.DeviceToken = device.APNSToken
client := uc.apnsProduction
@ -283,12 +290,10 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
}
func payloadFromUserPost(post *reddit.Thing) *payload.Payload {
title := fmt.Sprintf("👨\u200d🚀 User post! (u/%s)", post.Author)
payload := payload.
NewPayload().
AlertTitle(title).
AlertBody(post.Title).
AlertSubtitle(post.Author).
AlertSummaryArg(post.Author).
Category("user-watch").
Custom("post_title", post.Title).