diff --git a/internal/cmd/scheduler.go b/internal/cmd/scheduler.go index e68f577..ecd9add 100644 --- a/internal/cmd/scheduler.go +++ b/internal/cmd/scheduler.go @@ -23,9 +23,10 @@ const ( batchSize = 250 checkTimeout = 60 // how long until we force a check - accountEnqueueTimeout = 5 // how frequently we want to check (seconds) - subredditEnqueueTimeout = 2 * 60 // how frequently we want to check (seconds) - userEnqueueTimeout = 2 * 60 // how frequently we want to check (seconds) + accountEnqueueInterval = 5 // how frequently we want to check (seconds) + subredditEnqueueInterval = 2 * 60 // how frequently we want to check (seconds) + userEnqueueInterval = 2 * 60 // how frequently we want to check (seconds) + stuckAccountEnqueueInterval = 1 * 60 // how frequently we want to check (seconds) staleAccountThreshold = 7200 // 2 hours ) @@ -87,11 +88,17 @@ func SchedulerCmd(ctx context.Context) *cobra.Command { return err } + stuckNotificationsQueue, err := queue.OpenQueue("stuck-notifications") + if err != nil { + return err + } + s := gocron.NewScheduler(time.UTC) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) }) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, []rmq.Queue{subredditQueue, trendingQueue}) }) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) }) _, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) }) + _, _ = s.Every(1).Second().Do(func() { enqueueStuckAccounts(ctx, logger, statsd, db, stuckNotificationsQueue) }) _, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) }) _, _ = s.Every(1).Minute().Do(func() { pruneAccounts(ctx, logger, db) }) _, _ = s.Every(1).Minute().Do(func() { pruneDevices(ctx, logger, db) }) @@ -185,9 +192,11 @@ func cleanQueues(ctx context.Context, logger *logrus.Logger, jobsConn rmq.Connec return } - logger.WithFields(logrus.Fields{ - "count": count, - }).Debug("returned jobs to queues") + if count > 0 { + logger.WithFields(logrus.Fields{ + "count": count, + }).Info("returned jobs to queues") + } } func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client) { @@ -200,6 +209,8 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie }{ {"SELECT COUNT(*) FROM accounts", "apollo.registrations.accounts"}, {"SELECT COUNT(*) FROM devices", "apollo.registrations.devices"}, + {"SELECT COUNT(*) FROM subreddits", "apollo.registrations.subreddits"}, + {"SELECT COUNT(*) FROM users", "apollo.registrations.users"}, } ) @@ -216,7 +227,7 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { now := time.Now() - ready := now.Unix() - userEnqueueTimeout + ready := now.Unix() - userEnqueueInterval ids := []int64{} @@ -273,14 +284,14 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli }).Error("failed to enqueue user") } - _ = statsd.Histogram("apollo.queue.users.enqueued", float64(len(ids)), []string{}, 1) - _ = statsd.Histogram("apollo.queue.users.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1) - + tags := []string{"queue:users"} + _ = statsd.Histogram("apollo.queue.enqueued", float64(len(ids)), tags, 1) + _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) } func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) { now := time.Now() - ready := now.Unix() - subredditEnqueueTimeout + ready := now.Unix() - subredditEnqueueInterval ids := []int64{} @@ -340,20 +351,85 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats } } - _ = statsd.Histogram("apollo.queue.subreddits.enqueued", float64(len(ids)), []string{}, 1) - _ = statsd.Histogram("apollo.queue.subreddits.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1) + tags := []string{"queue:subreddits"} + _ = statsd.Histogram("apollo.queue.enqueued", float64(len(ids)), tags, 1) + _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) +} + +func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { + now := time.Now() + ready := now.Unix() - stuckAccountEnqueueInterval + + ids := []int64{} + + err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { + stmt := ` + WITH account AS ( + SELECT id + FROM accounts + WHERE + last_unstuck_at < $1 + ORDER BY last_unstuck_at + LIMIT 500 + ) + UPDATE accounts + SET last_unstuck_at = $2 + WHERE accounts.id IN(SELECT id FROM account) + RETURNING accounts.id` + rows, err := tx.Query(ctx, stmt, ready, now.Unix()) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var id int64 + _ = rows.Scan(&id) + ids = append(ids, id) + } + return nil + }) + + if err != nil { + logger.WithFields(logrus.Fields{ + "err": err, + }).Error("failed to fetch possible stuck accounts") + return + } + + if len(ids) == 0 { + return + } + + logger.WithFields(logrus.Fields{ + "count": len(ids), + "start": ready, + }).Debug("enqueueing stuck account batch") + + batchIds := make([]string, len(ids)) + for i, id := range ids { + batchIds[i] = strconv.FormatInt(id, 10) + } + + if err = queue.Publish(batchIds...); err != nil { + logger.WithFields(logrus.Fields{ + "queue": queue, + "err": err, + }).Error("failed to enqueue stuck accounts") + } + + tags := []string{"queue:stuck-accounts"} + _ = statsd.Histogram("apollo.queue.enqueued", float64(len(ids)), tags, 1) + _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) } func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) { - start := time.Now() - - now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000 + now := time.Now() // Start looking for accounts that were last checked at least 5 seconds ago // and at most 6 seconds ago. Also look for accounts that haven't been checked // in over a minute. - ts := start.Unix() - ready := ts - accountEnqueueTimeout + ts := now.Unix() + ready := ts - accountEnqueueInterval expired := ts - checkTimeout ids := []int64{} @@ -373,7 +449,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd. SET last_enqueued_at = $3 WHERE accounts.id IN(SELECT id FROM account) RETURNING accounts.id` - rows, err := tx.Query(ctx, stmt, ready, expired, now) + rows, err := tx.Query(ctx, stmt, ready, expired, ts) if err != nil { return err } @@ -393,6 +469,10 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd. return } + if len(ids) == 0 { + return + } + logger.WithFields(logrus.Fields{ "count": len(ids), "start": ready, @@ -441,9 +521,10 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd. } } - _ = statsd.Histogram("apollo.queue.notifications.enqueued", float64(enqueued), []string{}, 1) - _ = statsd.Histogram("apollo.queue.notifications.skipped", float64(skipped), []string{}, 1) - _ = statsd.Histogram("apollo.queue.notifications.runtime", float64(time.Since(start).Milliseconds()), []string{}, 1) + tags := []string{"queue:notifications"} + _ = statsd.Histogram("apollo.queue.enqueued", float64(enqueued), tags, 1) + _ = statsd.Histogram("apollo.queue.skipped", float64(skipped), tags, 1) + _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) logger.WithFields(logrus.Fields{ "count": enqueued, diff --git a/internal/cmd/worker.go b/internal/cmd/worker.go index 907e3f8..97416f0 100644 --- a/internal/cmd/worker.go +++ b/internal/cmd/worker.go @@ -13,10 +13,11 @@ import ( var ( queues = map[string]worker.NewWorkerFn{ - "notifications": worker.NewNotificationsWorker, - "subreddits": worker.NewSubredditsWorker, - "trending": worker.NewTrendingWorker, - "users": worker.NewUsersWorker, + "notifications": worker.NewNotificationsWorker, + "stuck-notifications": worker.NewStuckNotificationsWorker, + "subreddits": worker.NewSubredditsWorker, + "trending": worker.NewTrendingWorker, + "users": worker.NewUsersWorker, } ) diff --git a/internal/domain/account.go b/internal/domain/account.go index 97ef2bd..7e9cd0c 100644 --- a/internal/domain/account.go +++ b/internal/domain/account.go @@ -21,6 +21,7 @@ type Account struct { // Tracking how far behind we are LastMessageID string LastCheckedAt float64 + LastUnstuckAt float64 } func (acct *Account) NormalizedUsername() string { diff --git a/internal/reddit/client.go b/internal/reddit/client.go index e2a061d..8accdff 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -171,6 +171,23 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { return ret, nil } +func (rac *AuthenticatedClient) AboutInfo(fullname string, opts ...RequestOption) (*ListingResponse, error) { + opts = append([]RequestOption{ + WithMethod("GET"), + WithToken(rac.accessToken), + WithURL("https://oauth.reddit.com/api/info"), + WithQuery("id", fullname), + }, opts...) + req := NewRequest(opts...) + + lr, err := rac.request(req, NewListingResponse, nil) + if err != nil { + return nil, err + } + + return lr.(*ListingResponse), nil +} + func (rac *AuthenticatedClient) UserPosts(user string, opts ...RequestOption) (*ListingResponse, error) { url := fmt.Sprintf("https://oauth.reddit.com/u/%s/submitted.json", user) opts = append([]RequestOption{ diff --git a/internal/repository/postgres_account.go b/internal/repository/postgres_account.go index 03c636a..9db52e6 100644 --- a/internal/repository/postgres_account.go +++ b/internal/repository/postgres_account.go @@ -36,6 +36,7 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg &acc.ExpiresAt, &acc.LastMessageID, &acc.LastCheckedAt, + &acc.LastUnstuckAt, ); err != nil { return nil, err } @@ -46,7 +47,7 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (domain.Account, error) { query := ` - SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at + SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at FROM accounts WHERE id = $1` @@ -80,8 +81,8 @@ func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string } func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error { query := ` - INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at) - VALUES ($1, $2, $3, $4, $5, '', 0) + INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at) + VALUES ($1, $2, $3, $4, $5, '', 0, 0) ON CONFLICT(username) DO UPDATE SET access_token = $3, refresh_token = $4, @@ -102,8 +103,8 @@ func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *dom func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error { query := ` INSERT INTO accounts - (username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at) - VALUES ($1, $2, $3, $4, $5, $6, $7) + (username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id` return p.pool.QueryRow( @@ -116,6 +117,7 @@ func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Acco acc.ExpiresAt, acc.LastMessageID, acc.LastCheckedAt, + acc.LastUnstuckAt, ).Scan(&acc.ID) } @@ -128,7 +130,8 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco refresh_token = $5, expires_at = $6, last_message_id = $7, - last_checked_at = $8 + last_checked_at = $8, + last_unstuck_at = $9 WHERE id = $1` res, err := p.pool.Exec( @@ -142,6 +145,7 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco acc.ExpiresAt, acc.LastMessageID, acc.LastCheckedAt, + acc.LastUnstuckAt, ) if res.RowsAffected() != 1 { diff --git a/internal/worker/stuck_notifications.go b/internal/worker/stuck_notifications.go new file mode 100644 index 0000000..cddc34a --- /dev/null +++ b/internal/worker/stuck_notifications.go @@ -0,0 +1,186 @@ +package worker + +import ( + "context" + "fmt" + "os" + "strconv" + + "github.com/DataDog/datadog-go/statsd" + "github.com/adjust/rmq/v4" + "github.com/go-redis/redis/v8" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/sirupsen/logrus" + + "github.com/christianselig/apollo-backend/internal/domain" + "github.com/christianselig/apollo-backend/internal/reddit" + "github.com/christianselig/apollo-backend/internal/repository" +) + +type stuckNotificationsWorker struct { + logger *logrus.Logger + statsd *statsd.Client + db *pgxpool.Pool + redis *redis.Client + queue rmq.Connection + reddit *reddit.Client + + consumers int + + accountRepo domain.AccountRepository +} + +func NewStuckNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker { + reddit := reddit.NewClient( + os.Getenv("REDDIT_CLIENT_ID"), + os.Getenv("REDDIT_CLIENT_SECRET"), + statsd, + consumers, + ) + + return &stuckNotificationsWorker{ + logger, + statsd, + db, + redis, + queue, + reddit, + consumers, + + repository.NewPostgresAccount(db), + } +} + +func (snw *stuckNotificationsWorker) Start() error { + queue, err := snw.queue.OpenQueue("stuck-notifications") + if err != nil { + return err + } + + snw.logger.WithFields(logrus.Fields{ + "numConsumers": snw.consumers, + }).Info("starting up stuck notifications worker") + + prefetchLimit := int64(snw.consumers * 2) + + if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil { + return err + } + + host, _ := os.Hostname() + + for i := 0; i < snw.consumers; i++ { + name := fmt.Sprintf("consumer %s-%d", host, i) + + consumer := NewStuckNotificationsConsumer(snw, i) + if _, err := queue.AddConsumer(name, consumer); err != nil { + return err + } + } + + return nil +} + +func (snw *stuckNotificationsWorker) Stop() { + <-snw.queue.StopAllConsuming() // wait for all Consume() calls to finish +} + +type stuckNotificationsConsumer struct { + *stuckNotificationsWorker + tag int +} + +func NewStuckNotificationsConsumer(snw *stuckNotificationsWorker, tag int) *stuckNotificationsConsumer { + return &stuckNotificationsConsumer{ + snw, + tag, + } +} + +func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { + ctx := context.Background() + + snc.logger.WithFields(logrus.Fields{ + "account#id": delivery.Payload(), + }).Debug("starting job") + + id, err := strconv.ParseInt(delivery.Payload(), 10, 64) + if err != nil { + snc.logger.WithFields(logrus.Fields{ + "account#id": delivery.Payload(), + "err": err, + }).Error("failed to parse account ID") + + _ = delivery.Reject() + return + } + + defer func() { _ = delivery.Ack() }() + + account, err := snc.accountRepo.GetByID(ctx, id) + if err != nil { + snc.logger.WithFields(logrus.Fields{ + "err": err, + }).Error("failed to fetch account from database") + return + } + + if account.LastMessageID == "" { + snc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + }).Debug("account has no messages, returning") + return + } + + rac := snc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) + + snc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + "thing#id": account.LastMessageID, + }).Debug("fetching last thing") + + things, err := rac.AboutInfo(account.LastMessageID) + if err != nil { + snc.logger.WithFields(logrus.Fields{ + "err": err, + }).Error("failed to fetch last thing") + return + } + + if things.Count == 1 { + thing := things.Children[0] + if thing.Author != "[deleted]" { + snc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + "thing#id": account.LastMessageID, + }).Debug("thing exists, returning") + return + } + } + + snc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + "thing#id": account.LastMessageID, + }).Debug("thing got deleted, resetting") + + things, err = rac.MessageInbox() + if err != nil { + snc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + "err": err, + }).Error("failed to get message inbox") + return + } + + account.LastMessageID = "" + if things.Count > 0 { + account.LastMessageID = things.Children[0].FullName() + } + + if err := snc.accountRepo.Update(ctx, &account); err != nil { + snc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + "err": err, + }).Error("failed to update account's message id") + } +}