test reddit

This commit is contained in:
Andre Medeiros 2021-08-14 13:42:28 -04:00
parent 9a5d699f66
commit d17151a3b3
8 changed files with 230 additions and 108 deletions

View file

@ -11,6 +11,8 @@ type Device struct {
type DeviceRepository interface { type DeviceRepository interface {
GetByAPNSToken(ctx context.Context, token string) (Device, error) GetByAPNSToken(ctx context.Context, token string) (Device, error)
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
CreateOrUpdate(ctx context.Context, dev *Device) error CreateOrUpdate(ctx context.Context, dev *Device) error
Update(ctx context.Context, dev *Device) error Update(ctx context.Context, dev *Device) error
Create(ctx context.Context, dev *Device) error Create(ctx context.Context, dev *Device) error

View file

@ -1,7 +1,6 @@
package reddit package reddit
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
@ -19,7 +18,7 @@ type Client struct {
client *http.Client client *http.Client
tracer *httptrace.ClientTrace tracer *httptrace.ClientTrace
pool *fastjson.ParserPool pool *fastjson.ParserPool
statsd *statsd.Client statsd statsd.ClientInterface
} }
func SplitID(id string) (string, string) { func SplitID(id string) (string, string) {
@ -45,7 +44,7 @@ func PostIDFromContext(context string) string {
return "" return ""
} }
func NewClient(id, secret string, statsd *statsd.Client, connLimit int) *Client { func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) *Client {
tracer := &httptrace.ClientTrace{ tracer := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) { GotConn: func(info httptrace.GotConnInfo) {
if info.Reused { if info.Reused {
@ -127,9 +126,9 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
val, jerr := parser.ParseBytes(bb) val, jerr := parser.ParseBytes(bb)
if jerr != nil { if jerr != nil {
return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode) return nil, ServerError{resp.StatusCode}
} }
return nil, NewError(val) return nil, NewError(val, resp.StatusCode)
} }
if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes { if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes {
@ -159,6 +158,13 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
rtr, err := rac.request(req, NewRefreshTokenResponse, nil) rtr, err := rac.request(req, NewRefreshTokenResponse, nil)
if err != nil { if err != nil {
switch rerr := err.(type) {
case ServerError:
if rerr.StatusCode == 400 {
return nil, ErrOauthRevoked
}
}
return nil, err return nil, err
} }
@ -182,6 +188,13 @@ func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingRes
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse) lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
if err != nil { if err != nil {
switch rerr := err.(type) {
case ServerError:
if rerr.StatusCode == 403 {
return nil, ErrOauthRevoked
}
}
return nil, err return nil, err
} }
return lr.(*ListingResponse), nil return lr.(*ListingResponse), nil
@ -200,6 +213,13 @@ func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingRe
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse) lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
if err != nil { if err != nil {
switch rerr := err.(type) {
case ServerError:
if rerr.StatusCode == 403 {
return nil, ErrOauthRevoked
}
}
return nil, err return nil, err
} }
return lr.(*ListingResponse), nil return lr.(*ListingResponse), nil
@ -215,6 +235,13 @@ func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
mr, err := rac.request(req, NewMeResponse, nil) mr, err := rac.request(req, NewMeResponse, nil)
if err != nil { if err != nil {
switch rerr := err.(type) {
case ServerError:
if rerr.StatusCode == 403 {
return nil, ErrOauthRevoked
}
}
return nil, err return nil, err
} }
return mr.(*MeResponse), nil return mr.(*MeResponse), nil

View file

@ -0,0 +1,62 @@
package reddit
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/DataDog/datadog-go/statsd"
"github.com/stretchr/testify/assert"
)
// RoundTripFunc .
type RoundTripFunc func(req *http.Request) *http.Response
// RoundTrip .
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}
//NewTestClient returns *http.Client with Transport replaced to avoid making real calls
func NewTestClient(fn RoundTripFunc) *http.Client {
return &http.Client{
Transport: RoundTripFunc(fn),
}
}
func TestErrorResponse(t *testing.T) {
rc := NewClient("", "", &statsd.NoOpClient{}, 1)
rac := rc.NewAuthenticatedClient("", "")
errortests := []struct {
name string
call func() error
status int
body string
err error
}{
{"/api/v1/me 500 returns ServerError", func() error { _, err := rac.Me(); return err }, 500, "", ServerError{500}},
{"/api/v1/access_token 400 returns ErrOauthRevoked", func() error { _, err := rac.RefreshTokens(); return err }, 400, "", ErrOauthRevoked},
{"/api/v1/message/inbox 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageInbox(); return err }, 403, "", ErrOauthRevoked},
{"/api/v1/message/unread 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageUnread(); return err }, 403, "", ErrOauthRevoked},
{"/api/v1/me 403 returns ErrOauthRevoked", func() error { _, err := rac.Me(); return err }, 403, "", ErrOauthRevoked},
}
for _, tt := range errortests {
t.Run(tt.name, func(t *testing.T) {
rac.client = NewTestClient(func(req *http.Request) *http.Response {
return &http.Response{
StatusCode: tt.status,
Body: ioutil.NopCloser(bytes.NewBufferString(tt.body)),
Header: make(http.Header),
}
})
err := tt.call()
assert.ErrorIs(t, err, tt.err)
})
}
}

19
internal/reddit/errors.go Normal file
View file

@ -0,0 +1,19 @@
package reddit
import (
"errors"
"fmt"
)
type ServerError struct {
StatusCode int
}
func (se ServerError) Error() string {
return fmt.Sprintf("errror from reddit: %d", se.StatusCode)
}
var (
// ErrOauthRevoked .
ErrOauthRevoked = errors.New("oauth revoked")
)

View file

@ -12,13 +12,14 @@ type ResponseHandler func(*fastjson.Value) interface{}
type Error struct { type Error struct {
Message string `json:"message"` Message string `json:"message"`
Code int `json:"error"` Code int `json:"error"`
StatusCode int
} }
func (err *Error) Error() string { func (err *Error) Error() string {
return fmt.Sprintf("%s (%d)", err.Message, err.Code) return fmt.Sprintf("%s (%d)", err.Message, err.Code)
} }
func NewError(val *fastjson.Value) *Error { func NewError(val *fastjson.Value, status int) *Error {
err := &Error{} err := &Error{}
err.Message = string(val.GetStringBytes("message")) err.Message = string(val.GetStringBytes("message"))

View file

@ -134,6 +134,8 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
res, err := p.pool.Exec( res, err := p.pool.Exec(
ctx, ctx,
query, query,
acc.ID,
acc.Username,
acc.AccountID, acc.AccountID,
acc.AccessToken, acc.AccessToken,
acc.RefreshToken, acc.RefreshToken,

View file

@ -57,6 +57,16 @@ func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token str
return devs[0], nil return devs[0], nil
} }
func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := `
SELECT devices.id, apns_token, sandbox, last_pinged_at
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1`
return p.fetch(ctx, query, id)
}
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error { func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
query := ` query := `
INSERT INTO devices (apns_token, sandbox, last_pinged_at) INSERT INTO devices (apns_token, sandbox, last_pinged_at)

View file

@ -10,15 +10,15 @@ import (
"github.com/DataDog/datadog-go/statsd" "github.com/DataDog/datadog-go/statsd"
"github.com/adjust/rmq/v4" "github.com/adjust/rmq/v4"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
"github.com/sideshow/apns2" "github.com/sideshow/apns2"
"github.com/sideshow/apns2/payload" "github.com/sideshow/apns2/payload"
"github.com/sideshow/apns2/token" "github.com/sideshow/apns2/token"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/data" "github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit" "github.com/christianselig/apollo-backend/internal/reddit"
"github.com/christianselig/apollo-backend/internal/repository"
) )
const ( const (
@ -35,7 +35,11 @@ type notificationsWorker struct {
queue rmq.Connection queue rmq.Connection
reddit *reddit.Client reddit *reddit.Client
apns *token.Token apns *token.Token
consumers int consumers int
accountRepo domain.AccountRepository
deviceRepo domain.DeviceRepository
} }
func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker { func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
@ -69,6 +73,9 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg
reddit, reddit,
apns, apns,
consumers, consumers,
repository.NewPostgresAccount(db),
repository.NewPostgresDevice(db),
} }
} }
@ -137,13 +144,13 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
}() }()
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": delivery.Payload(), "account#id": delivery.Payload(),
}).Debug("starting job") }).Debug("starting job")
id, err := strconv.ParseInt(delivery.Payload(), 10, 64) id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
if err != nil { if err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": delivery.Payload(), "account#id": delivery.Payload(),
"err": err, "err": err,
}).Error("failed to parse account ID") }).Error("failed to parse account ID")
@ -155,45 +162,21 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000 now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000
stmt := `SELECT account, err := nc.accountRepo.GetByID(ctx, id)
id, if err != nil {
username,
account_id,
access_token,
refresh_token,
expires_at,
last_message_id,
last_checked_at
FROM accounts
WHERE id = $1`
account := &data.Account{}
if err := nc.db.QueryRow(ctx, stmt, id).Scan(
&account.ID,
&account.Username,
&account.AccountID,
&account.AccessToken,
&account.RefreshToken,
&account.ExpiresAt,
&account.LastMessageID,
&account.LastCheckedAt,
); err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
"err": err, "err": err,
}).Error("failed to fetch account from database") }).Error("failed to fetch account from database")
return return
} }
if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error { newAccount := (account.LastCheckedAt == 0)
stmt := ` account.LastCheckedAt = now
UPDATE accounts
SET last_checked_at = $1 if err = nc.accountRepo.Update(ctx, &account); err != nil {
WHERE id = $2`
_, err := tx.Exec(ctx, stmt, now, account.ID)
return err
}); err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
"err": err, "err": err,
}).Error("failed to update last_checked_at for account") }).Error("failed to update last_checked_at for account")
return return
@ -202,18 +185,29 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
if account.ExpiresAt < int64(now) { if account.ExpiresAt < int64(now) {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
}).Debug("refreshing reddit token") }).Debug("refreshing reddit token")
tokens, err := rac.RefreshTokens() tokens, err := rac.RefreshTokens()
if err != nil { if err != nil {
if err != reddit.ErrOauthRevoked {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
"err": err, "err": err,
}).Error("failed to refresh reddit tokens") }).Error("failed to refresh reddit tokens")
return return
} }
err = nc.deleteAccount(ctx, account)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to remove revoked account")
return
}
}
// Update account // Update account
account.AccessToken = tokens.AccessToken account.AccessToken = tokens.AccessToken
account.RefreshToken = tokens.RefreshToken account.RefreshToken = tokens.RefreshToken
@ -222,16 +216,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
// Refresh client // Refresh client
rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken) rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken)
err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error { if err = nc.accountRepo.Update(ctx, &account); err != nil {
stmt := `
UPDATE accounts
SET access_token = $1, refresh_token = $2, expires_at = $3 WHERE id = $4`
_, err := tx.Exec(ctx, stmt, account.AccessToken, account.RefreshToken, account.ExpiresAt, account.ID)
return err
})
if err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
"err": err, "err": err,
}).Error("failed to update reddit tokens for account") }).Error("failed to update reddit tokens for account")
return return
@ -240,13 +227,13 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
// Only update delay on accounts we can actually check, otherwise it skews // Only update delay on accounts we can actually check, otherwise it skews
// the numbers too much. // the numbers too much.
if account.LastCheckedAt > 0 { if !newAccount {
latency := now - account.LastCheckedAt - float64(backoff) latency := now - account.LastCheckedAt - float64(backoff)
nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate) nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate)
} }
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
}).Debug("fetching message inbox") }).Debug("fetching message inbox")
opts := []reddit.RequestOption{reddit.WithQuery("limit", "10")} opts := []reddit.RequestOption{reddit.WithQuery("limit", "10")}
@ -256,70 +243,66 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
msgs, err := rac.MessageInbox(opts...) msgs, err := rac.MessageInbox(opts...)
if err != nil { if err != nil {
if err != reddit.ErrOauthRevoked {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
"err": err, "err": err,
}).Error("failed to fetch message inbox") }).Error("failed to fetch message inbox")
}
err = nc.deleteAccount(ctx, account)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to remove revoked account")
return
}
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
}).Info("removed revoked account")
return return
} }
// Figure out where we stand // Figure out where we stand
if msgs.Count == 0 { if msgs.Count == 0 {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
}).Debug("no new messages, bailing early") }).Debug("no new messages, bailing early")
return return
} }
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
"count": msgs.Count, "count": msgs.Count,
}).Debug("fetched messages") }).Debug("fetched messages")
// Set latest message we alerted on account.LastMessageID = msgs.Children[0].FullName()
if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := ` if err = nc.accountRepo.Update(ctx, &account); err != nil {
UPDATE accounts
SET last_message_id = $1
WHERE id = $2`
_, err := tx.Exec(ctx, stmt, msgs.Children[0].FullName(), account.ID)
return err
}); err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
"err": err, "err": err,
}).Error("failed to update last_message_id for account") }).Error("failed to update last_message_id for account")
return return
} }
// Let's populate this with the latest message so we don't flood users with stuff // Let's populate this with the latest message so we don't flood users with stuff
if account.LastMessageID == "" && account.LastCheckedAt == 0 { if newAccount {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": delivery.Payload(), "account#username": account.NormalizedUsername(),
}).Debug("populating first message ID to prevent spamming") }).Debug("populating first message ID to prevent spamming")
return return
} }
devices := []data.Device{} devices, err := nc.deviceRepo.GetByAccountID(ctx, account.ID)
stmt = `
SELECT apns_token, sandbox
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1`
rows, err := nc.db.Query(ctx, stmt, account.ID)
if err != nil { if err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "account#username": account.NormalizedUsername(),
"err": err, "err": err,
}).Error("failed to fetch account devices") }).Error("failed to fetch account devices")
return return
} }
defer rows.Close()
for rows.Next() {
var device data.Device
rows.Scan(&device.APNSToken, &device.Sandbox)
devices = append(devices, device)
}
// Iterate backwards so we notify from older to newer // Iterate backwards so we notify from older to newer
for i := msgs.Count - 1; i >= 0; i-- { for i := msgs.Count - 1; i >= 0; i-- {
@ -359,11 +342,27 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
nc.statsd.SimpleEvent(ev, "") nc.statsd.SimpleEvent(ev, "")
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": delivery.Payload(), "account#username": account.NormalizedUsername(),
}).Debug("finishing job") }).Debug("finishing job")
} }
func payloadFromMessage(acct *data.Account, msg *reddit.Thing, badgeCount int) *payload.Payload { func (nc *notificationsConsumer) deleteAccount(ctx context.Context, account domain.Account) error {
// Disassociate account from devices
devs, err := nc.deviceRepo.GetByAccountID(ctx, account.ID)
if err != nil {
return err
}
for _, dev := range devs {
if err := nc.accountRepo.Disassociate(ctx, &account, &dev); err != nil {
return err
}
}
return nc.accountRepo.Delete(ctx, account.ID)
}
func payloadFromMessage(acct domain.Account, msg *reddit.Thing, badgeCount int) *payload.Payload {
postBody := msg.Body postBody := msg.Body
if len(postBody) > 2000 { if len(postBody) > 2000 {
postBody = msg.Body[:2000] postBody = msg.Body[:2000]