finalize context fixing

This commit is contained in:
Andre Medeiros 2022-10-26 20:46:17 -04:00
parent b0025e2367
commit 1a861ea628
4 changed files with 42 additions and 30 deletions

View file

@ -101,6 +101,9 @@ func NewStuckNotificationsConsumer(snw *stuckNotificationsWorker, tag int) *stuc
} }
func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
ctx, cancel := context.WithCancel(snc)
defer cancel()
now := time.Now() now := time.Now()
defer func() { defer func() {
elapsed := time.Now().Sub(now).Milliseconds() elapsed := time.Now().Sub(now).Milliseconds()
@ -119,7 +122,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }() defer func() { _ = delivery.Ack() }()
account, err := snc.accountRepo.GetByID(snc, id) account, err := snc.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
snc.logger.Error("failed to fetch account from database", zap.Error(err), zap.Int64("account#id", id)) snc.logger.Error("failed to fetch account from database", zap.Error(err), zap.Int64("account#id", id))
return return
@ -149,7 +152,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
zap.String("account#username", account.NormalizedUsername()), zap.String("account#username", account.NormalizedUsername()),
) )
things, err = rac.MessageInbox(snc) things, err = rac.MessageInbox(ctx)
if err != nil { if err != nil {
if err != reddit.ErrRateLimited { if err != reddit.ErrRateLimited {
snc.logger.Error("failed to fetch last thing via inbox", snc.logger.Error("failed to fetch last thing via inbox",
@ -161,7 +164,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
return return
} }
} else { } else {
things, err = rac.AboutInfo(snc, account.LastMessageID) things, err = rac.AboutInfo(ctx, account.LastMessageID)
if err != nil { if err != nil {
snc.logger.Error("failed to fetch last thing", snc.logger.Error("failed to fetch last thing",
zap.Error(err), zap.Error(err),
@ -186,7 +189,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
return return
} }
sthings, err := rac.MessageInbox(snc) sthings, err := rac.MessageInbox(ctx)
if err != nil { if err != nil {
snc.logger.Error("failed to check inbox", snc.logger.Error("failed to check inbox",
zap.Error(err), zap.Error(err),
@ -233,7 +236,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
zap.String("account#username", account.NormalizedUsername()), zap.String("account#username", account.NormalizedUsername()),
) )
things, err = rac.MessageInbox(snc) things, err = rac.MessageInbox(ctx)
if err != nil { if err != nil {
snc.logger.Error("failed to check inbox", snc.logger.Error("failed to check inbox",
zap.Error(err), zap.Error(err),
@ -270,7 +273,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
zap.String("thing#id", account.LastMessageID), zap.String("thing#id", account.LastMessageID),
) )
if err := snc.accountRepo.Update(snc, &account); err != nil { if err := snc.accountRepo.Update(ctx, &account); err != nil {
snc.logger.Error("failed to update account's last message id", snc.logger.Error("failed to update account's last message id",
zap.Error(err), zap.Error(err),
zap.Int64("account#id", id), zap.Int64("account#id", id),

View file

@ -138,6 +138,9 @@ func NewSubredditsConsumer(sw *subredditsWorker, tag int) *subredditsConsumer {
} }
func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
ctx, cancel := context.WithCancel(sc)
defer cancel()
id, err := strconv.ParseInt(delivery.Payload(), 10, 64) id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
if err != nil { if err != nil {
sc.logger.Error("failed to parse subreddit id from payload", zap.Error(err), zap.String("payload", delivery.Payload())) sc.logger.Error("failed to parse subreddit id from payload", zap.Error(err), zap.String("payload", delivery.Payload()))
@ -149,13 +152,13 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }() defer func() { _ = delivery.Ack() }()
subreddit, err := sc.subredditRepo.GetByID(sc, id) subreddit, err := sc.subredditRepo.GetByID(ctx, id)
if err != nil { if err != nil {
sc.logger.Error("failed to fetch subreddit from database", zap.Error(err), zap.Int64("subreddit#id", id)) sc.logger.Error("failed to fetch subreddit from database", zap.Error(err), zap.Int64("subreddit#id", id))
return return
} }
watchers, err := sc.watcherRepo.GetBySubredditID(sc, subreddit.ID) watchers, err := sc.watcherRepo.GetBySubredditID(ctx, subreddit.ID)
if err != nil { if err != nil {
sc.logger.Error("failed to fetch watchers from database", sc.logger.Error("failed to fetch watchers from database",
zap.Error(err), zap.Error(err),
@ -196,7 +199,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
watcher := watchers[i] watcher := watchers[i]
rac := sc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) rac := sc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
sps, err := rac.SubredditNew(sc, sps, err := rac.SubredditNew(ctx,
subreddit.Name, subreddit.Name,
reddit.WithQuery("before", before), reddit.WithQuery("before", before),
reddit.WithQuery("limit", "100"), reddit.WithQuery("limit", "100"),
@ -262,9 +265,9 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
i := rand.Intn(len(watchers)) i := rand.Intn(len(watchers))
watcher := watchers[i] watcher := watchers[i]
acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID) acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken) rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
sps, err := rac.SubredditHot(sc, sps, err := rac.SubredditHot(ctx,
subreddit.Name, subreddit.Name,
reddit.WithQuery("limit", "100"), reddit.WithQuery("limit", "100"),
reddit.WithQuery("show", "all"), reddit.WithQuery("show", "all"),
@ -349,7 +352,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
) )
lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID) lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID)
notified, _ := sc.redis.Get(sc, lockKey).Bool() notified, _ := sc.redis.Get(ctx, lockKey).Bool()
if notified { if notified {
sc.logger.Debug("already notified, skipping", sc.logger.Debug("already notified, skipping",
@ -361,7 +364,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
continue continue
} }
if err := sc.watcherRepo.IncrementHits(sc, watcher.ID); err != nil { if err := sc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
sc.logger.Error("could not increment hits", sc.logger.Error("could not increment hits",
zap.Error(err), zap.Error(err),
zap.Int64("subreddit#id", id), zap.Int64("subreddit#id", id),
@ -377,7 +380,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
zap.String("post#id", post.ID), zap.String("post#id", post.ID),
) )
sc.redis.SetEX(sc, lockKey, true, 24*time.Hour) sc.redis.SetEX(ctx, lockKey, true, 24*time.Hour)
notifs = append(notifs, watcher) notifs = append(notifs, watcher)
} }

View file

@ -133,6 +133,9 @@ func NewTrendingConsumer(tw *trendingWorker, tag int) *trendingConsumer {
} }
func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
ctx, cancel := context.WithCancel(tc)
defer cancel()
id, err := strconv.ParseInt(delivery.Payload(), 10, 64) id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
if err != nil { if err != nil {
tc.logger.Error("failed to parse subreddit id from payload", zap.Error(err), zap.String("payload", delivery.Payload())) tc.logger.Error("failed to parse subreddit id from payload", zap.Error(err), zap.String("payload", delivery.Payload()))
@ -144,13 +147,13 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }() defer func() { _ = delivery.Ack() }()
subreddit, err := tc.subredditRepo.GetByID(tc, id) subreddit, err := tc.subredditRepo.GetByID(ctx, id)
if err != nil { if err != nil {
tc.logger.Error("failed to fetch subreddit from database", zap.Error(err), zap.Int64("subreddit#id", id)) tc.logger.Error("failed to fetch subreddit from database", zap.Error(err), zap.Int64("subreddit#id", id))
return return
} }
watchers, err := tc.watcherRepo.GetByTrendingSubredditID(tc, subreddit.ID) watchers, err := tc.watcherRepo.GetByTrendingSubredditID(ctx, subreddit.ID)
if err != nil { if err != nil {
tc.logger.Error("failed to fetch watchers from database", tc.logger.Error("failed to fetch watchers from database",
zap.Error(err), zap.Error(err),
@ -173,7 +176,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
watcher := watchers[i] watcher := watchers[i]
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
tps, err := rac.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week"), reddit.WithQuery("show", "all"), reddit.WithQuery("limit", "25")) tps, err := rac.SubredditTop(ctx, subreddit.Name, reddit.WithQuery("t", "week"), reddit.WithQuery("show", "all"), reddit.WithQuery("limit", "25"))
if err != nil { if err != nil {
tc.logger.Error("failed to fetch weeks's top posts", tc.logger.Error("failed to fetch weeks's top posts",
zap.Error(err), zap.Error(err),
@ -223,7 +226,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
watcher = watchers[i] watcher = watchers[i]
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
hps, err := rac.SubredditHot(tc, subreddit.Name, reddit.WithQuery("show", "all"), reddit.WithQuery("always_show_media", "1")) hps, err := rac.SubredditHot(ctx, subreddit.Name, reddit.WithQuery("show", "all"), reddit.WithQuery("always_show_media", "1"))
if err != nil { if err != nil {
tc.logger.Error("failed to fetch hot posts", tc.logger.Error("failed to fetch hot posts",
zap.Error(err), zap.Error(err),
@ -260,7 +263,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
} }
lockKey := fmt.Sprintf("watcher:trending:%d:%s", watcher.DeviceID, post.ID) lockKey := fmt.Sprintf("watcher:trending:%d:%s", watcher.DeviceID, post.ID)
notified, _ := tc.redis.Get(tc, lockKey).Bool() notified, _ := tc.redis.Get(ctx, lockKey).Bool()
if notified { if notified {
tc.logger.Debug("already notified, skipping", tc.logger.Debug("already notified, skipping",
@ -272,9 +275,9 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
continue continue
} }
tc.redis.SetEX(tc, lockKey, true, 48*time.Hour) tc.redis.SetEX(ctx, lockKey, true, 48*time.Hour)
if err := tc.watcherRepo.IncrementHits(tc, watcher.ID); err != nil { if err := tc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
tc.logger.Error("could not increment hits", tc.logger.Error("could not increment hits",
zap.Error(err), zap.Error(err),
zap.Int64("subreddit#id", id), zap.Int64("subreddit#id", id),

View file

@ -134,6 +134,9 @@ func NewUsersConsumer(uw *usersWorker, tag int) *usersConsumer {
} }
func (uc *usersConsumer) Consume(delivery rmq.Delivery) { func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
ctx, cancel := context.WithCancel(uc)
defer cancel()
id, err := strconv.ParseInt(delivery.Payload(), 10, 64) id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
if err != nil { if err != nil {
uc.logger.Error("failed to parse subreddit id from payload", zap.Error(err), zap.String("payload", delivery.Payload())) uc.logger.Error("failed to parse subreddit id from payload", zap.Error(err), zap.String("payload", delivery.Payload()))
@ -145,13 +148,13 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }() defer func() { _ = delivery.Ack() }()
user, err := uc.userRepo.GetByID(uc, id) user, err := uc.userRepo.GetByID(ctx, id)
if err != nil { if err != nil {
uc.logger.Error("failed to fetch user from database", zap.Error(err), zap.Int64("subreddit#id", id)) uc.logger.Error("failed to fetch user from database", zap.Error(err), zap.Int64("subreddit#id", id))
return return
} }
watchers, err := uc.watcherRepo.GetByUserID(uc, user.ID) watchers, err := uc.watcherRepo.GetByUserID(ctx, user.ID)
if err != nil { if err != nil {
uc.logger.Error("failed to fetch watchers from database", uc.logger.Error("failed to fetch watchers from database",
zap.Error(err), zap.Error(err),
@ -173,10 +176,10 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
i := rand.Intn(len(watchers)) i := rand.Intn(len(watchers))
watcher := watchers[i] watcher := watchers[i]
acc, _ := uc.accountRepo.GetByID(uc, watcher.AccountID) acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID)
rac := uc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken) rac := uc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
ru, err := rac.UserAbout(uc, user.Name) ru, err := rac.UserAbout(ctx, user.Name)
if err != nil { if err != nil {
uc.logger.Error("failed to fetch user details", uc.logger.Error("failed to fetch user details",
zap.Error(err), zap.Error(err),
@ -192,7 +195,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
zap.String("user#name", user.NormalizedName()), zap.String("user#name", user.NormalizedName()),
) )
if err := uc.watcherRepo.DeleteByTypeAndWatcheeID(uc, domain.UserWatcher, user.ID); err != nil { if err := uc.watcherRepo.DeleteByTypeAndWatcheeID(ctx, domain.UserWatcher, user.ID); err != nil {
uc.logger.Error("failed to remove watchers for user who disallows followers", uc.logger.Error("failed to remove watchers for user who disallows followers",
zap.Error(err), zap.Error(err),
zap.Int64("user#id", id), zap.Int64("user#id", id),
@ -201,7 +204,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
return return
} }
if err := uc.userRepo.Delete(uc, user.ID); err != nil { if err := uc.userRepo.Delete(ctx, user.ID); err != nil {
uc.logger.Error("failed to remove user", uc.logger.Error("failed to remove user",
zap.Error(err), zap.Error(err),
zap.Int64("user#id", id), zap.Int64("user#id", id),
@ -211,7 +214,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
} }
} }
posts, err := rac.UserPosts(uc, user.Name) posts, err := rac.UserPosts(ctx, user.Name)
if err != nil { if err != nil {
uc.logger.Error("failed to fetch user activity", uc.logger.Error("failed to fetch user activity",
zap.Error(err), zap.Error(err),
@ -257,7 +260,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
notification.Topic = "com.christianselig.Apollo" notification.Topic = "com.christianselig.Apollo"
for _, watcher := range notifs { for _, watcher := range notifs {
if err := uc.watcherRepo.IncrementHits(uc, watcher.ID); err != nil { if err := uc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
uc.logger.Error("failed to increment watcher hits", uc.logger.Error("failed to increment watcher hits",
zap.Error(err), zap.Error(err),
zap.Int64("user#id", id), zap.Int64("user#id", id),
@ -267,7 +270,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
return return
} }
device, _ := uc.deviceRepo.GetByID(uc, watcher.DeviceID) device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID)
title := fmt.Sprintf(userNotificationTitleFormat, watcher.Label) title := fmt.Sprintf(userNotificationTitleFormat, watcher.Label)
payload.AlertTitle(title) payload.AlertTitle(title)