From 7da47176a2da63de0877506183c977dd24850ec7 Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Wed, 25 May 2022 20:17:03 -0400 Subject: [PATCH] Revert "token refresh mechanism" This reverts commit df96aaa7687cd5d002b9001566ca9fd37a522d53. --- internal/reddit/client.go | 46 ++++++++++-------------------- internal/reddit/errors.go | 2 -- internal/worker/notifications.go | 48 +++++++++++++++++++++++++++----- 3 files changed, 55 insertions(+), 41 deletions(-) diff --git a/internal/reddit/client.go b/internal/reddit/client.go index c921ab8..c7af450 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -114,11 +114,9 @@ func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Cl type AuthenticatedClient struct { client *Client - redditId string - tokenRefreshed bool - - RefreshToken string - AccessToken string + redditId string + refreshToken string + accessToken string } func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient { @@ -134,7 +132,7 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str panic("requires a refresh token") } - return &AuthenticatedClient{rc, redditId, false, refreshToken, accessToken} + return &AuthenticatedClient{rc, redditId, refreshToken, accessToken} } func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimitingInfo, error) { @@ -201,16 +199,6 @@ func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh Resp bb, rli, err := rac.client.doRequest(ctx, r) - if err == ErrInvalidBasicAuth { - tokens, err := rac.RefreshTokens(ctx) - if err != nil { - return nil, ErrInvalidBasicAuth - } - - rac.RefreshToken = tokens.RefreshToken - rac.AccessToken = tokens.AccessToken - } - if err != nil && err != ErrOauthRevoked && r.retry { for _, backoff := range backoffSchedule { done := make(chan struct{}) @@ -309,19 +297,13 @@ func (rac *AuthenticatedClient) markRateLimited(rli *RateLimitingInfo) error { } func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...RequestOption) (*RefreshTokenResponse, error) { - if rac.tokenRefreshed { - return nil, ErrTokenAlreadyRefreshed - } - - rac.tokenRefreshed = true - opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/access_token"}), WithMethod("POST"), WithURL("https://www.reddit.com/api/v1/access_token"), WithBody("grant_type", "refresh_token"), - WithBody("refresh_token", rac.RefreshToken), + WithBody("refresh_token", rac.refreshToken), WithBasicAuth(rac.client.id, rac.client.secret), }...) req := NewRequest(opts...) @@ -333,7 +315,7 @@ func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...Reque ret := rtr.(*RefreshTokenResponse) if ret.RefreshToken == "" { - ret.RefreshToken = rac.RefreshToken + ret.RefreshToken = rac.refreshToken } return ret, nil @@ -343,7 +325,7 @@ func (rac *AuthenticatedClient) AboutInfo(ctx context.Context, fullname string, opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.AccessToken), + WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/api/info"), WithQuery("id", fullname), }...) @@ -362,7 +344,7 @@ func (rac *AuthenticatedClient) UserPosts(ctx context.Context, user string, opts opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.AccessToken), + WithToken(rac.accessToken), WithURL(url), }...) req := NewRequest(opts...) @@ -380,7 +362,7 @@ func (rac *AuthenticatedClient) UserAbout(ctx context.Context, user string, opts opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.AccessToken), + WithToken(rac.accessToken), WithURL(url), }...) req := NewRequest(opts...) @@ -399,7 +381,7 @@ func (rac *AuthenticatedClient) SubredditAbout(ctx context.Context, subreddit st opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.AccessToken), + WithToken(rac.accessToken), WithURL(url), }...) req := NewRequest(opts...) @@ -417,7 +399,7 @@ func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit st opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.AccessToken), + WithToken(rac.accessToken), WithURL(url), }...) req := NewRequest(opts...) @@ -447,7 +429,7 @@ func (rac *AuthenticatedClient) MessageInbox(ctx context.Context, opts ...Reques opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/message/inbox"}), WithMethod("GET"), - WithToken(rac.AccessToken), + WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/message/inbox"), WithEmptyResponseBytes(122), }...) @@ -465,7 +447,7 @@ func (rac *AuthenticatedClient) MessageUnread(ctx context.Context, opts ...Reque opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/message/unread"}), WithMethod("GET"), - WithToken(rac.AccessToken), + WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/message/unread"), WithEmptyResponseBytes(122), }...) @@ -484,7 +466,7 @@ func (rac *AuthenticatedClient) Me(ctx context.Context, opts ...RequestOption) ( opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/me"}), WithMethod("GET"), - WithToken(rac.AccessToken), + WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/api/v1/me"), }...) diff --git a/internal/reddit/errors.go b/internal/reddit/errors.go index da5b17d..01629b1 100644 --- a/internal/reddit/errors.go +++ b/internal/reddit/errors.go @@ -25,6 +25,4 @@ var ( ErrRequiresRedditId = errors.New("requires reddit id") // ErrInvalidBasicAuth . ErrInvalidBasicAuth = errors.New("invalid basic auth") - // ErrTokenAlreadyRefreshed . - ErrTokenAlreadyRefreshed = errors.New("token already refreshed") ) diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index bbe8111..750af91 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -167,12 +167,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { account.CheckCount++ - rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) - - defer func(acc *domain.Account, rac *reddit.AuthenticatedClient) { - account.AccessToken = rac.AccessToken - account.RefreshToken = rac.RefreshToken - + defer func(acc *domain.Account) { if err = nc.accountRepo.Update(nc, acc); err != nil { nc.logger.Error("failed to update account", zap.Error(err), @@ -180,7 +175,46 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { zap.String("account#username", account.NormalizedUsername()), ) } - }(&account, rac) + }(&account) + + rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) + if account.TokenExpiresAt.Before(now.Add(5 * time.Minute)) { + nc.logger.Debug("refreshing reddit token", + zap.Int64("account#id", id), + zap.String("account#username", account.NormalizedUsername()), + ) + + tokens, err := rac.RefreshTokens(nc) + if err != nil { + if err != reddit.ErrOauthRevoked { + nc.logger.Error("failed to refresh reddit tokens", + zap.Error(err), + zap.Int64("account#id", id), + zap.String("account#username", account.NormalizedUsername()), + ) + return + } + + err = nc.deleteAccount(account) + if err != nil { + nc.logger.Error("failed to remove revoked account", + zap.Error(err), + zap.Int64("account#id", id), + zap.String("account#username", account.NormalizedUsername()), + ) + } + + return + } + + // Update account + account.AccessToken = tokens.AccessToken + account.RefreshToken = tokens.RefreshToken + account.TokenExpiresAt = now.Add(tokens.Expiry) + + // Refresh client + rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken) + } // Only update delay on accounts we can actually check, otherwise it skews // the numbers too much.