mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-15 00:17:42 +00:00
287 lines
7.9 KiB
Go
287 lines
7.9 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"os"
|
|
"sort"
|
|
"time"
|
|
|
|
"github.com/DataDog/datadog-go/statsd"
|
|
"github.com/go-redis/redis/v8"
|
|
"github.com/jackc/pgx/v4/pgxpool"
|
|
"github.com/sideshow/apns2"
|
|
"github.com/sideshow/apns2/payload"
|
|
"github.com/sideshow/apns2/token"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/christianselig/apollo-backend/internal/domain"
|
|
"github.com/christianselig/apollo-backend/internal/reddit"
|
|
"github.com/christianselig/apollo-backend/internal/repository"
|
|
)
|
|
|
|
type trendingWorker struct {
|
|
logger *zap.Logger
|
|
statsd *statsd.Client
|
|
redis *redis.Client
|
|
reddit *reddit.Client
|
|
apns *apns2.Client
|
|
|
|
accountRepo domain.AccountRepository
|
|
deviceRepo domain.DeviceRepository
|
|
subredditRepo domain.SubredditRepository
|
|
watcherRepo domain.WatcherRepository
|
|
}
|
|
|
|
const trendingNotificationTitleFormat = "🔥 r/%s Trending"
|
|
|
|
func NewTrendingWorker(ctx context.Context, logger *zap.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, consumers int) Worker {
|
|
reddit := reddit.NewClient(
|
|
os.Getenv("REDDIT_CLIENT_ID"),
|
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
|
statsd,
|
|
redis,
|
|
consumers,
|
|
)
|
|
|
|
var apns *apns2.Client
|
|
{
|
|
authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH"))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
tok := &token.Token{
|
|
AuthKey: authKey,
|
|
KeyID: os.Getenv("APPLE_KEY_ID"),
|
|
TeamID: os.Getenv("APPLE_TEAM_ID"),
|
|
}
|
|
apns = apns2.NewTokenClient(tok).Production()
|
|
}
|
|
|
|
return &trendingWorker{
|
|
logger,
|
|
statsd,
|
|
redis,
|
|
reddit,
|
|
apns,
|
|
|
|
repository.NewPostgresAccount(db),
|
|
repository.NewPostgresDevice(db),
|
|
repository.NewPostgresSubreddit(db),
|
|
repository.NewPostgresWatcher(db),
|
|
}
|
|
}
|
|
|
|
func (tw *trendingWorker) Process(ctx context.Context, args ...interface{}) error {
|
|
id := int64(args[0].(float64))
|
|
tw.logger.Debug("starting job", zap.Int64("subreddit#id", id))
|
|
|
|
subreddit, err := tw.subredditRepo.GetByID(ctx, id)
|
|
if err != nil {
|
|
tw.logger.Error("failed to fetch subreddit from database", zap.Error(err), zap.Int64("subreddit#id", id))
|
|
return nil
|
|
}
|
|
|
|
watchers, err := tw.watcherRepo.GetByTrendingSubredditID(ctx, subreddit.ID)
|
|
if err != nil {
|
|
tw.logger.Error("failed to fetch watchers from database",
|
|
zap.Error(err),
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
)
|
|
return err
|
|
}
|
|
|
|
if len(watchers) == 0 {
|
|
tw.logger.Debug("no watchers for subreddit, bailing early",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
// Grab last month's top posts so we calculate a trending average
|
|
i := rand.Intn(len(watchers))
|
|
watcher := watchers[i]
|
|
rac := tw.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
|
|
|
tps, err := rac.SubredditTop(ctx, subreddit.Name, reddit.WithQuery("t", "week"), reddit.WithQuery("show", "all"), reddit.WithQuery("limit", "25"))
|
|
if err != nil {
|
|
tw.logger.Error("failed to fetch weeks's top posts",
|
|
zap.Error(err),
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
tw.logger.Debug("loaded weeks's top posts",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
zap.Int("count", tps.Count),
|
|
)
|
|
|
|
if tps.Count == 0 {
|
|
tw.logger.Debug("no top posts, bailing early",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
if tps.Count < 20 {
|
|
tw.logger.Debug("no top posts, bailing early",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
zap.Int("count", tps.Count),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
sort.SliceStable(tps.Children, func(i, j int) bool {
|
|
return tps.Children[i].Score > tps.Children[j].Score
|
|
})
|
|
|
|
middlePost := tps.Count / 2
|
|
medianScore := tps.Children[middlePost].Score
|
|
tw.logger.Debug("calculated median score",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
zap.Int64("score", medianScore),
|
|
)
|
|
|
|
// Grab hot posts and filter out anything that's > 2 days old
|
|
i = rand.Intn(len(watchers))
|
|
watcher = watchers[i]
|
|
rac = tw.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
|
|
|
hps, err := rac.SubredditHot(ctx, subreddit.Name, reddit.WithQuery("show", "all"), reddit.WithQuery("always_show_media", "1"))
|
|
if err != nil {
|
|
tw.logger.Error("failed to fetch hot posts",
|
|
zap.Error(err),
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
)
|
|
return err
|
|
}
|
|
tw.logger.Debug("loaded hot posts",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
zap.Int("count", hps.Count),
|
|
)
|
|
|
|
// Trending only counts for posts less than 2 days old
|
|
threshold := time.Now().Add(-24 * time.Hour * 2)
|
|
|
|
for _, post := range hps.Children {
|
|
if post.Score < medianScore {
|
|
continue
|
|
}
|
|
|
|
if post.CreatedAt.Before(threshold) {
|
|
break
|
|
}
|
|
|
|
notification := &apns2.Notification{}
|
|
notification.Topic = "com.christianselig.Apollo"
|
|
notification.Payload = payloadFromTrendingPost(post)
|
|
|
|
for _, watcher := range watchers {
|
|
if watcher.CreatedAt.After(post.CreatedAt) {
|
|
continue
|
|
}
|
|
|
|
lockKey := fmt.Sprintf("watcher:trending:%d:%s", watcher.DeviceID, post.ID)
|
|
notified, _ := tw.redis.Get(ctx, lockKey).Bool()
|
|
|
|
if notified {
|
|
tw.logger.Debug("already notified, skipping",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
zap.Int64("watcher#id", watcher.ID),
|
|
zap.String("post#id", post.ID),
|
|
)
|
|
continue
|
|
}
|
|
|
|
tw.redis.SetEX(ctx, lockKey, true, 48*time.Hour)
|
|
|
|
if err := tw.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
|
|
tw.logger.Error("could not increment hits",
|
|
zap.Error(err),
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
zap.Int64("watcher#id", watcher.ID),
|
|
)
|
|
return err
|
|
}
|
|
|
|
notification.DeviceToken = watcher.Device.APNSToken
|
|
|
|
res, err := tw.apns.Push(notification)
|
|
if err != nil {
|
|
_ = tw.statsd.Incr("apns.notification.errors", []string{}, 1)
|
|
tw.logger.Error("failed to send notification",
|
|
zap.Error(err),
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
zap.String("post#id", post.ID),
|
|
zap.String("apns", watcher.Device.APNSToken),
|
|
)
|
|
} else if !res.Sent() {
|
|
_ = tw.statsd.Incr("apns.notification.errors", []string{}, 1)
|
|
tw.logger.Error("notification not sent",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
zap.String("post#id", post.ID),
|
|
zap.String("apns", watcher.Device.APNSToken),
|
|
zap.Int("response#status", res.StatusCode),
|
|
zap.String("response#reason", res.Reason),
|
|
)
|
|
} else {
|
|
_ = tw.statsd.Incr("apns.notification.sent", []string{}, 1)
|
|
tw.logger.Info("sent notification",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
zap.String("post#id", post.ID),
|
|
zap.Int64("post#score", post.Score),
|
|
zap.String("device#token", watcher.Device.APNSToken),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
tw.logger.Debug("finishing job",
|
|
zap.Int64("subreddit#id", id),
|
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
func payloadFromTrendingPost(post *reddit.Thing) *payload.Payload {
|
|
title := fmt.Sprintf(trendingNotificationTitleFormat, post.Subreddit)
|
|
|
|
payload := payload.
|
|
NewPayload().
|
|
AlertTitle(title).
|
|
AlertBody(post.Title).
|
|
AlertSummaryArg(post.Subreddit).
|
|
Category("trending-post").
|
|
Custom("post_title", post.Title).
|
|
Custom("post_id", post.ID).
|
|
Custom("subreddit", post.Subreddit).
|
|
Custom("author", post.Author).
|
|
Custom("post_age", post.CreatedAt).
|
|
ThreadID("trending-post").
|
|
MutableContent().
|
|
Sound("traloop.wav")
|
|
|
|
if post.Thumbnail != "" && !post.Over18 {
|
|
payload.Custom("thumbnail", post.Thumbnail)
|
|
}
|
|
|
|
return payload
|
|
}
|