From d17151a3b3e9a679df2f9296635b45edb0e5441a Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Sat, 14 Aug 2021 13:42:28 -0400 Subject: [PATCH] test reddit --- internal/domain/device.go | 2 + internal/reddit/client.go | 37 ++++- internal/reddit/client_test.go | 62 ++++++++ internal/reddit/errors.go | 19 +++ internal/reddit/types.go | 7 +- internal/repository/postgres_account.go | 2 + internal/repository/postgres_device.go | 10 ++ internal/worker/notifications.go | 199 ++++++++++++------------ 8 files changed, 230 insertions(+), 108 deletions(-) create mode 100644 internal/reddit/client_test.go create mode 100644 internal/reddit/errors.go diff --git a/internal/domain/device.go b/internal/domain/device.go index 40deeeb..51cc979 100644 --- a/internal/domain/device.go +++ b/internal/domain/device.go @@ -11,6 +11,8 @@ type Device struct { type DeviceRepository interface { GetByAPNSToken(ctx context.Context, token string) (Device, error) + GetByAccountID(ctx context.Context, id int64) ([]Device, error) + CreateOrUpdate(ctx context.Context, dev *Device) error Update(ctx context.Context, dev *Device) error Create(ctx context.Context, dev *Device) error diff --git a/internal/reddit/client.go b/internal/reddit/client.go index b88b9ca..0b9938a 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -1,7 +1,6 @@ package reddit import ( - "fmt" "io/ioutil" "net/http" "net/http/httptrace" @@ -19,7 +18,7 @@ type Client struct { client *http.Client tracer *httptrace.ClientTrace pool *fastjson.ParserPool - statsd *statsd.Client + statsd statsd.ClientInterface } func SplitID(id string) (string, string) { @@ -45,7 +44,7 @@ func PostIDFromContext(context string) string { return "" } -func NewClient(id, secret string, statsd *statsd.Client, connLimit int) *Client { +func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) *Client { tracer := &httptrace.ClientTrace{ GotConn: func(info httptrace.GotConnInfo) { if info.Reused { @@ -127,9 +126,9 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in val, jerr := parser.ParseBytes(bb) if jerr != nil { - return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode) + return nil, ServerError{resp.StatusCode} } - return nil, NewError(val) + return nil, NewError(val, resp.StatusCode) } if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes { @@ -159,6 +158,13 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { rtr, err := rac.request(req, NewRefreshTokenResponse, nil) if err != nil { + switch rerr := err.(type) { + case ServerError: + if rerr.StatusCode == 400 { + return nil, ErrOauthRevoked + } + } + return nil, err } @@ -182,6 +188,13 @@ func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingRes lr, err := rac.request(req, NewListingResponse, EmptyListingResponse) if err != nil { + switch rerr := err.(type) { + case ServerError: + if rerr.StatusCode == 403 { + return nil, ErrOauthRevoked + } + } + return nil, err } return lr.(*ListingResponse), nil @@ -200,6 +213,13 @@ func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingRe lr, err := rac.request(req, NewListingResponse, EmptyListingResponse) if err != nil { + switch rerr := err.(type) { + case ServerError: + if rerr.StatusCode == 403 { + return nil, ErrOauthRevoked + } + } + return nil, err } return lr.(*ListingResponse), nil @@ -215,6 +235,13 @@ func (rac *AuthenticatedClient) Me() (*MeResponse, error) { mr, err := rac.request(req, NewMeResponse, nil) if err != nil { + switch rerr := err.(type) { + case ServerError: + if rerr.StatusCode == 403 { + return nil, ErrOauthRevoked + } + } + return nil, err } return mr.(*MeResponse), nil diff --git a/internal/reddit/client_test.go b/internal/reddit/client_test.go new file mode 100644 index 0000000..76de88f --- /dev/null +++ b/internal/reddit/client_test.go @@ -0,0 +1,62 @@ +package reddit + +import ( + "bytes" + "io/ioutil" + "net/http" + "testing" + + "github.com/DataDog/datadog-go/statsd" + "github.com/stretchr/testify/assert" +) + +// RoundTripFunc . +type RoundTripFunc func(req *http.Request) *http.Response + +// RoundTrip . +func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} + +//NewTestClient returns *http.Client with Transport replaced to avoid making real calls +func NewTestClient(fn RoundTripFunc) *http.Client { + return &http.Client{ + Transport: RoundTripFunc(fn), + } +} + +func TestErrorResponse(t *testing.T) { + rc := NewClient("", "", &statsd.NoOpClient{}, 1) + rac := rc.NewAuthenticatedClient("", "") + + errortests := []struct { + name string + call func() error + + status int + body string + err error + }{ + {"/api/v1/me 500 returns ServerError", func() error { _, err := rac.Me(); return err }, 500, "", ServerError{500}}, + {"/api/v1/access_token 400 returns ErrOauthRevoked", func() error { _, err := rac.RefreshTokens(); return err }, 400, "", ErrOauthRevoked}, + {"/api/v1/message/inbox 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageInbox(); return err }, 403, "", ErrOauthRevoked}, + {"/api/v1/message/unread 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageUnread(); return err }, 403, "", ErrOauthRevoked}, + {"/api/v1/me 403 returns ErrOauthRevoked", func() error { _, err := rac.Me(); return err }, 403, "", ErrOauthRevoked}, + } + + for _, tt := range errortests { + t.Run(tt.name, func(t *testing.T) { + rac.client = NewTestClient(func(req *http.Request) *http.Response { + return &http.Response{ + StatusCode: tt.status, + Body: ioutil.NopCloser(bytes.NewBufferString(tt.body)), + Header: make(http.Header), + } + }) + + err := tt.call() + + assert.ErrorIs(t, err, tt.err) + }) + } +} diff --git a/internal/reddit/errors.go b/internal/reddit/errors.go new file mode 100644 index 0000000..8a6b01e --- /dev/null +++ b/internal/reddit/errors.go @@ -0,0 +1,19 @@ +package reddit + +import ( + "errors" + "fmt" +) + +type ServerError struct { + StatusCode int +} + +func (se ServerError) Error() string { + return fmt.Sprintf("errror from reddit: %d", se.StatusCode) +} + +var ( + // ErrOauthRevoked . + ErrOauthRevoked = errors.New("oauth revoked") +) diff --git a/internal/reddit/types.go b/internal/reddit/types.go index 123ca21..6109149 100644 --- a/internal/reddit/types.go +++ b/internal/reddit/types.go @@ -10,15 +10,16 @@ import ( type ResponseHandler func(*fastjson.Value) interface{} type Error struct { - Message string `json:"message"` - Code int `json:"error"` + Message string `json:"message"` + Code int `json:"error"` + StatusCode int } func (err *Error) Error() string { return fmt.Sprintf("%s (%d)", err.Message, err.Code) } -func NewError(val *fastjson.Value) *Error { +func NewError(val *fastjson.Value, status int) *Error { err := &Error{} err.Message = string(val.GetStringBytes("message")) diff --git a/internal/repository/postgres_account.go b/internal/repository/postgres_account.go index 1694c86..03c636a 100644 --- a/internal/repository/postgres_account.go +++ b/internal/repository/postgres_account.go @@ -134,6 +134,8 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco res, err := p.pool.Exec( ctx, query, + acc.ID, + acc.Username, acc.AccountID, acc.AccessToken, acc.RefreshToken, diff --git a/internal/repository/postgres_device.go b/internal/repository/postgres_device.go index 637023c..2adb0ce 100644 --- a/internal/repository/postgres_device.go +++ b/internal/repository/postgres_device.go @@ -57,6 +57,16 @@ func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token str return devs[0], nil } +func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { + query := ` + SELECT devices.id, apns_token, sandbox, last_pinged_at + FROM devices + INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id + WHERE devices_accounts.account_id = $1` + + return p.fetch(ctx, query, id) +} + func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error { query := ` INSERT INTO devices (apns_token, sandbox, last_pinged_at) diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index a3b55ff..f505b16 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -10,15 +10,15 @@ import ( "github.com/DataDog/datadog-go/statsd" "github.com/adjust/rmq/v4" "github.com/go-redis/redis/v8" - "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/sideshow/apns2" "github.com/sideshow/apns2/payload" "github.com/sideshow/apns2/token" "github.com/sirupsen/logrus" - "github.com/christianselig/apollo-backend/internal/data" + "github.com/christianselig/apollo-backend/internal/domain" "github.com/christianselig/apollo-backend/internal/reddit" + "github.com/christianselig/apollo-backend/internal/repository" ) const ( @@ -28,14 +28,18 @@ const ( ) type notificationsWorker struct { - logger *logrus.Logger - statsd *statsd.Client - db *pgxpool.Pool - redis *redis.Client - queue rmq.Connection - reddit *reddit.Client - apns *token.Token + logger *logrus.Logger + statsd *statsd.Client + db *pgxpool.Pool + redis *redis.Client + queue rmq.Connection + reddit *reddit.Client + apns *token.Token + consumers int + + accountRepo domain.AccountRepository + deviceRepo domain.DeviceRepository } func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker { @@ -69,6 +73,9 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg reddit, apns, consumers, + + repository.NewPostgresAccount(db), + repository.NewPostgresDevice(db), } } @@ -137,14 +144,14 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { }() nc.logger.WithFields(logrus.Fields{ - "accountID": delivery.Payload(), + "account#id": delivery.Payload(), }).Debug("starting job") id, err := strconv.ParseInt(delivery.Payload(), 10, 64) if err != nil { nc.logger.WithFields(logrus.Fields{ - "accountID": delivery.Payload(), - "err": err, + "account#id": delivery.Payload(), + "err": err, }).Error("failed to parse account ID") delivery.Reject() @@ -155,46 +162,22 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000 - stmt := `SELECT - id, - username, - account_id, - access_token, - refresh_token, - expires_at, - last_message_id, - last_checked_at - FROM accounts - WHERE id = $1` - account := &data.Account{} - if err := nc.db.QueryRow(ctx, stmt, id).Scan( - &account.ID, - &account.Username, - &account.AccountID, - &account.AccessToken, - &account.RefreshToken, - &account.ExpiresAt, - &account.LastMessageID, - &account.LastCheckedAt, - ); err != nil { + account, err := nc.accountRepo.GetByID(ctx, id) + if err != nil { nc.logger.WithFields(logrus.Fields{ - "accountID": id, - "err": err, + "account#username": account.NormalizedUsername(), + "err": err, }).Error("failed to fetch account from database") return } - if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error { - stmt := ` - UPDATE accounts - SET last_checked_at = $1 - WHERE id = $2` - _, err := tx.Exec(ctx, stmt, now, account.ID) - return err - }); err != nil { + newAccount := (account.LastCheckedAt == 0) + account.LastCheckedAt = now + + if err = nc.accountRepo.Update(ctx, &account); err != nil { nc.logger.WithFields(logrus.Fields{ - "accountID": id, - "err": err, + "account#username": account.NormalizedUsername(), + "err": err, }).Error("failed to update last_checked_at for account") return } @@ -202,16 +185,27 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) if account.ExpiresAt < int64(now) { nc.logger.WithFields(logrus.Fields{ - "accountID": id, + "account#username": account.NormalizedUsername(), }).Debug("refreshing reddit token") tokens, err := rac.RefreshTokens() if err != nil { - nc.logger.WithFields(logrus.Fields{ - "accountID": id, - "err": err, - }).Error("failed to refresh reddit tokens") - return + if err != reddit.ErrOauthRevoked { + nc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + "err": err, + }).Error("failed to refresh reddit tokens") + return + } + + err = nc.deleteAccount(ctx, account) + if err != nil { + nc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + "err": err, + }).Error("failed to remove revoked account") + return + } } // Update account @@ -222,17 +216,10 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { // Refresh client rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken) - err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error { - stmt := ` - UPDATE accounts - SET access_token = $1, refresh_token = $2, expires_at = $3 WHERE id = $4` - _, err := tx.Exec(ctx, stmt, account.AccessToken, account.RefreshToken, account.ExpiresAt, account.ID) - return err - }) - if err != nil { + if err = nc.accountRepo.Update(ctx, &account); err != nil { nc.logger.WithFields(logrus.Fields{ - "accountID": id, - "err": err, + "account#username": account.NormalizedUsername(), + "err": err, }).Error("failed to update reddit tokens for account") return } @@ -240,13 +227,13 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { // Only update delay on accounts we can actually check, otherwise it skews // the numbers too much. - if account.LastCheckedAt > 0 { + if !newAccount { latency := now - account.LastCheckedAt - float64(backoff) nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate) } nc.logger.WithFields(logrus.Fields{ - "accountID": id, + "account#username": account.NormalizedUsername(), }).Debug("fetching message inbox") opts := []reddit.RequestOption{reddit.WithQuery("limit", "10")} @@ -256,70 +243,66 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { msgs, err := rac.MessageInbox(opts...) if err != nil { + if err != reddit.ErrOauthRevoked { + nc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + "err": err, + }).Error("failed to fetch message inbox") + } + + err = nc.deleteAccount(ctx, account) + if err != nil { + nc.logger.WithFields(logrus.Fields{ + "account#username": account.NormalizedUsername(), + "err": err, + }).Error("failed to remove revoked account") + return + } nc.logger.WithFields(logrus.Fields{ - "accountID": id, - "err": err, - }).Error("failed to fetch message inbox") + "account#username": account.NormalizedUsername(), + }).Info("removed revoked account") return } // Figure out where we stand if msgs.Count == 0 { nc.logger.WithFields(logrus.Fields{ - "accountID": id, + "account#username": account.NormalizedUsername(), }).Debug("no new messages, bailing early") return } nc.logger.WithFields(logrus.Fields{ - "accountID": id, - "count": msgs.Count, + "account#username": account.NormalizedUsername(), + "count": msgs.Count, }).Debug("fetched messages") - // Set latest message we alerted on - if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error { - stmt := ` - UPDATE accounts - SET last_message_id = $1 - WHERE id = $2` - _, err := tx.Exec(ctx, stmt, msgs.Children[0].FullName(), account.ID) - return err - }); err != nil { + account.LastMessageID = msgs.Children[0].FullName() + + if err = nc.accountRepo.Update(ctx, &account); err != nil { nc.logger.WithFields(logrus.Fields{ - "accountID": id, - "err": err, + "account#username": account.NormalizedUsername(), + "err": err, }).Error("failed to update last_message_id for account") return } // Let's populate this with the latest message so we don't flood users with stuff - if account.LastMessageID == "" && account.LastCheckedAt == 0 { + if newAccount { nc.logger.WithFields(logrus.Fields{ - "accountID": delivery.Payload(), + "account#username": account.NormalizedUsername(), }).Debug("populating first message ID to prevent spamming") return } - devices := []data.Device{} - stmt = ` - SELECT apns_token, sandbox - FROM devices - INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id - WHERE devices_accounts.account_id = $1` - rows, err := nc.db.Query(ctx, stmt, account.ID) + devices, err := nc.deviceRepo.GetByAccountID(ctx, account.ID) if err != nil { nc.logger.WithFields(logrus.Fields{ - "accountID": id, - "err": err, + "account#username": account.NormalizedUsername(), + "err": err, }).Error("failed to fetch account devices") return } - defer rows.Close() - for rows.Next() { - var device data.Device - rows.Scan(&device.APNSToken, &device.Sandbox) - devices = append(devices, device) - } // Iterate backwards so we notify from older to newer for i := msgs.Count - 1; i >= 0; i-- { @@ -359,11 +342,27 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { nc.statsd.SimpleEvent(ev, "") nc.logger.WithFields(logrus.Fields{ - "accountID": delivery.Payload(), + "account#username": account.NormalizedUsername(), }).Debug("finishing job") } -func payloadFromMessage(acct *data.Account, msg *reddit.Thing, badgeCount int) *payload.Payload { +func (nc *notificationsConsumer) deleteAccount(ctx context.Context, account domain.Account) error { + // Disassociate account from devices + devs, err := nc.deviceRepo.GetByAccountID(ctx, account.ID) + if err != nil { + return err + } + + for _, dev := range devs { + if err := nc.accountRepo.Disassociate(ctx, &account, &dev); err != nil { + return err + } + } + + return nc.accountRepo.Delete(ctx, account.ID) +} + +func payloadFromMessage(acct domain.Account, msg *reddit.Thing, badgeCount int) *payload.Payload { postBody := msg.Body if len(postBody) > 2000 { postBody = msg.Body[:2000]