diff --git a/internal/api/accounts.go b/internal/api/accounts.go index 9381d2f..85f1341 100644 --- a/internal/api/accounts.go +++ b/internal/api/accounts.go @@ -11,10 +11,12 @@ import ( "github.com/sirupsen/logrus" "github.com/christianselig/apollo-backend/internal/domain" + "github.com/christianselig/apollo-backend/internal/reddit" ) type accountNotificationsRequest struct { - Enabled bool + InboxNotifications bool `json:"inbox_notifications"` + WatcherNotifications bool `json:"watcher_notifications"` } func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) { @@ -42,7 +44,7 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request return } - if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.Enabled); err != nil { + if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.InboxNotifications, anr.WatcherNotifications); err != nil { a.errorResponse(w, r, 500, err.Error()) return } @@ -50,6 +52,37 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request w.WriteHeader(http.StatusOK) } +func (a *api) getNotificationsAccountHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + apns := vars["apns"] + rid := vars["redditID"] + + ctx := context.Background() + + dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) + if err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + acct, err := a.accountRepo.GetByRedditID(ctx, rid) + if err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + inbox, watchers, err := a.deviceRepo.GetNotifiable(ctx, &dev, &acct) + if err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + w.WriteHeader(http.StatusOK) + + an := &accountNotificationsRequest{InboxNotifications: inbox, WatcherNotifications: watchers} + _ = json.NewEncoder(w).Encode(an) +} + func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) apns := vars["apns"] @@ -108,7 +141,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { for _, acc := range raccs { delete(accsMap, acc.NormalizedUsername()) - ac := a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) + ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken) tokens, err := ac.RefreshTokens() if err != nil { a.errorResponse(w, r, 422, err.Error()) @@ -120,7 +153,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { acc.RefreshToken = tokens.RefreshToken acc.AccessToken = tokens.AccessToken - ac = a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) + ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken) me, err := ac.Me() if err != nil { @@ -167,7 +200,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { } // Here we check whether the account is supplied with a valid token. - ac := a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken) + ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken) tokens, err := ac.RefreshTokens() if err != nil { a.logger.WithFields(logrus.Fields{ @@ -182,7 +215,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { acct.RefreshToken = tokens.RefreshToken acct.AccessToken = tokens.AccessToken - ac = a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken) + ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken) me, err := ac.Me() if err != nil { diff --git a/internal/api/api.go b/internal/api/api.go index c505fc9..2796800 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -3,11 +3,13 @@ package api import ( "context" "fmt" + "net" "net/http" "os" "time" "github.com/DataDog/datadog-go/statsd" + "github.com/go-redis/redis/v8" "github.com/gorilla/mux" "github.com/jackc/pgx/v4/pgxpool" "github.com/sideshow/apns2/token" @@ -31,11 +33,12 @@ type api struct { userRepo domain.UserRepository } -func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api { +func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, redis *redis.Client, pool *pgxpool.Pool) *api { reddit := reddit.NewClient( os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_SECRET"), statsd, + redis, 16, ) @@ -93,6 +96,7 @@ func (a *api) Routes() *mux.Router { r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE") r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH") + r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.getNotificationsAccountHandler).Methods("GET") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE") @@ -142,10 +146,19 @@ func (a *api) loggingMiddleware(next http.Handler) http.Handler { // Call the next handler, which can be another middleware in the chain, or the final handler. next.ServeHTTP(lrw, r) + remoteAddr := r.Header.Get("X-Forwarded-For") + if remoteAddr == "" { + if ip, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { + remoteAddr = "unknown" + } else { + remoteAddr = ip + } + } + logEntry := a.logger.WithFields(logrus.Fields{ "duration": time.Since(start).Milliseconds(), "method": r.Method, - "remote#addr": r.RemoteAddr, + "remote#addr": remoteAddr, "response#bytes": lrw.bytes, "status": lrw.statusCode, "uri": r.RequestURI, diff --git a/internal/api/watcher.go b/internal/api/watcher.go index e2e3226..7748ce2 100644 --- a/internal/api/watcher.go +++ b/internal/api/watcher.go @@ -7,8 +7,10 @@ import ( "strconv" "strings" - "github.com/christianselig/apollo-backend/internal/domain" + validation "github.com/go-ozzo/ozzo-validation/v4" "github.com/gorilla/mux" + + "github.com/christianselig/apollo-backend/internal/domain" ) type watcherCriteria struct { @@ -28,6 +30,14 @@ type createWatcherRequest struct { Criteria watcherCriteria } +func (cwr *createWatcherRequest) Validate() error { + return validation.ValidateStruct(cwr, + validation.Field(&cwr.Type, validation.Required), + validation.Field(&cwr.User, validation.Required.When(cwr.Type == "user")), + validation.Field(&cwr.Subreddit, validation.Required.When(cwr.Type == "subreddit" || cwr.Type == "trending")), + ) +} + type watcherCreatedResponse struct { ID int64 `json:"id"` } @@ -47,6 +57,11 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { return } + if err := cwr.Validate(); err != nil { + a.errorResponse(w, r, 422, err.Error()) + return + } + dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) if err != nil { a.errorResponse(w, r, 422, err.Error()) @@ -73,7 +88,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { return } - ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) + ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) watcher := domain.Watcher{ Label: cwr.Label, @@ -220,11 +235,12 @@ type watcherItem struct { Type string `json:"type"` Label string `json:"label"` SourceLabel string `json:"source_label"` - Upvotes int64 `json:"upvotes"` - Keyword string `json:"keyword"` - Flair string `json:"flair"` - Domain string `json:"domain"` + Upvotes *int64 `json:"upvotes,omitempty"` + Keyword string `json:"keyword,omitempty"` + Flair string `json:"flair,omitempty"` + Domain string `json:"domain,omitempty"` Hits int64 `json:"hits"` + Author string `json:"author,omitempty"` } func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) { @@ -248,11 +264,15 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) { Type: watcher.Type.String(), Label: watcher.Label, SourceLabel: watcher.WatcheeLabel, - Upvotes: watcher.Upvotes, Keyword: watcher.Keyword, Flair: watcher.Flair, Domain: watcher.Domain, Hits: watcher.Hits, + Author: watcher.Author, + } + + if watcher.Upvotes != 0 { + wi.Upvotes = &watcher.Upvotes } wis[i] = wi diff --git a/internal/cmd/api.go b/internal/cmd/api.go index 46f0ad4..1760a81 100644 --- a/internal/cmd/api.go +++ b/internal/cmd/api.go @@ -44,7 +44,7 @@ func APICmd(ctx context.Context) *cobra.Command { } defer redis.Close() - api := api.NewAPI(ctx, logger, statsd, db) + api := api.NewAPI(ctx, logger, statsd, redis, db) srv := api.Server(port) go func() { _ = srv.ListenAndServe() }() diff --git a/internal/domain/device.go b/internal/domain/device.go index 34e7132..62c3420 100644 --- a/internal/domain/device.go +++ b/internal/domain/device.go @@ -27,14 +27,16 @@ func (dev *Device) Validate() error { type DeviceRepository interface { GetByID(ctx context.Context, id int64) (Device, error) GetByAPNSToken(ctx context.Context, token string) (Device, error) - GetNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error) + GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error) + GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error) GetByAccountID(ctx context.Context, id int64) ([]Device, error) CreateOrUpdate(ctx context.Context, dev *Device) error Update(ctx context.Context, dev *Device) error Create(ctx context.Context, dev *Device) error Delete(ctx context.Context, token string) error - SetNotifiable(ctx context.Context, dev *Device, acct *Account, notifiable bool) error + SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher bool) error + GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, error) PruneStale(ctx context.Context, before int64) (int64, error) } diff --git a/internal/reddit/client.go b/internal/reddit/client.go index cbb792b..f10e949 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -1,18 +1,26 @@ package reddit import ( + "context" "fmt" "io/ioutil" "net/http" "net/http/httptrace" "regexp" + "strconv" "strings" "time" "github.com/DataDog/datadog-go/statsd" + "github.com/go-redis/redis/v8" "github.com/valyala/fastjson" ) +const ( + SkipRateLimiting = "" + RequestRemainingBuffer = 50 +) + type Client struct { id string secret string @@ -20,6 +28,7 @@ type Client struct { tracer *httptrace.ClientTrace pool *fastjson.ParserPool statsd statsd.ClientInterface + redis *redis.Client } var backoffSchedule = []time.Duration{ @@ -51,7 +60,7 @@ func PostIDFromContext(context string) string { return "" } -func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) *Client { +func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Client, connLimit int) *Client { tracer := &httptrace.ClientTrace{ GotConn: func(info httptrace.GotConnInfo) { if info.Reused { @@ -84,25 +93,30 @@ func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) tracer, pool, statsd, + redis, } } type AuthenticatedClient struct { *Client + redditId string refreshToken string accessToken string - expiry *time.Time } -func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *AuthenticatedClient { - return &AuthenticatedClient{rc, refreshToken, accessToken, nil} +func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient { + if redditId == "" { + panic("requires a redditId") + } + + return &AuthenticatedClient{rc, redditId, refreshToken, accessToken} } -func (rc *Client) doRequest(r *Request) ([]byte, error) { +func (rc *Client) doRequest(r *Request) ([]byte, int, int, error) { req, err := r.HTTPRequest() if err != nil { - return nil, err + return nil, 0, 0, err } req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer)) @@ -117,34 +131,53 @@ func (rc *Client) doRequest(r *Request) ([]byte, error) { if err != nil { _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) if strings.Contains(err.Error(), "http2: timeout awaiting response headers") { - return nil, ErrTimeout + return nil, 0, 0, ErrTimeout } - return nil, err + return nil, 0, 0, err } defer resp.Body.Close() + remaining, err := strconv.Atoi(resp.Header.Get("x-ratelimit-remaining")) + if err != nil { + remaining = 0 + } + + reset, err := strconv.Atoi(resp.Header.Get("x-ratelimit-reset")) + if err != nil { + reset = 0 + } + if resp.StatusCode != 200 { _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) - return nil, ServerError{resp.StatusCode} + return nil, remaining, reset, ServerError{resp.StatusCode} } bb, err := ioutil.ReadAll(resp.Body) if err != nil { _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) - return nil, err + return nil, remaining, reset, err } - return bb, nil + return bb, remaining, reset, nil } func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) { - bb, err := rac.doRequest(r) + if rl, err := rac.isRateLimited(); rl || err != nil { + return nil, ErrRateLimited + } + + bb, remaining, reset, err := rac.doRequest(r) + + if remaining <= RequestRemainingBuffer { + rac.markRateLimited(time.Duration(reset) * time.Second) + } + if err != nil && r.retry { for _, backoff := range backoffSchedule { done := make(chan struct{}) time.AfterFunc(backoff, func() { _ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1) - bb, err = rac.doRequest(r) + bb, remaining, reset, err = rac.doRequest(r) done <- struct{}{} }) @@ -179,6 +212,31 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in return rh(val), nil } +func (rac *AuthenticatedClient) isRateLimited() (bool, error) { + if rac.redditId == SkipRateLimiting { + return false, nil + } + + key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) + res, err := rac.redis.Exists(context.Background(), key).Result() + + if err != nil { + return false, err + } + + return res > 0, nil +} + +func (rac *AuthenticatedClient) markRateLimited(duration time.Duration) error { + if rac.redditId == SkipRateLimiting { + return ErrRequiresRedditId + } + + key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) + _, err := rac.redis.SetEX(context.Background(), key, true, duration).Result() + return err +} + func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { req := NewRequest( WithTags([]string{"url:/api/v1/access_token"}), diff --git a/internal/reddit/errors.go b/internal/reddit/errors.go index cb605df..d0a6959 100644 --- a/internal/reddit/errors.go +++ b/internal/reddit/errors.go @@ -10,7 +10,7 @@ type ServerError struct { } func (se ServerError) Error() string { - return fmt.Sprintf("errror from reddit: %d", se.StatusCode) + return fmt.Sprintf("error from reddit: %d", se.StatusCode) } var ( @@ -18,4 +18,8 @@ var ( ErrOauthRevoked = errors.New("oauth revoked") // ErrTimeout . ErrTimeout = errors.New("timeout") + // ErrRateLimited . + ErrRateLimited = errors.New("rate limited") + // ErrRequiresRedditId . + ErrRequiresRedditId = errors.New("requires reddit id") ) diff --git a/internal/repository/postgres_device.go b/internal/repository/postgres_device.go index 99f0d07..c7b6594 100644 --- a/internal/repository/postgres_device.go +++ b/internal/repository/postgres_device.go @@ -84,12 +84,22 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) return p.fetch(ctx, query, id) } -func (p *postgresDeviceRepository) GetNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { +func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { query := ` SELECT devices.id, apns_token, sandbox, active_until FROM devices INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id - WHERE devices_accounts.account_id = $1 AND devices_accounts.notifiable = TRUE` + WHERE devices_accounts.account_id = $1 AND devices_accounts.inbox_notifiable = TRUE` + + return p.fetch(ctx, query, id) +} + +func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { + query := ` + SELECT devices.id, apns_token, sandbox, active_until + FROM devices + INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id + WHERE devices_accounts.account_id = $1 AND devices_accounts.watcher_notifiable = TRUE` return p.fetch(ctx, query, id) } @@ -160,13 +170,15 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err return err } -func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, notifiable bool) error { +func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, inbox, watcher bool) error { query := ` UPDATE devices_accounts - SET notifiable = $1 - WHERE device_id = $2 AND account_id = $3` + SET + inbox_notifiable = $1, + watcher_notifiable = $2, + WHERE device_id = $3 AND account_id = $4` - res, err := p.pool.Exec(ctx, query, notifiable, dev.ID, acct.ID) + res, err := p.pool.Exec(ctx, query, inbox, watcher, dev.ID, acct.ID) if res.RowsAffected() != 1 { return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) @@ -175,6 +187,28 @@ func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domai } +func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account) (bool, bool, error) { + query := ` + SELECT inbox_notifiable, watcher_notifiable + FROM devices_accounts + WHERE device_id = $1 AND account_id = $2` + + rows, err := p.pool.Query(ctx, query, dev.ID, acct.ID) + if err != nil { + return false, false, err + } + defer rows.Close() + for rows.Next() { + var inbox, watcher bool + if err := rows.Scan(&inbox, &watcher); err != nil { + return false, false, err + } + return inbox, watcher, nil + } + + return false, false, domain.ErrNotFound +} + func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) { query := `DELETE FROM devices WHERE active_until < $1` diff --git a/internal/repository/postgres_watcher.go b/internal/repository/postgres_watcher.go index d6f9462..a841e77 100644 --- a/internal/repository/postgres_watcher.go +++ b/internal/repository/postgres_watcher.go @@ -50,6 +50,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg &watcher.Device.APNSToken, &watcher.Device.Sandbox, &watcher.Account.ID, + &watcher.Account.AccountID, &watcher.Account.AccessToken, &watcher.Account.RefreshToken, &subredditLabel, @@ -92,6 +93,7 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma devices.apns_token, devices.sandbox, accounts.id, + accounts.account_id, accounts.access_token, accounts.refresh_token, COALESCE(subreddits.name, '') AS subreddit_label, @@ -136,6 +138,7 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t devices.apns_token, devices.sandbox, accounts.id, + accounts.account_id, accounts.access_token, accounts.refresh_token, COALESCE(subreddits.name, '') AS subreddit_label, @@ -143,9 +146,10 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t FROM watchers INNER JOIN devices ON watchers.device_id = devices.id INNER JOIN accounts ON watchers.account_id = accounts.id + INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id AND accounts.id = devices_accounts.account_id LEFT JOIN subreddits ON watchers.type IN(0,2) AND watchers.watchee_id = subreddits.id LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id - WHERE watchers.type = $1 AND watchers.watchee_id = $2` + WHERE watchers.type = $1 AND watchers.watchee_id = $2 AND devices_accounts.watcher_notifiable = TRUE` return p.fetch(ctx, query, typ, id) } @@ -184,6 +188,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c devices.apns_token, devices.sandbox, accounts.id, + accounts.account_id, accounts.access_token, accounts.refresh_token, COALESCE(subreddits.name, '') AS subreddit_label, diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index 94df6f3..0da3249 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -47,6 +47,7 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_SECRET"), statsd, + redis, consumers, ) @@ -182,7 +183,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { return } - rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) + rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) if account.ExpiresAt < int64(now) { nc.logger.WithFields(logrus.Fields{ "account#username": account.NormalizedUsername(), @@ -215,7 +216,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { account.ExpiresAt = int64(now + 3540) // Refresh client - rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken) + rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken) if err = nc.accountRepo.Update(ctx, &account); err != nil { nc.logger.WithFields(logrus.Fields{ @@ -304,7 +305,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { return } - devices, err := nc.deviceRepo.GetNotifiableByAccountID(ctx, account.ID) + devices, err := nc.deviceRepo.GetInboxNotifiableByAccountID(ctx, account.ID) if err != nil { nc.logger.WithFields(logrus.Fields{ "account#username": account.NormalizedUsername(), diff --git a/internal/worker/stuck_notifications.go b/internal/worker/stuck_notifications.go index cd88fb7..d87f190 100644 --- a/internal/worker/stuck_notifications.go +++ b/internal/worker/stuck_notifications.go @@ -35,6 +35,7 @@ func NewStuckNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, d os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_SECRET"), statsd, + redis, consumers, ) @@ -132,7 +133,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) { return } - rac := snc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) + rac := snc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) snc.logger.WithFields(logrus.Fields{ "account#username": account.NormalizedUsername(), diff --git a/internal/worker/subreddits.go b/internal/worker/subreddits.go index 6bd14f8..d5747ff 100644 --- a/internal/worker/subreddits.go +++ b/internal/worker/subreddits.go @@ -47,6 +47,7 @@ func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpo os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_SECRET"), statsd, + redis, consumers, ) @@ -199,7 +200,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { watcher := watchers[i] acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID) - rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) + rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken) sps, err := rac.SubredditNew( subreddit.Name, @@ -264,7 +265,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { watcher := watchers[i] acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID) - rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) + rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken) sps, err := rac.SubredditHot( subreddit.Name, reddit.WithQuery("limit", "100"), diff --git a/internal/worker/trending.go b/internal/worker/trending.go index db86181..b27a3bd 100644 --- a/internal/worker/trending.go +++ b/internal/worker/trending.go @@ -45,6 +45,7 @@ func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_SECRET"), statsd, + redis, consumers, ) @@ -176,7 +177,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { // Grab last month's top posts so we calculate a trending average i := rand.Intn(len(watchers)) watcher := watchers[i] - rac := tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken) + rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week")) if err != nil { @@ -217,7 +218,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { // Grab hot posts and filter out anything that's > 2 days old i = rand.Intn(len(watchers)) watcher = watchers[i] - rac = tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken) + rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) hps, err := rac.SubredditHot(subreddit.Name) if err != nil { tc.logger.WithFields(logrus.Fields{ diff --git a/internal/worker/users.go b/internal/worker/users.go index ecb8ae8..ce1ffc8 100644 --- a/internal/worker/users.go +++ b/internal/worker/users.go @@ -46,6 +46,7 @@ func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Po os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_SECRET"), statsd, + redis, consumers, ) @@ -180,7 +181,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { watcher := watchers[i] acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID) - rac := uc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) + rac := uc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken) ru, err := rac.UserAbout(user.Name) if err != nil {