mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-28 22:57:43 +00:00
add subreddit trending posts watcher
This commit is contained in:
parent
009d60dc2f
commit
65792abd94
12 changed files with 577 additions and 69 deletions
|
@ -94,8 +94,9 @@ func (a *api) Routes() *mux.Router {
|
|||
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
|
||||
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.editWatcherHandler).Methods("PATCH")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET")
|
||||
|
||||
r.HandleFunc("/v1/receipt", a.checkReceiptHandler).Methods("POST")
|
||||
r.HandleFunc("/v1/receipt/{apns}", a.checkReceiptHandler).Methods("POST")
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
)
|
||||
|
||||
type watcherCriteria struct {
|
||||
Author string
|
||||
Upvotes int64
|
||||
Keyword string
|
||||
Flair string
|
||||
|
@ -22,6 +23,7 @@ type createWatcherRequest struct {
|
|||
Type string
|
||||
User string
|
||||
Subreddit string
|
||||
Label string
|
||||
Criteria watcherCriteria
|
||||
}
|
||||
|
||||
|
@ -74,15 +76,17 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
||||
|
||||
watcher := domain.Watcher{
|
||||
Label: cwr.Label,
|
||||
DeviceID: dev.ID,
|
||||
AccountID: account.ID,
|
||||
Author: strings.ToLower(cwr.Criteria.Author),
|
||||
Upvotes: cwr.Criteria.Upvotes,
|
||||
Keyword: strings.ToLower(cwr.Criteria.Keyword),
|
||||
Flair: strings.ToLower(cwr.Criteria.Flair),
|
||||
Domain: strings.ToLower(cwr.Criteria.Domain),
|
||||
}
|
||||
|
||||
if cwr.Type == "subreddit" {
|
||||
if cwr.Type == "subreddit" || cwr.Type == "trending" {
|
||||
srr, err := ac.SubredditAbout(cwr.Subreddit)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 422, err.Error())
|
||||
|
@ -102,7 +106,13 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
watcher.Type = domain.SubredditWatcher
|
||||
switch cwr.Type {
|
||||
case "subreddit":
|
||||
watcher.Type = domain.SubredditWatcher
|
||||
case "trending":
|
||||
watcher.Type = domain.TrendingWatcher
|
||||
}
|
||||
|
||||
watcher.WatcheeID = sr.ID
|
||||
} else if cwr.Type == "user" {
|
||||
urr, err := ac.UserAbout(cwr.User)
|
||||
|
@ -139,6 +149,50 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.Background()
|
||||
|
||||
vars := mux.Vars(r)
|
||||
id, err := strconv.ParseInt(vars["watcherID"], 10, 64)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 422, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
watcher, err := a.watcherRepo.GetByID(ctx, id)
|
||||
if err != nil || watcher.Device.APNSToken != vars["apns"] {
|
||||
a.errorResponse(w, r, 422, "nice try")
|
||||
return
|
||||
}
|
||||
|
||||
ewr := &createWatcherRequest{
|
||||
Criteria: watcherCriteria{
|
||||
Upvotes: 0,
|
||||
Keyword: "",
|
||||
Flair: "",
|
||||
Domain: "",
|
||||
},
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(ewr); err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
watcher.Label = ewr.Label
|
||||
watcher.Upvotes = ewr.Criteria.Upvotes
|
||||
watcher.Keyword = ewr.Criteria.Keyword
|
||||
watcher.Flair = ewr.Criteria.Flair
|
||||
watcher.Domain = ewr.Criteria.Domain
|
||||
|
||||
if err := a.watcherRepo.Update(ctx, &watcher); err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -149,16 +203,8 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
apns := vars["apns"]
|
||||
|
||||
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 422, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
watcher, err := a.watcherRepo.GetByID(ctx, id)
|
||||
if err != nil || watcher.DeviceID != dev.ID {
|
||||
if err != nil || watcher.Device.APNSToken != vars["apns"] {
|
||||
a.errorResponse(w, r, 422, "nice try")
|
||||
return
|
||||
}
|
||||
|
@ -168,12 +214,15 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
type watcherItem struct {
|
||||
ID int64
|
||||
Upvotes int64
|
||||
Keyword string
|
||||
Flair string
|
||||
Domain string
|
||||
Hits int64
|
||||
ID int64
|
||||
CreatedAt float64
|
||||
Type string
|
||||
Label string
|
||||
Upvotes int64
|
||||
Keyword string
|
||||
Flair string
|
||||
Domain string
|
||||
Hits int64
|
||||
}
|
||||
|
||||
func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -192,12 +241,15 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
|
|||
wis := make([]watcherItem, len(watchers))
|
||||
for i, watcher := range watchers {
|
||||
wi := watcherItem{
|
||||
ID: watcher.ID,
|
||||
Upvotes: watcher.Upvotes,
|
||||
Keyword: watcher.Keyword,
|
||||
Flair: watcher.Flair,
|
||||
Domain: watcher.Domain,
|
||||
Hits: watcher.Hits,
|
||||
ID: watcher.ID,
|
||||
CreatedAt: watcher.CreatedAt,
|
||||
Type: watcher.Type.String(),
|
||||
Label: watcher.Label,
|
||||
Upvotes: watcher.Upvotes,
|
||||
Keyword: watcher.Keyword,
|
||||
Flair: watcher.Flair,
|
||||
Domain: watcher.Domain,
|
||||
Hits: watcher.Hits,
|
||||
}
|
||||
|
||||
wis[i] = wi
|
||||
|
|
|
@ -77,6 +77,11 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
|
|||
return err
|
||||
}
|
||||
|
||||
trendingQueue, err := queue.OpenQueue("trending")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userQueue, err := queue.OpenQueue("users")
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -84,7 +89,7 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
|
|||
|
||||
s := gocron.NewScheduler(time.UTC)
|
||||
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
|
||||
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) })
|
||||
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, []rmq.Queue{subredditQueue, trendingQueue}) })
|
||||
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) })
|
||||
_, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) })
|
||||
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) })
|
||||
|
@ -273,7 +278,7 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
|
|||
|
||||
}
|
||||
|
||||
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
|
||||
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) {
|
||||
now := time.Now()
|
||||
ready := now.Unix() - subredditEnqueueTimeout
|
||||
|
||||
|
@ -326,15 +331,17 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
|
|||
batchIds[i] = strconv.FormatInt(id, 10)
|
||||
}
|
||||
|
||||
if err = queue.Publish(batchIds...); err != nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"err": err,
|
||||
}).Error("failed to enqueue subreddit")
|
||||
for _, queue := range queues {
|
||||
if err = queue.Publish(batchIds...); err != nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"queue": queue,
|
||||
"err": err,
|
||||
}).Error("failed to enqueue subreddit")
|
||||
}
|
||||
}
|
||||
|
||||
_ = statsd.Histogram("apollo.queue.subreddits.enqueued", float64(len(ids)), []string{}, 1)
|
||||
_ = statsd.Histogram("apollo.queue.subreddits.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1)
|
||||
|
||||
}
|
||||
|
||||
func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) {
|
||||
|
|
|
@ -15,6 +15,7 @@ var (
|
|||
queues = map[string]worker.NewWorkerFn{
|
||||
"notifications": worker.NewNotificationsWorker,
|
||||
"subreddits": worker.NewSubredditsWorker,
|
||||
"trending": worker.NewTrendingWorker,
|
||||
"users": worker.NewUsersWorker,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -7,29 +7,50 @@ type WatcherType int64
|
|||
const (
|
||||
SubredditWatcher WatcherType = iota
|
||||
UserWatcher
|
||||
TrendingWatcher
|
||||
)
|
||||
|
||||
func (wt WatcherType) String() string {
|
||||
switch wt {
|
||||
case SubredditWatcher:
|
||||
return "subreddit"
|
||||
case UserWatcher:
|
||||
return "user"
|
||||
case TrendingWatcher:
|
||||
return "trending"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
type Watcher struct {
|
||||
ID int64
|
||||
CreatedAt float64
|
||||
LastNotifiedAt float64
|
||||
Label string
|
||||
|
||||
DeviceID int64
|
||||
AccountID int64
|
||||
Type WatcherType
|
||||
WatcheeID int64
|
||||
|
||||
Author string
|
||||
Upvotes int64
|
||||
Keyword string
|
||||
Flair string
|
||||
Domain string
|
||||
Hits int64
|
||||
|
||||
// Related models
|
||||
Device Device
|
||||
Account Account
|
||||
}
|
||||
|
||||
type WatcherRepository interface {
|
||||
GetByID(ctx context.Context, id int64) (Watcher, error)
|
||||
GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error)
|
||||
GetByUserID(ctx context.Context, id int64) ([]Watcher, error)
|
||||
GetByTrendingSubredditID(ctx context.Context, id int64) ([]Watcher, error)
|
||||
GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error)
|
||||
|
||||
Create(ctx context.Context, watcher *Watcher) error
|
||||
|
|
|
@ -244,6 +244,10 @@ func (rac *AuthenticatedClient) SubredditHot(subreddit string, opts ...RequestOp
|
|||
return rac.subredditPosts(subreddit, "hot", opts...)
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) SubredditTop(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||
return rac.subredditPosts(subreddit, "top", opts...)
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) SubredditNew(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||
return rac.subredditPosts(subreddit, "new", opts...)
|
||||
}
|
||||
|
|
|
@ -79,8 +79,7 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
|
|||
query := `
|
||||
INSERT INTO subreddits (subreddit_id, name)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT(subreddit_id) DO
|
||||
UPDATE SET last_checked_at = $3
|
||||
ON CONFLICT(subreddit_id) DO NOTHING
|
||||
RETURNING id`
|
||||
|
||||
return p.pool.QueryRow(
|
||||
|
@ -88,6 +87,5 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
|
|||
query,
|
||||
sr.SubredditID,
|
||||
sr.NormalizedName(),
|
||||
sr.LastCheckedAt,
|
||||
).Scan(&sr.ID)
|
||||
}
|
||||
|
|
|
@ -32,15 +32,23 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
|
|||
&watcher.ID,
|
||||
&watcher.CreatedAt,
|
||||
&watcher.LastNotifiedAt,
|
||||
&watcher.Label,
|
||||
&watcher.DeviceID,
|
||||
&watcher.AccountID,
|
||||
&watcher.Type,
|
||||
&watcher.WatcheeID,
|
||||
&watcher.Author,
|
||||
&watcher.Upvotes,
|
||||
&watcher.Keyword,
|
||||
&watcher.Flair,
|
||||
&watcher.Domain,
|
||||
&watcher.Hits,
|
||||
&watcher.Device.ID,
|
||||
&watcher.Device.APNSToken,
|
||||
&watcher.Device.Sandbox,
|
||||
&watcher.Account.ID,
|
||||
&watcher.Account.AccessToken,
|
||||
&watcher.Account.RefreshToken,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -51,9 +59,31 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
|
|||
|
||||
func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) {
|
||||
query := `
|
||||
SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits
|
||||
SELECT
|
||||
watchers.id,
|
||||
watchers.created_at,
|
||||
watchers.last_notified_at,
|
||||
watchers.label,
|
||||
watchers.device_id,
|
||||
watchers.account_id,
|
||||
watchers.type,
|
||||
watchers.watchee_id,
|
||||
watchers.author,
|
||||
watchers.upvotes,
|
||||
watchers.keyword,
|
||||
watchers.flair,
|
||||
watchers.domain,
|
||||
watchers.hits,
|
||||
devices.id,
|
||||
devices.apns_token,
|
||||
devices.sandbox,
|
||||
accounts.id,
|
||||
accounts.access_token,
|
||||
accounts.refresh_token
|
||||
FROM watchers
|
||||
WHERE id = $1`
|
||||
INNER JOIN devices ON watchers.device_id = devices.id
|
||||
INNER JOIN accounts ON watchers.account_id = accounts.id
|
||||
WHERE watchers.id = $1`
|
||||
|
||||
watchers, err := p.fetch(ctx, query, id)
|
||||
|
||||
|
@ -68,13 +98,39 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
|
|||
|
||||
func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) ([]domain.Watcher, error) {
|
||||
query := `
|
||||
SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits
|
||||
SELECT
|
||||
watchers.id,
|
||||
watchers.created_at,
|
||||
watchers.last_notified_at,
|
||||
watchers.label,
|
||||
watchers.device_id,
|
||||
watchers.account_id,
|
||||
watchers.type,
|
||||
watchers.watchee_id,
|
||||
watchers.author,
|
||||
watchers.upvotes,
|
||||
watchers.keyword,
|
||||
watchers.flair,
|
||||
watchers.domain,
|
||||
watchers.hits,
|
||||
devices.id,
|
||||
devices.apns_token,
|
||||
devices.sandbox,
|
||||
accounts.id,
|
||||
accounts.access_token,
|
||||
accounts.refresh_token
|
||||
FROM watchers
|
||||
WHERE type = $1 AND watchee_id = $2`
|
||||
INNER JOIN devices ON watchers.device_id = devices.id
|
||||
INNER JOIN accounts ON watchers.account_id = accounts.id
|
||||
WHERE watchers.type = $1 AND watchers.watchee_id = $2`
|
||||
|
||||
return p.fetch(ctx, query, typ, id)
|
||||
}
|
||||
|
||||
func (p *postgresWatcherRepository) GetByTrendingSubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
|
||||
return p.GetByTypeAndWatcheeID(ctx, domain.TrendingWatcher, id)
|
||||
}
|
||||
|
||||
func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
|
||||
return p.GetByTypeAndWatcheeID(ctx, domain.SubredditWatcher, id)
|
||||
}
|
||||
|
@ -88,16 +144,24 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
|
|||
SELECT
|
||||
watchers.id,
|
||||
watchers.created_at,
|
||||
watchers.last_notified_at
|
||||
watchers.last_notified_at,
|
||||
watchers.label,
|
||||
watchers.device_id,
|
||||
watchers.account_id,
|
||||
watchers.type,
|
||||
watchers.watchee_id,
|
||||
watchers.author,
|
||||
watchers.upvotes,
|
||||
watchers.keyword,
|
||||
watchers.flair,
|
||||
watchers.domain,
|
||||
watchers.hits
|
||||
watchers.hits,
|
||||
devices.id,
|
||||
devices.apns_token,
|
||||
devices.sandbox,
|
||||
accounts.id,
|
||||
accounts.access_token,
|
||||
accounts.refresh_token
|
||||
FROM watchers
|
||||
INNER JOIN accounts ON watchers.account_id = accounts.id
|
||||
INNER JOIN devices ON watchers.device_id = devices.id
|
||||
|
@ -113,18 +177,20 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
|
|||
|
||||
query := `
|
||||
INSERT INTO watchers
|
||||
(created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain)
|
||||
VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
(created_at, last_notified_at, label, device_id, account_id, type, watchee_id, author, upvotes, keyword, flair, domain)
|
||||
VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id`
|
||||
|
||||
return p.pool.QueryRow(
|
||||
ctx,
|
||||
query,
|
||||
now,
|
||||
watcher.Label,
|
||||
watcher.DeviceID,
|
||||
watcher.AccountID,
|
||||
watcher.Type,
|
||||
watcher.WatcheeID,
|
||||
watcher.Author,
|
||||
watcher.Upvotes,
|
||||
watcher.Keyword,
|
||||
watcher.Flair,
|
||||
|
@ -135,20 +201,24 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
|
|||
func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.Watcher) error {
|
||||
query := `
|
||||
UPDATE watchers
|
||||
SET upvotes = $2,
|
||||
keyword = $3,
|
||||
flair = $4,
|
||||
domain = $5,
|
||||
SET author = $2,
|
||||
upvotes = $3,
|
||||
keyword = $4,
|
||||
flair = $5,
|
||||
domain = $6,
|
||||
label = $7
|
||||
WHERE id = $1`
|
||||
|
||||
res, err := p.pool.Exec(
|
||||
ctx,
|
||||
query,
|
||||
watcher.ID,
|
||||
watcher.Author,
|
||||
watcher.Upvotes,
|
||||
watcher.Keyword,
|
||||
watcher.Flair,
|
||||
watcher.Domain,
|
||||
watcher.Label,
|
||||
)
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
|
|
|
@ -330,6 +330,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
"status": res.StatusCode,
|
||||
"reason": res.Reason,
|
||||
}).Error("failed to send notification")
|
||||
|
||||
// Delete device as notifications might have been disabled here
|
||||
_ = nc.deviceRepo.Delete(ctx, device.APNSToken)
|
||||
} else {
|
||||
_ = nc.statsd.Incr("apns.notification.sent", []string{}, 1)
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
|
|
|
@ -40,6 +40,8 @@ type subredditsWorker struct {
|
|||
watcherRepo domain.WatcherRepository
|
||||
}
|
||||
|
||||
const subredditNotificationTitleFormat = "📣 %s"
|
||||
|
||||
func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
|
||||
reddit := reddit.NewClient(
|
||||
os.Getenv("REDDIT_CLIENT_ID"),
|
||||
|
@ -170,7 +172,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
if len(watchers) == 0 {
|
||||
sc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
}).Info("no watchers for subreddit, skipping")
|
||||
}).Debug("no watchers for subreddit, finishing job")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -298,11 +300,12 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
"count": len(posts),
|
||||
}).Debug("checking posts for hits")
|
||||
for _, post := range posts {
|
||||
lowcaseAuthor := strings.ToLower(post.Author)
|
||||
lowcaseTitle := strings.ToLower(post.Title)
|
||||
lowcaseFlair := strings.ToLower(post.Flair)
|
||||
lowcaseDomain := strings.ToLower(post.URL)
|
||||
|
||||
ids := []int64{}
|
||||
notifs := []domain.Watcher{}
|
||||
|
||||
for _, watcher := range watchers {
|
||||
// Make sure we only alert on posts created after the search
|
||||
|
@ -312,6 +315,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
|
||||
matched := true
|
||||
|
||||
if watcher.Author != "" && lowcaseAuthor != watcher.Author {
|
||||
matched = false
|
||||
}
|
||||
|
||||
if watcher.Upvotes > 0 && post.Score < watcher.Upvotes {
|
||||
matched = false
|
||||
}
|
||||
|
@ -363,10 +370,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
}).Debug("got a hit")
|
||||
|
||||
sc.redis.SetEX(ctx, lockKey, true, 24*time.Hour)
|
||||
ids = append(ids, watcher.DeviceID)
|
||||
notifs = append(notifs, watcher)
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
if len(notifs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -374,19 +381,22 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
"subreddit#id": subreddit.ID,
|
||||
"subreddit#name": subreddit.Name,
|
||||
"post#id": post.ID,
|
||||
"count": len(ids),
|
||||
"count": len(notifs),
|
||||
}).Debug("got hits for post")
|
||||
|
||||
notification := &apns2.Notification{}
|
||||
notification.Topic = "com.christianselig.Apollo"
|
||||
notification.Payload = payloadFromPost(post)
|
||||
payload := payloadFromPost(post)
|
||||
|
||||
for _, id := range ids {
|
||||
device, _ := sc.deviceRepo.GetByID(ctx, id)
|
||||
notification.DeviceToken = device.APNSToken
|
||||
for _, watcher := range notifs {
|
||||
title := fmt.Sprintf(subredditNotificationTitleFormat, watcher.Label)
|
||||
payload.AlertTitle(title)
|
||||
|
||||
notification := &apns2.Notification{}
|
||||
notification.Topic = "com.christianselig.Apollo"
|
||||
notification.DeviceToken = watcher.Device.APNSToken
|
||||
notification.Payload = payload
|
||||
|
||||
client := sc.apnsProduction
|
||||
if device.Sandbox {
|
||||
if watcher.Device.Sandbox {
|
||||
client = sc.apnsSandbox
|
||||
}
|
||||
|
||||
|
@ -395,7 +405,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
_ = sc.statsd.Incr("apns.notification.errors", []string{}, 1)
|
||||
sc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"device#id": device.ID,
|
||||
"device#id": watcher.Device.ID,
|
||||
"err": err,
|
||||
"status": res.StatusCode,
|
||||
"reason": res.Reason,
|
||||
|
@ -404,8 +414,8 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
_ = sc.statsd.Incr("apns.notification.sent", []string{}, 1)
|
||||
sc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"device#id": device.ID,
|
||||
"device#token": device.APNSToken,
|
||||
"device#id": watcher.Device.ID,
|
||||
"device#token": watcher.Device.APNSToken,
|
||||
}).Info("sent notification")
|
||||
}
|
||||
}
|
||||
|
@ -418,11 +428,11 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
}
|
||||
|
||||
func payloadFromPost(post *reddit.Thing) *payload.Payload {
|
||||
title := fmt.Sprintf("📣 Subreddit Watch (r/%s)", post.Subreddit)
|
||||
subtitle := fmt.Sprintf("r/%s", post.Subreddit)
|
||||
|
||||
payload := payload.
|
||||
NewPayload().
|
||||
AlertTitle(title).
|
||||
AlertSubtitle(subtitle).
|
||||
AlertBody(post.Title).
|
||||
AlertSummaryArg(post.Subreddit).
|
||||
Category("post-watch").
|
||||
|
|
336
internal/worker/trending.go
Normal file
336
internal/worker/trending.go
Normal file
|
@ -0,0 +1,336 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/DataDog/datadog-go/statsd"
|
||||
"github.com/adjust/rmq/v4"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/sideshow/apns2"
|
||||
"github.com/sideshow/apns2/payload"
|
||||
"github.com/sideshow/apns2/token"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/christianselig/apollo-backend/internal/domain"
|
||||
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||
"github.com/christianselig/apollo-backend/internal/repository"
|
||||
)
|
||||
|
||||
type trendingWorker struct {
|
||||
logger *logrus.Logger
|
||||
statsd *statsd.Client
|
||||
redis *redis.Client
|
||||
queue rmq.Connection
|
||||
reddit *reddit.Client
|
||||
apns *token.Token
|
||||
|
||||
consumers int
|
||||
|
||||
accountRepo domain.AccountRepository
|
||||
deviceRepo domain.DeviceRepository
|
||||
subredditRepo domain.SubredditRepository
|
||||
watcherRepo domain.WatcherRepository
|
||||
}
|
||||
|
||||
const trendingNotificationTitleFormat = "🔥 Trending in r/%s"
|
||||
|
||||
func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
|
||||
reddit := reddit.NewClient(
|
||||
os.Getenv("REDDIT_CLIENT_ID"),
|
||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||
statsd,
|
||||
consumers,
|
||||
)
|
||||
|
||||
var apns *token.Token
|
||||
{
|
||||
authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
apns = &token.Token{
|
||||
AuthKey: authKey,
|
||||
KeyID: os.Getenv("APPLE_KEY_ID"),
|
||||
TeamID: os.Getenv("APPLE_TEAM_ID"),
|
||||
}
|
||||
}
|
||||
|
||||
return &trendingWorker{
|
||||
logger,
|
||||
statsd,
|
||||
redis,
|
||||
queue,
|
||||
reddit,
|
||||
apns,
|
||||
consumers,
|
||||
|
||||
repository.NewPostgresAccount(db),
|
||||
repository.NewPostgresDevice(db),
|
||||
repository.NewPostgresSubreddit(db),
|
||||
repository.NewPostgresWatcher(db),
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *trendingWorker) Start() error {
|
||||
queue, err := tw.queue.OpenQueue("trending")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tw.logger.WithFields(logrus.Fields{
|
||||
"numConsumers": tw.consumers,
|
||||
}).Info("starting up trending worker")
|
||||
|
||||
prefetchLimit := int64(tw.consumers * 2)
|
||||
|
||||
if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
host, _ := os.Hostname()
|
||||
|
||||
for i := 0; i < tw.consumers; i++ {
|
||||
name := fmt.Sprintf("consumer %s-%d", host, i)
|
||||
|
||||
consumer := NewTrendingConsumer(tw, i)
|
||||
if _, err := queue.AddConsumer(name, consumer); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tw *trendingWorker) Stop() {
|
||||
<-tw.queue.StopAllConsuming() // wait for all Consume() calls to finish
|
||||
}
|
||||
|
||||
type trendingConsumer struct {
|
||||
*trendingWorker
|
||||
tag int
|
||||
|
||||
apnsSandbox *apns2.Client
|
||||
apnsProduction *apns2.Client
|
||||
}
|
||||
|
||||
func NewTrendingConsumer(tw *trendingWorker, tag int) *trendingConsumer {
|
||||
return &trendingConsumer{
|
||||
tw,
|
||||
tag,
|
||||
apns2.NewTokenClient(tw.apns),
|
||||
apns2.NewTokenClient(tw.apns).Production(),
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
||||
ctx := context.Background()
|
||||
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": delivery.Payload(),
|
||||
}).Debug("starting job")
|
||||
|
||||
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
|
||||
if err != nil {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": delivery.Payload(),
|
||||
"err": err,
|
||||
}).Error("failed to parse subreddit ID")
|
||||
|
||||
_ = delivery.Reject()
|
||||
return
|
||||
}
|
||||
|
||||
defer func() { _ = delivery.Ack() }()
|
||||
|
||||
subreddit, err := tc.subredditRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"err": err,
|
||||
}).Error("failed to fetch subreddit from database")
|
||||
return
|
||||
}
|
||||
|
||||
watchers, err := tc.watcherRepo.GetByTrendingSubredditID(ctx, subreddit.ID)
|
||||
if err != nil {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"err": err,
|
||||
}).Error("failed to fetch watchers from database")
|
||||
return
|
||||
}
|
||||
|
||||
if len(watchers) == 0 {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
}).Debug("no watchers for trending, finishing job")
|
||||
return
|
||||
}
|
||||
|
||||
// Grab last month's top posts so we calculate a trending average
|
||||
i := rand.Intn(len(watchers))
|
||||
watcher := watchers[i]
|
||||
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||
|
||||
tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week"))
|
||||
if err != nil {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"subreddit#name": subreddit.Name,
|
||||
"err": err,
|
||||
}).Error("failed to fetch month's top posts")
|
||||
return
|
||||
}
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"subreddit#name": subreddit.Name,
|
||||
"count": tps.Count,
|
||||
}).Debug("loaded month's hot posts")
|
||||
|
||||
if tps.Count == 0 {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
}).Debug("no top posts for subreddit, returning")
|
||||
return
|
||||
}
|
||||
|
||||
if tps.Count < 20 {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
}).Debug("not enough posts, returning")
|
||||
return
|
||||
}
|
||||
|
||||
middlePost := tps.Count / 2
|
||||
medianScore := tps.Children[middlePost].Score
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"score": medianScore,
|
||||
}).Debug("calculated median score")
|
||||
|
||||
// Grab hot posts and filter out anything that's > 2 days old
|
||||
i = rand.Intn(len(watchers))
|
||||
watcher = watchers[i]
|
||||
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||
hps, err := rac.SubredditHot(subreddit.Name)
|
||||
if err != nil {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"subreddit#name": subreddit.Name,
|
||||
"err": err,
|
||||
}).Error("failed to fetch hot posts")
|
||||
return
|
||||
}
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"subreddit#name": subreddit.Name,
|
||||
"count": hps.Count,
|
||||
}).Debug("loaded hot posts")
|
||||
|
||||
// Trending only counts for posts less than 2 days old
|
||||
threshold := float64(time.Now().AddDate(0, 0, -2).UTC().Unix())
|
||||
|
||||
for _, post := range hps.Children {
|
||||
if post.Score < medianScore {
|
||||
continue
|
||||
}
|
||||
|
||||
if post.CreatedAt < threshold {
|
||||
break
|
||||
}
|
||||
|
||||
notification := &apns2.Notification{}
|
||||
notification.Topic = "com.christianselig.Apollo"
|
||||
notification.Payload = payloadFromTrendingPost(post)
|
||||
|
||||
for _, watcher := range watchers {
|
||||
if watcher.CreatedAt > post.CreatedAt {
|
||||
continue
|
||||
}
|
||||
|
||||
lockKey := fmt.Sprintf("watcher:trending:%d:%s", watcher.DeviceID, post.ID)
|
||||
notified, _ := tc.redis.Get(ctx, lockKey).Bool()
|
||||
|
||||
if notified {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"subreddit#name": subreddit.Name,
|
||||
"watcher#id": watcher.ID,
|
||||
"post#id": post.ID,
|
||||
}).Debug("already notified, skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
tc.redis.SetEX(ctx, lockKey, true, 48*time.Hour)
|
||||
|
||||
if err := tc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"watcher#id": watcher.ID,
|
||||
"err": err,
|
||||
}).Error("could not increment hits")
|
||||
return
|
||||
}
|
||||
|
||||
notification.DeviceToken = watcher.Device.APNSToken
|
||||
|
||||
client := tc.apnsProduction
|
||||
if watcher.Device.Sandbox {
|
||||
client = tc.apnsSandbox
|
||||
}
|
||||
|
||||
res, err := client.Push(notification)
|
||||
if err != nil {
|
||||
_ = tc.statsd.Incr("apns.notification.errors", []string{}, 1)
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"post#id": post.ID,
|
||||
"device#id": watcher.Device.ID,
|
||||
"err": err,
|
||||
"status": res.StatusCode,
|
||||
"reason": res.Reason,
|
||||
}).Error("failed to send notification")
|
||||
} else {
|
||||
_ = tc.statsd.Incr("apns.notification.sent", []string{}, 1)
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"post#id": post.ID,
|
||||
"device#id": watcher.Device.ID,
|
||||
"device#token": watcher.Device.APNSToken,
|
||||
}).Info("sent notification")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
"subreddit#id": subreddit.ID,
|
||||
"subreddit#name": subreddit.Name,
|
||||
}).Debug("finishing job")
|
||||
}
|
||||
|
||||
func payloadFromTrendingPost(post *reddit.Thing) *payload.Payload {
|
||||
title := fmt.Sprintf(trendingNotificationTitleFormat, post.Subreddit)
|
||||
|
||||
payload := payload.
|
||||
NewPayload().
|
||||
AlertTitle(title).
|
||||
AlertSubtitle(post.Title).
|
||||
AlertBody(post.Title).
|
||||
AlertSummaryArg(post.Subreddit).
|
||||
Category("post-watch").
|
||||
Custom("post_title", post.Title).
|
||||
Custom("post_id", post.ID).
|
||||
Custom("subreddit", post.Subreddit).
|
||||
Custom("author", post.Author).
|
||||
Custom("post_age", post.CreatedAt).
|
||||
MutableContent().
|
||||
Sound("traloop.wav")
|
||||
|
||||
return payload
|
||||
}
|
|
@ -38,6 +38,8 @@ type usersWorker struct {
|
|||
watcherRepo domain.WatcherRepository
|
||||
}
|
||||
|
||||
const userNotificationTitleFormat = "👨\u200d🚀 %s"
|
||||
|
||||
func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
|
||||
reddit := reddit.NewClient(
|
||||
os.Getenv("REDDIT_CLIENT_ID"),
|
||||
|
@ -224,9 +226,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
|
|||
continue
|
||||
}
|
||||
|
||||
notification := &apns2.Notification{}
|
||||
notification.Topic = "com.christianselig.Apollo"
|
||||
notification.Payload = payloadFromUserPost(post)
|
||||
payload := payloadFromUserPost(post)
|
||||
|
||||
for _, watcher := range watchers {
|
||||
// Make sure we only alert on activities created after the search
|
||||
|
@ -248,6 +248,13 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
|
|||
}
|
||||
|
||||
device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID)
|
||||
|
||||
title := fmt.Sprintf(userNotificationTitleFormat, watcher.Label)
|
||||
payload.AlertTitle(title)
|
||||
|
||||
notification := &apns2.Notification{}
|
||||
notification.Topic = "com.christianselig.Apollo"
|
||||
notification.Payload = payload
|
||||
notification.DeviceToken = device.APNSToken
|
||||
|
||||
client := uc.apnsProduction
|
||||
|
@ -283,12 +290,10 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
|
|||
}
|
||||
|
||||
func payloadFromUserPost(post *reddit.Thing) *payload.Payload {
|
||||
title := fmt.Sprintf("👨\u200d🚀 User post! (u/%s)", post.Author)
|
||||
|
||||
payload := payload.
|
||||
NewPayload().
|
||||
AlertTitle(title).
|
||||
AlertBody(post.Title).
|
||||
AlertSubtitle(post.Author).
|
||||
AlertSummaryArg(post.Author).
|
||||
Category("user-watch").
|
||||
Custom("post_title", post.Title).
|
||||
|
|
Loading…
Reference in a new issue