Merge pull request #19 from christianselig/chore/test-reddit

This commit is contained in:
André Medeiros 2021-08-14 13:43:53 -04:00 committed by GitHub
commit 15db5d8b7a
8 changed files with 230 additions and 108 deletions

View file

@ -11,6 +11,8 @@ type Device struct {
type DeviceRepository interface {
GetByAPNSToken(ctx context.Context, token string) (Device, error)
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
CreateOrUpdate(ctx context.Context, dev *Device) error
Update(ctx context.Context, dev *Device) error
Create(ctx context.Context, dev *Device) error

View file

@ -1,7 +1,6 @@
package reddit
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptrace"
@ -19,7 +18,7 @@ type Client struct {
client *http.Client
tracer *httptrace.ClientTrace
pool *fastjson.ParserPool
statsd *statsd.Client
statsd statsd.ClientInterface
}
func SplitID(id string) (string, string) {
@ -45,7 +44,7 @@ func PostIDFromContext(context string) string {
return ""
}
func NewClient(id, secret string, statsd *statsd.Client, connLimit int) *Client {
func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) *Client {
tracer := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
if info.Reused {
@ -127,9 +126,9 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
val, jerr := parser.ParseBytes(bb)
if jerr != nil {
return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode)
return nil, ServerError{resp.StatusCode}
}
return nil, NewError(val)
return nil, NewError(val, resp.StatusCode)
}
if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes {
@ -159,6 +158,13 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
rtr, err := rac.request(req, NewRefreshTokenResponse, nil)
if err != nil {
switch rerr := err.(type) {
case ServerError:
if rerr.StatusCode == 400 {
return nil, ErrOauthRevoked
}
}
return nil, err
}
@ -182,6 +188,13 @@ func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingRes
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
if err != nil {
switch rerr := err.(type) {
case ServerError:
if rerr.StatusCode == 403 {
return nil, ErrOauthRevoked
}
}
return nil, err
}
return lr.(*ListingResponse), nil
@ -200,6 +213,13 @@ func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingRe
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
if err != nil {
switch rerr := err.(type) {
case ServerError:
if rerr.StatusCode == 403 {
return nil, ErrOauthRevoked
}
}
return nil, err
}
return lr.(*ListingResponse), nil
@ -215,6 +235,13 @@ func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
mr, err := rac.request(req, NewMeResponse, nil)
if err != nil {
switch rerr := err.(type) {
case ServerError:
if rerr.StatusCode == 403 {
return nil, ErrOauthRevoked
}
}
return nil, err
}
return mr.(*MeResponse), nil

View file

@ -0,0 +1,62 @@
package reddit
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/DataDog/datadog-go/statsd"
"github.com/stretchr/testify/assert"
)
// RoundTripFunc .
type RoundTripFunc func(req *http.Request) *http.Response
// RoundTrip .
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}
//NewTestClient returns *http.Client with Transport replaced to avoid making real calls
func NewTestClient(fn RoundTripFunc) *http.Client {
return &http.Client{
Transport: RoundTripFunc(fn),
}
}
func TestErrorResponse(t *testing.T) {
rc := NewClient("", "", &statsd.NoOpClient{}, 1)
rac := rc.NewAuthenticatedClient("", "")
errortests := []struct {
name string
call func() error
status int
body string
err error
}{
{"/api/v1/me 500 returns ServerError", func() error { _, err := rac.Me(); return err }, 500, "", ServerError{500}},
{"/api/v1/access_token 400 returns ErrOauthRevoked", func() error { _, err := rac.RefreshTokens(); return err }, 400, "", ErrOauthRevoked},
{"/api/v1/message/inbox 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageInbox(); return err }, 403, "", ErrOauthRevoked},
{"/api/v1/message/unread 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageUnread(); return err }, 403, "", ErrOauthRevoked},
{"/api/v1/me 403 returns ErrOauthRevoked", func() error { _, err := rac.Me(); return err }, 403, "", ErrOauthRevoked},
}
for _, tt := range errortests {
t.Run(tt.name, func(t *testing.T) {
rac.client = NewTestClient(func(req *http.Request) *http.Response {
return &http.Response{
StatusCode: tt.status,
Body: ioutil.NopCloser(bytes.NewBufferString(tt.body)),
Header: make(http.Header),
}
})
err := tt.call()
assert.ErrorIs(t, err, tt.err)
})
}
}

19
internal/reddit/errors.go Normal file
View file

@ -0,0 +1,19 @@
package reddit
import (
"errors"
"fmt"
)
type ServerError struct {
StatusCode int
}
func (se ServerError) Error() string {
return fmt.Sprintf("errror from reddit: %d", se.StatusCode)
}
var (
// ErrOauthRevoked .
ErrOauthRevoked = errors.New("oauth revoked")
)

View file

@ -12,13 +12,14 @@ type ResponseHandler func(*fastjson.Value) interface{}
type Error struct {
Message string `json:"message"`
Code int `json:"error"`
StatusCode int
}
func (err *Error) Error() string {
return fmt.Sprintf("%s (%d)", err.Message, err.Code)
}
func NewError(val *fastjson.Value) *Error {
func NewError(val *fastjson.Value, status int) *Error {
err := &Error{}
err.Message = string(val.GetStringBytes("message"))

View file

@ -134,6 +134,8 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
res, err := p.pool.Exec(
ctx,
query,
acc.ID,
acc.Username,
acc.AccountID,
acc.AccessToken,
acc.RefreshToken,

View file

@ -57,6 +57,16 @@ func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token str
return devs[0], nil
}
func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := `
SELECT devices.id, apns_token, sandbox, last_pinged_at
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1`
return p.fetch(ctx, query, id)
}
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
query := `
INSERT INTO devices (apns_token, sandbox, last_pinged_at)

View file

@ -10,15 +10,15 @@ import (
"github.com/DataDog/datadog-go/statsd"
"github.com/adjust/rmq/v4"
"github.com/go-redis/redis/v8"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/sideshow/apns2"
"github.com/sideshow/apns2/payload"
"github.com/sideshow/apns2/token"
"github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/data"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit"
"github.com/christianselig/apollo-backend/internal/repository"
)
const (
@ -35,7 +35,11 @@ type notificationsWorker struct {
queue rmq.Connection
reddit *reddit.Client
apns *token.Token
consumers int
accountRepo domain.AccountRepository
deviceRepo domain.DeviceRepository
}
func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
@ -69,6 +73,9 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg
reddit,
apns,
consumers,
repository.NewPostgresAccount(db),
repository.NewPostgresDevice(db),
}
}
@ -137,13 +144,13 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
}()
nc.logger.WithFields(logrus.Fields{
"accountID": delivery.Payload(),
"account#id": delivery.Payload(),
}).Debug("starting job")
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"accountID": delivery.Payload(),
"account#id": delivery.Payload(),
"err": err,
}).Error("failed to parse account ID")
@ -155,45 +162,21 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000
stmt := `SELECT
id,
username,
account_id,
access_token,
refresh_token,
expires_at,
last_message_id,
last_checked_at
FROM accounts
WHERE id = $1`
account := &data.Account{}
if err := nc.db.QueryRow(ctx, stmt, id).Scan(
&account.ID,
&account.Username,
&account.AccountID,
&account.AccessToken,
&account.RefreshToken,
&account.ExpiresAt,
&account.LastMessageID,
&account.LastCheckedAt,
); err != nil {
account, err := nc.accountRepo.GetByID(ctx, id)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to fetch account from database")
return
}
if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := `
UPDATE accounts
SET last_checked_at = $1
WHERE id = $2`
_, err := tx.Exec(ctx, stmt, now, account.ID)
return err
}); err != nil {
newAccount := (account.LastCheckedAt == 0)
account.LastCheckedAt = now
if err = nc.accountRepo.Update(ctx, &account); err != nil {
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to update last_checked_at for account")
return
@ -202,18 +185,29 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
if account.ExpiresAt < int64(now) {
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
}).Debug("refreshing reddit token")
tokens, err := rac.RefreshTokens()
if err != nil {
if err != reddit.ErrOauthRevoked {
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to refresh reddit tokens")
return
}
err = nc.deleteAccount(ctx, account)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to remove revoked account")
return
}
}
// Update account
account.AccessToken = tokens.AccessToken
account.RefreshToken = tokens.RefreshToken
@ -222,16 +216,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
// Refresh client
rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken)
err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := `
UPDATE accounts
SET access_token = $1, refresh_token = $2, expires_at = $3 WHERE id = $4`
_, err := tx.Exec(ctx, stmt, account.AccessToken, account.RefreshToken, account.ExpiresAt, account.ID)
return err
})
if err != nil {
if err = nc.accountRepo.Update(ctx, &account); err != nil {
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to update reddit tokens for account")
return
@ -240,13 +227,13 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
// Only update delay on accounts we can actually check, otherwise it skews
// the numbers too much.
if account.LastCheckedAt > 0 {
if !newAccount {
latency := now - account.LastCheckedAt - float64(backoff)
nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate)
}
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
}).Debug("fetching message inbox")
opts := []reddit.RequestOption{reddit.WithQuery("limit", "10")}
@ -256,70 +243,66 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
msgs, err := rac.MessageInbox(opts...)
if err != nil {
if err != reddit.ErrOauthRevoked {
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to fetch message inbox")
}
err = nc.deleteAccount(ctx, account)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to remove revoked account")
return
}
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
}).Info("removed revoked account")
return
}
// Figure out where we stand
if msgs.Count == 0 {
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
}).Debug("no new messages, bailing early")
return
}
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
"count": msgs.Count,
}).Debug("fetched messages")
// Set latest message we alerted on
if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := `
UPDATE accounts
SET last_message_id = $1
WHERE id = $2`
_, err := tx.Exec(ctx, stmt, msgs.Children[0].FullName(), account.ID)
return err
}); err != nil {
account.LastMessageID = msgs.Children[0].FullName()
if err = nc.accountRepo.Update(ctx, &account); err != nil {
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to update last_message_id for account")
return
}
// Let's populate this with the latest message so we don't flood users with stuff
if account.LastMessageID == "" && account.LastCheckedAt == 0 {
if newAccount {
nc.logger.WithFields(logrus.Fields{
"accountID": delivery.Payload(),
"account#username": account.NormalizedUsername(),
}).Debug("populating first message ID to prevent spamming")
return
}
devices := []data.Device{}
stmt = `
SELECT apns_token, sandbox
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1`
rows, err := nc.db.Query(ctx, stmt, account.ID)
devices, err := nc.deviceRepo.GetByAccountID(ctx, account.ID)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to fetch account devices")
return
}
defer rows.Close()
for rows.Next() {
var device data.Device
rows.Scan(&device.APNSToken, &device.Sandbox)
devices = append(devices, device)
}
// Iterate backwards so we notify from older to newer
for i := msgs.Count - 1; i >= 0; i-- {
@ -359,11 +342,27 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
nc.statsd.SimpleEvent(ev, "")
nc.logger.WithFields(logrus.Fields{
"accountID": delivery.Payload(),
"account#username": account.NormalizedUsername(),
}).Debug("finishing job")
}
func payloadFromMessage(acct *data.Account, msg *reddit.Thing, badgeCount int) *payload.Payload {
func (nc *notificationsConsumer) deleteAccount(ctx context.Context, account domain.Account) error {
// Disassociate account from devices
devs, err := nc.deviceRepo.GetByAccountID(ctx, account.ID)
if err != nil {
return err
}
for _, dev := range devs {
if err := nc.accountRepo.Disassociate(ctx, &account, &dev); err != nil {
return err
}
}
return nc.accountRepo.Delete(ctx, account.ID)
}
func payloadFromMessage(acct domain.Account, msg *reddit.Thing, badgeCount int) *payload.Payload {
postBody := msg.Body
if len(postBody) > 2000 {
postBody = msg.Body[:2000]