ch-ch-changes

This commit is contained in:
Andre Medeiros 2022-03-12 12:50:05 -05:00
parent c149606f24
commit a1a098b448
14 changed files with 222 additions and 48 deletions

View file

@ -11,10 +11,12 @@ import (
"github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit"
)
type accountNotificationsRequest struct {
Enabled bool
InboxNotifications bool `json:"inbox_notifications"`
WatcherNotifications bool `json:"watcher_notifications"`
}
func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
@ -42,7 +44,7 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request
return
}
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.Enabled); err != nil {
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.InboxNotifications, anr.WatcherNotifications); err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
@ -50,6 +52,37 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request
w.WriteHeader(http.StatusOK)
}
func (a *api) getNotificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
apns := vars["apns"]
rid := vars["redditID"]
ctx := context.Background()
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
inbox, watchers, err := a.deviceRepo.GetNotifiable(ctx, &dev, &acct)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
w.WriteHeader(http.StatusOK)
an := &accountNotificationsRequest{InboxNotifications: inbox, WatcherNotifications: watchers}
_ = json.NewEncoder(w).Encode(an)
}
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
apns := vars["apns"]
@ -108,7 +141,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
for _, acc := range raccs {
delete(accsMap, acc.NormalizedUsername())
ac := a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
tokens, err := ac.RefreshTokens()
if err != nil {
a.errorResponse(w, r, 422, err.Error())
@ -120,7 +153,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
acc.RefreshToken = tokens.RefreshToken
acc.AccessToken = tokens.AccessToken
ac = a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
me, err := ac.Me()
if err != nil {
@ -167,7 +200,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
}
// Here we check whether the account is supplied with a valid token.
ac := a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken)
ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
tokens, err := ac.RefreshTokens()
if err != nil {
a.logger.WithFields(logrus.Fields{
@ -182,7 +215,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
acct.RefreshToken = tokens.RefreshToken
acct.AccessToken = tokens.AccessToken
ac = a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken)
ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
me, err := ac.Me()
if err != nil {

View file

@ -3,11 +3,13 @@ package api
import (
"context"
"fmt"
"net"
"net/http"
"os"
"time"
"github.com/DataDog/datadog-go/statsd"
"github.com/go-redis/redis/v8"
"github.com/gorilla/mux"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/sideshow/apns2/token"
@ -31,11 +33,12 @@ type api struct {
userRepo domain.UserRepository
}
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api {
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, redis *redis.Client, pool *pgxpool.Pool) *api {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
redis,
16,
)
@ -93,6 +96,7 @@ func (a *api) Routes() *mux.Router {
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.getNotificationsAccountHandler).Methods("GET")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
@ -142,10 +146,19 @@ func (a *api) loggingMiddleware(next http.Handler) http.Handler {
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(lrw, r)
remoteAddr := r.Header.Get("X-Forwarded-For")
if remoteAddr == "" {
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
remoteAddr = "unknown"
} else {
remoteAddr = ip
}
}
logEntry := a.logger.WithFields(logrus.Fields{
"duration": time.Since(start).Milliseconds(),
"method": r.Method,
"remote#addr": r.RemoteAddr,
"remote#addr": remoteAddr,
"response#bytes": lrw.bytes,
"status": lrw.statusCode,
"uri": r.RequestURI,

View file

@ -7,8 +7,10 @@ import (
"strconv"
"strings"
"github.com/christianselig/apollo-backend/internal/domain"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/gorilla/mux"
"github.com/christianselig/apollo-backend/internal/domain"
)
type watcherCriteria struct {
@ -28,6 +30,14 @@ type createWatcherRequest struct {
Criteria watcherCriteria
}
func (cwr *createWatcherRequest) Validate() error {
return validation.ValidateStruct(cwr,
validation.Field(&cwr.Type, validation.Required),
validation.Field(&cwr.User, validation.Required.When(cwr.Type == "user")),
validation.Field(&cwr.Subreddit, validation.Required.When(cwr.Type == "subreddit" || cwr.Type == "trending")),
)
}
type watcherCreatedResponse struct {
ID int64 `json:"id"`
}
@ -47,6 +57,11 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
return
}
if err := cwr.Validate(); err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
@ -73,7 +88,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
return
}
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
watcher := domain.Watcher{
Label: cwr.Label,
@ -220,11 +235,12 @@ type watcherItem struct {
Type string `json:"type"`
Label string `json:"label"`
SourceLabel string `json:"source_label"`
Upvotes int64 `json:"upvotes"`
Keyword string `json:"keyword"`
Flair string `json:"flair"`
Domain string `json:"domain"`
Upvotes *int64 `json:"upvotes,omitempty"`
Keyword string `json:"keyword,omitempty"`
Flair string `json:"flair,omitempty"`
Domain string `json:"domain,omitempty"`
Hits int64 `json:"hits"`
Author string `json:"author,omitempty"`
}
func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
@ -248,11 +264,15 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
Type: watcher.Type.String(),
Label: watcher.Label,
SourceLabel: watcher.WatcheeLabel,
Upvotes: watcher.Upvotes,
Keyword: watcher.Keyword,
Flair: watcher.Flair,
Domain: watcher.Domain,
Hits: watcher.Hits,
Author: watcher.Author,
}
if watcher.Upvotes != 0 {
wi.Upvotes = &watcher.Upvotes
}
wis[i] = wi

View file

@ -44,7 +44,7 @@ func APICmd(ctx context.Context) *cobra.Command {
}
defer redis.Close()
api := api.NewAPI(ctx, logger, statsd, db)
api := api.NewAPI(ctx, logger, statsd, redis, db)
srv := api.Server(port)
go func() { _ = srv.ListenAndServe() }()

View file

@ -27,14 +27,16 @@ func (dev *Device) Validate() error {
type DeviceRepository interface {
GetByID(ctx context.Context, id int64) (Device, error)
GetByAPNSToken(ctx context.Context, token string) (Device, error)
GetNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
CreateOrUpdate(ctx context.Context, dev *Device) error
Update(ctx context.Context, dev *Device) error
Create(ctx context.Context, dev *Device) error
Delete(ctx context.Context, token string) error
SetNotifiable(ctx context.Context, dev *Device, acct *Account, notifiable bool) error
SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher bool) error
GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, error)
PruneStale(ctx context.Context, before int64) (int64, error)
}

View file

@ -1,18 +1,26 @@
package reddit
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptrace"
"regexp"
"strconv"
"strings"
"time"
"github.com/DataDog/datadog-go/statsd"
"github.com/go-redis/redis/v8"
"github.com/valyala/fastjson"
)
const (
SkipRateLimiting = "<SKIP_RATE_LIMITING>"
RequestRemainingBuffer = 50
)
type Client struct {
id string
secret string
@ -20,6 +28,7 @@ type Client struct {
tracer *httptrace.ClientTrace
pool *fastjson.ParserPool
statsd statsd.ClientInterface
redis *redis.Client
}
var backoffSchedule = []time.Duration{
@ -51,7 +60,7 @@ func PostIDFromContext(context string) string {
return ""
}
func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) *Client {
func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Client, connLimit int) *Client {
tracer := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
if info.Reused {
@ -84,25 +93,30 @@ func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int)
tracer,
pool,
statsd,
redis,
}
}
type AuthenticatedClient struct {
*Client
redditId string
refreshToken string
accessToken string
expiry *time.Time
}
func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *AuthenticatedClient {
return &AuthenticatedClient{rc, refreshToken, accessToken, nil}
func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient {
if redditId == "" {
panic("requires a redditId")
}
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
}
func (rc *Client) doRequest(r *Request) ([]byte, error) {
func (rc *Client) doRequest(r *Request) ([]byte, int, int, error) {
req, err := r.HTTPRequest()
if err != nil {
return nil, err
return nil, 0, 0, err
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
@ -117,34 +131,53 @@ func (rc *Client) doRequest(r *Request) ([]byte, error) {
if err != nil {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
return nil, ErrTimeout
return nil, 0, 0, ErrTimeout
}
return nil, err
return nil, 0, 0, err
}
defer resp.Body.Close()
remaining, err := strconv.Atoi(resp.Header.Get("x-ratelimit-remaining"))
if err != nil {
remaining = 0
}
reset, err := strconv.Atoi(resp.Header.Get("x-ratelimit-reset"))
if err != nil {
reset = 0
}
if resp.StatusCode != 200 {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
return nil, ServerError{resp.StatusCode}
return nil, remaining, reset, ServerError{resp.StatusCode}
}
bb, err := ioutil.ReadAll(resp.Body)
if err != nil {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
return nil, err
return nil, remaining, reset, err
}
return bb, nil
return bb, remaining, reset, nil
}
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
bb, err := rac.doRequest(r)
if rl, err := rac.isRateLimited(); rl || err != nil {
return nil, ErrRateLimited
}
bb, remaining, reset, err := rac.doRequest(r)
if remaining <= RequestRemainingBuffer {
rac.markRateLimited(time.Duration(reset) * time.Second)
}
if err != nil && r.retry {
for _, backoff := range backoffSchedule {
done := make(chan struct{})
time.AfterFunc(backoff, func() {
_ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1)
bb, err = rac.doRequest(r)
bb, remaining, reset, err = rac.doRequest(r)
done <- struct{}{}
})
@ -179,6 +212,31 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
return rh(val), nil
}
func (rac *AuthenticatedClient) isRateLimited() (bool, error) {
if rac.redditId == SkipRateLimiting {
return false, nil
}
key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId)
res, err := rac.redis.Exists(context.Background(), key).Result()
if err != nil {
return false, err
}
return res > 0, nil
}
func (rac *AuthenticatedClient) markRateLimited(duration time.Duration) error {
if rac.redditId == SkipRateLimiting {
return ErrRequiresRedditId
}
key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId)
_, err := rac.redis.SetEX(context.Background(), key, true, duration).Result()
return err
}
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
req := NewRequest(
WithTags([]string{"url:/api/v1/access_token"}),

View file

@ -10,7 +10,7 @@ type ServerError struct {
}
func (se ServerError) Error() string {
return fmt.Sprintf("errror from reddit: %d", se.StatusCode)
return fmt.Sprintf("error from reddit: %d", se.StatusCode)
}
var (
@ -18,4 +18,8 @@ var (
ErrOauthRevoked = errors.New("oauth revoked")
// ErrTimeout .
ErrTimeout = errors.New("timeout")
// ErrRateLimited .
ErrRateLimited = errors.New("rate limited")
// ErrRequiresRedditId .
ErrRequiresRedditId = errors.New("requires reddit id")
)

View file

@ -84,12 +84,22 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
return p.fetch(ctx, query, id)
}
func (p *postgresDeviceRepository) GetNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := `
SELECT devices.id, apns_token, sandbox, active_until
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1 AND devices_accounts.notifiable = TRUE`
WHERE devices_accounts.account_id = $1 AND devices_accounts.inbox_notifiable = TRUE`
return p.fetch(ctx, query, id)
}
func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := `
SELECT devices.id, apns_token, sandbox, active_until
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1 AND devices_accounts.watcher_notifiable = TRUE`
return p.fetch(ctx, query, id)
}
@ -160,13 +170,15 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err
return err
}
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, notifiable bool) error {
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, inbox, watcher bool) error {
query := `
UPDATE devices_accounts
SET notifiable = $1
WHERE device_id = $2 AND account_id = $3`
SET
inbox_notifiable = $1,
watcher_notifiable = $2,
WHERE device_id = $3 AND account_id = $4`
res, err := p.pool.Exec(ctx, query, notifiable, dev.ID, acct.ID)
res, err := p.pool.Exec(ctx, query, inbox, watcher, dev.ID, acct.ID)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -175,6 +187,28 @@ func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domai
}
func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account) (bool, bool, error) {
query := `
SELECT inbox_notifiable, watcher_notifiable
FROM devices_accounts
WHERE device_id = $1 AND account_id = $2`
rows, err := p.pool.Query(ctx, query, dev.ID, acct.ID)
if err != nil {
return false, false, err
}
defer rows.Close()
for rows.Next() {
var inbox, watcher bool
if err := rows.Scan(&inbox, &watcher); err != nil {
return false, false, err
}
return inbox, watcher, nil
}
return false, false, domain.ErrNotFound
}
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
query := `DELETE FROM devices WHERE active_until < $1`

View file

@ -50,6 +50,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
&watcher.Device.APNSToken,
&watcher.Device.Sandbox,
&watcher.Account.ID,
&watcher.Account.AccountID,
&watcher.Account.AccessToken,
&watcher.Account.RefreshToken,
&subredditLabel,
@ -92,6 +93,7 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.account_id,
accounts.access_token,
accounts.refresh_token,
COALESCE(subreddits.name, '') AS subreddit_label,
@ -136,6 +138,7 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.account_id,
accounts.access_token,
accounts.refresh_token,
COALESCE(subreddits.name, '') AS subreddit_label,
@ -143,9 +146,10 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
FROM watchers
INNER JOIN devices ON watchers.device_id = devices.id
INNER JOIN accounts ON watchers.account_id = accounts.id
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id AND accounts.id = devices_accounts.account_id
LEFT JOIN subreddits ON watchers.type IN(0,2) AND watchers.watchee_id = subreddits.id
LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id
WHERE watchers.type = $1 AND watchers.watchee_id = $2`
WHERE watchers.type = $1 AND watchers.watchee_id = $2 AND devices_accounts.watcher_notifiable = TRUE`
return p.fetch(ctx, query, typ, id)
}
@ -184,6 +188,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.account_id,
accounts.access_token,
accounts.refresh_token,
COALESCE(subreddits.name, '') AS subreddit_label,

View file

@ -47,6 +47,7 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
redis,
consumers,
)
@ -182,7 +183,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return
}
rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
if account.ExpiresAt < int64(now) {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
@ -215,7 +216,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
account.ExpiresAt = int64(now + 3540)
// Refresh client
rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken)
rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
if err = nc.accountRepo.Update(ctx, &account); err != nil {
nc.logger.WithFields(logrus.Fields{
@ -304,7 +305,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return
}
devices, err := nc.deviceRepo.GetNotifiableByAccountID(ctx, account.ID)
devices, err := nc.deviceRepo.GetInboxNotifiableByAccountID(ctx, account.ID)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),

View file

@ -35,6 +35,7 @@ func NewStuckNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, d
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
redis,
consumers,
)
@ -132,7 +133,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
return
}
rac := snc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
rac := snc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
snc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),

View file

@ -47,6 +47,7 @@ func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpo
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
redis,
consumers,
)
@ -199,7 +200,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
watcher := watchers[i]
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
sps, err := rac.SubredditNew(
subreddit.Name,
@ -264,7 +265,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
watcher := watchers[i]
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
sps, err := rac.SubredditHot(
subreddit.Name,
reddit.WithQuery("limit", "100"),

View file

@ -45,6 +45,7 @@ func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
redis,
consumers,
)
@ -176,7 +177,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
// Grab last month's top posts so we calculate a trending average
i := rand.Intn(len(watchers))
watcher := watchers[i]
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week"))
if err != nil {
@ -217,7 +218,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
// Grab hot posts and filter out anything that's > 2 days old
i = rand.Intn(len(watchers))
watcher = watchers[i]
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
hps, err := rac.SubredditHot(subreddit.Name)
if err != nil {
tc.logger.WithFields(logrus.Fields{

View file

@ -46,6 +46,7 @@ func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Po
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
redis,
consumers,
)
@ -180,7 +181,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
watcher := watchers[i]
acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID)
rac := uc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
rac := uc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
ru, err := rac.UserAbout(user.Name)
if err != nil {