From 5fa9bfaa459b042aec17ba012280ed7fbf8378c3 Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Mon, 23 May 2022 17:26:40 -0400 Subject: [PATCH] distinguish between 401 and 403 --- internal/reddit/client.go | 16 +++----- internal/reddit/client_test.go | 72 ---------------------------------- internal/reddit/errors.go | 5 ++- internal/reddit/request.go | 4 ++ 4 files changed, 14 insertions(+), 83 deletions(-) delete mode 100644 internal/reddit/client_test.go diff --git a/internal/reddit/client.go b/internal/reddit/client.go index cdbeb1f..49c66bd 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -165,15 +165,18 @@ func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimit rli.Timestamp = time.Now().String() } + bb, err := ioutil.ReadAll(resp.Body) + switch resp.StatusCode { case 200: - bb, err := ioutil.ReadAll(resp.Body) return bb, rli, err - case 401, 403: + case 401: + return nil, rli, ErrInvalidBasicAuth + case 403: return nil, rli, ErrOauthRevoked default: _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) - return nil, rli, ServerError{resp.StatusCode} + return nil, rli, ServerError{string(bb), resp.StatusCode} } } @@ -299,13 +302,6 @@ func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...Reque rtr, err := rac.request(ctx, req, NewRefreshTokenResponse, nil) if err != nil { - switch rerr := err.(type) { - case ServerError: - if rerr.StatusCode == 400 { - return nil, ErrOauthRevoked - } - } - return nil, err } diff --git a/internal/reddit/client_test.go b/internal/reddit/client_test.go deleted file mode 100644 index 8e3b714..0000000 --- a/internal/reddit/client_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package reddit_test - -import ( - "bytes" - "context" - "io/ioutil" - "net/http" - "testing" - - "github.com/DataDog/datadog-go/statsd" - "github.com/christianselig/apollo-backend/internal/reddit" - "github.com/go-redis/redismock/v8" - "github.com/stretchr/testify/assert" -) - -// RoundTripFunc . -type RoundTripFunc func(req *http.Request) *http.Response - -// RoundTrip . -func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return f(req), nil -} - -//NewTestClient returns *http.Client with Transport replaced to avoid making real calls -func NewTestClient(fn RoundTripFunc) *http.Client { - return &http.Client{Transport: fn} -} - -func TestErrorResponse(t *testing.T) { - t.Parallel() - - ctx := context.Background() - - db, _ := redismock.NewClientMock() - - errortests := map[string]struct { - call func(*reddit.AuthenticatedClient) error - - status int - body string - err error - }{ - "/api/v1/me 500 returns ServerError": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 500, "", reddit.ServerError{500}}, - "/api/v1/access_token 400 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.RefreshTokens(ctx); return err }, 400, "", reddit.ErrOauthRevoked}, - "/api/v1/message/inbox 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageInbox(ctx); return err }, 403, "", reddit.ErrOauthRevoked}, - "/api/v1/message/unread 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageUnread(ctx); return err }, 403, "", reddit.ErrOauthRevoked}, - "/api/v1/me 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 403, "", reddit.ErrOauthRevoked}, - } - - for scenario, tt := range errortests { - tt := tt - - t.Run(scenario, func(t *testing.T) { - t.Parallel() - - tc := NewTestClient(func(req *http.Request) *http.Response { - return &http.Response{ - StatusCode: tt.status, - Body: ioutil.NopCloser(bytes.NewBufferString(tt.body)), - Header: make(http.Header), - } - }) - - rc := reddit.NewClient("", "", &statsd.NoOpClient{}, db, 1, reddit.WithRetry(false), reddit.WithClient(tc)) - rac := rc.NewAuthenticatedClient(reddit.SkipRateLimiting, "", "") - - err := tt.call(rac) - - assert.ErrorIs(t, err, tt.err) - }) - } -} diff --git a/internal/reddit/errors.go b/internal/reddit/errors.go index d0a6959..01629b1 100644 --- a/internal/reddit/errors.go +++ b/internal/reddit/errors.go @@ -6,11 +6,12 @@ import ( ) type ServerError struct { + Body string StatusCode int } func (se ServerError) Error() string { - return fmt.Sprintf("error from reddit: %d", se.StatusCode) + return fmt.Sprintf("error from reddit: %d (%s)", se.StatusCode, se.Body) } var ( @@ -22,4 +23,6 @@ var ( ErrRateLimited = errors.New("rate limited") // ErrRequiresRedditId . ErrRequiresRedditId = errors.New("requires reddit id") + // ErrInvalidBasicAuth . + ErrInvalidBasicAuth = errors.New("invalid basic auth") ) diff --git a/internal/reddit/request.go b/internal/reddit/request.go index 230a7a7..02269da 100644 --- a/internal/reddit/request.go +++ b/internal/reddit/request.go @@ -59,6 +59,10 @@ func (r *Request) HTTPRequest(ctx context.Context) (*http.Request, error) { req.Header.Add("Accept", "application/json") req.Header.Add("User-Agent", userAgent) + if len(r.body) > 0 { + req.Header.Add("Content-Type", "multipart/form-data") + } + if r.token != "" { req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.token)) }