mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-10 22:17:44 +00:00
test reddit
This commit is contained in:
parent
9a5d699f66
commit
d17151a3b3
8 changed files with 230 additions and 108 deletions
|
@ -11,6 +11,8 @@ type Device struct {
|
||||||
|
|
||||||
type DeviceRepository interface {
|
type DeviceRepository interface {
|
||||||
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
||||||
|
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||||
|
|
||||||
CreateOrUpdate(ctx context.Context, dev *Device) error
|
CreateOrUpdate(ctx context.Context, dev *Device) error
|
||||||
Update(ctx context.Context, dev *Device) error
|
Update(ctx context.Context, dev *Device) error
|
||||||
Create(ctx context.Context, dev *Device) error
|
Create(ctx context.Context, dev *Device) error
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package reddit
|
package reddit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
|
@ -19,7 +18,7 @@ type Client struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
tracer *httptrace.ClientTrace
|
tracer *httptrace.ClientTrace
|
||||||
pool *fastjson.ParserPool
|
pool *fastjson.ParserPool
|
||||||
statsd *statsd.Client
|
statsd statsd.ClientInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
func SplitID(id string) (string, string) {
|
func SplitID(id string) (string, string) {
|
||||||
|
@ -45,7 +44,7 @@ func PostIDFromContext(context string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(id, secret string, statsd *statsd.Client, connLimit int) *Client {
|
func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) *Client {
|
||||||
tracer := &httptrace.ClientTrace{
|
tracer := &httptrace.ClientTrace{
|
||||||
GotConn: func(info httptrace.GotConnInfo) {
|
GotConn: func(info httptrace.GotConnInfo) {
|
||||||
if info.Reused {
|
if info.Reused {
|
||||||
|
@ -127,9 +126,9 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
|
||||||
|
|
||||||
val, jerr := parser.ParseBytes(bb)
|
val, jerr := parser.ParseBytes(bb)
|
||||||
if jerr != nil {
|
if jerr != nil {
|
||||||
return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode)
|
return nil, ServerError{resp.StatusCode}
|
||||||
}
|
}
|
||||||
return nil, NewError(val)
|
return nil, NewError(val, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes {
|
if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes {
|
||||||
|
@ -159,6 +158,13 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
||||||
|
|
||||||
rtr, err := rac.request(req, NewRefreshTokenResponse, nil)
|
rtr, err := rac.request(req, NewRefreshTokenResponse, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
switch rerr := err.(type) {
|
||||||
|
case ServerError:
|
||||||
|
if rerr.StatusCode == 400 {
|
||||||
|
return nil, ErrOauthRevoked
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,6 +188,13 @@ func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingRes
|
||||||
|
|
||||||
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
|
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
switch rerr := err.(type) {
|
||||||
|
case ServerError:
|
||||||
|
if rerr.StatusCode == 403 {
|
||||||
|
return nil, ErrOauthRevoked
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return lr.(*ListingResponse), nil
|
return lr.(*ListingResponse), nil
|
||||||
|
@ -200,6 +213,13 @@ func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingRe
|
||||||
|
|
||||||
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
|
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
switch rerr := err.(type) {
|
||||||
|
case ServerError:
|
||||||
|
if rerr.StatusCode == 403 {
|
||||||
|
return nil, ErrOauthRevoked
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return lr.(*ListingResponse), nil
|
return lr.(*ListingResponse), nil
|
||||||
|
@ -215,6 +235,13 @@ func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
|
||||||
|
|
||||||
mr, err := rac.request(req, NewMeResponse, nil)
|
mr, err := rac.request(req, NewMeResponse, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
switch rerr := err.(type) {
|
||||||
|
case ServerError:
|
||||||
|
if rerr.StatusCode == 403 {
|
||||||
|
return nil, ErrOauthRevoked
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return mr.(*MeResponse), nil
|
return mr.(*MeResponse), nil
|
||||||
|
|
62
internal/reddit/client_test.go
Normal file
62
internal/reddit/client_test.go
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
package reddit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RoundTripFunc .
|
||||||
|
type RoundTripFunc func(req *http.Request) *http.Response
|
||||||
|
|
||||||
|
// RoundTrip .
|
||||||
|
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//NewTestClient returns *http.Client with Transport replaced to avoid making real calls
|
||||||
|
func NewTestClient(fn RoundTripFunc) *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: RoundTripFunc(fn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrorResponse(t *testing.T) {
|
||||||
|
rc := NewClient("", "", &statsd.NoOpClient{}, 1)
|
||||||
|
rac := rc.NewAuthenticatedClient("", "")
|
||||||
|
|
||||||
|
errortests := []struct {
|
||||||
|
name string
|
||||||
|
call func() error
|
||||||
|
|
||||||
|
status int
|
||||||
|
body string
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"/api/v1/me 500 returns ServerError", func() error { _, err := rac.Me(); return err }, 500, "", ServerError{500}},
|
||||||
|
{"/api/v1/access_token 400 returns ErrOauthRevoked", func() error { _, err := rac.RefreshTokens(); return err }, 400, "", ErrOauthRevoked},
|
||||||
|
{"/api/v1/message/inbox 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageInbox(); return err }, 403, "", ErrOauthRevoked},
|
||||||
|
{"/api/v1/message/unread 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageUnread(); return err }, 403, "", ErrOauthRevoked},
|
||||||
|
{"/api/v1/me 403 returns ErrOauthRevoked", func() error { _, err := rac.Me(); return err }, 403, "", ErrOauthRevoked},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range errortests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rac.client = NewTestClient(func(req *http.Request) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: tt.status,
|
||||||
|
Body: ioutil.NopCloser(bytes.NewBufferString(tt.body)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err := tt.call()
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, tt.err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
19
internal/reddit/errors.go
Normal file
19
internal/reddit/errors.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package reddit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerError struct {
|
||||||
|
StatusCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (se ServerError) Error() string {
|
||||||
|
return fmt.Sprintf("errror from reddit: %d", se.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrOauthRevoked .
|
||||||
|
ErrOauthRevoked = errors.New("oauth revoked")
|
||||||
|
)
|
|
@ -12,13 +12,14 @@ type ResponseHandler func(*fastjson.Value) interface{}
|
||||||
type Error struct {
|
type Error struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Code int `json:"error"`
|
Code int `json:"error"`
|
||||||
|
StatusCode int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (err *Error) Error() string {
|
func (err *Error) Error() string {
|
||||||
return fmt.Sprintf("%s (%d)", err.Message, err.Code)
|
return fmt.Sprintf("%s (%d)", err.Message, err.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewError(val *fastjson.Value) *Error {
|
func NewError(val *fastjson.Value, status int) *Error {
|
||||||
err := &Error{}
|
err := &Error{}
|
||||||
|
|
||||||
err.Message = string(val.GetStringBytes("message"))
|
err.Message = string(val.GetStringBytes("message"))
|
||||||
|
|
|
@ -134,6 +134,8 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
|
||||||
res, err := p.pool.Exec(
|
res, err := p.pool.Exec(
|
||||||
ctx,
|
ctx,
|
||||||
query,
|
query,
|
||||||
|
acc.ID,
|
||||||
|
acc.Username,
|
||||||
acc.AccountID,
|
acc.AccountID,
|
||||||
acc.AccessToken,
|
acc.AccessToken,
|
||||||
acc.RefreshToken,
|
acc.RefreshToken,
|
||||||
|
|
|
@ -57,6 +57,16 @@ func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token str
|
||||||
return devs[0], nil
|
return devs[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||||
|
query := `
|
||||||
|
SELECT devices.id, apns_token, sandbox, last_pinged_at
|
||||||
|
FROM devices
|
||||||
|
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||||
|
WHERE devices_accounts.account_id = $1`
|
||||||
|
|
||||||
|
return p.fetch(ctx, query, id)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
|
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO devices (apns_token, sandbox, last_pinged_at)
|
INSERT INTO devices (apns_token, sandbox, last_pinged_at)
|
||||||
|
|
|
@ -10,15 +10,15 @@ import (
|
||||||
"github.com/DataDog/datadog-go/statsd"
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
"github.com/adjust/rmq/v4"
|
"github.com/adjust/rmq/v4"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
"github.com/jackc/pgx/v4/pgxpool"
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
"github.com/sideshow/apns2"
|
"github.com/sideshow/apns2"
|
||||||
"github.com/sideshow/apns2/payload"
|
"github.com/sideshow/apns2/payload"
|
||||||
"github.com/sideshow/apns2/token"
|
"github.com/sideshow/apns2/token"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/christianselig/apollo-backend/internal/data"
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
"github.com/christianselig/apollo-backend/internal/reddit"
|
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||||
|
"github.com/christianselig/apollo-backend/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -35,7 +35,11 @@ type notificationsWorker struct {
|
||||||
queue rmq.Connection
|
queue rmq.Connection
|
||||||
reddit *reddit.Client
|
reddit *reddit.Client
|
||||||
apns *token.Token
|
apns *token.Token
|
||||||
|
|
||||||
consumers int
|
consumers int
|
||||||
|
|
||||||
|
accountRepo domain.AccountRepository
|
||||||
|
deviceRepo domain.DeviceRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
|
func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
|
||||||
|
@ -69,6 +73,9 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg
|
||||||
reddit,
|
reddit,
|
||||||
apns,
|
apns,
|
||||||
consumers,
|
consumers,
|
||||||
|
|
||||||
|
repository.NewPostgresAccount(db),
|
||||||
|
repository.NewPostgresDevice(db),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,13 +144,13 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": delivery.Payload(),
|
"account#id": delivery.Payload(),
|
||||||
}).Debug("starting job")
|
}).Debug("starting job")
|
||||||
|
|
||||||
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
|
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": delivery.Payload(),
|
"account#id": delivery.Payload(),
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to parse account ID")
|
}).Error("failed to parse account ID")
|
||||||
|
|
||||||
|
@ -155,45 +162,21 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
|
|
||||||
now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000
|
now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000
|
||||||
|
|
||||||
stmt := `SELECT
|
account, err := nc.accountRepo.GetByID(ctx, id)
|
||||||
id,
|
if err != nil {
|
||||||
username,
|
|
||||||
account_id,
|
|
||||||
access_token,
|
|
||||||
refresh_token,
|
|
||||||
expires_at,
|
|
||||||
last_message_id,
|
|
||||||
last_checked_at
|
|
||||||
FROM accounts
|
|
||||||
WHERE id = $1`
|
|
||||||
account := &data.Account{}
|
|
||||||
if err := nc.db.QueryRow(ctx, stmt, id).Scan(
|
|
||||||
&account.ID,
|
|
||||||
&account.Username,
|
|
||||||
&account.AccountID,
|
|
||||||
&account.AccessToken,
|
|
||||||
&account.RefreshToken,
|
|
||||||
&account.ExpiresAt,
|
|
||||||
&account.LastMessageID,
|
|
||||||
&account.LastCheckedAt,
|
|
||||||
); err != nil {
|
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to fetch account from database")
|
}).Error("failed to fetch account from database")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
|
newAccount := (account.LastCheckedAt == 0)
|
||||||
stmt := `
|
account.LastCheckedAt = now
|
||||||
UPDATE accounts
|
|
||||||
SET last_checked_at = $1
|
if err = nc.accountRepo.Update(ctx, &account); err != nil {
|
||||||
WHERE id = $2`
|
|
||||||
_, err := tx.Exec(ctx, stmt, now, account.ID)
|
|
||||||
return err
|
|
||||||
}); err != nil {
|
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to update last_checked_at for account")
|
}).Error("failed to update last_checked_at for account")
|
||||||
return
|
return
|
||||||
|
@ -202,18 +185,29 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
||||||
if account.ExpiresAt < int64(now) {
|
if account.ExpiresAt < int64(now) {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
}).Debug("refreshing reddit token")
|
}).Debug("refreshing reddit token")
|
||||||
|
|
||||||
tokens, err := rac.RefreshTokens()
|
tokens, err := rac.RefreshTokens()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err != reddit.ErrOauthRevoked {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to refresh reddit tokens")
|
}).Error("failed to refresh reddit tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = nc.deleteAccount(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
nc.logger.WithFields(logrus.Fields{
|
||||||
|
"account#username": account.NormalizedUsername(),
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to remove revoked account")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Update account
|
// Update account
|
||||||
account.AccessToken = tokens.AccessToken
|
account.AccessToken = tokens.AccessToken
|
||||||
account.RefreshToken = tokens.RefreshToken
|
account.RefreshToken = tokens.RefreshToken
|
||||||
|
@ -222,16 +216,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
// Refresh client
|
// Refresh client
|
||||||
rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken)
|
rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken)
|
||||||
|
|
||||||
err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
|
if err = nc.accountRepo.Update(ctx, &account); err != nil {
|
||||||
stmt := `
|
|
||||||
UPDATE accounts
|
|
||||||
SET access_token = $1, refresh_token = $2, expires_at = $3 WHERE id = $4`
|
|
||||||
_, err := tx.Exec(ctx, stmt, account.AccessToken, account.RefreshToken, account.ExpiresAt, account.ID)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to update reddit tokens for account")
|
}).Error("failed to update reddit tokens for account")
|
||||||
return
|
return
|
||||||
|
@ -240,13 +227,13 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
|
|
||||||
// Only update delay on accounts we can actually check, otherwise it skews
|
// Only update delay on accounts we can actually check, otherwise it skews
|
||||||
// the numbers too much.
|
// the numbers too much.
|
||||||
if account.LastCheckedAt > 0 {
|
if !newAccount {
|
||||||
latency := now - account.LastCheckedAt - float64(backoff)
|
latency := now - account.LastCheckedAt - float64(backoff)
|
||||||
nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate)
|
nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate)
|
||||||
}
|
}
|
||||||
|
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
}).Debug("fetching message inbox")
|
}).Debug("fetching message inbox")
|
||||||
|
|
||||||
opts := []reddit.RequestOption{reddit.WithQuery("limit", "10")}
|
opts := []reddit.RequestOption{reddit.WithQuery("limit", "10")}
|
||||||
|
@ -256,70 +243,66 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
msgs, err := rac.MessageInbox(opts...)
|
msgs, err := rac.MessageInbox(opts...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err != reddit.ErrOauthRevoked {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to fetch message inbox")
|
}).Error("failed to fetch message inbox")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = nc.deleteAccount(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
nc.logger.WithFields(logrus.Fields{
|
||||||
|
"account#username": account.NormalizedUsername(),
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to remove revoked account")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nc.logger.WithFields(logrus.Fields{
|
||||||
|
"account#username": account.NormalizedUsername(),
|
||||||
|
}).Info("removed revoked account")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Figure out where we stand
|
// Figure out where we stand
|
||||||
if msgs.Count == 0 {
|
if msgs.Count == 0 {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
}).Debug("no new messages, bailing early")
|
}).Debug("no new messages, bailing early")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
"count": msgs.Count,
|
"count": msgs.Count,
|
||||||
}).Debug("fetched messages")
|
}).Debug("fetched messages")
|
||||||
|
|
||||||
// Set latest message we alerted on
|
account.LastMessageID = msgs.Children[0].FullName()
|
||||||
if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
|
|
||||||
stmt := `
|
if err = nc.accountRepo.Update(ctx, &account); err != nil {
|
||||||
UPDATE accounts
|
|
||||||
SET last_message_id = $1
|
|
||||||
WHERE id = $2`
|
|
||||||
_, err := tx.Exec(ctx, stmt, msgs.Children[0].FullName(), account.ID)
|
|
||||||
return err
|
|
||||||
}); err != nil {
|
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to update last_message_id for account")
|
}).Error("failed to update last_message_id for account")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let's populate this with the latest message so we don't flood users with stuff
|
// Let's populate this with the latest message so we don't flood users with stuff
|
||||||
if account.LastMessageID == "" && account.LastCheckedAt == 0 {
|
if newAccount {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": delivery.Payload(),
|
"account#username": account.NormalizedUsername(),
|
||||||
}).Debug("populating first message ID to prevent spamming")
|
}).Debug("populating first message ID to prevent spamming")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
devices := []data.Device{}
|
devices, err := nc.deviceRepo.GetByAccountID(ctx, account.ID)
|
||||||
stmt = `
|
|
||||||
SELECT apns_token, sandbox
|
|
||||||
FROM devices
|
|
||||||
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
|
||||||
WHERE devices_accounts.account_id = $1`
|
|
||||||
rows, err := nc.db.Query(ctx, stmt, account.ID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": id,
|
"account#username": account.NormalizedUsername(),
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to fetch account devices")
|
}).Error("failed to fetch account devices")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
for rows.Next() {
|
|
||||||
var device data.Device
|
|
||||||
rows.Scan(&device.APNSToken, &device.Sandbox)
|
|
||||||
devices = append(devices, device)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate backwards so we notify from older to newer
|
// Iterate backwards so we notify from older to newer
|
||||||
for i := msgs.Count - 1; i >= 0; i-- {
|
for i := msgs.Count - 1; i >= 0; i-- {
|
||||||
|
@ -359,11 +342,27 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
nc.statsd.SimpleEvent(ev, "")
|
nc.statsd.SimpleEvent(ev, "")
|
||||||
|
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"accountID": delivery.Payload(),
|
"account#username": account.NormalizedUsername(),
|
||||||
}).Debug("finishing job")
|
}).Debug("finishing job")
|
||||||
}
|
}
|
||||||
|
|
||||||
func payloadFromMessage(acct *data.Account, msg *reddit.Thing, badgeCount int) *payload.Payload {
|
func (nc *notificationsConsumer) deleteAccount(ctx context.Context, account domain.Account) error {
|
||||||
|
// Disassociate account from devices
|
||||||
|
devs, err := nc.deviceRepo.GetByAccountID(ctx, account.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dev := range devs {
|
||||||
|
if err := nc.accountRepo.Disassociate(ctx, &account, &dev); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nc.accountRepo.Delete(ctx, account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func payloadFromMessage(acct domain.Account, msg *reddit.Thing, badgeCount int) *payload.Payload {
|
||||||
postBody := msg.Body
|
postBody := msg.Body
|
||||||
if len(postBody) > 2000 {
|
if len(postBody) > 2000 {
|
||||||
postBody = msg.Body[:2000]
|
postBody = msg.Body[:2000]
|
||||||
|
|
Loading…
Reference in a new issue