ch-ch-changes

This commit is contained in:
Andre Medeiros 2022-03-12 12:50:05 -05:00
parent c149606f24
commit a1a098b448
14 changed files with 222 additions and 48 deletions

View file

@ -11,10 +11,12 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/domain" "github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit"
) )
type accountNotificationsRequest struct { type accountNotificationsRequest struct {
Enabled bool InboxNotifications bool `json:"inbox_notifications"`
WatcherNotifications bool `json:"watcher_notifications"`
} }
func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) { func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
@ -42,7 +44,7 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request
return return
} }
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.Enabled); err != nil { if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.InboxNotifications, anr.WatcherNotifications); err != nil {
a.errorResponse(w, r, 500, err.Error()) a.errorResponse(w, r, 500, err.Error())
return return
} }
@ -50,6 +52,37 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func (a *api) getNotificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
apns := vars["apns"]
rid := vars["redditID"]
ctx := context.Background()
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
inbox, watchers, err := a.deviceRepo.GetNotifiable(ctx, &dev, &acct)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
w.WriteHeader(http.StatusOK)
an := &accountNotificationsRequest{InboxNotifications: inbox, WatcherNotifications: watchers}
_ = json.NewEncoder(w).Encode(an)
}
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) { func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
apns := vars["apns"] apns := vars["apns"]
@ -108,7 +141,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
for _, acc := range raccs { for _, acc := range raccs {
delete(accsMap, acc.NormalizedUsername()) delete(accsMap, acc.NormalizedUsername())
ac := a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
tokens, err := ac.RefreshTokens() tokens, err := ac.RefreshTokens()
if err != nil { if err != nil {
a.errorResponse(w, r, 422, err.Error()) a.errorResponse(w, r, 422, err.Error())
@ -120,7 +153,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
acc.RefreshToken = tokens.RefreshToken acc.RefreshToken = tokens.RefreshToken
acc.AccessToken = tokens.AccessToken acc.AccessToken = tokens.AccessToken
ac = a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
me, err := ac.Me() me, err := ac.Me()
if err != nil { if err != nil {
@ -167,7 +200,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
} }
// Here we check whether the account is supplied with a valid token. // Here we check whether the account is supplied with a valid token.
ac := a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken) ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
tokens, err := ac.RefreshTokens() tokens, err := ac.RefreshTokens()
if err != nil { if err != nil {
a.logger.WithFields(logrus.Fields{ a.logger.WithFields(logrus.Fields{
@ -182,7 +215,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
acct.RefreshToken = tokens.RefreshToken acct.RefreshToken = tokens.RefreshToken
acct.AccessToken = tokens.AccessToken acct.AccessToken = tokens.AccessToken
ac = a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken) ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
me, err := ac.Me() me, err := ac.Me()
if err != nil { if err != nil {

View file

@ -3,11 +3,13 @@ package api
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/http" "net/http"
"os" "os"
"time" "time"
"github.com/DataDog/datadog-go/statsd" "github.com/DataDog/datadog-go/statsd"
"github.com/go-redis/redis/v8"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
"github.com/sideshow/apns2/token" "github.com/sideshow/apns2/token"
@ -31,11 +33,12 @@ type api struct {
userRepo domain.UserRepository userRepo domain.UserRepository
} }
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api { func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, redis *redis.Client, pool *pgxpool.Pool) *api {
reddit := reddit.NewClient( reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"), os.Getenv("REDDIT_CLIENT_SECRET"),
statsd, statsd,
redis,
16, 16,
) )
@ -93,6 +96,7 @@ func (a *api) Routes() *mux.Router {
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE") r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH") r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.getNotificationsAccountHandler).Methods("GET")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
@ -142,10 +146,19 @@ func (a *api) loggingMiddleware(next http.Handler) http.Handler {
// Call the next handler, which can be another middleware in the chain, or the final handler. // Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(lrw, r) next.ServeHTTP(lrw, r)
remoteAddr := r.Header.Get("X-Forwarded-For")
if remoteAddr == "" {
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
remoteAddr = "unknown"
} else {
remoteAddr = ip
}
}
logEntry := a.logger.WithFields(logrus.Fields{ logEntry := a.logger.WithFields(logrus.Fields{
"duration": time.Since(start).Milliseconds(), "duration": time.Since(start).Milliseconds(),
"method": r.Method, "method": r.Method,
"remote#addr": r.RemoteAddr, "remote#addr": remoteAddr,
"response#bytes": lrw.bytes, "response#bytes": lrw.bytes,
"status": lrw.statusCode, "status": lrw.statusCode,
"uri": r.RequestURI, "uri": r.RequestURI,

View file

@ -7,8 +7,10 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/christianselig/apollo-backend/internal/domain" validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/christianselig/apollo-backend/internal/domain"
) )
type watcherCriteria struct { type watcherCriteria struct {
@ -28,6 +30,14 @@ type createWatcherRequest struct {
Criteria watcherCriteria Criteria watcherCriteria
} }
func (cwr *createWatcherRequest) Validate() error {
return validation.ValidateStruct(cwr,
validation.Field(&cwr.Type, validation.Required),
validation.Field(&cwr.User, validation.Required.When(cwr.Type == "user")),
validation.Field(&cwr.Subreddit, validation.Required.When(cwr.Type == "subreddit" || cwr.Type == "trending")),
)
}
type watcherCreatedResponse struct { type watcherCreatedResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
} }
@ -47,6 +57,11 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := cwr.Validate(); err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil { if err != nil {
a.errorResponse(w, r, 422, err.Error()) a.errorResponse(w, r, 422, err.Error())
@ -73,7 +88,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
watcher := domain.Watcher{ watcher := domain.Watcher{
Label: cwr.Label, Label: cwr.Label,
@ -220,11 +235,12 @@ type watcherItem struct {
Type string `json:"type"` Type string `json:"type"`
Label string `json:"label"` Label string `json:"label"`
SourceLabel string `json:"source_label"` SourceLabel string `json:"source_label"`
Upvotes int64 `json:"upvotes"` Upvotes *int64 `json:"upvotes,omitempty"`
Keyword string `json:"keyword"` Keyword string `json:"keyword,omitempty"`
Flair string `json:"flair"` Flair string `json:"flair,omitempty"`
Domain string `json:"domain"` Domain string `json:"domain,omitempty"`
Hits int64 `json:"hits"` Hits int64 `json:"hits"`
Author string `json:"author,omitempty"`
} }
func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) { func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
@ -248,11 +264,15 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
Type: watcher.Type.String(), Type: watcher.Type.String(),
Label: watcher.Label, Label: watcher.Label,
SourceLabel: watcher.WatcheeLabel, SourceLabel: watcher.WatcheeLabel,
Upvotes: watcher.Upvotes,
Keyword: watcher.Keyword, Keyword: watcher.Keyword,
Flair: watcher.Flair, Flair: watcher.Flair,
Domain: watcher.Domain, Domain: watcher.Domain,
Hits: watcher.Hits, Hits: watcher.Hits,
Author: watcher.Author,
}
if watcher.Upvotes != 0 {
wi.Upvotes = &watcher.Upvotes
} }
wis[i] = wi wis[i] = wi

View file

@ -44,7 +44,7 @@ func APICmd(ctx context.Context) *cobra.Command {
} }
defer redis.Close() defer redis.Close()
api := api.NewAPI(ctx, logger, statsd, db) api := api.NewAPI(ctx, logger, statsd, redis, db)
srv := api.Server(port) srv := api.Server(port)
go func() { _ = srv.ListenAndServe() }() go func() { _ = srv.ListenAndServe() }()

View file

@ -27,14 +27,16 @@ func (dev *Device) Validate() error {
type DeviceRepository interface { type DeviceRepository interface {
GetByID(ctx context.Context, id int64) (Device, error) GetByID(ctx context.Context, id int64) (Device, error)
GetByAPNSToken(ctx context.Context, token string) (Device, error) GetByAPNSToken(ctx context.Context, token string) (Device, error)
GetNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
GetByAccountID(ctx context.Context, id int64) ([]Device, error) GetByAccountID(ctx context.Context, id int64) ([]Device, error)
CreateOrUpdate(ctx context.Context, dev *Device) error CreateOrUpdate(ctx context.Context, dev *Device) error
Update(ctx context.Context, dev *Device) error Update(ctx context.Context, dev *Device) error
Create(ctx context.Context, dev *Device) error Create(ctx context.Context, dev *Device) error
Delete(ctx context.Context, token string) error Delete(ctx context.Context, token string) error
SetNotifiable(ctx context.Context, dev *Device, acct *Account, notifiable bool) error SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher bool) error
GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, error)
PruneStale(ctx context.Context, before int64) (int64, error) PruneStale(ctx context.Context, before int64) (int64, error)
} }

View file

@ -1,18 +1,26 @@
package reddit package reddit
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"regexp" "regexp"
"strconv"
"strings" "strings"
"time" "time"
"github.com/DataDog/datadog-go/statsd" "github.com/DataDog/datadog-go/statsd"
"github.com/go-redis/redis/v8"
"github.com/valyala/fastjson" "github.com/valyala/fastjson"
) )
const (
SkipRateLimiting = "<SKIP_RATE_LIMITING>"
RequestRemainingBuffer = 50
)
type Client struct { type Client struct {
id string id string
secret string secret string
@ -20,6 +28,7 @@ type Client struct {
tracer *httptrace.ClientTrace tracer *httptrace.ClientTrace
pool *fastjson.ParserPool pool *fastjson.ParserPool
statsd statsd.ClientInterface statsd statsd.ClientInterface
redis *redis.Client
} }
var backoffSchedule = []time.Duration{ var backoffSchedule = []time.Duration{
@ -51,7 +60,7 @@ func PostIDFromContext(context string) string {
return "" return ""
} }
func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) *Client { func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Client, connLimit int) *Client {
tracer := &httptrace.ClientTrace{ tracer := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) { GotConn: func(info httptrace.GotConnInfo) {
if info.Reused { if info.Reused {
@ -84,25 +93,30 @@ func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int)
tracer, tracer,
pool, pool,
statsd, statsd,
redis,
} }
} }
type AuthenticatedClient struct { type AuthenticatedClient struct {
*Client *Client
redditId string
refreshToken string refreshToken string
accessToken string accessToken string
expiry *time.Time
} }
func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *AuthenticatedClient { func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient {
return &AuthenticatedClient{rc, refreshToken, accessToken, nil} if redditId == "" {
panic("requires a redditId")
} }
func (rc *Client) doRequest(r *Request) ([]byte, error) { return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
}
func (rc *Client) doRequest(r *Request) ([]byte, int, int, error) {
req, err := r.HTTPRequest() req, err := r.HTTPRequest()
if err != nil { if err != nil {
return nil, err return nil, 0, 0, err
} }
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer)) req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
@ -117,34 +131,53 @@ func (rc *Client) doRequest(r *Request) ([]byte, error) {
if err != nil { if err != nil {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") { if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
return nil, ErrTimeout return nil, 0, 0, ErrTimeout
} }
return nil, err return nil, 0, 0, err
} }
defer resp.Body.Close() defer resp.Body.Close()
remaining, err := strconv.Atoi(resp.Header.Get("x-ratelimit-remaining"))
if err != nil {
remaining = 0
}
reset, err := strconv.Atoi(resp.Header.Get("x-ratelimit-reset"))
if err != nil {
reset = 0
}
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
return nil, ServerError{resp.StatusCode} return nil, remaining, reset, ServerError{resp.StatusCode}
} }
bb, err := ioutil.ReadAll(resp.Body) bb, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
return nil, err return nil, remaining, reset, err
} }
return bb, nil return bb, remaining, reset, nil
} }
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) { func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
bb, err := rac.doRequest(r) if rl, err := rac.isRateLimited(); rl || err != nil {
return nil, ErrRateLimited
}
bb, remaining, reset, err := rac.doRequest(r)
if remaining <= RequestRemainingBuffer {
rac.markRateLimited(time.Duration(reset) * time.Second)
}
if err != nil && r.retry { if err != nil && r.retry {
for _, backoff := range backoffSchedule { for _, backoff := range backoffSchedule {
done := make(chan struct{}) done := make(chan struct{})
time.AfterFunc(backoff, func() { time.AfterFunc(backoff, func() {
_ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1) _ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1)
bb, err = rac.doRequest(r) bb, remaining, reset, err = rac.doRequest(r)
done <- struct{}{} done <- struct{}{}
}) })
@ -179,6 +212,31 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
return rh(val), nil return rh(val), nil
} }
func (rac *AuthenticatedClient) isRateLimited() (bool, error) {
if rac.redditId == SkipRateLimiting {
return false, nil
}
key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId)
res, err := rac.redis.Exists(context.Background(), key).Result()
if err != nil {
return false, err
}
return res > 0, nil
}
func (rac *AuthenticatedClient) markRateLimited(duration time.Duration) error {
if rac.redditId == SkipRateLimiting {
return ErrRequiresRedditId
}
key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId)
_, err := rac.redis.SetEX(context.Background(), key, true, duration).Result()
return err
}
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
req := NewRequest( req := NewRequest(
WithTags([]string{"url:/api/v1/access_token"}), WithTags([]string{"url:/api/v1/access_token"}),

View file

@ -10,7 +10,7 @@ type ServerError struct {
} }
func (se ServerError) Error() string { func (se ServerError) Error() string {
return fmt.Sprintf("errror from reddit: %d", se.StatusCode) return fmt.Sprintf("error from reddit: %d", se.StatusCode)
} }
var ( var (
@ -18,4 +18,8 @@ var (
ErrOauthRevoked = errors.New("oauth revoked") ErrOauthRevoked = errors.New("oauth revoked")
// ErrTimeout . // ErrTimeout .
ErrTimeout = errors.New("timeout") ErrTimeout = errors.New("timeout")
// ErrRateLimited .
ErrRateLimited = errors.New("rate limited")
// ErrRequiresRedditId .
ErrRequiresRedditId = errors.New("requires reddit id")
) )

View file

@ -84,12 +84,22 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
return p.fetch(ctx, query, id) return p.fetch(ctx, query, id)
} }
func (p *postgresDeviceRepository) GetNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := ` query := `
SELECT devices.id, apns_token, sandbox, active_until SELECT devices.id, apns_token, sandbox, active_until
FROM devices FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1 AND devices_accounts.notifiable = TRUE` WHERE devices_accounts.account_id = $1 AND devices_accounts.inbox_notifiable = TRUE`
return p.fetch(ctx, query, id)
}
func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := `
SELECT devices.id, apns_token, sandbox, active_until
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1 AND devices_accounts.watcher_notifiable = TRUE`
return p.fetch(ctx, query, id) return p.fetch(ctx, query, id)
} }
@ -160,13 +170,15 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err
return err return err
} }
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, notifiable bool) error { func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, inbox, watcher bool) error {
query := ` query := `
UPDATE devices_accounts UPDATE devices_accounts
SET notifiable = $1 SET
WHERE device_id = $2 AND account_id = $3` inbox_notifiable = $1,
watcher_notifiable = $2,
WHERE device_id = $3 AND account_id = $4`
res, err := p.pool.Exec(ctx, query, notifiable, dev.ID, acct.ID) res, err := p.pool.Exec(ctx, query, inbox, watcher, dev.ID, acct.ID)
if res.RowsAffected() != 1 { if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -175,6 +187,28 @@ func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domai
} }
func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account) (bool, bool, error) {
query := `
SELECT inbox_notifiable, watcher_notifiable
FROM devices_accounts
WHERE device_id = $1 AND account_id = $2`
rows, err := p.pool.Query(ctx, query, dev.ID, acct.ID)
if err != nil {
return false, false, err
}
defer rows.Close()
for rows.Next() {
var inbox, watcher bool
if err := rows.Scan(&inbox, &watcher); err != nil {
return false, false, err
}
return inbox, watcher, nil
}
return false, false, domain.ErrNotFound
}
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) { func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
query := `DELETE FROM devices WHERE active_until < $1` query := `DELETE FROM devices WHERE active_until < $1`

View file

@ -50,6 +50,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
&watcher.Device.APNSToken, &watcher.Device.APNSToken,
&watcher.Device.Sandbox, &watcher.Device.Sandbox,
&watcher.Account.ID, &watcher.Account.ID,
&watcher.Account.AccountID,
&watcher.Account.AccessToken, &watcher.Account.AccessToken,
&watcher.Account.RefreshToken, &watcher.Account.RefreshToken,
&subredditLabel, &subredditLabel,
@ -92,6 +93,7 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
devices.apns_token, devices.apns_token,
devices.sandbox, devices.sandbox,
accounts.id, accounts.id,
accounts.account_id,
accounts.access_token, accounts.access_token,
accounts.refresh_token, accounts.refresh_token,
COALESCE(subreddits.name, '') AS subreddit_label, COALESCE(subreddits.name, '') AS subreddit_label,
@ -136,6 +138,7 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
devices.apns_token, devices.apns_token,
devices.sandbox, devices.sandbox,
accounts.id, accounts.id,
accounts.account_id,
accounts.access_token, accounts.access_token,
accounts.refresh_token, accounts.refresh_token,
COALESCE(subreddits.name, '') AS subreddit_label, COALESCE(subreddits.name, '') AS subreddit_label,
@ -143,9 +146,10 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
FROM watchers FROM watchers
INNER JOIN devices ON watchers.device_id = devices.id INNER JOIN devices ON watchers.device_id = devices.id
INNER JOIN accounts ON watchers.account_id = accounts.id INNER JOIN accounts ON watchers.account_id = accounts.id
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id AND accounts.id = devices_accounts.account_id
LEFT JOIN subreddits ON watchers.type IN(0,2) AND watchers.watchee_id = subreddits.id LEFT JOIN subreddits ON watchers.type IN(0,2) AND watchers.watchee_id = subreddits.id
LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id
WHERE watchers.type = $1 AND watchers.watchee_id = $2` WHERE watchers.type = $1 AND watchers.watchee_id = $2 AND devices_accounts.watcher_notifiable = TRUE`
return p.fetch(ctx, query, typ, id) return p.fetch(ctx, query, typ, id)
} }
@ -184,6 +188,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
devices.apns_token, devices.apns_token,
devices.sandbox, devices.sandbox,
accounts.id, accounts.id,
accounts.account_id,
accounts.access_token, accounts.access_token,
accounts.refresh_token, accounts.refresh_token,
COALESCE(subreddits.name, '') AS subreddit_label, COALESCE(subreddits.name, '') AS subreddit_label,

View file

@ -47,6 +47,7 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg
os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"), os.Getenv("REDDIT_CLIENT_SECRET"),
statsd, statsd,
redis,
consumers, consumers,
) )
@ -182,7 +183,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return return
} }
rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
if account.ExpiresAt < int64(now) { if account.ExpiresAt < int64(now) {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(), "account#username": account.NormalizedUsername(),
@ -215,7 +216,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
account.ExpiresAt = int64(now + 3540) account.ExpiresAt = int64(now + 3540)
// Refresh client // Refresh client
rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken) rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
if err = nc.accountRepo.Update(ctx, &account); err != nil { if err = nc.accountRepo.Update(ctx, &account); err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
@ -304,7 +305,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return return
} }
devices, err := nc.deviceRepo.GetNotifiableByAccountID(ctx, account.ID) devices, err := nc.deviceRepo.GetInboxNotifiableByAccountID(ctx, account.ID)
if err != nil { if err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(), "account#username": account.NormalizedUsername(),

View file

@ -35,6 +35,7 @@ func NewStuckNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, d
os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"), os.Getenv("REDDIT_CLIENT_SECRET"),
statsd, statsd,
redis,
consumers, consumers,
) )
@ -132,7 +133,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
return return
} }
rac := snc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) rac := snc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
snc.logger.WithFields(logrus.Fields{ snc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(), "account#username": account.NormalizedUsername(),

View file

@ -47,6 +47,7 @@ func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpo
os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"), os.Getenv("REDDIT_CLIENT_SECRET"),
statsd, statsd,
redis,
consumers, consumers,
) )
@ -199,7 +200,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
watcher := watchers[i] watcher := watchers[i]
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID) acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
sps, err := rac.SubredditNew( sps, err := rac.SubredditNew(
subreddit.Name, subreddit.Name,
@ -264,7 +265,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
watcher := watchers[i] watcher := watchers[i]
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID) acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
sps, err := rac.SubredditHot( sps, err := rac.SubredditHot(
subreddit.Name, subreddit.Name,
reddit.WithQuery("limit", "100"), reddit.WithQuery("limit", "100"),

View file

@ -45,6 +45,7 @@ func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool
os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"), os.Getenv("REDDIT_CLIENT_SECRET"),
statsd, statsd,
redis,
consumers, consumers,
) )
@ -176,7 +177,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
// Grab last month's top posts so we calculate a trending average // Grab last month's top posts so we calculate a trending average
i := rand.Intn(len(watchers)) i := rand.Intn(len(watchers))
watcher := watchers[i] watcher := watchers[i]
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken) rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week")) tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week"))
if err != nil { if err != nil {
@ -217,7 +218,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
// Grab hot posts and filter out anything that's > 2 days old // Grab hot posts and filter out anything that's > 2 days old
i = rand.Intn(len(watchers)) i = rand.Intn(len(watchers))
watcher = watchers[i] watcher = watchers[i]
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken) rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
hps, err := rac.SubredditHot(subreddit.Name) hps, err := rac.SubredditHot(subreddit.Name)
if err != nil { if err != nil {
tc.logger.WithFields(logrus.Fields{ tc.logger.WithFields(logrus.Fields{

View file

@ -46,6 +46,7 @@ func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Po
os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"), os.Getenv("REDDIT_CLIENT_SECRET"),
statsd, statsd,
redis,
consumers, consumers,
) )
@ -180,7 +181,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
watcher := watchers[i] watcher := watchers[i]
acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID) acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID)
rac := uc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) rac := uc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
ru, err := rac.UserAbout(user.Name) ru, err := rac.UserAbout(user.Name)
if err != nil { if err != nil {