distinguish between 401 and 403

This commit is contained in:
Andre Medeiros 2022-05-23 17:26:40 -04:00
parent be926f9118
commit 5fa9bfaa45
4 changed files with 14 additions and 83 deletions

View file

@ -165,15 +165,18 @@ func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimit
rli.Timestamp = time.Now().String() rli.Timestamp = time.Now().String()
} }
bb, err := ioutil.ReadAll(resp.Body)
switch resp.StatusCode { switch resp.StatusCode {
case 200: case 200:
bb, err := ioutil.ReadAll(resp.Body)
return bb, rli, err return bb, rli, err
case 401, 403: case 401:
return nil, rli, ErrInvalidBasicAuth
case 403:
return nil, rli, ErrOauthRevoked return nil, rli, ErrOauthRevoked
default: default:
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
return nil, rli, ServerError{resp.StatusCode} return nil, rli, ServerError{string(bb), resp.StatusCode}
} }
} }
@ -299,13 +302,6 @@ func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...Reque
rtr, err := rac.request(ctx, req, NewRefreshTokenResponse, nil) rtr, err := rac.request(ctx, req, NewRefreshTokenResponse, nil)
if err != nil { if err != nil {
switch rerr := err.(type) {
case ServerError:
if rerr.StatusCode == 400 {
return nil, ErrOauthRevoked
}
}
return nil, err return nil, err
} }

View file

@ -1,72 +0,0 @@
package reddit_test
import (
"bytes"
"context"
"io/ioutil"
"net/http"
"testing"
"github.com/DataDog/datadog-go/statsd"
"github.com/christianselig/apollo-backend/internal/reddit"
"github.com/go-redis/redismock/v8"
"github.com/stretchr/testify/assert"
)
// RoundTripFunc .
type RoundTripFunc func(req *http.Request) *http.Response
// RoundTrip .
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}
//NewTestClient returns *http.Client with Transport replaced to avoid making real calls
func NewTestClient(fn RoundTripFunc) *http.Client {
return &http.Client{Transport: fn}
}
func TestErrorResponse(t *testing.T) {
t.Parallel()
ctx := context.Background()
db, _ := redismock.NewClientMock()
errortests := map[string]struct {
call func(*reddit.AuthenticatedClient) error
status int
body string
err error
}{
"/api/v1/me 500 returns ServerError": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 500, "", reddit.ServerError{500}},
"/api/v1/access_token 400 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.RefreshTokens(ctx); return err }, 400, "", reddit.ErrOauthRevoked},
"/api/v1/message/inbox 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageInbox(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
"/api/v1/message/unread 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageUnread(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
"/api/v1/me 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
}
for scenario, tt := range errortests {
tt := tt
t.Run(scenario, func(t *testing.T) {
t.Parallel()
tc := NewTestClient(func(req *http.Request) *http.Response {
return &http.Response{
StatusCode: tt.status,
Body: ioutil.NopCloser(bytes.NewBufferString(tt.body)),
Header: make(http.Header),
}
})
rc := reddit.NewClient("", "", &statsd.NoOpClient{}, db, 1, reddit.WithRetry(false), reddit.WithClient(tc))
rac := rc.NewAuthenticatedClient(reddit.SkipRateLimiting, "", "")
err := tt.call(rac)
assert.ErrorIs(t, err, tt.err)
})
}
}

View file

@ -6,11 +6,12 @@ import (
) )
type ServerError struct { type ServerError struct {
Body string
StatusCode int StatusCode int
} }
func (se ServerError) Error() string { func (se ServerError) Error() string {
return fmt.Sprintf("error from reddit: %d", se.StatusCode) return fmt.Sprintf("error from reddit: %d (%s)", se.StatusCode, se.Body)
} }
var ( var (
@ -22,4 +23,6 @@ var (
ErrRateLimited = errors.New("rate limited") ErrRateLimited = errors.New("rate limited")
// ErrRequiresRedditId . // ErrRequiresRedditId .
ErrRequiresRedditId = errors.New("requires reddit id") ErrRequiresRedditId = errors.New("requires reddit id")
// ErrInvalidBasicAuth .
ErrInvalidBasicAuth = errors.New("invalid basic auth")
) )

View file

@ -59,6 +59,10 @@ func (r *Request) HTTPRequest(ctx context.Context) (*http.Request, error) {
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
req.Header.Add("User-Agent", userAgent) req.Header.Add("User-Agent", userAgent)
if len(r.body) > 0 {
req.Header.Add("Content-Type", "multipart/form-data")
}
if r.token != "" { if r.token != "" {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.token)) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.token))
} }