Live activities

This commit is contained in:
Andre Medeiros 2022-10-19 09:37:41 -04:00
parent 01a4ae9559
commit 696f932baa
12 changed files with 2383 additions and 11 deletions

View file

@ -34,6 +34,7 @@ type api struct {
subredditRepo domain.SubredditRepository
watcherRepo domain.WatcherRepository
userRepo domain.UserRepository
liveActivityRepo domain.LiveActivityRepository
}
func NewAPI(ctx context.Context, logger *zap.Logger, statsd *statsd.Client, redis *redis.Client, pool *pgxpool.Pool) *api {
@ -64,6 +65,7 @@ func NewAPI(ctx context.Context, logger *zap.Logger, statsd *statsd.Client, redi
subredditRepo := repository.NewPostgresSubreddit(pool)
watcherRepo := repository.NewPostgresWatcher(pool)
userRepo := repository.NewPostgresUser(pool)
liveActivityRepo := repository.NewPostgresLiveActivity(pool)
client := &http.Client{}
@ -79,6 +81,7 @@ func NewAPI(ctx context.Context, logger *zap.Logger, statsd *statsd.Client, redi
subredditRepo: subredditRepo,
watcherRepo: watcherRepo,
userRepo: userRepo,
liveActivityRepo: liveActivityRepo,
}
}
@ -115,6 +118,8 @@ func (a *api) Routes() *mux.Router {
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.editWatcherHandler).Methods("PATCH")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET")
r.HandleFunc("/v1/live_activities", a.createLiveActivityHandler).Methods("POST")
r.HandleFunc("/v1/receipt", a.checkReceiptHandler).Methods("POST")
r.HandleFunc("/v1/receipt/{apns}", a.checkReceiptHandler).Methods("POST")

View file

@ -0,0 +1,36 @@
package api
import (
"encoding/json"
"net/http"
"time"
"github.com/christianselig/apollo-backend/internal/domain"
)
func (a *api) createLiveActivityHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
la := &domain.LiveActivity{}
if err := json.NewDecoder(r.Body).Decode(la); err != nil {
a.errorResponse(w, r, 500, err)
return
}
ac := a.reddit.NewAuthenticatedClient(la.RedditAccountID, la.RefreshToken, la.AccessToken)
rtr, err := ac.RefreshTokens(ctx)
if err != nil {
a.errorResponse(w, r, 500, err)
return
}
la.RefreshToken = rtr.RefreshToken
la.TokenExpiresAt = time.Now().Add(1 * time.Hour).UTC()
if err := a.liveActivityRepo.Create(ctx, la); err != nil {
a.errorResponse(w, r, 500, err)
return
}
w.WriteHeader(http.StatusOK)
}

View file

@ -91,10 +91,16 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
return err
}
liveActivitiesQueue, err := queue.OpenQueue("live-activities")
if err != nil {
return err
}
s := gocron.NewScheduler(time.UTC)
_, _ = s.Every(5).Seconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
_, _ = s.Every(5).Second().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, []rmq.Queue{subredditQueue, trendingQueue}) })
_, _ = s.Every(5).Second().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) })
_, _ = s.Every(5).Second().Do(func() { enqueueLiveActivities(ctx, logger, db, redis, luaSha, liveActivitiesQueue) })
_, _ = s.Every(5).Second().Do(func() { cleanQueues(logger, queue) })
_, _ = s.Every(5).Second().Do(func() { enqueueStuckAccounts(ctx, logger, statsd, db, stuckNotificationsQueue) })
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db) })
@ -134,6 +140,57 @@ func evalScript(ctx context.Context, redis *redis.Client) (string, error) {
return redis.ScriptLoad(ctx, lua).Result()
}
func enqueueLiveActivities(ctx context.Context, logger *zap.Logger, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) {
now := time.Now().UTC()
next := now.Add(domain.LiveActivityCheckInterval)
stmt := `UPDATE live_activities
SET next_check_at = $2
WHERE id IN (
SELECT id
FROM live_activities
WHERE next_check_at < $1
ORDER BY next_check_at
FOR UPDATE SKIP LOCKED
LIMIT 100
)
RETURNING live_activities.apns_token`
ats := []string{}
rows, err := pool.Query(ctx, stmt, now, next)
if err != nil {
logger.Error("failed to fetch batch of live activities", zap.Error(err))
return
}
for rows.Next() {
var at string
_ = rows.Scan(&at)
ats = append(ats, at)
}
rows.Close()
if len(ats) == 0 {
return
}
batch, err := redisConn.EvalSha(ctx, luaSha, []string{"locks:live-activities"}, ats).StringSlice()
if err != nil {
logger.Error("failed to lock live activities", zap.Error(err))
return
}
if len(batch) == 0 {
return
}
logger.Debug("enqueueing live activity batch", zap.Int("count", len(batch)), zap.Time("start", now))
if err = queue.Publish(batch...); err != nil {
logger.Error("failed to enqueue live activity batch", zap.Error(err))
}
}
func pruneAccounts(ctx context.Context, logger *zap.Logger, pool *pgxpool.Pool) {
expiry := time.Now().Add(-domain.StaleTokenThreshold)
ar := repository.NewPostgresAccount(pool)

View file

@ -13,6 +13,7 @@ import (
var (
queues = map[string]worker.NewWorkerFn{
"live-activities": worker.NewLiveActivitiesWorker,
"notifications": worker.NewNotificationsWorker,
"stuck-notifications": worker.NewStuckNotificationsWorker,
"subreddits": worker.NewSubredditsWorker,

View file

@ -0,0 +1,38 @@
package domain
import (
"context"
"time"
)
const (
LiveActivityDuration = 75 * time.Minute
LiveActivityCheckInterval = 30 * time.Second
)
type LiveActivity struct {
ID int64
APNSToken string `json:"apns_token"`
Sandbox bool `json:"sandbox"`
RedditAccountID string `json:"reddit_account_id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenExpiresAt time.Time
ThreadID string `json:"thread_id"`
Subreddit string `json:"subreddit"`
NextCheckAt time.Time
ExpiresAt time.Time
}
type LiveActivityRepository interface {
Get(ctx context.Context, apnsToken string) (LiveActivity, error)
List(ctx context.Context) ([]LiveActivity, error)
Create(ctx context.Context, la *LiveActivity) error
Update(ctx context.Context, la *LiveActivity) error
RemoveStale(ctx context.Context) error
Delete(ctx context.Context, apns_token string) error
}

View file

@ -594,3 +594,25 @@ func (rac *AuthenticatedClient) Me(ctx context.Context, opts ...RequestOption) (
}
return mr.(*MeResponse), nil
}
func (rac *AuthenticatedClient) TopLevelComments(ctx context.Context, subreddit string, threadID string, opts ...RequestOption) (*ThreadResponse, error) {
url := fmt.Sprintf("https://oauth.reddit.com/r/%s/comments/%s/.json", subreddit, threadID)
opts = append(rac.client.defaultOpts, opts...)
opts = append(opts, []RequestOption{
WithTags([]string{"url:/comments"}),
WithMethod("GET"),
WithToken(rac.accessToken),
WithURL(url),
WithQuery("sort", "new"),
WithQuery("limit", "100"),
WithQuery("depth", "1"),
}...)
req := NewRequest(opts...)
tr, err := rac.request(ctx, req, NewThreadResponse, nil)
if err != nil {
return nil, err
}
return tr.(*ThreadResponse), nil
}

1734
internal/reddit/testdata/thread.json vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -63,6 +63,27 @@ func NewMeResponse(val *fastjson.Value) interface{} {
return mr
}
type ThreadResponse struct {
Post *Thing
Children []*Thing
}
func NewThreadResponse(val *fastjson.Value) interface{} {
t := &ThreadResponse{}
listings := val.GetArray()
// Thread details comes in the first element of the array as a one item listing
t.Post = NewThing(listings[0].Get("data").GetArray("children")[0])
// Comments come in the second element of the array also as a listing
comments := listings[1].Get("data").GetArray("children")
t.Children = make([]*Thing, len(comments)-1)
for i, comment := range comments[:len(comments)-1] {
t.Children[i] = NewThing(comment)
}
return t
}
type Thing struct {
Kind string `json:"kind"`
ID string `json:"id"`
@ -84,6 +105,7 @@ type Thing struct {
Flair string `json:"flair"`
Thumbnail string `json:"thumbnail"`
Over18 bool `json:"over_18"`
NumComments int `json:"num_comments"`
}
func (t *Thing) FullName() string {
@ -122,6 +144,7 @@ func NewThing(val *fastjson.Value) *Thing {
t.Flair = string(data.GetStringBytes("link_flair_text"))
t.Thumbnail = string(data.GetStringBytes("thumbnail"))
t.Over18 = data.GetBool("over_18")
t.NumComments = data.GetInt("num_comments")
return t
}

View file

@ -178,3 +178,24 @@ func TestUserPostsParsing(t *testing.T) {
assert.Equal(t, "public", post.SubredditType)
}
func TestThreadResponseParsing(t *testing.T) {
t.Parallel()
bb, err := ioutil.ReadFile("testdata/thread.json")
assert.NoError(t, err)
parser := NewTestParser(t)
val, err := parser.ParseBytes(bb)
assert.NoError(t, err)
ret := reddit.NewThreadResponse(val)
tr := ret.(*reddit.ThreadResponse)
assert.NotNil(t, tr)
assert.Equal(t, "When you buy $400 machine to run games that you can run using $15 RPi", tr.Post.Title)
assert.Equal(t, 20, len(tr.Children))
assert.Equal(t, "The Deck is a lot more portable than the Pi though.", tr.Children[0].Body)
assert.Equal(t, "PhonicUK", tr.Children[1].Author)
}

View file

@ -0,0 +1,123 @@
package repository
import (
"context"
"time"
"github.com/christianselig/apollo-backend/internal/domain"
)
type postgresLiveActivityRepository struct {
conn Connection
}
func NewPostgresLiveActivity(conn Connection) domain.LiveActivityRepository {
return &postgresLiveActivityRepository{conn: conn}
}
func (p *postgresLiveActivityRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.LiveActivity, error) {
rows, err := p.conn.Query(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var las []domain.LiveActivity
for rows.Next() {
var la domain.LiveActivity
if err := rows.Scan(
&la.ID,
&la.APNSToken,
&la.Sandbox,
&la.RedditAccountID,
&la.AccessToken,
&la.RefreshToken,
&la.TokenExpiresAt,
&la.ThreadID,
&la.Subreddit,
&la.NextCheckAt,
&la.ExpiresAt,
); err != nil {
return nil, err
}
las = append(las, la)
}
return las, nil
}
func (p *postgresLiveActivityRepository) Get(ctx context.Context, apnsToken string) (domain.LiveActivity, error) {
query := `
SELECT id, apns_token, sandbox, reddit_account_id, access_token, refresh_token, token_expires_at, thread_id, subreddit, next_check_at, expires_at
FROM live_activities
WHERE apns_token = $1`
las, err := p.fetch(ctx, query, apnsToken)
if err != nil {
return domain.LiveActivity{}, err
}
if len(las) == 0 {
return domain.LiveActivity{}, domain.ErrNotFound
}
return las[0], nil
}
func (p *postgresLiveActivityRepository) List(ctx context.Context) ([]domain.LiveActivity, error) {
query := `
SELECT id, apns_token, sandbox, reddit_account_id, access_token, refresh_token, token_expires_at, thread_id, subreddit, next_check_at, expires_at
FROM live_activities
WHERE expires_at > NOW()`
return p.fetch(ctx, query)
}
func (p *postgresLiveActivityRepository) Create(ctx context.Context, la *domain.LiveActivity) error {
query := `
INSERT INTO live_activities (apns_token, sandbox, reddit_account_id, access_token, refresh_token, token_expires_at, thread_id, subreddit, next_check_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (apns_token) DO UPDATE SET expires_at = $10
RETURNING id`
return p.conn.QueryRow(ctx, query,
la.APNSToken,
la.Sandbox,
la.RedditAccountID,
la.AccessToken,
la.RefreshToken,
la.TokenExpiresAt,
la.ThreadID,
la.Subreddit,
time.Now().UTC(),
time.Now().Add(domain.LiveActivityDuration).UTC(),
).Scan(&la.ID)
}
func (p *postgresLiveActivityRepository) Update(ctx context.Context, la *domain.LiveActivity) error {
query := `
UPDATE live_activities
SET access_token = $1, refresh_token = $2, token_expires_at = $3, next_check_at = $4
WHERE id = $5`
_, err := p.conn.Exec(ctx, query,
la.AccessToken,
la.RefreshToken,
la.TokenExpiresAt,
la.NextCheckAt,
la.ID,
)
return err
}
func (p *postgresLiveActivityRepository) RemoveStale(ctx context.Context) error {
query := `DELETE FROM live_activities WHERE expires_at < NOW()`
_, err := p.conn.Exec(ctx, query)
return err
}
func (p *postgresLiveActivityRepository) Delete(ctx context.Context, apns_token string) error {
query := `DELETE FROM live_activities WHERE apns_token = $1`
_, err := p.conn.Exec(ctx, query, apns_token)
return err
}

View file

@ -0,0 +1,312 @@
package worker
import (
"context"
"encoding/json"
"fmt"
"os"
"sort"
"time"
"github.com/DataDog/datadog-go/statsd"
"github.com/adjust/rmq/v4"
"github.com/go-redis/redis/v8"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/sideshow/apns2"
"github.com/sideshow/apns2/token"
"go.uber.org/zap"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit"
"github.com/christianselig/apollo-backend/internal/repository"
)
type DynamicIslandNotification struct {
PostCommentCount int `json:"postTotalComments"`
PostScore int64 `json:"postScore"`
CommentAuthor string `json:"commentAuthor"`
CommentBody string `json:"commentBody"`
CommentAge int64 `json:"commentAge"`
CommentScore int64 `json:"commentScore"`
}
type liveActivitiesWorker struct {
context.Context
logger *zap.Logger
statsd *statsd.Client
db *pgxpool.Pool
redis *redis.Client
queue rmq.Connection
reddit *reddit.Client
apns *token.Token
consumers int
liveActivityRepo domain.LiveActivityRepository
}
func NewLiveActivitiesWorker(ctx context.Context, logger *zap.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
redis,
consumers,
)
var apns *token.Token
{
authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH"))
if err != nil {
panic(err)
}
apns = &token.Token{
AuthKey: authKey,
KeyID: os.Getenv("APPLE_KEY_ID"),
TeamID: os.Getenv("APPLE_TEAM_ID"),
}
}
return &liveActivitiesWorker{
ctx,
logger,
statsd,
db,
redis,
queue,
reddit,
apns,
consumers,
repository.NewPostgresLiveActivity(db),
}
}
func (law *liveActivitiesWorker) Start() error {
queue, err := law.queue.OpenQueue("live-activities")
if err != nil {
return err
}
law.logger.Info("starting up live activities worker", zap.Int("consumers", law.consumers))
prefetchLimit := int64(law.consumers * 4)
if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil {
return err
}
host, _ := os.Hostname()
for i := 0; i < law.consumers; i++ {
name := fmt.Sprintf("consumer %s-%d", host, i)
consumer := NewLiveActivitiesConsumer(law, i)
if _, err := queue.AddConsumer(name, consumer); err != nil {
return err
}
}
return nil
}
func (law *liveActivitiesWorker) Stop() {
<-law.queue.StopAllConsuming() // wait for all Consume() calls to finish
}
type liveActivitiesConsumer struct {
*liveActivitiesWorker
tag int
apnsSandbox *apns2.Client
apnsProduction *apns2.Client
}
func NewLiveActivitiesConsumer(law *liveActivitiesWorker, tag int) *liveActivitiesConsumer {
return &liveActivitiesConsumer{
law,
tag,
apns2.NewTokenClient(law.apns),
apns2.NewTokenClient(law.apns).Production(),
}
}
func (lac *liveActivitiesConsumer) Consume(delivery rmq.Delivery) {
now := time.Now().UTC()
defer func() {
elapsed := time.Now().Sub(now).Milliseconds()
_ = lac.statsd.Histogram("apollo.consumer.runtime", float64(elapsed), []string{"queue:live_activities"}, 0.1)
}()
at := delivery.Payload()
key := fmt.Sprintf("locks:live-activities:%s", at)
// Measure queue latency
ttl := lac.redis.TTL(lac, key).Val()
age := (domain.NotificationCheckTimeout - ttl)
_ = lac.statsd.Histogram("apollo.dequeue.latency", float64(age.Milliseconds()), []string{"queue:live_activities"}, 0.1)
defer func() {
if err := lac.redis.Del(lac, key).Err(); err != nil {
lac.logger.Error("failed to remove account lock", zap.Error(err), zap.String("key", key))
}
}()
lac.logger.Debug("starting job", zap.String("live_activity#apns_token", at))
defer func() {
if err := delivery.Ack(); err != nil {
lac.logger.Error("failed to acknowledge message", zap.Error(err), zap.String("live_activity#apns_token", at))
}
}()
la, err := lac.liveActivityRepo.Get(lac, at)
if err != nil {
lac.logger.Error("failed to get live activity", zap.Error(err), zap.String("live_activity#apns_token", at))
return
}
rac := lac.reddit.NewAuthenticatedClient(la.RedditAccountID, la.RefreshToken, la.AccessToken)
if la.TokenExpiresAt.Before(now.Add(5 * time.Minute)) {
lac.logger.Debug("refreshing reddit token",
zap.String("live_activity#apns_token", at),
)
tokens, err := rac.RefreshTokens(lac)
if err != nil {
if err != reddit.ErrOauthRevoked {
lac.logger.Error("failed to refresh reddit tokens",
zap.Error(err),
zap.String("live_activity#apns_token", at),
)
return
}
err = lac.liveActivityRepo.Delete(lac, at)
if err != nil {
lac.logger.Error("failed to remove revoked account",
zap.Error(err),
zap.String("live_activity#apns_token", at),
)
}
return
}
// Update account
la.AccessToken = tokens.AccessToken
la.RefreshToken = tokens.RefreshToken
la.TokenExpiresAt = now.Add(tokens.Expiry)
_ = lac.liveActivityRepo.Update(lac, &la)
// Refresh client
rac = lac.reddit.NewAuthenticatedClient(la.RedditAccountID, tokens.RefreshToken, tokens.AccessToken)
}
lac.logger.Debug("fetching latest comments", zap.String("live_activity#apns_token", at))
tr, err := rac.TopLevelComments(lac, la.Subreddit, la.ThreadID)
if err != nil {
lac.logger.Error("failed to fetch latest comments",
zap.Error(err),
zap.String("live_activity#apns_token", at),
)
return
}
if len(tr.Children) == 0 {
lac.logger.Debug("no comments found", zap.String("live_activity#apns_token", at))
return
}
// Filter out comments in the last minute
candidates := make([]*reddit.Thing, 0)
cutoff := now.Add(-domain.LiveActivityCheckInterval)
for _, t := range tr.Children {
if t.CreatedAt.After(cutoff) {
candidates = append(candidates, t)
}
}
if len(candidates) == 0 {
lac.logger.Debug("no new comments found", zap.String("live_activity#apns_token", at))
return
}
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].Score > candidates[j].Score
})
comment := candidates[0]
din := DynamicIslandNotification{
PostCommentCount: tr.Post.NumComments,
PostScore: tr.Post.Score,
CommentAuthor: comment.Author,
CommentBody: comment.Body,
CommentAge: comment.CreatedAt.Unix(),
CommentScore: comment.Score,
}
ev := "update"
if la.ExpiresAt.Before(now) {
ev = "end"
}
pl := map[string]interface{}{
"aps": map[string]interface{}{
"timestamp": time.Now().Unix(),
"event": ev,
"content-state": din,
},
}
bb, _ := json.Marshal(pl)
notification := &apns2.Notification{
DeviceToken: la.APNSToken,
Topic: "com.christianselig.Apollo.push-type.liveactivity",
PushType: "liveactivity",
Payload: bb,
}
client := lac.apnsProduction
if la.Sandbox {
client = lac.apnsSandbox
}
res, err := client.PushWithContext(lac, notification)
if err != nil {
_ = lac.statsd.Incr("apns.live_activities.errors", []string{}, 1)
lac.logger.Error("failed to send notification",
zap.Error(err),
zap.String("live_activity#apns_token", at),
)
_ = lac.liveActivityRepo.Delete(lac, at)
} else if !res.Sent() {
_ = lac.statsd.Incr("apns.live_activities.errors", []string{}, 1)
lac.logger.Error("notification not sent",
zap.String("live_activity#apns_token", at),
zap.Int("response#status", res.StatusCode),
zap.String("response#reason", res.Reason),
)
_ = lac.liveActivityRepo.Delete(lac, at)
} else {
_ = lac.statsd.Incr("apns.live_activities.sent", []string{}, 1)
lac.logger.Info("sent notification",
zap.String("live_activity#apns_token", at),
)
}
if la.ExpiresAt.Before(now) {
lac.logger.Debug("live activity expired, deleting", zap.String("live_activity#apns_token", at))
_ = lac.liveActivityRepo.Delete(lac, at)
}
lac.logger.Debug("finishing job",
zap.String("live_activity#apns_token", at),
)
}

View file

@ -321,7 +321,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
client = nc.apnsSandbox
}
res, err := client.Push(notification)
res, err := client.PushWithContext(nc, notification)
if err != nil {
_ = nc.statsd.Incr("apns.notification.errors", []string{}, 1)
nc.logger.Error("failed to send notification",