mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-22 19:57:43 +00:00
ch-ch-changes
This commit is contained in:
parent
c149606f24
commit
a1a098b448
14 changed files with 222 additions and 48 deletions
|
@ -11,10 +11,12 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/christianselig/apollo-backend/internal/domain"
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
|
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||||
)
|
)
|
||||||
|
|
||||||
type accountNotificationsRequest struct {
|
type accountNotificationsRequest struct {
|
||||||
Enabled bool
|
InboxNotifications bool `json:"inbox_notifications"`
|
||||||
|
WatcherNotifications bool `json:"watcher_notifications"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -42,7 +44,7 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.Enabled); err != nil {
|
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.InboxNotifications, anr.WatcherNotifications); err != nil {
|
||||||
a.errorResponse(w, r, 500, err.Error())
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -50,6 +52,37 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *api) getNotificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
apns := vars["apns"]
|
||||||
|
rid := vars["redditID"]
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inbox, watchers, err := a.deviceRepo.GetNotifiable(ctx, &dev, &acct)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
an := &accountNotificationsRequest{InboxNotifications: inbox, WatcherNotifications: watchers}
|
||||||
|
_ = json.NewEncoder(w).Encode(an)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
apns := vars["apns"]
|
apns := vars["apns"]
|
||||||
|
@ -108,7 +141,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
for _, acc := range raccs {
|
for _, acc := range raccs {
|
||||||
delete(accsMap, acc.NormalizedUsername())
|
delete(accsMap, acc.NormalizedUsername())
|
||||||
|
|
||||||
ac := a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
|
||||||
tokens, err := ac.RefreshTokens()
|
tokens, err := ac.RefreshTokens()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.errorResponse(w, r, 422, err.Error())
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
@ -120,7 +153,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
acc.RefreshToken = tokens.RefreshToken
|
acc.RefreshToken = tokens.RefreshToken
|
||||||
acc.AccessToken = tokens.AccessToken
|
acc.AccessToken = tokens.AccessToken
|
||||||
|
|
||||||
ac = a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
|
||||||
me, err := ac.Me()
|
me, err := ac.Me()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -167,7 +200,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Here we check whether the account is supplied with a valid token.
|
// Here we check whether the account is supplied with a valid token.
|
||||||
ac := a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken)
|
ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
|
||||||
tokens, err := ac.RefreshTokens()
|
tokens, err := ac.RefreshTokens()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.WithFields(logrus.Fields{
|
a.logger.WithFields(logrus.Fields{
|
||||||
|
@ -182,7 +215,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
acct.RefreshToken = tokens.RefreshToken
|
acct.RefreshToken = tokens.RefreshToken
|
||||||
acct.AccessToken = tokens.AccessToken
|
acct.AccessToken = tokens.AccessToken
|
||||||
|
|
||||||
ac = a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken)
|
ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
|
||||||
me, err := ac.Me()
|
me, err := ac.Me()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -3,11 +3,13 @@ package api
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DataDog/datadog-go/statsd"
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/jackc/pgx/v4/pgxpool"
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
"github.com/sideshow/apns2/token"
|
"github.com/sideshow/apns2/token"
|
||||||
|
@ -31,11 +33,12 @@ type api struct {
|
||||||
userRepo domain.UserRepository
|
userRepo domain.UserRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api {
|
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, redis *redis.Client, pool *pgxpool.Pool) *api {
|
||||||
reddit := reddit.NewClient(
|
reddit := reddit.NewClient(
|
||||||
os.Getenv("REDDIT_CLIENT_ID"),
|
os.Getenv("REDDIT_CLIENT_ID"),
|
||||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||||
statsd,
|
statsd,
|
||||||
|
redis,
|
||||||
16,
|
16,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -93,6 +96,7 @@ func (a *api) Routes() *mux.Router {
|
||||||
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
|
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
|
||||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
|
||||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH")
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH")
|
||||||
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.getNotificationsAccountHandler).Methods("GET")
|
||||||
|
|
||||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
|
||||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
|
||||||
|
@ -142,10 +146,19 @@ func (a *api) loggingMiddleware(next http.Handler) http.Handler {
|
||||||
// Call the next handler, which can be another middleware in the chain, or the final handler.
|
// Call the next handler, which can be another middleware in the chain, or the final handler.
|
||||||
next.ServeHTTP(lrw, r)
|
next.ServeHTTP(lrw, r)
|
||||||
|
|
||||||
|
remoteAddr := r.Header.Get("X-Forwarded-For")
|
||||||
|
if remoteAddr == "" {
|
||||||
|
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
|
||||||
|
remoteAddr = "unknown"
|
||||||
|
} else {
|
||||||
|
remoteAddr = ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logEntry := a.logger.WithFields(logrus.Fields{
|
logEntry := a.logger.WithFields(logrus.Fields{
|
||||||
"duration": time.Since(start).Milliseconds(),
|
"duration": time.Since(start).Milliseconds(),
|
||||||
"method": r.Method,
|
"method": r.Method,
|
||||||
"remote#addr": r.RemoteAddr,
|
"remote#addr": remoteAddr,
|
||||||
"response#bytes": lrw.bytes,
|
"response#bytes": lrw.bytes,
|
||||||
"status": lrw.statusCode,
|
"status": lrw.statusCode,
|
||||||
"uri": r.RequestURI,
|
"uri": r.RequestURI,
|
||||||
|
|
|
@ -7,8 +7,10 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/christianselig/apollo-backend/internal/domain"
|
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type watcherCriteria struct {
|
type watcherCriteria struct {
|
||||||
|
@ -28,6 +30,14 @@ type createWatcherRequest struct {
|
||||||
Criteria watcherCriteria
|
Criteria watcherCriteria
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cwr *createWatcherRequest) Validate() error {
|
||||||
|
return validation.ValidateStruct(cwr,
|
||||||
|
validation.Field(&cwr.Type, validation.Required),
|
||||||
|
validation.Field(&cwr.User, validation.Required.When(cwr.Type == "user")),
|
||||||
|
validation.Field(&cwr.Subreddit, validation.Required.When(cwr.Type == "subreddit" || cwr.Type == "trending")),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
type watcherCreatedResponse struct {
|
type watcherCreatedResponse struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
}
|
}
|
||||||
|
@ -47,6 +57,11 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := cwr.Validate(); err != nil {
|
||||||
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.errorResponse(w, r, 422, err.Error())
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
@ -73,7 +88,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||||
|
|
||||||
watcher := domain.Watcher{
|
watcher := domain.Watcher{
|
||||||
Label: cwr.Label,
|
Label: cwr.Label,
|
||||||
|
@ -220,11 +235,12 @@ type watcherItem struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
SourceLabel string `json:"source_label"`
|
SourceLabel string `json:"source_label"`
|
||||||
Upvotes int64 `json:"upvotes"`
|
Upvotes *int64 `json:"upvotes,omitempty"`
|
||||||
Keyword string `json:"keyword"`
|
Keyword string `json:"keyword,omitempty"`
|
||||||
Flair string `json:"flair"`
|
Flair string `json:"flair,omitempty"`
|
||||||
Domain string `json:"domain"`
|
Domain string `json:"domain,omitempty"`
|
||||||
Hits int64 `json:"hits"`
|
Hits int64 `json:"hits"`
|
||||||
|
Author string `json:"author,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -248,11 +264,15 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
Type: watcher.Type.String(),
|
Type: watcher.Type.String(),
|
||||||
Label: watcher.Label,
|
Label: watcher.Label,
|
||||||
SourceLabel: watcher.WatcheeLabel,
|
SourceLabel: watcher.WatcheeLabel,
|
||||||
Upvotes: watcher.Upvotes,
|
|
||||||
Keyword: watcher.Keyword,
|
Keyword: watcher.Keyword,
|
||||||
Flair: watcher.Flair,
|
Flair: watcher.Flair,
|
||||||
Domain: watcher.Domain,
|
Domain: watcher.Domain,
|
||||||
Hits: watcher.Hits,
|
Hits: watcher.Hits,
|
||||||
|
Author: watcher.Author,
|
||||||
|
}
|
||||||
|
|
||||||
|
if watcher.Upvotes != 0 {
|
||||||
|
wi.Upvotes = &watcher.Upvotes
|
||||||
}
|
}
|
||||||
|
|
||||||
wis[i] = wi
|
wis[i] = wi
|
||||||
|
|
|
@ -44,7 +44,7 @@ func APICmd(ctx context.Context) *cobra.Command {
|
||||||
}
|
}
|
||||||
defer redis.Close()
|
defer redis.Close()
|
||||||
|
|
||||||
api := api.NewAPI(ctx, logger, statsd, db)
|
api := api.NewAPI(ctx, logger, statsd, redis, db)
|
||||||
srv := api.Server(port)
|
srv := api.Server(port)
|
||||||
|
|
||||||
go func() { _ = srv.ListenAndServe() }()
|
go func() { _ = srv.ListenAndServe() }()
|
||||||
|
|
|
@ -27,14 +27,16 @@ func (dev *Device) Validate() error {
|
||||||
type DeviceRepository interface {
|
type DeviceRepository interface {
|
||||||
GetByID(ctx context.Context, id int64) (Device, error)
|
GetByID(ctx context.Context, id int64) (Device, error)
|
||||||
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
||||||
GetNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
|
GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||||
|
GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||||
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
|
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||||
|
|
||||||
CreateOrUpdate(ctx context.Context, dev *Device) error
|
CreateOrUpdate(ctx context.Context, dev *Device) error
|
||||||
Update(ctx context.Context, dev *Device) error
|
Update(ctx context.Context, dev *Device) error
|
||||||
Create(ctx context.Context, dev *Device) error
|
Create(ctx context.Context, dev *Device) error
|
||||||
Delete(ctx context.Context, token string) error
|
Delete(ctx context.Context, token string) error
|
||||||
SetNotifiable(ctx context.Context, dev *Device, acct *Account, notifiable bool) error
|
SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher bool) error
|
||||||
|
GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, error)
|
||||||
|
|
||||||
PruneStale(ctx context.Context, before int64) (int64, error)
|
PruneStale(ctx context.Context, before int64) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,18 +1,26 @@
|
||||||
package reddit
|
package reddit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DataDog/datadog-go/statsd"
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/valyala/fastjson"
|
"github.com/valyala/fastjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SkipRateLimiting = "<SKIP_RATE_LIMITING>"
|
||||||
|
RequestRemainingBuffer = 50
|
||||||
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
id string
|
id string
|
||||||
secret string
|
secret string
|
||||||
|
@ -20,6 +28,7 @@ type Client struct {
|
||||||
tracer *httptrace.ClientTrace
|
tracer *httptrace.ClientTrace
|
||||||
pool *fastjson.ParserPool
|
pool *fastjson.ParserPool
|
||||||
statsd statsd.ClientInterface
|
statsd statsd.ClientInterface
|
||||||
|
redis *redis.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
var backoffSchedule = []time.Duration{
|
var backoffSchedule = []time.Duration{
|
||||||
|
@ -51,7 +60,7 @@ func PostIDFromContext(context string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) *Client {
|
func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Client, connLimit int) *Client {
|
||||||
tracer := &httptrace.ClientTrace{
|
tracer := &httptrace.ClientTrace{
|
||||||
GotConn: func(info httptrace.GotConnInfo) {
|
GotConn: func(info httptrace.GotConnInfo) {
|
||||||
if info.Reused {
|
if info.Reused {
|
||||||
|
@ -84,25 +93,30 @@ func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int)
|
||||||
tracer,
|
tracer,
|
||||||
pool,
|
pool,
|
||||||
statsd,
|
statsd,
|
||||||
|
redis,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthenticatedClient struct {
|
type AuthenticatedClient struct {
|
||||||
*Client
|
*Client
|
||||||
|
|
||||||
|
redditId string
|
||||||
refreshToken string
|
refreshToken string
|
||||||
accessToken string
|
accessToken string
|
||||||
expiry *time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *AuthenticatedClient {
|
func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient {
|
||||||
return &AuthenticatedClient{rc, refreshToken, accessToken, nil}
|
if redditId == "" {
|
||||||
|
panic("requires a redditId")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *Client) doRequest(r *Request) ([]byte, error) {
|
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *Client) doRequest(r *Request) ([]byte, int, int, error) {
|
||||||
req, err := r.HTTPRequest()
|
req, err := r.HTTPRequest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
|
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
|
||||||
|
@ -117,34 +131,53 @@ func (rc *Client) doRequest(r *Request) ([]byte, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
||||||
return nil, ErrTimeout
|
return nil, 0, 0, ErrTimeout
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
remaining, err := strconv.Atoi(resp.Header.Get("x-ratelimit-remaining"))
|
||||||
|
if err != nil {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
reset, err := strconv.Atoi(resp.Header.Get("x-ratelimit-reset"))
|
||||||
|
if err != nil {
|
||||||
|
reset = 0
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
return nil, ServerError{resp.StatusCode}
|
return nil, remaining, reset, ServerError{resp.StatusCode}
|
||||||
}
|
}
|
||||||
|
|
||||||
bb, err := ioutil.ReadAll(resp.Body)
|
bb, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
return nil, err
|
return nil, remaining, reset, err
|
||||||
}
|
}
|
||||||
return bb, nil
|
return bb, remaining, reset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
||||||
bb, err := rac.doRequest(r)
|
if rl, err := rac.isRateLimited(); rl || err != nil {
|
||||||
|
return nil, ErrRateLimited
|
||||||
|
}
|
||||||
|
|
||||||
|
bb, remaining, reset, err := rac.doRequest(r)
|
||||||
|
|
||||||
|
if remaining <= RequestRemainingBuffer {
|
||||||
|
rac.markRateLimited(time.Duration(reset) * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil && r.retry {
|
if err != nil && r.retry {
|
||||||
for _, backoff := range backoffSchedule {
|
for _, backoff := range backoffSchedule {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
time.AfterFunc(backoff, func() {
|
time.AfterFunc(backoff, func() {
|
||||||
_ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1)
|
_ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1)
|
||||||
bb, err = rac.doRequest(r)
|
bb, remaining, reset, err = rac.doRequest(r)
|
||||||
done <- struct{}{}
|
done <- struct{}{}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -179,6 +212,31 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
|
||||||
return rh(val), nil
|
return rh(val), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rac *AuthenticatedClient) isRateLimited() (bool, error) {
|
||||||
|
if rac.redditId == SkipRateLimiting {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId)
|
||||||
|
res, err := rac.redis.Exists(context.Background(), key).Result()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return res > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *AuthenticatedClient) markRateLimited(duration time.Duration) error {
|
||||||
|
if rac.redditId == SkipRateLimiting {
|
||||||
|
return ErrRequiresRedditId
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId)
|
||||||
|
_, err := rac.redis.SetEX(context.Background(), key, true, duration).Result()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
||||||
req := NewRequest(
|
req := NewRequest(
|
||||||
WithTags([]string{"url:/api/v1/access_token"}),
|
WithTags([]string{"url:/api/v1/access_token"}),
|
||||||
|
|
|
@ -10,7 +10,7 @@ type ServerError struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (se ServerError) Error() string {
|
func (se ServerError) Error() string {
|
||||||
return fmt.Sprintf("errror from reddit: %d", se.StatusCode)
|
return fmt.Sprintf("error from reddit: %d", se.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -18,4 +18,8 @@ var (
|
||||||
ErrOauthRevoked = errors.New("oauth revoked")
|
ErrOauthRevoked = errors.New("oauth revoked")
|
||||||
// ErrTimeout .
|
// ErrTimeout .
|
||||||
ErrTimeout = errors.New("timeout")
|
ErrTimeout = errors.New("timeout")
|
||||||
|
// ErrRateLimited .
|
||||||
|
ErrRateLimited = errors.New("rate limited")
|
||||||
|
// ErrRequiresRedditId .
|
||||||
|
ErrRequiresRedditId = errors.New("requires reddit id")
|
||||||
)
|
)
|
||||||
|
|
|
@ -84,12 +84,22 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
|
||||||
return p.fetch(ctx, query, id)
|
return p.fetch(ctx, query, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *postgresDeviceRepository) GetNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT devices.id, apns_token, sandbox, active_until
|
SELECT devices.id, apns_token, sandbox, active_until
|
||||||
FROM devices
|
FROM devices
|
||||||
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||||
WHERE devices_accounts.account_id = $1 AND devices_accounts.notifiable = TRUE`
|
WHERE devices_accounts.account_id = $1 AND devices_accounts.inbox_notifiable = TRUE`
|
||||||
|
|
||||||
|
return p.fetch(ctx, query, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||||
|
query := `
|
||||||
|
SELECT devices.id, apns_token, sandbox, active_until
|
||||||
|
FROM devices
|
||||||
|
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||||
|
WHERE devices_accounts.account_id = $1 AND devices_accounts.watcher_notifiable = TRUE`
|
||||||
|
|
||||||
return p.fetch(ctx, query, id)
|
return p.fetch(ctx, query, id)
|
||||||
}
|
}
|
||||||
|
@ -160,13 +170,15 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, notifiable bool) error {
|
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, inbox, watcher bool) error {
|
||||||
query := `
|
query := `
|
||||||
UPDATE devices_accounts
|
UPDATE devices_accounts
|
||||||
SET notifiable = $1
|
SET
|
||||||
WHERE device_id = $2 AND account_id = $3`
|
inbox_notifiable = $1,
|
||||||
|
watcher_notifiable = $2,
|
||||||
|
WHERE device_id = $3 AND account_id = $4`
|
||||||
|
|
||||||
res, err := p.pool.Exec(ctx, query, notifiable, dev.ID, acct.ID)
|
res, err := p.pool.Exec(ctx, query, inbox, watcher, dev.ID, acct.ID)
|
||||||
|
|
||||||
if res.RowsAffected() != 1 {
|
if res.RowsAffected() != 1 {
|
||||||
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||||
|
@ -175,6 +187,28 @@ func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domai
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account) (bool, bool, error) {
|
||||||
|
query := `
|
||||||
|
SELECT inbox_notifiable, watcher_notifiable
|
||||||
|
FROM devices_accounts
|
||||||
|
WHERE device_id = $1 AND account_id = $2`
|
||||||
|
|
||||||
|
rows, err := p.pool.Query(ctx, query, dev.ID, acct.ID)
|
||||||
|
if err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var inbox, watcher bool
|
||||||
|
if err := rows.Scan(&inbox, &watcher); err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
return inbox, watcher, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, false, domain.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
|
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
|
||||||
query := `DELETE FROM devices WHERE active_until < $1`
|
query := `DELETE FROM devices WHERE active_until < $1`
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
|
||||||
&watcher.Device.APNSToken,
|
&watcher.Device.APNSToken,
|
||||||
&watcher.Device.Sandbox,
|
&watcher.Device.Sandbox,
|
||||||
&watcher.Account.ID,
|
&watcher.Account.ID,
|
||||||
|
&watcher.Account.AccountID,
|
||||||
&watcher.Account.AccessToken,
|
&watcher.Account.AccessToken,
|
||||||
&watcher.Account.RefreshToken,
|
&watcher.Account.RefreshToken,
|
||||||
&subredditLabel,
|
&subredditLabel,
|
||||||
|
@ -92,6 +93,7 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
|
||||||
devices.apns_token,
|
devices.apns_token,
|
||||||
devices.sandbox,
|
devices.sandbox,
|
||||||
accounts.id,
|
accounts.id,
|
||||||
|
accounts.account_id,
|
||||||
accounts.access_token,
|
accounts.access_token,
|
||||||
accounts.refresh_token,
|
accounts.refresh_token,
|
||||||
COALESCE(subreddits.name, '') AS subreddit_label,
|
COALESCE(subreddits.name, '') AS subreddit_label,
|
||||||
|
@ -136,6 +138,7 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
|
||||||
devices.apns_token,
|
devices.apns_token,
|
||||||
devices.sandbox,
|
devices.sandbox,
|
||||||
accounts.id,
|
accounts.id,
|
||||||
|
accounts.account_id,
|
||||||
accounts.access_token,
|
accounts.access_token,
|
||||||
accounts.refresh_token,
|
accounts.refresh_token,
|
||||||
COALESCE(subreddits.name, '') AS subreddit_label,
|
COALESCE(subreddits.name, '') AS subreddit_label,
|
||||||
|
@ -143,9 +146,10 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
|
||||||
FROM watchers
|
FROM watchers
|
||||||
INNER JOIN devices ON watchers.device_id = devices.id
|
INNER JOIN devices ON watchers.device_id = devices.id
|
||||||
INNER JOIN accounts ON watchers.account_id = accounts.id
|
INNER JOIN accounts ON watchers.account_id = accounts.id
|
||||||
|
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id AND accounts.id = devices_accounts.account_id
|
||||||
LEFT JOIN subreddits ON watchers.type IN(0,2) AND watchers.watchee_id = subreddits.id
|
LEFT JOIN subreddits ON watchers.type IN(0,2) AND watchers.watchee_id = subreddits.id
|
||||||
LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id
|
LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id
|
||||||
WHERE watchers.type = $1 AND watchers.watchee_id = $2`
|
WHERE watchers.type = $1 AND watchers.watchee_id = $2 AND devices_accounts.watcher_notifiable = TRUE`
|
||||||
|
|
||||||
return p.fetch(ctx, query, typ, id)
|
return p.fetch(ctx, query, typ, id)
|
||||||
}
|
}
|
||||||
|
@ -184,6 +188,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
|
||||||
devices.apns_token,
|
devices.apns_token,
|
||||||
devices.sandbox,
|
devices.sandbox,
|
||||||
accounts.id,
|
accounts.id,
|
||||||
|
accounts.account_id,
|
||||||
accounts.access_token,
|
accounts.access_token,
|
||||||
accounts.refresh_token,
|
accounts.refresh_token,
|
||||||
COALESCE(subreddits.name, '') AS subreddit_label,
|
COALESCE(subreddits.name, '') AS subreddit_label,
|
||||||
|
|
|
@ -47,6 +47,7 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg
|
||||||
os.Getenv("REDDIT_CLIENT_ID"),
|
os.Getenv("REDDIT_CLIENT_ID"),
|
||||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||||
statsd,
|
statsd,
|
||||||
|
redis,
|
||||||
consumers,
|
consumers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -182,7 +183,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||||
if account.ExpiresAt < int64(now) {
|
if account.ExpiresAt < int64(now) {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"account#username": account.NormalizedUsername(),
|
"account#username": account.NormalizedUsername(),
|
||||||
|
@ -215,7 +216,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
account.ExpiresAt = int64(now + 3540)
|
account.ExpiresAt = int64(now + 3540)
|
||||||
|
|
||||||
// Refresh client
|
// Refresh client
|
||||||
rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken)
|
rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
|
||||||
|
|
||||||
if err = nc.accountRepo.Update(ctx, &account); err != nil {
|
if err = nc.accountRepo.Update(ctx, &account); err != nil {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
|
@ -304,7 +305,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
devices, err := nc.deviceRepo.GetNotifiableByAccountID(ctx, account.ID)
|
devices, err := nc.deviceRepo.GetInboxNotifiableByAccountID(ctx, account.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"account#username": account.NormalizedUsername(),
|
"account#username": account.NormalizedUsername(),
|
||||||
|
|
|
@ -35,6 +35,7 @@ func NewStuckNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, d
|
||||||
os.Getenv("REDDIT_CLIENT_ID"),
|
os.Getenv("REDDIT_CLIENT_ID"),
|
||||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||||
statsd,
|
statsd,
|
||||||
|
redis,
|
||||||
consumers,
|
consumers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -132,7 +133,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rac := snc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
rac := snc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||||
|
|
||||||
snc.logger.WithFields(logrus.Fields{
|
snc.logger.WithFields(logrus.Fields{
|
||||||
"account#username": account.NormalizedUsername(),
|
"account#username": account.NormalizedUsername(),
|
||||||
|
|
|
@ -47,6 +47,7 @@ func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpo
|
||||||
os.Getenv("REDDIT_CLIENT_ID"),
|
os.Getenv("REDDIT_CLIENT_ID"),
|
||||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||||
statsd,
|
statsd,
|
||||||
|
redis,
|
||||||
consumers,
|
consumers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -199,7 +200,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
watcher := watchers[i]
|
watcher := watchers[i]
|
||||||
|
|
||||||
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
|
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
|
||||||
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
||||||
|
|
||||||
sps, err := rac.SubredditNew(
|
sps, err := rac.SubredditNew(
|
||||||
subreddit.Name,
|
subreddit.Name,
|
||||||
|
@ -264,7 +265,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
watcher := watchers[i]
|
watcher := watchers[i]
|
||||||
|
|
||||||
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
|
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
|
||||||
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
||||||
sps, err := rac.SubredditHot(
|
sps, err := rac.SubredditHot(
|
||||||
subreddit.Name,
|
subreddit.Name,
|
||||||
reddit.WithQuery("limit", "100"),
|
reddit.WithQuery("limit", "100"),
|
||||||
|
|
|
@ -45,6 +45,7 @@ func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool
|
||||||
os.Getenv("REDDIT_CLIENT_ID"),
|
os.Getenv("REDDIT_CLIENT_ID"),
|
||||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||||
statsd,
|
statsd,
|
||||||
|
redis,
|
||||||
consumers,
|
consumers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -176,7 +177,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
||||||
// Grab last month's top posts so we calculate a trending average
|
// Grab last month's top posts so we calculate a trending average
|
||||||
i := rand.Intn(len(watchers))
|
i := rand.Intn(len(watchers))
|
||||||
watcher := watchers[i]
|
watcher := watchers[i]
|
||||||
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||||
|
|
||||||
tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week"))
|
tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -217,7 +218,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
||||||
// Grab hot posts and filter out anything that's > 2 days old
|
// Grab hot posts and filter out anything that's > 2 days old
|
||||||
i = rand.Intn(len(watchers))
|
i = rand.Intn(len(watchers))
|
||||||
watcher = watchers[i]
|
watcher = watchers[i]
|
||||||
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||||
hps, err := rac.SubredditHot(subreddit.Name)
|
hps, err := rac.SubredditHot(subreddit.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tc.logger.WithFields(logrus.Fields{
|
tc.logger.WithFields(logrus.Fields{
|
||||||
|
|
|
@ -46,6 +46,7 @@ func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Po
|
||||||
os.Getenv("REDDIT_CLIENT_ID"),
|
os.Getenv("REDDIT_CLIENT_ID"),
|
||||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||||
statsd,
|
statsd,
|
||||||
|
redis,
|
||||||
consumers,
|
consumers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -180,7 +181,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
|
||||||
watcher := watchers[i]
|
watcher := watchers[i]
|
||||||
|
|
||||||
acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID)
|
acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID)
|
||||||
rac := uc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
rac := uc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
||||||
|
|
||||||
ru, err := rac.UserAbout(user.Name)
|
ru, err := rac.UserAbout(user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in a new issue