Revert "token refresh mechanism"

This reverts commit df96aaa768.
This commit is contained in:
Andre Medeiros 2022-05-25 20:17:03 -04:00
parent df96aaa768
commit 7da47176a2
3 changed files with 55 additions and 41 deletions

View file

@ -115,10 +115,8 @@ type AuthenticatedClient struct {
client *Client client *Client
redditId string redditId string
tokenRefreshed bool refreshToken string
accessToken string
RefreshToken string
AccessToken string
} }
func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient { func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient {
@ -134,7 +132,7 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str
panic("requires a refresh token") panic("requires a refresh token")
} }
return &AuthenticatedClient{rc, redditId, false, refreshToken, accessToken} return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
} }
func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimitingInfo, error) { func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimitingInfo, error) {
@ -201,16 +199,6 @@ func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh Resp
bb, rli, err := rac.client.doRequest(ctx, r) bb, rli, err := rac.client.doRequest(ctx, r)
if err == ErrInvalidBasicAuth {
tokens, err := rac.RefreshTokens(ctx)
if err != nil {
return nil, ErrInvalidBasicAuth
}
rac.RefreshToken = tokens.RefreshToken
rac.AccessToken = tokens.AccessToken
}
if err != nil && err != ErrOauthRevoked && r.retry { if err != nil && err != ErrOauthRevoked && r.retry {
for _, backoff := range backoffSchedule { for _, backoff := range backoffSchedule {
done := make(chan struct{}) done := make(chan struct{})
@ -309,19 +297,13 @@ func (rac *AuthenticatedClient) markRateLimited(rli *RateLimitingInfo) error {
} }
func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...RequestOption) (*RefreshTokenResponse, error) { func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...RequestOption) (*RefreshTokenResponse, error) {
if rac.tokenRefreshed {
return nil, ErrTokenAlreadyRefreshed
}
rac.tokenRefreshed = true
opts = append(rac.client.defaultOpts, opts...) opts = append(rac.client.defaultOpts, opts...)
opts = append(opts, []RequestOption{ opts = append(opts, []RequestOption{
WithTags([]string{"url:/api/v1/access_token"}), WithTags([]string{"url:/api/v1/access_token"}),
WithMethod("POST"), WithMethod("POST"),
WithURL("https://www.reddit.com/api/v1/access_token"), WithURL("https://www.reddit.com/api/v1/access_token"),
WithBody("grant_type", "refresh_token"), WithBody("grant_type", "refresh_token"),
WithBody("refresh_token", rac.RefreshToken), WithBody("refresh_token", rac.refreshToken),
WithBasicAuth(rac.client.id, rac.client.secret), WithBasicAuth(rac.client.id, rac.client.secret),
}...) }...)
req := NewRequest(opts...) req := NewRequest(opts...)
@ -333,7 +315,7 @@ func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...Reque
ret := rtr.(*RefreshTokenResponse) ret := rtr.(*RefreshTokenResponse)
if ret.RefreshToken == "" { if ret.RefreshToken == "" {
ret.RefreshToken = rac.RefreshToken ret.RefreshToken = rac.refreshToken
} }
return ret, nil return ret, nil
@ -343,7 +325,7 @@ func (rac *AuthenticatedClient) AboutInfo(ctx context.Context, fullname string,
opts = append(rac.client.defaultOpts, opts...) opts = append(rac.client.defaultOpts, opts...)
opts = append(opts, []RequestOption{ opts = append(opts, []RequestOption{
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.AccessToken), WithToken(rac.accessToken),
WithURL("https://oauth.reddit.com/api/info"), WithURL("https://oauth.reddit.com/api/info"),
WithQuery("id", fullname), WithQuery("id", fullname),
}...) }...)
@ -362,7 +344,7 @@ func (rac *AuthenticatedClient) UserPosts(ctx context.Context, user string, opts
opts = append(rac.client.defaultOpts, opts...) opts = append(rac.client.defaultOpts, opts...)
opts = append(opts, []RequestOption{ opts = append(opts, []RequestOption{
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.AccessToken), WithToken(rac.accessToken),
WithURL(url), WithURL(url),
}...) }...)
req := NewRequest(opts...) req := NewRequest(opts...)
@ -380,7 +362,7 @@ func (rac *AuthenticatedClient) UserAbout(ctx context.Context, user string, opts
opts = append(rac.client.defaultOpts, opts...) opts = append(rac.client.defaultOpts, opts...)
opts = append(opts, []RequestOption{ opts = append(opts, []RequestOption{
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.AccessToken), WithToken(rac.accessToken),
WithURL(url), WithURL(url),
}...) }...)
req := NewRequest(opts...) req := NewRequest(opts...)
@ -399,7 +381,7 @@ func (rac *AuthenticatedClient) SubredditAbout(ctx context.Context, subreddit st
opts = append(rac.client.defaultOpts, opts...) opts = append(rac.client.defaultOpts, opts...)
opts = append(opts, []RequestOption{ opts = append(opts, []RequestOption{
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.AccessToken), WithToken(rac.accessToken),
WithURL(url), WithURL(url),
}...) }...)
req := NewRequest(opts...) req := NewRequest(opts...)
@ -417,7 +399,7 @@ func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit st
opts = append(rac.client.defaultOpts, opts...) opts = append(rac.client.defaultOpts, opts...)
opts = append(opts, []RequestOption{ opts = append(opts, []RequestOption{
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.AccessToken), WithToken(rac.accessToken),
WithURL(url), WithURL(url),
}...) }...)
req := NewRequest(opts...) req := NewRequest(opts...)
@ -447,7 +429,7 @@ func (rac *AuthenticatedClient) MessageInbox(ctx context.Context, opts ...Reques
opts = append(opts, []RequestOption{ opts = append(opts, []RequestOption{
WithTags([]string{"url:/api/v1/message/inbox"}), WithTags([]string{"url:/api/v1/message/inbox"}),
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.AccessToken), WithToken(rac.accessToken),
WithURL("https://oauth.reddit.com/message/inbox"), WithURL("https://oauth.reddit.com/message/inbox"),
WithEmptyResponseBytes(122), WithEmptyResponseBytes(122),
}...) }...)
@ -465,7 +447,7 @@ func (rac *AuthenticatedClient) MessageUnread(ctx context.Context, opts ...Reque
opts = append(opts, []RequestOption{ opts = append(opts, []RequestOption{
WithTags([]string{"url:/api/v1/message/unread"}), WithTags([]string{"url:/api/v1/message/unread"}),
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.AccessToken), WithToken(rac.accessToken),
WithURL("https://oauth.reddit.com/message/unread"), WithURL("https://oauth.reddit.com/message/unread"),
WithEmptyResponseBytes(122), WithEmptyResponseBytes(122),
}...) }...)
@ -484,7 +466,7 @@ func (rac *AuthenticatedClient) Me(ctx context.Context, opts ...RequestOption) (
opts = append(opts, []RequestOption{ opts = append(opts, []RequestOption{
WithTags([]string{"url:/api/v1/me"}), WithTags([]string{"url:/api/v1/me"}),
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.AccessToken), WithToken(rac.accessToken),
WithURL("https://oauth.reddit.com/api/v1/me"), WithURL("https://oauth.reddit.com/api/v1/me"),
}...) }...)

View file

@ -25,6 +25,4 @@ var (
ErrRequiresRedditId = errors.New("requires reddit id") ErrRequiresRedditId = errors.New("requires reddit id")
// ErrInvalidBasicAuth . // ErrInvalidBasicAuth .
ErrInvalidBasicAuth = errors.New("invalid basic auth") ErrInvalidBasicAuth = errors.New("invalid basic auth")
// ErrTokenAlreadyRefreshed .
ErrTokenAlreadyRefreshed = errors.New("token already refreshed")
) )

View file

@ -167,12 +167,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
account.CheckCount++ account.CheckCount++
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) defer func(acc *domain.Account) {
defer func(acc *domain.Account, rac *reddit.AuthenticatedClient) {
account.AccessToken = rac.AccessToken
account.RefreshToken = rac.RefreshToken
if err = nc.accountRepo.Update(nc, acc); err != nil { if err = nc.accountRepo.Update(nc, acc); err != nil {
nc.logger.Error("failed to update account", nc.logger.Error("failed to update account",
zap.Error(err), zap.Error(err),
@ -180,7 +175,46 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
zap.String("account#username", account.NormalizedUsername()), zap.String("account#username", account.NormalizedUsername()),
) )
} }
}(&account, rac) }(&account)
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
if account.TokenExpiresAt.Before(now.Add(5 * time.Minute)) {
nc.logger.Debug("refreshing reddit token",
zap.Int64("account#id", id),
zap.String("account#username", account.NormalizedUsername()),
)
tokens, err := rac.RefreshTokens(nc)
if err != nil {
if err != reddit.ErrOauthRevoked {
nc.logger.Error("failed to refresh reddit tokens",
zap.Error(err),
zap.Int64("account#id", id),
zap.String("account#username", account.NormalizedUsername()),
)
return
}
err = nc.deleteAccount(account)
if err != nil {
nc.logger.Error("failed to remove revoked account",
zap.Error(err),
zap.Int64("account#id", id),
zap.String("account#username", account.NormalizedUsername()),
)
}
return
}
// Update account
account.AccessToken = tokens.AccessToken
account.RefreshToken = tokens.RefreshToken
account.TokenExpiresAt = now.Add(tokens.Expiry)
// Refresh client
rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
}
// Only update delay on accounts we can actually check, otherwise it skews // Only update delay on accounts we can actually check, otherwise it skews
// the numbers too much. // the numbers too much.