mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-10 22:17:44 +00:00
token refresh mechanism
This commit is contained in:
parent
47a0aa47dd
commit
df96aaa768
3 changed files with 41 additions and 55 deletions
|
@ -115,8 +115,10 @@ type AuthenticatedClient struct {
|
||||||
client *Client
|
client *Client
|
||||||
|
|
||||||
redditId string
|
redditId string
|
||||||
refreshToken string
|
tokenRefreshed bool
|
||||||
accessToken string
|
|
||||||
|
RefreshToken string
|
||||||
|
AccessToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient {
|
func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient {
|
||||||
|
@ -132,7 +134,7 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str
|
||||||
panic("requires a refresh token")
|
panic("requires a refresh token")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
|
return &AuthenticatedClient{rc, redditId, false, refreshToken, accessToken}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimitingInfo, error) {
|
func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimitingInfo, error) {
|
||||||
|
@ -199,6 +201,16 @@ func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh Resp
|
||||||
|
|
||||||
bb, rli, err := rac.client.doRequest(ctx, r)
|
bb, rli, err := rac.client.doRequest(ctx, r)
|
||||||
|
|
||||||
|
if err == ErrInvalidBasicAuth {
|
||||||
|
tokens, err := rac.RefreshTokens(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrInvalidBasicAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
rac.RefreshToken = tokens.RefreshToken
|
||||||
|
rac.AccessToken = tokens.AccessToken
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil && err != ErrOauthRevoked && r.retry {
|
if err != nil && err != ErrOauthRevoked && r.retry {
|
||||||
for _, backoff := range backoffSchedule {
|
for _, backoff := range backoffSchedule {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
@ -297,13 +309,19 @@ func (rac *AuthenticatedClient) markRateLimited(rli *RateLimitingInfo) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...RequestOption) (*RefreshTokenResponse, error) {
|
func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...RequestOption) (*RefreshTokenResponse, error) {
|
||||||
|
if rac.tokenRefreshed {
|
||||||
|
return nil, ErrTokenAlreadyRefreshed
|
||||||
|
}
|
||||||
|
|
||||||
|
rac.tokenRefreshed = true
|
||||||
|
|
||||||
opts = append(rac.client.defaultOpts, opts...)
|
opts = append(rac.client.defaultOpts, opts...)
|
||||||
opts = append(opts, []RequestOption{
|
opts = append(opts, []RequestOption{
|
||||||
WithTags([]string{"url:/api/v1/access_token"}),
|
WithTags([]string{"url:/api/v1/access_token"}),
|
||||||
WithMethod("POST"),
|
WithMethod("POST"),
|
||||||
WithURL("https://www.reddit.com/api/v1/access_token"),
|
WithURL("https://www.reddit.com/api/v1/access_token"),
|
||||||
WithBody("grant_type", "refresh_token"),
|
WithBody("grant_type", "refresh_token"),
|
||||||
WithBody("refresh_token", rac.refreshToken),
|
WithBody("refresh_token", rac.RefreshToken),
|
||||||
WithBasicAuth(rac.client.id, rac.client.secret),
|
WithBasicAuth(rac.client.id, rac.client.secret),
|
||||||
}...)
|
}...)
|
||||||
req := NewRequest(opts...)
|
req := NewRequest(opts...)
|
||||||
|
@ -315,7 +333,7 @@ func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...Reque
|
||||||
|
|
||||||
ret := rtr.(*RefreshTokenResponse)
|
ret := rtr.(*RefreshTokenResponse)
|
||||||
if ret.RefreshToken == "" {
|
if ret.RefreshToken == "" {
|
||||||
ret.RefreshToken = rac.refreshToken
|
ret.RefreshToken = rac.RefreshToken
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret, nil
|
return ret, nil
|
||||||
|
@ -325,7 +343,7 @@ func (rac *AuthenticatedClient) AboutInfo(ctx context.Context, fullname string,
|
||||||
opts = append(rac.client.defaultOpts, opts...)
|
opts = append(rac.client.defaultOpts, opts...)
|
||||||
opts = append(opts, []RequestOption{
|
opts = append(opts, []RequestOption{
|
||||||
WithMethod("GET"),
|
WithMethod("GET"),
|
||||||
WithToken(rac.accessToken),
|
WithToken(rac.AccessToken),
|
||||||
WithURL("https://oauth.reddit.com/api/info"),
|
WithURL("https://oauth.reddit.com/api/info"),
|
||||||
WithQuery("id", fullname),
|
WithQuery("id", fullname),
|
||||||
}...)
|
}...)
|
||||||
|
@ -344,7 +362,7 @@ func (rac *AuthenticatedClient) UserPosts(ctx context.Context, user string, opts
|
||||||
opts = append(rac.client.defaultOpts, opts...)
|
opts = append(rac.client.defaultOpts, opts...)
|
||||||
opts = append(opts, []RequestOption{
|
opts = append(opts, []RequestOption{
|
||||||
WithMethod("GET"),
|
WithMethod("GET"),
|
||||||
WithToken(rac.accessToken),
|
WithToken(rac.AccessToken),
|
||||||
WithURL(url),
|
WithURL(url),
|
||||||
}...)
|
}...)
|
||||||
req := NewRequest(opts...)
|
req := NewRequest(opts...)
|
||||||
|
@ -362,7 +380,7 @@ func (rac *AuthenticatedClient) UserAbout(ctx context.Context, user string, opts
|
||||||
opts = append(rac.client.defaultOpts, opts...)
|
opts = append(rac.client.defaultOpts, opts...)
|
||||||
opts = append(opts, []RequestOption{
|
opts = append(opts, []RequestOption{
|
||||||
WithMethod("GET"),
|
WithMethod("GET"),
|
||||||
WithToken(rac.accessToken),
|
WithToken(rac.AccessToken),
|
||||||
WithURL(url),
|
WithURL(url),
|
||||||
}...)
|
}...)
|
||||||
req := NewRequest(opts...)
|
req := NewRequest(opts...)
|
||||||
|
@ -381,7 +399,7 @@ func (rac *AuthenticatedClient) SubredditAbout(ctx context.Context, subreddit st
|
||||||
opts = append(rac.client.defaultOpts, opts...)
|
opts = append(rac.client.defaultOpts, opts...)
|
||||||
opts = append(opts, []RequestOption{
|
opts = append(opts, []RequestOption{
|
||||||
WithMethod("GET"),
|
WithMethod("GET"),
|
||||||
WithToken(rac.accessToken),
|
WithToken(rac.AccessToken),
|
||||||
WithURL(url),
|
WithURL(url),
|
||||||
}...)
|
}...)
|
||||||
req := NewRequest(opts...)
|
req := NewRequest(opts...)
|
||||||
|
@ -399,7 +417,7 @@ func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit st
|
||||||
opts = append(rac.client.defaultOpts, opts...)
|
opts = append(rac.client.defaultOpts, opts...)
|
||||||
opts = append(opts, []RequestOption{
|
opts = append(opts, []RequestOption{
|
||||||
WithMethod("GET"),
|
WithMethod("GET"),
|
||||||
WithToken(rac.accessToken),
|
WithToken(rac.AccessToken),
|
||||||
WithURL(url),
|
WithURL(url),
|
||||||
}...)
|
}...)
|
||||||
req := NewRequest(opts...)
|
req := NewRequest(opts...)
|
||||||
|
@ -429,7 +447,7 @@ func (rac *AuthenticatedClient) MessageInbox(ctx context.Context, opts ...Reques
|
||||||
opts = append(opts, []RequestOption{
|
opts = append(opts, []RequestOption{
|
||||||
WithTags([]string{"url:/api/v1/message/inbox"}),
|
WithTags([]string{"url:/api/v1/message/inbox"}),
|
||||||
WithMethod("GET"),
|
WithMethod("GET"),
|
||||||
WithToken(rac.accessToken),
|
WithToken(rac.AccessToken),
|
||||||
WithURL("https://oauth.reddit.com/message/inbox"),
|
WithURL("https://oauth.reddit.com/message/inbox"),
|
||||||
WithEmptyResponseBytes(122),
|
WithEmptyResponseBytes(122),
|
||||||
}...)
|
}...)
|
||||||
|
@ -447,7 +465,7 @@ func (rac *AuthenticatedClient) MessageUnread(ctx context.Context, opts ...Reque
|
||||||
opts = append(opts, []RequestOption{
|
opts = append(opts, []RequestOption{
|
||||||
WithTags([]string{"url:/api/v1/message/unread"}),
|
WithTags([]string{"url:/api/v1/message/unread"}),
|
||||||
WithMethod("GET"),
|
WithMethod("GET"),
|
||||||
WithToken(rac.accessToken),
|
WithToken(rac.AccessToken),
|
||||||
WithURL("https://oauth.reddit.com/message/unread"),
|
WithURL("https://oauth.reddit.com/message/unread"),
|
||||||
WithEmptyResponseBytes(122),
|
WithEmptyResponseBytes(122),
|
||||||
}...)
|
}...)
|
||||||
|
@ -466,7 +484,7 @@ func (rac *AuthenticatedClient) Me(ctx context.Context, opts ...RequestOption) (
|
||||||
opts = append(opts, []RequestOption{
|
opts = append(opts, []RequestOption{
|
||||||
WithTags([]string{"url:/api/v1/me"}),
|
WithTags([]string{"url:/api/v1/me"}),
|
||||||
WithMethod("GET"),
|
WithMethod("GET"),
|
||||||
WithToken(rac.accessToken),
|
WithToken(rac.AccessToken),
|
||||||
WithURL("https://oauth.reddit.com/api/v1/me"),
|
WithURL("https://oauth.reddit.com/api/v1/me"),
|
||||||
}...)
|
}...)
|
||||||
|
|
||||||
|
|
|
@ -25,4 +25,6 @@ var (
|
||||||
ErrRequiresRedditId = errors.New("requires reddit id")
|
ErrRequiresRedditId = errors.New("requires reddit id")
|
||||||
// ErrInvalidBasicAuth .
|
// ErrInvalidBasicAuth .
|
||||||
ErrInvalidBasicAuth = errors.New("invalid basic auth")
|
ErrInvalidBasicAuth = errors.New("invalid basic auth")
|
||||||
|
// ErrTokenAlreadyRefreshed .
|
||||||
|
ErrTokenAlreadyRefreshed = errors.New("token already refreshed")
|
||||||
)
|
)
|
||||||
|
|
|
@ -167,7 +167,12 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
|
|
||||||
account.CheckCount++
|
account.CheckCount++
|
||||||
|
|
||||||
defer func(acc *domain.Account) {
|
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||||
|
|
||||||
|
defer func(acc *domain.Account, rac *reddit.AuthenticatedClient) {
|
||||||
|
account.AccessToken = rac.AccessToken
|
||||||
|
account.RefreshToken = rac.RefreshToken
|
||||||
|
|
||||||
if err = nc.accountRepo.Update(nc, acc); err != nil {
|
if err = nc.accountRepo.Update(nc, acc); err != nil {
|
||||||
nc.logger.Error("failed to update account",
|
nc.logger.Error("failed to update account",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
|
@ -175,46 +180,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
zap.String("account#username", account.NormalizedUsername()),
|
zap.String("account#username", account.NormalizedUsername()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}(&account)
|
}(&account, rac)
|
||||||
|
|
||||||
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
|
||||||
if account.TokenExpiresAt.Before(now.Add(5 * time.Minute)) {
|
|
||||||
nc.logger.Debug("refreshing reddit token",
|
|
||||||
zap.Int64("account#id", id),
|
|
||||||
zap.String("account#username", account.NormalizedUsername()),
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens, err := rac.RefreshTokens(nc)
|
|
||||||
if err != nil {
|
|
||||||
if err != reddit.ErrOauthRevoked {
|
|
||||||
nc.logger.Error("failed to refresh reddit tokens",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.Int64("account#id", id),
|
|
||||||
zap.String("account#username", account.NormalizedUsername()),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = nc.deleteAccount(account)
|
|
||||||
if err != nil {
|
|
||||||
nc.logger.Error("failed to remove revoked account",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.Int64("account#id", id),
|
|
||||||
zap.String("account#username", account.NormalizedUsername()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update account
|
|
||||||
account.AccessToken = tokens.AccessToken
|
|
||||||
account.RefreshToken = tokens.RefreshToken
|
|
||||||
account.TokenExpiresAt = now.Add(tokens.Expiry)
|
|
||||||
|
|
||||||
// Refresh client
|
|
||||||
rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only update delay on accounts we can actually check, otherwise it skews
|
// Only update delay on accounts we can actually check, otherwise it skews
|
||||||
// the numbers too much.
|
// the numbers too much.
|
||||||
|
|
Loading…
Reference in a new issue