From 294243b02d9f732df03fbbd961e8c3de79b2616f Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Sat, 26 Mar 2022 13:40:51 -0400 Subject: [PATCH] fix ci --- .github/workflows/test.yml | 2 +- internal/reddit/client.go | 77 ++++++++++++++++---------- internal/reddit/client_test.go | 12 ++-- internal/repository/postgres_device.go | 16 ++---- 4 files changed, 59 insertions(+), 48 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4fbb879..a70e84d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,4 +17,4 @@ jobs: - name: Lint uses: golangci/golangci-lint-action@v2 - name: Test - run: go test ./... -v + run: go test ./... -v -race -timeout 5s diff --git a/internal/reddit/client.go b/internal/reddit/client.go index 4bdfb5b..6d4d023 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -26,13 +26,14 @@ const ( ) type Client struct { - id string - secret string - client *http.Client - tracer *httptrace.ClientTrace - pool *fastjson.ParserPool - statsd statsd.ClientInterface - redis *redis.Client + id string + secret string + client *http.Client + tracer *httptrace.ClientTrace + pool *fastjson.ParserPool + statsd statsd.ClientInterface + redis *redis.Client + defaultOpts []RequestOption } type RateLimitingInfo struct { @@ -72,7 +73,7 @@ func PostIDFromContext(context string) string { return "" } -func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Client, connLimit int) *Client { +func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Client, connLimit int, opts ...RequestOption) *Client { tracer := &httptrace.ClientTrace{ GotConn: func(info httptrace.GotConnInfo) { if info.Reused { @@ -106,6 +107,7 @@ func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Cl pool, statsd, redis, + opts, } } @@ -176,7 +178,10 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in return nil, ErrRateLimited } - rac.logRequest() + if err := rac.logRequest(); err != nil { + return nil, err + } + bb, rli, err := rac.doRequest(r) if err != nil && r.retry { @@ -185,7 +190,12 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in time.AfterFunc(backoff, func() { _ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1) - rac.logRequest() + + if err = rac.logRequest(); err != nil { + done <- struct{}{} + return + } + bb, rli, err = rac.doRequest(r) done <- struct{}{} }) @@ -205,7 +215,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in } return nil, err } else { - rac.markRateLimited(rli) + _ = rac.markRateLimited(rli) } if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes { @@ -272,14 +282,15 @@ func (rac *AuthenticatedClient) markRateLimited(rli *RateLimitingInfo) error { } func (rac *AuthenticatedClient) RefreshTokens(opts ...RequestOption) (*RefreshTokenResponse, error) { - opts = append([]RequestOption{ + opts = append(rac.defaultOpts, opts...) + opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/access_token"}), WithMethod("POST"), WithURL("https://www.reddit.com/api/v1/access_token"), WithBody("grant_type", "refresh_token"), WithBody("refresh_token", rac.refreshToken), WithBasicAuth(rac.id, rac.secret), - }, opts...) + }...) req := NewRequest(opts...) rtr, err := rac.request(req, NewRefreshTokenResponse, nil) @@ -303,12 +314,13 @@ func (rac *AuthenticatedClient) RefreshTokens(opts ...RequestOption) (*RefreshTo } func (rac *AuthenticatedClient) AboutInfo(fullname string, opts ...RequestOption) (*ListingResponse, error) { - opts = append([]RequestOption{ + opts = append(rac.defaultOpts, opts...) + opts = append(opts, []RequestOption{ WithMethod("GET"), WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/api/info"), WithQuery("id", fullname), - }, opts...) + }...) req := NewRequest(opts...) lr, err := rac.request(req, NewListingResponse, nil) @@ -321,11 +333,12 @@ func (rac *AuthenticatedClient) AboutInfo(fullname string, opts ...RequestOption func (rac *AuthenticatedClient) UserPosts(user string, opts ...RequestOption) (*ListingResponse, error) { url := fmt.Sprintf("https://oauth.reddit.com/u/%s/submitted", user) - opts = append([]RequestOption{ + opts = append(rac.defaultOpts, opts...) + opts = append(opts, []RequestOption{ WithMethod("GET"), WithToken(rac.accessToken), WithURL(url), - }, opts...) + }...) req := NewRequest(opts...) lr, err := rac.request(req, NewListingResponse, nil) @@ -338,11 +351,12 @@ func (rac *AuthenticatedClient) UserPosts(user string, opts ...RequestOption) (* func (rac *AuthenticatedClient) UserAbout(user string, opts ...RequestOption) (*UserResponse, error) { url := fmt.Sprintf("https://oauth.reddit.com/u/%s/about", user) - opts = append([]RequestOption{ + opts = append(rac.defaultOpts, opts...) + opts = append(opts, []RequestOption{ WithMethod("GET"), WithToken(rac.accessToken), WithURL(url), - }, opts...) + }...) req := NewRequest(opts...) ur, err := rac.request(req, NewUserResponse, nil) @@ -356,11 +370,12 @@ func (rac *AuthenticatedClient) UserAbout(user string, opts ...RequestOption) (* func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...RequestOption) (*SubredditResponse, error) { url := fmt.Sprintf("https://oauth.reddit.com/r/%s/about", subreddit) - opts = append([]RequestOption{ + opts = append(rac.defaultOpts, opts...) + opts = append(opts, []RequestOption{ WithMethod("GET"), WithToken(rac.accessToken), WithURL(url), - }, opts...) + }...) req := NewRequest(opts...) sr, err := rac.request(req, NewSubredditResponse, nil) @@ -373,11 +388,12 @@ func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...Request func (rac *AuthenticatedClient) subredditPosts(subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) { url := fmt.Sprintf("https://oauth.reddit.com/r/%s/%s", subreddit, sort) - opts = append([]RequestOption{ + opts = append(rac.defaultOpts, opts...) + opts = append(opts, []RequestOption{ WithMethod("GET"), WithToken(rac.accessToken), WithURL(url), - }, opts...) + }...) req := NewRequest(opts...) lr, err := rac.request(req, NewListingResponse, nil) @@ -401,13 +417,14 @@ func (rac *AuthenticatedClient) SubredditNew(subreddit string, opts ...RequestOp } func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) { - opts = append([]RequestOption{ + opts = append(rac.defaultOpts, opts...) + opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/message/inbox"}), WithMethod("GET"), WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/message/inbox"), WithEmptyResponseBytes(122), - }, opts...) + }...) req := NewRequest(opts...) lr, err := rac.request(req, NewListingResponse, EmptyListingResponse) @@ -425,13 +442,14 @@ func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingRes } func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingResponse, error) { - opts = append([]RequestOption{ + opts = append(rac.defaultOpts, opts...) + opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/message/unread"}), WithMethod("GET"), WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/message/unread"), WithEmptyResponseBytes(122), - }, opts...) + }...) req := NewRequest(opts...) @@ -450,12 +468,13 @@ func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingRe } func (rac *AuthenticatedClient) Me(opts ...RequestOption) (*MeResponse, error) { - opts = append([]RequestOption{ + opts = append(rac.defaultOpts, opts...) + opts = append(opts, []RequestOption{ WithTags([]string{"url:/api/v1/me"}), WithMethod("GET"), WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/api/v1/me"), - }, opts...) + }...) req := NewRequest(opts...) mr, err := rac.request(req, NewMeResponse, nil) diff --git a/internal/reddit/client_test.go b/internal/reddit/client_test.go index 82a1330..e68598a 100644 --- a/internal/reddit/client_test.go +++ b/internal/reddit/client_test.go @@ -29,7 +29,7 @@ func NewTestClient(fn RoundTripFunc) *http.Client { func TestErrorResponse(t *testing.T) { db, _ := redismock.NewClientMock() - rc := NewClient("", "", &statsd.NoOpClient{}, db, 1) + rc := NewClient("", "", &statsd.NoOpClient{}, db, 1, WithRetry(false)) rac := rc.NewAuthenticatedClient(SkipRateLimiting, "", "") errortests := []struct { @@ -40,11 +40,11 @@ func TestErrorResponse(t *testing.T) { body string err error }{ - {"/api/v1/me 500 returns ServerError", func() error { _, err := rac.Me(WithRetry(false)); return err }, 500, "", ServerError{500}}, - {"/api/v1/access_token 400 returns ErrOauthRevoked", func() error { _, err := rac.RefreshTokens(WithRetry(false)); return err }, 400, "", ErrOauthRevoked}, - {"/api/v1/message/inbox 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageInbox(WithRetry(false)); return err }, 403, "", ErrOauthRevoked}, - {"/api/v1/message/unread 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageUnread(WithRetry(false)); return err }, 403, "", ErrOauthRevoked}, - {"/api/v1/me 403 returns ErrOauthRevoked", func() error { _, err := rac.Me(WithRetry(false)); return err }, 403, "", ErrOauthRevoked}, + {"/api/v1/me 500 returns ServerError", func() error { _, err := rac.Me(); return err }, 500, "", ServerError{500}}, + {"/api/v1/access_token 400 returns ErrOauthRevoked", func() error { _, err := rac.RefreshTokens(); return err }, 400, "", ErrOauthRevoked}, + {"/api/v1/message/inbox 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageInbox(); return err }, 403, "", ErrOauthRevoked}, + {"/api/v1/message/unread 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageUnread(); return err }, 403, "", ErrOauthRevoked}, + {"/api/v1/me 403 returns ErrOauthRevoked", func() error { _, err := rac.Me(); return err }, 403, "", ErrOauthRevoked}, } for _, tt := range errortests { diff --git a/internal/repository/postgres_device.go b/internal/repository/postgres_device.go index a6e6249..c4bf709 100644 --- a/internal/repository/postgres_device.go +++ b/internal/repository/postgres_device.go @@ -201,20 +201,12 @@ func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domai FROM devices_accounts WHERE device_id = $1 AND account_id = $2` - rows, err := p.pool.Query(ctx, query, dev.ID, acct.ID) - if err != nil { - return false, false, false, err - } - defer rows.Close() - for rows.Next() { - var inbox, watcher, global bool - if err := rows.Scan(&inbox, &watcher, &global); err != nil { - return false, false, false, err - } - return inbox, watcher, global, nil + var inbox, watcher, global bool + if err := p.pool.QueryRow(ctx, query, dev.ID, acct.ID).Scan(&inbox, &watcher, &global); err != nil { + return false, false, false, domain.ErrNotFound } - return false, false, false, domain.ErrNotFound + return inbox, watcher, global, nil } func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {