mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-22 03:37:43 +00:00
distinguish between 401 and 403
This commit is contained in:
parent
be926f9118
commit
5fa9bfaa45
4 changed files with 14 additions and 83 deletions
|
@ -165,15 +165,18 @@ func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimit
|
||||||
rli.Timestamp = time.Now().String()
|
rli.Timestamp = time.Now().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bb, err := ioutil.ReadAll(resp.Body)
|
||||||
|
|
||||||
switch resp.StatusCode {
|
switch resp.StatusCode {
|
||||||
case 200:
|
case 200:
|
||||||
bb, err := ioutil.ReadAll(resp.Body)
|
|
||||||
return bb, rli, err
|
return bb, rli, err
|
||||||
case 401, 403:
|
case 401:
|
||||||
|
return nil, rli, ErrInvalidBasicAuth
|
||||||
|
case 403:
|
||||||
return nil, rli, ErrOauthRevoked
|
return nil, rli, ErrOauthRevoked
|
||||||
default:
|
default:
|
||||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
return nil, rli, ServerError{resp.StatusCode}
|
return nil, rli, ServerError{string(bb), resp.StatusCode}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,13 +302,6 @@ func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...Reque
|
||||||
|
|
||||||
rtr, err := rac.request(ctx, req, NewRefreshTokenResponse, nil)
|
rtr, err := rac.request(ctx, req, NewRefreshTokenResponse, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch rerr := err.(type) {
|
|
||||||
case ServerError:
|
|
||||||
if rerr.StatusCode == 400 {
|
|
||||||
return nil, ErrOauthRevoked
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,72 +0,0 @@
|
||||||
package reddit_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/DataDog/datadog-go/statsd"
|
|
||||||
"github.com/christianselig/apollo-backend/internal/reddit"
|
|
||||||
"github.com/go-redis/redismock/v8"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RoundTripFunc .
|
|
||||||
type RoundTripFunc func(req *http.Request) *http.Response
|
|
||||||
|
|
||||||
// RoundTrip .
|
|
||||||
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
return f(req), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//NewTestClient returns *http.Client with Transport replaced to avoid making real calls
|
|
||||||
func NewTestClient(fn RoundTripFunc) *http.Client {
|
|
||||||
return &http.Client{Transport: fn}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorResponse(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
db, _ := redismock.NewClientMock()
|
|
||||||
|
|
||||||
errortests := map[string]struct {
|
|
||||||
call func(*reddit.AuthenticatedClient) error
|
|
||||||
|
|
||||||
status int
|
|
||||||
body string
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"/api/v1/me 500 returns ServerError": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 500, "", reddit.ServerError{500}},
|
|
||||||
"/api/v1/access_token 400 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.RefreshTokens(ctx); return err }, 400, "", reddit.ErrOauthRevoked},
|
|
||||||
"/api/v1/message/inbox 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageInbox(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
|
|
||||||
"/api/v1/message/unread 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageUnread(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
|
|
||||||
"/api/v1/me 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
|
|
||||||
}
|
|
||||||
|
|
||||||
for scenario, tt := range errortests {
|
|
||||||
tt := tt
|
|
||||||
|
|
||||||
t.Run(scenario, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tc := NewTestClient(func(req *http.Request) *http.Response {
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: tt.status,
|
|
||||||
Body: ioutil.NopCloser(bytes.NewBufferString(tt.body)),
|
|
||||||
Header: make(http.Header),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
rc := reddit.NewClient("", "", &statsd.NoOpClient{}, db, 1, reddit.WithRetry(false), reddit.WithClient(tc))
|
|
||||||
rac := rc.NewAuthenticatedClient(reddit.SkipRateLimiting, "", "")
|
|
||||||
|
|
||||||
err := tt.call(rac)
|
|
||||||
|
|
||||||
assert.ErrorIs(t, err, tt.err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -6,11 +6,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServerError struct {
|
type ServerError struct {
|
||||||
|
Body string
|
||||||
StatusCode int
|
StatusCode int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (se ServerError) Error() string {
|
func (se ServerError) Error() string {
|
||||||
return fmt.Sprintf("error from reddit: %d", se.StatusCode)
|
return fmt.Sprintf("error from reddit: %d (%s)", se.StatusCode, se.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -22,4 +23,6 @@ var (
|
||||||
ErrRateLimited = errors.New("rate limited")
|
ErrRateLimited = errors.New("rate limited")
|
||||||
// ErrRequiresRedditId .
|
// ErrRequiresRedditId .
|
||||||
ErrRequiresRedditId = errors.New("requires reddit id")
|
ErrRequiresRedditId = errors.New("requires reddit id")
|
||||||
|
// ErrInvalidBasicAuth .
|
||||||
|
ErrInvalidBasicAuth = errors.New("invalid basic auth")
|
||||||
)
|
)
|
||||||
|
|
|
@ -59,6 +59,10 @@ func (r *Request) HTTPRequest(ctx context.Context) (*http.Request, error) {
|
||||||
req.Header.Add("Accept", "application/json")
|
req.Header.Add("Accept", "application/json")
|
||||||
req.Header.Add("User-Agent", userAgent)
|
req.Header.Add("User-Agent", userAgent)
|
||||||
|
|
||||||
|
if len(r.body) > 0 {
|
||||||
|
req.Header.Add("Content-Type", "multipart/form-data")
|
||||||
|
}
|
||||||
|
|
||||||
if r.token != "" {
|
if r.token != "" {
|
||||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.token))
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.token))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue