mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-22 03:37:43 +00:00
distinguish between 401 and 403
This commit is contained in:
parent
be926f9118
commit
5fa9bfaa45
4 changed files with 14 additions and 83 deletions
|
@ -165,15 +165,18 @@ func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimit
|
|||
rli.Timestamp = time.Now().String()
|
||||
}
|
||||
|
||||
bb, err := ioutil.ReadAll(resp.Body)
|
||||
|
||||
switch resp.StatusCode {
|
||||
case 200:
|
||||
bb, err := ioutil.ReadAll(resp.Body)
|
||||
return bb, rli, err
|
||||
case 401, 403:
|
||||
case 401:
|
||||
return nil, rli, ErrInvalidBasicAuth
|
||||
case 403:
|
||||
return nil, rli, ErrOauthRevoked
|
||||
default:
|
||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
return nil, rli, ServerError{resp.StatusCode}
|
||||
return nil, rli, ServerError{string(bb), resp.StatusCode}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -299,13 +302,6 @@ func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...Reque
|
|||
|
||||
rtr, err := rac.request(ctx, req, NewRefreshTokenResponse, nil)
|
||||
if err != nil {
|
||||
switch rerr := err.(type) {
|
||||
case ServerError:
|
||||
if rerr.StatusCode == 400 {
|
||||
return nil, ErrOauthRevoked
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -1,72 +0,0 @@
|
|||
package reddit_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/DataDog/datadog-go/statsd"
|
||||
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||
"github.com/go-redis/redismock/v8"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// RoundTripFunc .
|
||||
type RoundTripFunc func(req *http.Request) *http.Response
|
||||
|
||||
// RoundTrip .
|
||||
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req), nil
|
||||
}
|
||||
|
||||
//NewTestClient returns *http.Client with Transport replaced to avoid making real calls
|
||||
func NewTestClient(fn RoundTripFunc) *http.Client {
|
||||
return &http.Client{Transport: fn}
|
||||
}
|
||||
|
||||
func TestErrorResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
db, _ := redismock.NewClientMock()
|
||||
|
||||
errortests := map[string]struct {
|
||||
call func(*reddit.AuthenticatedClient) error
|
||||
|
||||
status int
|
||||
body string
|
||||
err error
|
||||
}{
|
||||
"/api/v1/me 500 returns ServerError": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 500, "", reddit.ServerError{500}},
|
||||
"/api/v1/access_token 400 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.RefreshTokens(ctx); return err }, 400, "", reddit.ErrOauthRevoked},
|
||||
"/api/v1/message/inbox 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageInbox(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
|
||||
"/api/v1/message/unread 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageUnread(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
|
||||
"/api/v1/me 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
|
||||
}
|
||||
|
||||
for scenario, tt := range errortests {
|
||||
tt := tt
|
||||
|
||||
t.Run(scenario, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tc := NewTestClient(func(req *http.Request) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: tt.status,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(tt.body)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
})
|
||||
|
||||
rc := reddit.NewClient("", "", &statsd.NoOpClient{}, db, 1, reddit.WithRetry(false), reddit.WithClient(tc))
|
||||
rac := rc.NewAuthenticatedClient(reddit.SkipRateLimiting, "", "")
|
||||
|
||||
err := tt.call(rac)
|
||||
|
||||
assert.ErrorIs(t, err, tt.err)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -6,11 +6,12 @@ import (
|
|||
)
|
||||
|
||||
type ServerError struct {
|
||||
Body string
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (se ServerError) Error() string {
|
||||
return fmt.Sprintf("error from reddit: %d", se.StatusCode)
|
||||
return fmt.Sprintf("error from reddit: %d (%s)", se.StatusCode, se.Body)
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -22,4 +23,6 @@ var (
|
|||
ErrRateLimited = errors.New("rate limited")
|
||||
// ErrRequiresRedditId .
|
||||
ErrRequiresRedditId = errors.New("requires reddit id")
|
||||
// ErrInvalidBasicAuth .
|
||||
ErrInvalidBasicAuth = errors.New("invalid basic auth")
|
||||
)
|
||||
|
|
|
@ -59,6 +59,10 @@ func (r *Request) HTTPRequest(ctx context.Context) (*http.Request, error) {
|
|||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("User-Agent", userAgent)
|
||||
|
||||
if len(r.body) > 0 {
|
||||
req.Header.Add("Content-Type", "multipart/form-data")
|
||||
}
|
||||
|
||||
if r.token != "" {
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.token))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue