mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-13 07:27:43 +00:00
ch-ch-changes
This commit is contained in:
parent
c149606f24
commit
a1a098b448
14 changed files with 222 additions and 48 deletions
|
@ -11,10 +11,12 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/christianselig/apollo-backend/internal/domain"
|
||||
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||
)
|
||||
|
||||
type accountNotificationsRequest struct {
|
||||
Enabled bool
|
||||
InboxNotifications bool `json:"inbox_notifications"`
|
||||
WatcherNotifications bool `json:"watcher_notifications"`
|
||||
}
|
||||
|
||||
func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -42,7 +44,7 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.Enabled); err != nil {
|
||||
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.InboxNotifications, anr.WatcherNotifications); err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
@ -50,6 +52,37 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request
|
|||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (a *api) getNotificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
apns := vars["apns"]
|
||||
rid := vars["redditID"]
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
inbox, watchers, err := a.deviceRepo.GetNotifiable(ctx, &dev, &acct)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
an := &accountNotificationsRequest{InboxNotifications: inbox, WatcherNotifications: watchers}
|
||||
_ = json.NewEncoder(w).Encode(an)
|
||||
}
|
||||
|
||||
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
apns := vars["apns"]
|
||||
|
@ -108,7 +141,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
|
|||
for _, acc := range raccs {
|
||||
delete(accsMap, acc.NormalizedUsername())
|
||||
|
||||
ac := a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
||||
ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
|
||||
tokens, err := ac.RefreshTokens()
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 422, err.Error())
|
||||
|
@ -120,7 +153,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
|
|||
acc.RefreshToken = tokens.RefreshToken
|
||||
acc.AccessToken = tokens.AccessToken
|
||||
|
||||
ac = a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
||||
ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
|
||||
me, err := ac.Me()
|
||||
|
||||
if err != nil {
|
||||
|
@ -167,7 +200,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Here we check whether the account is supplied with a valid token.
|
||||
ac := a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken)
|
||||
ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
|
||||
tokens, err := ac.RefreshTokens()
|
||||
if err != nil {
|
||||
a.logger.WithFields(logrus.Fields{
|
||||
|
@ -182,7 +215,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||
acct.RefreshToken = tokens.RefreshToken
|
||||
acct.AccessToken = tokens.AccessToken
|
||||
|
||||
ac = a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken)
|
||||
ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
|
||||
me, err := ac.Me()
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -3,11 +3,13 @@ package api
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/DataDog/datadog-go/statsd"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/sideshow/apns2/token"
|
||||
|
@ -31,11 +33,12 @@ type api struct {
|
|||
userRepo domain.UserRepository
|
||||
}
|
||||
|
||||
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api {
|
||||
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, redis *redis.Client, pool *pgxpool.Pool) *api {
|
||||
reddit := reddit.NewClient(
|
||||
os.Getenv("REDDIT_CLIENT_ID"),
|
||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||
statsd,
|
||||
redis,
|
||||
16,
|
||||
)
|
||||
|
||||
|
@ -93,6 +96,7 @@ func (a *api) Routes() *mux.Router {
|
|||
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.getNotificationsAccountHandler).Methods("GET")
|
||||
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
|
||||
|
@ -142,10 +146,19 @@ func (a *api) loggingMiddleware(next http.Handler) http.Handler {
|
|||
// Call the next handler, which can be another middleware in the chain, or the final handler.
|
||||
next.ServeHTTP(lrw, r)
|
||||
|
||||
remoteAddr := r.Header.Get("X-Forwarded-For")
|
||||
if remoteAddr == "" {
|
||||
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
|
||||
remoteAddr = "unknown"
|
||||
} else {
|
||||
remoteAddr = ip
|
||||
}
|
||||
}
|
||||
|
||||
logEntry := a.logger.WithFields(logrus.Fields{
|
||||
"duration": time.Since(start).Milliseconds(),
|
||||
"method": r.Method,
|
||||
"remote#addr": r.RemoteAddr,
|
||||
"remote#addr": remoteAddr,
|
||||
"response#bytes": lrw.bytes,
|
||||
"status": lrw.statusCode,
|
||||
"uri": r.RequestURI,
|
||||
|
|
|
@ -7,8 +7,10 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/christianselig/apollo-backend/internal/domain"
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/christianselig/apollo-backend/internal/domain"
|
||||
)
|
||||
|
||||
type watcherCriteria struct {
|
||||
|
@ -28,6 +30,14 @@ type createWatcherRequest struct {
|
|||
Criteria watcherCriteria
|
||||
}
|
||||
|
||||
func (cwr *createWatcherRequest) Validate() error {
|
||||
return validation.ValidateStruct(cwr,
|
||||
validation.Field(&cwr.Type, validation.Required),
|
||||
validation.Field(&cwr.User, validation.Required.When(cwr.Type == "user")),
|
||||
validation.Field(&cwr.Subreddit, validation.Required.When(cwr.Type == "subreddit" || cwr.Type == "trending")),
|
||||
)
|
||||
}
|
||||
|
||||
type watcherCreatedResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
@ -47,6 +57,11 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := cwr.Validate(); err != nil {
|
||||
a.errorResponse(w, r, 422, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 422, err.Error())
|
||||
|
@ -73,7 +88,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
||||
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||
|
||||
watcher := domain.Watcher{
|
||||
Label: cwr.Label,
|
||||
|
@ -220,11 +235,12 @@ type watcherItem struct {
|
|||
Type string `json:"type"`
|
||||
Label string `json:"label"`
|
||||
SourceLabel string `json:"source_label"`
|
||||
Upvotes int64 `json:"upvotes"`
|
||||
Keyword string `json:"keyword"`
|
||||
Flair string `json:"flair"`
|
||||
Domain string `json:"domain"`
|
||||
Upvotes *int64 `json:"upvotes,omitempty"`
|
||||
Keyword string `json:"keyword,omitempty"`
|
||||
Flair string `json:"flair,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Hits int64 `json:"hits"`
|
||||
Author string `json:"author,omitempty"`
|
||||
}
|
||||
|
||||
func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -248,11 +264,15 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
|
|||
Type: watcher.Type.String(),
|
||||
Label: watcher.Label,
|
||||
SourceLabel: watcher.WatcheeLabel,
|
||||
Upvotes: watcher.Upvotes,
|
||||
Keyword: watcher.Keyword,
|
||||
Flair: watcher.Flair,
|
||||
Domain: watcher.Domain,
|
||||
Hits: watcher.Hits,
|
||||
Author: watcher.Author,
|
||||
}
|
||||
|
||||
if watcher.Upvotes != 0 {
|
||||
wi.Upvotes = &watcher.Upvotes
|
||||
}
|
||||
|
||||
wis[i] = wi
|
||||
|
|
|
@ -44,7 +44,7 @@ func APICmd(ctx context.Context) *cobra.Command {
|
|||
}
|
||||
defer redis.Close()
|
||||
|
||||
api := api.NewAPI(ctx, logger, statsd, db)
|
||||
api := api.NewAPI(ctx, logger, statsd, redis, db)
|
||||
srv := api.Server(port)
|
||||
|
||||
go func() { _ = srv.ListenAndServe() }()
|
||||
|
|
|
@ -27,14 +27,16 @@ func (dev *Device) Validate() error {
|
|||
type DeviceRepository interface {
|
||||
GetByID(ctx context.Context, id int64) (Device, error)
|
||||
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
||||
GetNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||
GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||
GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||
|
||||
CreateOrUpdate(ctx context.Context, dev *Device) error
|
||||
Update(ctx context.Context, dev *Device) error
|
||||
Create(ctx context.Context, dev *Device) error
|
||||
Delete(ctx context.Context, token string) error
|
||||
SetNotifiable(ctx context.Context, dev *Device, acct *Account, notifiable bool) error
|
||||
SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher bool) error
|
||||
GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, error)
|
||||
|
||||
PruneStale(ctx context.Context, before int64) (int64, error)
|
||||
}
|
||||
|
|
|
@ -1,18 +1,26 @@
|
|||
package reddit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DataDog/datadog-go/statsd"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/valyala/fastjson"
|
||||
)
|
||||
|
||||
const (
|
||||
SkipRateLimiting = "<SKIP_RATE_LIMITING>"
|
||||
RequestRemainingBuffer = 50
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
id string
|
||||
secret string
|
||||
|
@ -20,6 +28,7 @@ type Client struct {
|
|||
tracer *httptrace.ClientTrace
|
||||
pool *fastjson.ParserPool
|
||||
statsd statsd.ClientInterface
|
||||
redis *redis.Client
|
||||
}
|
||||
|
||||
var backoffSchedule = []time.Duration{
|
||||
|
@ -51,7 +60,7 @@ func PostIDFromContext(context string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int) *Client {
|
||||
func NewClient(id, secret string, statsd statsd.ClientInterface, redis *redis.Client, connLimit int) *Client {
|
||||
tracer := &httptrace.ClientTrace{
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
if info.Reused {
|
||||
|
@ -84,25 +93,30 @@ func NewClient(id, secret string, statsd statsd.ClientInterface, connLimit int)
|
|||
tracer,
|
||||
pool,
|
||||
statsd,
|
||||
redis,
|
||||
}
|
||||
}
|
||||
|
||||
type AuthenticatedClient struct {
|
||||
*Client
|
||||
|
||||
redditId string
|
||||
refreshToken string
|
||||
accessToken string
|
||||
expiry *time.Time
|
||||
}
|
||||
|
||||
func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *AuthenticatedClient {
|
||||
return &AuthenticatedClient{rc, refreshToken, accessToken, nil}
|
||||
func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken string) *AuthenticatedClient {
|
||||
if redditId == "" {
|
||||
panic("requires a redditId")
|
||||
}
|
||||
|
||||
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
|
||||
}
|
||||
|
||||
func (rc *Client) doRequest(r *Request) ([]byte, error) {
|
||||
func (rc *Client) doRequest(r *Request) ([]byte, int, int, error) {
|
||||
req, err := r.HTTPRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
|
||||
|
@ -117,34 +131,53 @@ func (rc *Client) doRequest(r *Request) ([]byte, error) {
|
|||
if err != nil {
|
||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
||||
return nil, ErrTimeout
|
||||
return nil, 0, 0, ErrTimeout
|
||||
}
|
||||
return nil, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
remaining, err := strconv.Atoi(resp.Header.Get("x-ratelimit-remaining"))
|
||||
if err != nil {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
reset, err := strconv.Atoi(resp.Header.Get("x-ratelimit-reset"))
|
||||
if err != nil {
|
||||
reset = 0
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
return nil, ServerError{resp.StatusCode}
|
||||
return nil, remaining, reset, ServerError{resp.StatusCode}
|
||||
}
|
||||
|
||||
bb, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
return nil, err
|
||||
return nil, remaining, reset, err
|
||||
}
|
||||
return bb, nil
|
||||
return bb, remaining, reset, nil
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
||||
bb, err := rac.doRequest(r)
|
||||
if rl, err := rac.isRateLimited(); rl || err != nil {
|
||||
return nil, ErrRateLimited
|
||||
}
|
||||
|
||||
bb, remaining, reset, err := rac.doRequest(r)
|
||||
|
||||
if remaining <= RequestRemainingBuffer {
|
||||
rac.markRateLimited(time.Duration(reset) * time.Second)
|
||||
}
|
||||
|
||||
if err != nil && r.retry {
|
||||
for _, backoff := range backoffSchedule {
|
||||
done := make(chan struct{})
|
||||
|
||||
time.AfterFunc(backoff, func() {
|
||||
_ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1)
|
||||
bb, err = rac.doRequest(r)
|
||||
bb, remaining, reset, err = rac.doRequest(r)
|
||||
done <- struct{}{}
|
||||
})
|
||||
|
||||
|
@ -179,6 +212,31 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
|
|||
return rh(val), nil
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) isRateLimited() (bool, error) {
|
||||
if rac.redditId == SkipRateLimiting {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId)
|
||||
res, err := rac.redis.Exists(context.Background(), key).Result()
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return res > 0, nil
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) markRateLimited(duration time.Duration) error {
|
||||
if rac.redditId == SkipRateLimiting {
|
||||
return ErrRequiresRedditId
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId)
|
||||
_, err := rac.redis.SetEX(context.Background(), key, true, duration).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
||||
req := NewRequest(
|
||||
WithTags([]string{"url:/api/v1/access_token"}),
|
||||
|
|
|
@ -10,7 +10,7 @@ type ServerError struct {
|
|||
}
|
||||
|
||||
func (se ServerError) Error() string {
|
||||
return fmt.Sprintf("errror from reddit: %d", se.StatusCode)
|
||||
return fmt.Sprintf("error from reddit: %d", se.StatusCode)
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -18,4 +18,8 @@ var (
|
|||
ErrOauthRevoked = errors.New("oauth revoked")
|
||||
// ErrTimeout .
|
||||
ErrTimeout = errors.New("timeout")
|
||||
// ErrRateLimited .
|
||||
ErrRateLimited = errors.New("rate limited")
|
||||
// ErrRequiresRedditId .
|
||||
ErrRequiresRedditId = errors.New("requires reddit id")
|
||||
)
|
||||
|
|
|
@ -84,12 +84,22 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
|
|||
return p.fetch(ctx, query, id)
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) GetNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||
func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||
query := `
|
||||
SELECT devices.id, apns_token, sandbox, active_until
|
||||
FROM devices
|
||||
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||
WHERE devices_accounts.account_id = $1 AND devices_accounts.notifiable = TRUE`
|
||||
WHERE devices_accounts.account_id = $1 AND devices_accounts.inbox_notifiable = TRUE`
|
||||
|
||||
return p.fetch(ctx, query, id)
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||
query := `
|
||||
SELECT devices.id, apns_token, sandbox, active_until
|
||||
FROM devices
|
||||
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||
WHERE devices_accounts.account_id = $1 AND devices_accounts.watcher_notifiable = TRUE`
|
||||
|
||||
return p.fetch(ctx, query, id)
|
||||
}
|
||||
|
@ -160,13 +170,15 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err
|
|||
return err
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, notifiable bool) error {
|
||||
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, inbox, watcher bool) error {
|
||||
query := `
|
||||
UPDATE devices_accounts
|
||||
SET notifiable = $1
|
||||
WHERE device_id = $2 AND account_id = $3`
|
||||
SET
|
||||
inbox_notifiable = $1,
|
||||
watcher_notifiable = $2,
|
||||
WHERE device_id = $3 AND account_id = $4`
|
||||
|
||||
res, err := p.pool.Exec(ctx, query, notifiable, dev.ID, acct.ID)
|
||||
res, err := p.pool.Exec(ctx, query, inbox, watcher, dev.ID, acct.ID)
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||
|
@ -175,6 +187,28 @@ func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domai
|
|||
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account) (bool, bool, error) {
|
||||
query := `
|
||||
SELECT inbox_notifiable, watcher_notifiable
|
||||
FROM devices_accounts
|
||||
WHERE device_id = $1 AND account_id = $2`
|
||||
|
||||
rows, err := p.pool.Query(ctx, query, dev.ID, acct.ID)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var inbox, watcher bool
|
||||
if err := rows.Scan(&inbox, &watcher); err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
return inbox, watcher, nil
|
||||
}
|
||||
|
||||
return false, false, domain.ErrNotFound
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
|
||||
query := `DELETE FROM devices WHERE active_until < $1`
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
|
|||
&watcher.Device.APNSToken,
|
||||
&watcher.Device.Sandbox,
|
||||
&watcher.Account.ID,
|
||||
&watcher.Account.AccountID,
|
||||
&watcher.Account.AccessToken,
|
||||
&watcher.Account.RefreshToken,
|
||||
&subredditLabel,
|
||||
|
@ -92,6 +93,7 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
|
|||
devices.apns_token,
|
||||
devices.sandbox,
|
||||
accounts.id,
|
||||
accounts.account_id,
|
||||
accounts.access_token,
|
||||
accounts.refresh_token,
|
||||
COALESCE(subreddits.name, '') AS subreddit_label,
|
||||
|
@ -136,6 +138,7 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
|
|||
devices.apns_token,
|
||||
devices.sandbox,
|
||||
accounts.id,
|
||||
accounts.account_id,
|
||||
accounts.access_token,
|
||||
accounts.refresh_token,
|
||||
COALESCE(subreddits.name, '') AS subreddit_label,
|
||||
|
@ -143,9 +146,10 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
|
|||
FROM watchers
|
||||
INNER JOIN devices ON watchers.device_id = devices.id
|
||||
INNER JOIN accounts ON watchers.account_id = accounts.id
|
||||
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id AND accounts.id = devices_accounts.account_id
|
||||
LEFT JOIN subreddits ON watchers.type IN(0,2) AND watchers.watchee_id = subreddits.id
|
||||
LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id
|
||||
WHERE watchers.type = $1 AND watchers.watchee_id = $2`
|
||||
WHERE watchers.type = $1 AND watchers.watchee_id = $2 AND devices_accounts.watcher_notifiable = TRUE`
|
||||
|
||||
return p.fetch(ctx, query, typ, id)
|
||||
}
|
||||
|
@ -184,6 +188,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
|
|||
devices.apns_token,
|
||||
devices.sandbox,
|
||||
accounts.id,
|
||||
accounts.account_id,
|
||||
accounts.access_token,
|
||||
accounts.refresh_token,
|
||||
COALESCE(subreddits.name, '') AS subreddit_label,
|
||||
|
|
|
@ -47,6 +47,7 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg
|
|||
os.Getenv("REDDIT_CLIENT_ID"),
|
||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||
statsd,
|
||||
redis,
|
||||
consumers,
|
||||
)
|
||||
|
||||
|
@ -182,7 +183,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
return
|
||||
}
|
||||
|
||||
rac := nc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
||||
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||
if account.ExpiresAt < int64(now) {
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
"account#username": account.NormalizedUsername(),
|
||||
|
@ -215,7 +216,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
account.ExpiresAt = int64(now + 3540)
|
||||
|
||||
// Refresh client
|
||||
rac = nc.reddit.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken)
|
||||
rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
|
||||
|
||||
if err = nc.accountRepo.Update(ctx, &account); err != nil {
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
|
@ -304,7 +305,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
return
|
||||
}
|
||||
|
||||
devices, err := nc.deviceRepo.GetNotifiableByAccountID(ctx, account.ID)
|
||||
devices, err := nc.deviceRepo.GetInboxNotifiableByAccountID(ctx, account.ID)
|
||||
if err != nil {
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
"account#username": account.NormalizedUsername(),
|
||||
|
|
|
@ -35,6 +35,7 @@ func NewStuckNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, d
|
|||
os.Getenv("REDDIT_CLIENT_ID"),
|
||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||
statsd,
|
||||
redis,
|
||||
consumers,
|
||||
)
|
||||
|
||||
|
@ -132,7 +133,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
return
|
||||
}
|
||||
|
||||
rac := snc.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
||||
rac := snc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||
|
||||
snc.logger.WithFields(logrus.Fields{
|
||||
"account#username": account.NormalizedUsername(),
|
||||
|
|
|
@ -47,6 +47,7 @@ func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpo
|
|||
os.Getenv("REDDIT_CLIENT_ID"),
|
||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||
statsd,
|
||||
redis,
|
||||
consumers,
|
||||
)
|
||||
|
||||
|
@ -199,7 +200,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
watcher := watchers[i]
|
||||
|
||||
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
|
||||
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
||||
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
||||
|
||||
sps, err := rac.SubredditNew(
|
||||
subreddit.Name,
|
||||
|
@ -264,7 +265,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
watcher := watchers[i]
|
||||
|
||||
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
|
||||
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
||||
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
||||
sps, err := rac.SubredditHot(
|
||||
subreddit.Name,
|
||||
reddit.WithQuery("limit", "100"),
|
||||
|
|
|
@ -45,6 +45,7 @@ func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool
|
|||
os.Getenv("REDDIT_CLIENT_ID"),
|
||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||
statsd,
|
||||
redis,
|
||||
consumers,
|
||||
)
|
||||
|
||||
|
@ -176,7 +177,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
|||
// Grab last month's top posts so we calculate a trending average
|
||||
i := rand.Intn(len(watchers))
|
||||
watcher := watchers[i]
|
||||
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||
|
||||
tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week"))
|
||||
if err != nil {
|
||||
|
@ -217,7 +218,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
|||
// Grab hot posts and filter out anything that's > 2 days old
|
||||
i = rand.Intn(len(watchers))
|
||||
watcher = watchers[i]
|
||||
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||
hps, err := rac.SubredditHot(subreddit.Name)
|
||||
if err != nil {
|
||||
tc.logger.WithFields(logrus.Fields{
|
||||
|
|
|
@ -46,6 +46,7 @@ func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Po
|
|||
os.Getenv("REDDIT_CLIENT_ID"),
|
||||
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||
statsd,
|
||||
redis,
|
||||
consumers,
|
||||
)
|
||||
|
||||
|
@ -180,7 +181,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
|
|||
watcher := watchers[i]
|
||||
|
||||
acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID)
|
||||
rac := uc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
||||
rac := uc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
||||
|
||||
ru, err := rac.UserAbout(user.Name)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue