omg fix watchers and subreddit things (#79)

* omg fix watchers and subreddit things

* don't send thumbnails for posts that are over 18
This commit is contained in:
André Medeiros 2022-05-26 17:54:02 -04:00 committed by GitHub
parent 7e9cf6e78a
commit 8dc4ac350e
6 changed files with 149 additions and 41 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/gorilla/mux"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit"
)
type watcherCriteria struct {
@ -97,8 +98,6 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
return
}
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
watcher := domain.Watcher{
Label: cwr.Label,
DeviceID: dev.ID,
@ -112,9 +111,14 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
}
if cwr.Type == "subreddit" || cwr.Type == "trending" {
srr, err := ac.SubredditAbout(ctx, cwr.Subreddit)
srr, err := a.reddit.SubredditAbout(ctx, cwr.Subreddit)
if err != nil {
a.errorResponse(w, r, 422, err)
switch err {
case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined:
a.errorResponse(w, r, 403, err)
default:
a.errorResponse(w, r, 422, err)
}
return
}
@ -141,6 +145,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
watcher.WatcheeID = sr.ID
} else if cwr.Type == "user" {
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
urr, err := ac.UserAbout(ctx, cwr.User)
if err != nil {
a.errorResponse(w, r, 500, err)
@ -235,12 +240,10 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
return
}
account := accs[0]
found := false
for _, acc := range accs {
if acc.AccountID == rid {
found = true
account = acc
}
}
@ -250,11 +253,14 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
return
}
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
srr, err := ac.SubredditAbout(ctx, lsr)
srr, err := a.reddit.SubredditAbout(ctx, lsr)
if err != nil {
a.errorResponse(w, r, 422, err)
switch err {
case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined:
a.errorResponse(w, r, 403, err)
default:
a.errorResponse(w, r, 422, err)
}
return
}

View file

@ -188,6 +188,108 @@ func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimit
}
}
func (rc *Client) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
bb, _, err := rc.doRequest(ctx, r)
if err != nil && err != ErrOauthRevoked && r.retry {
for _, backoff := range backoffSchedule {
done := make(chan struct{})
time.AfterFunc(backoff, func() {
_ = rc.statsd.Incr("reddit.api.retries", r.tags, 0.1)
bb, _, err = rc.doRequest(ctx, r)
done <- struct{}{}
})
<-done
if err == nil {
break
}
}
}
if err != nil {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
return nil, ErrTimeout
}
return nil, err
}
if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes {
return empty, nil
}
parser := rc.pool.Get()
defer rc.pool.Put(parser)
val, err := parser.ParseBytes(bb)
if err != nil {
return nil, err
}
return rh(val), nil
}
func (rc *Client) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {
url := fmt.Sprintf("https://www.reddit.com/r/%s/%s.json", subreddit, sort)
opts = append(rc.defaultOpts, opts...)
opts = append(opts, []RequestOption{
WithMethod("GET"),
WithURL(url),
}...)
req := NewRequest(opts...)
lr, err := rc.request(ctx, req, NewListingResponse, nil)
if err != nil {
return nil, err
}
return lr.(*ListingResponse), nil
}
func (rc *Client) SubredditHot(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rc.subredditPosts(ctx, subreddit, "hot", opts...)
}
func (rc *Client) SubredditTop(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rc.subredditPosts(ctx, subreddit, "top", opts...)
}
func (rc *Client) SubredditNew(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rc.subredditPosts(ctx, subreddit, "new", opts...)
}
func (rc *Client) SubredditAbout(ctx context.Context, subreddit string, opts ...RequestOption) (*SubredditResponse, error) {
url := fmt.Sprintf("https://www.reddit.com/r/%s/about.json", subreddit)
opts = append(rc.defaultOpts, opts...)
opts = append(opts, []RequestOption{
WithMethod("GET"),
WithURL(url),
}...)
req := NewRequest(opts...)
srr, err := rc.request(ctx, req, NewSubredditResponse, nil)
if err != nil {
if err == ErrOauthRevoked {
return nil, ErrSubredditIsPrivate
} else if serr, ok := err.(ServerError); ok {
if serr.StatusCode == 404 {
return nil, ErrSubredditNotFound
}
}
return nil, err
}
sr := srr.(*SubredditResponse)
if sr.Quarantined {
return nil, ErrSubredditIsQuarantined
}
return sr, nil
}
func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
if rac.isRateLimited() {
return nil, ErrRateLimited
@ -385,13 +487,25 @@ func (rac *AuthenticatedClient) SubredditAbout(ctx context.Context, subreddit st
WithURL(url),
}...)
req := NewRequest(opts...)
sr, err := rac.request(ctx, req, NewSubredditResponse, nil)
srr, err := rac.request(ctx, req, NewSubredditResponse, nil)
if err != nil {
if err == ErrOauthRevoked {
return nil, ErrSubredditIsPrivate
} else if serr, ok := err.(ServerError); ok {
if serr.StatusCode == 404 {
return nil, ErrSubredditNotFound
}
}
return nil, err
}
return sr.(*SubredditResponse), nil
sr := srr.(*SubredditResponse)
if sr.Quarantined {
return nil, ErrSubredditIsQuarantined
}
return sr, nil
}
func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {

View file

@ -25,4 +25,10 @@ var (
ErrRequiresRedditId = errors.New("requires reddit id")
// ErrInvalidBasicAuth .
ErrInvalidBasicAuth = errors.New("invalid basic auth")
// ErrSubredditIsPrivate .
ErrSubredditIsPrivate = errors.New("subreddit is private")
// ErrSubredditIsQuarantined .
ErrSubredditIsQuarantined = errors.New("subreddit is quarantined")
// ErrSubredditNotFound .
ErrSubredditNotFound = errors.New("subreddit not found")
)

View file

@ -83,6 +83,7 @@ type Thing struct {
URL string `json:"url"`
Flair string `json:"flair"`
Thumbnail string `json:"thumbnail"`
Over18 bool `json:"over_18"`
}
func (t *Thing) FullName() string {
@ -120,6 +121,7 @@ func NewThing(val *fastjson.Value) *Thing {
t.URL = string(data.GetStringBytes("url"))
t.Flair = string(data.GetStringBytes("link_flair_text"))
t.Thumbnail = string(data.GetStringBytes("thumbnail"))
t.Over18 = data.GetBool("over_18")
return t
}
@ -156,7 +158,8 @@ func NewListingResponse(val *fastjson.Value) interface{} {
type SubredditResponse struct {
Thing
Name string
Name string
Quarantined bool
}
func NewSubredditResponse(val *fastjson.Value) interface{} {
@ -167,6 +170,7 @@ func NewSubredditResponse(val *fastjson.Value) interface{} {
data := val.Get("data")
sr.ID = string(data.GetStringBytes("id"))
sr.Name = string(data.GetStringBytes("display_name"))
sr.Quarantined = data.GetBool("quarantined")
return sr
}

View file

@ -3,7 +3,6 @@ package worker
import (
"context"
"fmt"
"math/rand"
"os"
"strconv"
"strings"
@ -192,13 +191,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
zap.Int("page", page),
)
i := rand.Intn(len(watchers))
watcher := watchers[i]
acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID)
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
sps, err := rac.SubredditNew(sc,
sps, err := sc.reddit.SubredditNew(sc,
subreddit.Name,
reddit.WithQuery("before", before),
reddit.WithQuery("limit", "100"),
@ -259,12 +252,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
zap.String("subreddit#name", subreddit.NormalizedName()),
)
{
i := rand.Intn(len(watchers))
watcher := watchers[i]
acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID)
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
sps, err := rac.SubredditHot(sc,
sps, err := sc.reddit.SubredditHot(sc,
subreddit.Name,
reddit.WithQuery("limit", "100"),
)
@ -452,7 +440,7 @@ func payloadFromPost(post *reddit.Thing) *payload.Payload {
MutableContent().
Sound("traloop.wav")
if post.Thumbnail != "" {
if post.Thumbnail != "" && !post.Over18 {
payload.Custom("thumbnail", post.Thumbnail)
}

View file

@ -3,7 +3,6 @@ package worker
import (
"context"
"fmt"
"math/rand"
"os"
"strconv"
"time"
@ -167,12 +166,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
return
}
// Grab last month's top posts so we calculate a trending average
i := rand.Intn(len(watchers))
watcher := watchers[i]
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
tps, err := rac.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week"))
tps, err := tc.reddit.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week"))
if err != nil {
tc.logger.Error("failed to fetch weeks's top posts",
zap.Error(err),
@ -219,11 +213,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
zap.Int64("score", medianScore),
)
// Grab hot posts and filter out anything that's > 2 days old
i = rand.Intn(len(watchers))
watcher = watchers[i]
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
hps, err := rac.SubredditHot(tc, subreddit.Name)
hps, err := tc.reddit.SubredditHot(tc, subreddit.Name)
if err != nil {
tc.logger.Error("failed to fetch hot posts",
zap.Error(err),
@ -340,7 +330,7 @@ func payloadFromTrendingPost(post *reddit.Thing) *payload.Payload {
MutableContent().
Sound("traloop.wav")
if post.Thumbnail != "" {
if post.Thumbnail != "" && !post.Over18 {
payload.Custom("thumbnail", post.Thumbnail)
}