From df96aaa7687cd5d002b9001566ca9fd37a522d53 Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Wed, 25 May 2022 20:12:54 -0400 Subject: [PATCH] token refresh mechanism --- internal/reddit/client.go | 46 ++++++++++++++++++++---------- internal/reddit/errors.go | 2 ++ internal/worker/notifications.go | 48 +++++--------------------------- 3 files changed, 41 insertions(+), 55 deletions(-) diff --git a/internal/reddit/client.go b/internal/reddit/client.go index c7af450..c921ab8 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -114,9 +114,11 @@ func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Cl type AuthenticatedClient struct { client *Client - redditId string - refreshToken string - accessToken string + redditId string + tokenRefreshed bool + + RefreshToken string + AccessToken string } func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient { @@ -132,7 +134,7 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str panic("requires a refresh token") } - return &AuthenticatedClient{rc, redditId, refreshToken, accessToken} + return &AuthenticatedClient{rc, redditId, false, refreshToken, accessToken} } func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimitingInfo, error) { @@ -199,6 +201,16 @@ func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh Resp bb, rli, err := rac.client.doRequest(ctx, r) + if err == ErrInvalidBasicAuth { + tokens, err := rac.RefreshTokens(ctx) + if err != nil { + return nil, ErrInvalidBasicAuth + } + + rac.RefreshToken = tokens.RefreshToken + rac.AccessToken = tokens.AccessToken + } + if err != nil && err != ErrOauthRevoked && r.retry { for _, backoff := range backoffSchedule { done := make(chan struct{}) @@ -297,13 +309,19 @@ func (rac *AuthenticatedClient) markRateLimited(rli *RateLimitingInfo) error { } func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...RequestOption) (*RefreshTokenResponse, error) { + if rac.tokenRefreshed { + return nil, ErrTokenAlreadyRefreshed + } + + rac.tokenRefreshed = true + opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/access_token"}), WithMethod("POST"), WithURL("https://www.reddit.com/api/v1/access_token"), WithBody("grant_type", "refresh_token"), - WithBody("refresh_token", rac.refreshToken), + WithBody("refresh_token", rac.RefreshToken), WithBasicAuth(rac.client.id, rac.client.secret), }...) req := NewRequest(opts...) @@ -315,7 +333,7 @@ func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...Reque ret := rtr.(*RefreshTokenResponse) if ret.RefreshToken == "" { - ret.RefreshToken = rac.refreshToken + ret.RefreshToken = rac.RefreshToken } return ret, nil @@ -325,7 +343,7 @@ func (rac *AuthenticatedClient) AboutInfo(ctx context.Context, fullname string, opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.accessToken), + WithToken(rac.AccessToken), WithURL("https://oauth.reddit.com/api/info"), WithQuery("id", fullname), }...) @@ -344,7 +362,7 @@ func (rac *AuthenticatedClient) UserPosts(ctx context.Context, user string, opts opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.accessToken), + WithToken(rac.AccessToken), WithURL(url), }...) req := NewRequest(opts...) @@ -362,7 +380,7 @@ func (rac *AuthenticatedClient) UserAbout(ctx context.Context, user string, opts opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.accessToken), + WithToken(rac.AccessToken), WithURL(url), }...) req := NewRequest(opts...) @@ -381,7 +399,7 @@ func (rac *AuthenticatedClient) SubredditAbout(ctx context.Context, subreddit st opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.accessToken), + WithToken(rac.AccessToken), WithURL(url), }...) req := NewRequest(opts...) @@ -399,7 +417,7 @@ func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit st opts = append(rac.client.defaultOpts, opts...) opts = append(opts, []RequestOption{ WithMethod("GET"), - WithToken(rac.accessToken), + WithToken(rac.AccessToken), WithURL(url), }...) req := NewRequest(opts...) @@ -429,7 +447,7 @@ func (rac *AuthenticatedClient) MessageInbox(ctx context.Context, opts ...Reques opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/message/inbox"}), WithMethod("GET"), - WithToken(rac.accessToken), + WithToken(rac.AccessToken), WithURL("https://oauth.reddit.com/message/inbox"), WithEmptyResponseBytes(122), }...) @@ -447,7 +465,7 @@ func (rac *AuthenticatedClient) MessageUnread(ctx context.Context, opts ...Reque opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/message/unread"}), WithMethod("GET"), - WithToken(rac.accessToken), + WithToken(rac.AccessToken), WithURL("https://oauth.reddit.com/message/unread"), WithEmptyResponseBytes(122), }...) @@ -466,7 +484,7 @@ func (rac *AuthenticatedClient) Me(ctx context.Context, opts ...RequestOption) ( opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/me"}), WithMethod("GET"), - WithToken(rac.accessToken), + WithToken(rac.AccessToken), WithURL("https://oauth.reddit.com/api/v1/me"), }...) diff --git a/internal/reddit/errors.go b/internal/reddit/errors.go index 01629b1..da5b17d 100644 --- a/internal/reddit/errors.go +++ b/internal/reddit/errors.go @@ -25,4 +25,6 @@ var ( ErrRequiresRedditId = errors.New("requires reddit id") // ErrInvalidBasicAuth . ErrInvalidBasicAuth = errors.New("invalid basic auth") + // ErrTokenAlreadyRefreshed . + ErrTokenAlreadyRefreshed = errors.New("token already refreshed") ) diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index 750af91..bbe8111 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -167,7 +167,12 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { account.CheckCount++ - defer func(acc *domain.Account) { + rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) + + defer func(acc *domain.Account, rac *reddit.AuthenticatedClient) { + account.AccessToken = rac.AccessToken + account.RefreshToken = rac.RefreshToken + if err = nc.accountRepo.Update(nc, acc); err != nil { nc.logger.Error("failed to update account", zap.Error(err), @@ -175,46 +180,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { zap.String("account#username", account.NormalizedUsername()), ) } - }(&account) - - rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) - if account.TokenExpiresAt.Before(now.Add(5 * time.Minute)) { - nc.logger.Debug("refreshing reddit token", - zap.Int64("account#id", id), - zap.String("account#username", account.NormalizedUsername()), - ) - - tokens, err := rac.RefreshTokens(nc) - if err != nil { - if err != reddit.ErrOauthRevoked { - nc.logger.Error("failed to refresh reddit tokens", - zap.Error(err), - zap.Int64("account#id", id), - zap.String("account#username", account.NormalizedUsername()), - ) - return - } - - err = nc.deleteAccount(account) - if err != nil { - nc.logger.Error("failed to remove revoked account", - zap.Error(err), - zap.Int64("account#id", id), - zap.String("account#username", account.NormalizedUsername()), - ) - } - - return - } - - // Update account - account.AccessToken = tokens.AccessToken - account.RefreshToken = tokens.RefreshToken - account.TokenExpiresAt = now.Add(tokens.Expiry) - - // Refresh client - rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken) - } + }(&account, rac) // Only update delay on accounts we can actually check, otherwise it skews // the numbers too much.