From 1a861ea6285e38547b0ef87d035eafc49c2f94e4 Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Wed, 26 Oct 2022 20:46:17 -0400 Subject: [PATCH] finalize context fixing --- internal/worker/stuck_notifications.go | 15 +++++++++------ internal/worker/subreddits.go | 19 +++++++++++-------- internal/worker/trending.go | 17 ++++++++++------- internal/worker/users.go | 21 ++++++++++++--------- 4 files changed, 42 insertions(+), 30 deletions(-) diff --git a/internal/worker/stuck_notifications.go b/internal/worker/stuck_notifications.go index bab076d..84df0d9 100644 --- a/internal/worker/stuck_notifications.go +++ b/internal/worker/stuck_notifications.go @@ -101,6 +101,9 @@ func NewStuckNotificationsConsumer(snw *stuckNotificationsWorker, tag int) *stuc } func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { + ctx, cancel := context.WithCancel(snc) + defer cancel() + now := time.Now() defer func() { elapsed := time.Now().Sub(now).Milliseconds() @@ -119,7 +122,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { defer func() { _ = delivery.Ack() }() - account, err := snc.accountRepo.GetByID(snc, id) + account, err := snc.accountRepo.GetByID(ctx, id) if err != nil { snc.logger.Error("failed to fetch account from database", zap.Error(err), zap.Int64("account#id", id)) return @@ -149,7 +152,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { zap.String("account#username", account.NormalizedUsername()), ) - things, err = rac.MessageInbox(snc) + things, err = rac.MessageInbox(ctx) if err != nil { if err != reddit.ErrRateLimited { snc.logger.Error("failed to fetch last thing via inbox", @@ -161,7 +164,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { return } } else { - things, err = rac.AboutInfo(snc, account.LastMessageID) + things, err = rac.AboutInfo(ctx, account.LastMessageID) if err != nil { snc.logger.Error("failed to fetch last thing", zap.Error(err), @@ -186,7 +189,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { return } - sthings, err := rac.MessageInbox(snc) + sthings, err := rac.MessageInbox(ctx) if err != nil { snc.logger.Error("failed to check inbox", zap.Error(err), @@ -233,7 +236,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { zap.String("account#username", account.NormalizedUsername()), ) - things, err = rac.MessageInbox(snc) + things, err = rac.MessageInbox(ctx) if err != nil { snc.logger.Error("failed to check inbox", zap.Error(err), @@ -270,7 +273,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { zap.String("thing#id", account.LastMessageID), ) - if err := snc.accountRepo.Update(snc, &account); err != nil { + if err := snc.accountRepo.Update(ctx, &account); err != nil { snc.logger.Error("failed to update account's last message id", zap.Error(err), zap.Int64("account#id", id), diff --git a/internal/worker/subreddits.go b/internal/worker/subreddits.go index 4eb5927..e87e315 100644 --- a/internal/worker/subreddits.go +++ b/internal/worker/subreddits.go @@ -138,6 +138,9 @@ func NewSubredditsConsumer(sw *subredditsWorker, tag int) *subredditsConsumer { } func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { + ctx, cancel := context.WithCancel(sc) + defer cancel() + id, err := strconv.ParseInt(delivery.Payload(), 10, 64) if err != nil { sc.logger.Error("failed to parse subreddit id from payload", zap.Error(err), zap.String("payload", delivery.Payload())) @@ -149,13 +152,13 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { defer func() { _ = delivery.Ack() }() - subreddit, err := sc.subredditRepo.GetByID(sc, id) + subreddit, err := sc.subredditRepo.GetByID(ctx, id) if err != nil { sc.logger.Error("failed to fetch subreddit from database", zap.Error(err), zap.Int64("subreddit#id", id)) return } - watchers, err := sc.watcherRepo.GetBySubredditID(sc, subreddit.ID) + watchers, err := sc.watcherRepo.GetBySubredditID(ctx, subreddit.ID) if err != nil { sc.logger.Error("failed to fetch watchers from database", zap.Error(err), @@ -196,7 +199,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { watcher := watchers[i] rac := sc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) - sps, err := rac.SubredditNew(sc, + sps, err := rac.SubredditNew(ctx, subreddit.Name, reddit.WithQuery("before", before), reddit.WithQuery("limit", "100"), @@ -262,9 +265,9 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { i := rand.Intn(len(watchers)) watcher := watchers[i] - acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID) + acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID) rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken) - sps, err := rac.SubredditHot(sc, + sps, err := rac.SubredditHot(ctx, subreddit.Name, reddit.WithQuery("limit", "100"), reddit.WithQuery("show", "all"), @@ -349,7 +352,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { ) lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID) - notified, _ := sc.redis.Get(sc, lockKey).Bool() + notified, _ := sc.redis.Get(ctx, lockKey).Bool() if notified { sc.logger.Debug("already notified, skipping", @@ -361,7 +364,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { continue } - if err := sc.watcherRepo.IncrementHits(sc, watcher.ID); err != nil { + if err := sc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil { sc.logger.Error("could not increment hits", zap.Error(err), zap.Int64("subreddit#id", id), @@ -377,7 +380,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { zap.String("post#id", post.ID), ) - sc.redis.SetEX(sc, lockKey, true, 24*time.Hour) + sc.redis.SetEX(ctx, lockKey, true, 24*time.Hour) notifs = append(notifs, watcher) } diff --git a/internal/worker/trending.go b/internal/worker/trending.go index 493b6d9..45d4394 100644 --- a/internal/worker/trending.go +++ b/internal/worker/trending.go @@ -133,6 +133,9 @@ func NewTrendingConsumer(tw *trendingWorker, tag int) *trendingConsumer { } func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { + ctx, cancel := context.WithCancel(tc) + defer cancel() + id, err := strconv.ParseInt(delivery.Payload(), 10, 64) if err != nil { tc.logger.Error("failed to parse subreddit id from payload", zap.Error(err), zap.String("payload", delivery.Payload())) @@ -144,13 +147,13 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { defer func() { _ = delivery.Ack() }() - subreddit, err := tc.subredditRepo.GetByID(tc, id) + subreddit, err := tc.subredditRepo.GetByID(ctx, id) if err != nil { tc.logger.Error("failed to fetch subreddit from database", zap.Error(err), zap.Int64("subreddit#id", id)) return } - watchers, err := tc.watcherRepo.GetByTrendingSubredditID(tc, subreddit.ID) + watchers, err := tc.watcherRepo.GetByTrendingSubredditID(ctx, subreddit.ID) if err != nil { tc.logger.Error("failed to fetch watchers from database", zap.Error(err), @@ -173,7 +176,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { watcher := watchers[i] rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) - tps, err := rac.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week"), reddit.WithQuery("show", "all"), reddit.WithQuery("limit", "25")) + tps, err := rac.SubredditTop(ctx, subreddit.Name, reddit.WithQuery("t", "week"), reddit.WithQuery("show", "all"), reddit.WithQuery("limit", "25")) if err != nil { tc.logger.Error("failed to fetch weeks's top posts", zap.Error(err), @@ -223,7 +226,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { watcher = watchers[i] rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) - hps, err := rac.SubredditHot(tc, subreddit.Name, reddit.WithQuery("show", "all"), reddit.WithQuery("always_show_media", "1")) + hps, err := rac.SubredditHot(ctx, subreddit.Name, reddit.WithQuery("show", "all"), reddit.WithQuery("always_show_media", "1")) if err != nil { tc.logger.Error("failed to fetch hot posts", zap.Error(err), @@ -260,7 +263,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { } lockKey := fmt.Sprintf("watcher:trending:%d:%s", watcher.DeviceID, post.ID) - notified, _ := tc.redis.Get(tc, lockKey).Bool() + notified, _ := tc.redis.Get(ctx, lockKey).Bool() if notified { tc.logger.Debug("already notified, skipping", @@ -272,9 +275,9 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { continue } - tc.redis.SetEX(tc, lockKey, true, 48*time.Hour) + tc.redis.SetEX(ctx, lockKey, true, 48*time.Hour) - if err := tc.watcherRepo.IncrementHits(tc, watcher.ID); err != nil { + if err := tc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil { tc.logger.Error("could not increment hits", zap.Error(err), zap.Int64("subreddit#id", id), diff --git a/internal/worker/users.go b/internal/worker/users.go index fdcfab1..74e874a 100644 --- a/internal/worker/users.go +++ b/internal/worker/users.go @@ -134,6 +134,9 @@ func NewUsersConsumer(uw *usersWorker, tag int) *usersConsumer { } func (uc *usersConsumer) Consume(delivery rmq.Delivery) { + ctx, cancel := context.WithCancel(uc) + defer cancel() + id, err := strconv.ParseInt(delivery.Payload(), 10, 64) if err != nil { uc.logger.Error("failed to parse subreddit id from payload", zap.Error(err), zap.String("payload", delivery.Payload())) @@ -145,13 +148,13 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { defer func() { _ = delivery.Ack() }() - user, err := uc.userRepo.GetByID(uc, id) + user, err := uc.userRepo.GetByID(ctx, id) if err != nil { uc.logger.Error("failed to fetch user from database", zap.Error(err), zap.Int64("subreddit#id", id)) return } - watchers, err := uc.watcherRepo.GetByUserID(uc, user.ID) + watchers, err := uc.watcherRepo.GetByUserID(ctx, user.ID) if err != nil { uc.logger.Error("failed to fetch watchers from database", zap.Error(err), @@ -173,10 +176,10 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { i := rand.Intn(len(watchers)) watcher := watchers[i] - acc, _ := uc.accountRepo.GetByID(uc, watcher.AccountID) + acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID) rac := uc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken) - ru, err := rac.UserAbout(uc, user.Name) + ru, err := rac.UserAbout(ctx, user.Name) if err != nil { uc.logger.Error("failed to fetch user details", zap.Error(err), @@ -192,7 +195,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { zap.String("user#name", user.NormalizedName()), ) - if err := uc.watcherRepo.DeleteByTypeAndWatcheeID(uc, domain.UserWatcher, user.ID); err != nil { + if err := uc.watcherRepo.DeleteByTypeAndWatcheeID(ctx, domain.UserWatcher, user.ID); err != nil { uc.logger.Error("failed to remove watchers for user who disallows followers", zap.Error(err), zap.Int64("user#id", id), @@ -201,7 +204,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { return } - if err := uc.userRepo.Delete(uc, user.ID); err != nil { + if err := uc.userRepo.Delete(ctx, user.ID); err != nil { uc.logger.Error("failed to remove user", zap.Error(err), zap.Int64("user#id", id), @@ -211,7 +214,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { } } - posts, err := rac.UserPosts(uc, user.Name) + posts, err := rac.UserPosts(ctx, user.Name) if err != nil { uc.logger.Error("failed to fetch user activity", zap.Error(err), @@ -257,7 +260,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { notification.Topic = "com.christianselig.Apollo" for _, watcher := range notifs { - if err := uc.watcherRepo.IncrementHits(uc, watcher.ID); err != nil { + if err := uc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil { uc.logger.Error("failed to increment watcher hits", zap.Error(err), zap.Int64("user#id", id), @@ -267,7 +270,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { return } - device, _ := uc.deviceRepo.GetByID(uc, watcher.DeviceID) + device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID) title := fmt.Sprintf(userNotificationTitleFormat, watcher.Label) payload.AlertTitle(title)