diff --git a/internal/api/watcher.go b/internal/api/watcher.go index ff71fc9..6109cc1 100644 --- a/internal/api/watcher.go +++ b/internal/api/watcher.go @@ -13,6 +13,7 @@ import ( "github.com/gorilla/mux" "github.com/christianselig/apollo-backend/internal/domain" + "github.com/christianselig/apollo-backend/internal/reddit" ) type watcherCriteria struct { @@ -97,8 +98,6 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { return } - ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) - watcher := domain.Watcher{ Label: cwr.Label, DeviceID: dev.ID, @@ -112,9 +111,14 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { } if cwr.Type == "subreddit" || cwr.Type == "trending" { - srr, err := ac.SubredditAbout(ctx, cwr.Subreddit) + srr, err := a.reddit.SubredditAbout(ctx, cwr.Subreddit) if err != nil { - a.errorResponse(w, r, 422, err) + switch err { + case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined: + a.errorResponse(w, r, 403, err) + default: + a.errorResponse(w, r, 422, err) + } return } @@ -141,6 +145,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { watcher.WatcheeID = sr.ID } else if cwr.Type == "user" { + ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) urr, err := ac.UserAbout(ctx, cwr.User) if err != nil { a.errorResponse(w, r, 500, err) @@ -235,12 +240,10 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) { return } - account := accs[0] found := false for _, acc := range accs { if acc.AccountID == rid { found = true - account = acc } } @@ -250,11 +253,14 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) { return } - ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) - - srr, err := ac.SubredditAbout(ctx, lsr) + srr, err := a.reddit.SubredditAbout(ctx, lsr) if err != nil { - a.errorResponse(w, r, 422, err) + switch err { + case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined: + a.errorResponse(w, r, 403, err) + default: + a.errorResponse(w, r, 422, err) + } return } diff --git a/internal/reddit/client.go b/internal/reddit/client.go index c7af450..0496f4d 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -188,6 +188,108 @@ func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimit } } +func (rc *Client) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) { + bb, _, err := rc.doRequest(ctx, r) + + if err != nil && err != ErrOauthRevoked && r.retry { + for _, backoff := range backoffSchedule { + done := make(chan struct{}) + + time.AfterFunc(backoff, func() { + _ = rc.statsd.Incr("reddit.api.retries", r.tags, 0.1) + bb, _, err = rc.doRequest(ctx, r) + done <- struct{}{} + }) + + <-done + + if err == nil { + break + } + } + } + + if err != nil { + _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) + if strings.Contains(err.Error(), "http2: timeout awaiting response headers") { + return nil, ErrTimeout + } + return nil, err + } + + if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes { + return empty, nil + } + + parser := rc.pool.Get() + defer rc.pool.Put(parser) + + val, err := parser.ParseBytes(bb) + if err != nil { + return nil, err + } + + return rh(val), nil +} + +func (rc *Client) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) { + url := fmt.Sprintf("https://www.reddit.com/r/%s/%s.json", subreddit, sort) + opts = append(rc.defaultOpts, opts...) + opts = append(opts, []RequestOption{ + WithMethod("GET"), + WithURL(url), + }...) + req := NewRequest(opts...) + + lr, err := rc.request(ctx, req, NewListingResponse, nil) + if err != nil { + return nil, err + } + + return lr.(*ListingResponse), nil +} + +func (rc *Client) SubredditHot(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) { + return rc.subredditPosts(ctx, subreddit, "hot", opts...) +} + +func (rc *Client) SubredditTop(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) { + return rc.subredditPosts(ctx, subreddit, "top", opts...) +} + +func (rc *Client) SubredditNew(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) { + return rc.subredditPosts(ctx, subreddit, "new", opts...) +} + +func (rc *Client) SubredditAbout(ctx context.Context, subreddit string, opts ...RequestOption) (*SubredditResponse, error) { + url := fmt.Sprintf("https://www.reddit.com/r/%s/about.json", subreddit) + opts = append(rc.defaultOpts, opts...) + opts = append(opts, []RequestOption{ + WithMethod("GET"), + WithURL(url), + }...) + req := NewRequest(opts...) + srr, err := rc.request(ctx, req, NewSubredditResponse, nil) + + if err != nil { + if err == ErrOauthRevoked { + return nil, ErrSubredditIsPrivate + } else if serr, ok := err.(ServerError); ok { + if serr.StatusCode == 404 { + return nil, ErrSubredditNotFound + } + } + return nil, err + } + + sr := srr.(*SubredditResponse) + if sr.Quarantined { + return nil, ErrSubredditIsQuarantined + } + + return sr, nil +} + func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) { if rac.isRateLimited() { return nil, ErrRateLimited @@ -385,13 +487,25 @@ func (rac *AuthenticatedClient) SubredditAbout(ctx context.Context, subreddit st WithURL(url), }...) req := NewRequest(opts...) - sr, err := rac.request(ctx, req, NewSubredditResponse, nil) + srr, err := rac.request(ctx, req, NewSubredditResponse, nil) if err != nil { + if err == ErrOauthRevoked { + return nil, ErrSubredditIsPrivate + } else if serr, ok := err.(ServerError); ok { + if serr.StatusCode == 404 { + return nil, ErrSubredditNotFound + } + } return nil, err } - return sr.(*SubredditResponse), nil + sr := srr.(*SubredditResponse) + if sr.Quarantined { + return nil, ErrSubredditIsQuarantined + } + + return sr, nil } func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) { diff --git a/internal/reddit/errors.go b/internal/reddit/errors.go index 01629b1..c1fb179 100644 --- a/internal/reddit/errors.go +++ b/internal/reddit/errors.go @@ -25,4 +25,10 @@ var ( ErrRequiresRedditId = errors.New("requires reddit id") // ErrInvalidBasicAuth . ErrInvalidBasicAuth = errors.New("invalid basic auth") + // ErrSubredditIsPrivate . + ErrSubredditIsPrivate = errors.New("subreddit is private") + // ErrSubredditIsQuarantined . + ErrSubredditIsQuarantined = errors.New("subreddit is quarantined") + // ErrSubredditNotFound . + ErrSubredditNotFound = errors.New("subreddit not found") ) diff --git a/internal/reddit/types.go b/internal/reddit/types.go index 192d38d..d092518 100644 --- a/internal/reddit/types.go +++ b/internal/reddit/types.go @@ -83,6 +83,7 @@ type Thing struct { URL string `json:"url"` Flair string `json:"flair"` Thumbnail string `json:"thumbnail"` + Over18 bool `json:"over_18"` } func (t *Thing) FullName() string { @@ -120,6 +121,7 @@ func NewThing(val *fastjson.Value) *Thing { t.URL = string(data.GetStringBytes("url")) t.Flair = string(data.GetStringBytes("link_flair_text")) t.Thumbnail = string(data.GetStringBytes("thumbnail")) + t.Over18 = data.GetBool("over_18") return t } @@ -156,7 +158,8 @@ func NewListingResponse(val *fastjson.Value) interface{} { type SubredditResponse struct { Thing - Name string + Name string + Quarantined bool } func NewSubredditResponse(val *fastjson.Value) interface{} { @@ -167,6 +170,7 @@ func NewSubredditResponse(val *fastjson.Value) interface{} { data := val.Get("data") sr.ID = string(data.GetStringBytes("id")) sr.Name = string(data.GetStringBytes("display_name")) + sr.Quarantined = data.GetBool("quarantined") return sr } diff --git a/internal/worker/subreddits.go b/internal/worker/subreddits.go index a2f1d8c..5c9beac 100644 --- a/internal/worker/subreddits.go +++ b/internal/worker/subreddits.go @@ -3,7 +3,6 @@ package worker import ( "context" "fmt" - "math/rand" "os" "strconv" "strings" @@ -192,13 +191,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { zap.Int("page", page), ) - i := rand.Intn(len(watchers)) - watcher := watchers[i] - - acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID) - rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken) - - sps, err := rac.SubredditNew(sc, + sps, err := sc.reddit.SubredditNew(sc, subreddit.Name, reddit.WithQuery("before", before), reddit.WithQuery("limit", "100"), @@ -259,12 +252,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { zap.String("subreddit#name", subreddit.NormalizedName()), ) { - i := rand.Intn(len(watchers)) - watcher := watchers[i] - - acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID) - rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken) - sps, err := rac.SubredditHot(sc, + sps, err := sc.reddit.SubredditHot(sc, subreddit.Name, reddit.WithQuery("limit", "100"), ) @@ -452,7 +440,7 @@ func payloadFromPost(post *reddit.Thing) *payload.Payload { MutableContent(). Sound("traloop.wav") - if post.Thumbnail != "" { + if post.Thumbnail != "" && !post.Over18 { payload.Custom("thumbnail", post.Thumbnail) } diff --git a/internal/worker/trending.go b/internal/worker/trending.go index 7b8d4e2..504b423 100644 --- a/internal/worker/trending.go +++ b/internal/worker/trending.go @@ -3,7 +3,6 @@ package worker import ( "context" "fmt" - "math/rand" "os" "strconv" "time" @@ -167,12 +166,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { return } - // Grab last month's top posts so we calculate a trending average - i := rand.Intn(len(watchers)) - watcher := watchers[i] - rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) - - tps, err := rac.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week")) + tps, err := tc.reddit.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week")) if err != nil { tc.logger.Error("failed to fetch weeks's top posts", zap.Error(err), @@ -219,11 +213,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { zap.Int64("score", medianScore), ) - // Grab hot posts and filter out anything that's > 2 days old - i = rand.Intn(len(watchers)) - watcher = watchers[i] - rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken) - hps, err := rac.SubredditHot(tc, subreddit.Name) + hps, err := tc.reddit.SubredditHot(tc, subreddit.Name) if err != nil { tc.logger.Error("failed to fetch hot posts", zap.Error(err), @@ -340,7 +330,7 @@ func payloadFromTrendingPost(post *reddit.Thing) *payload.Payload { MutableContent(). Sound("traloop.wav") - if post.Thumbnail != "" { + if post.Thumbnail != "" && !post.Over18 { payload.Custom("thumbnail", post.Thumbnail) }