mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-21 19:37:41 +00:00
omg fix watchers and subreddit things (#79)
* omg fix watchers and subreddit things * don't send thumbnails for posts that are over 18
This commit is contained in:
parent
7e9cf6e78a
commit
8dc4ac350e
6 changed files with 149 additions and 41 deletions
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/christianselig/apollo-backend/internal/domain"
|
||||
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||
)
|
||||
|
||||
type watcherCriteria struct {
|
||||
|
@ -97,8 +98,6 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||
|
||||
watcher := domain.Watcher{
|
||||
Label: cwr.Label,
|
||||
DeviceID: dev.ID,
|
||||
|
@ -112,9 +111,14 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if cwr.Type == "subreddit" || cwr.Type == "trending" {
|
||||
srr, err := ac.SubredditAbout(ctx, cwr.Subreddit)
|
||||
srr, err := a.reddit.SubredditAbout(ctx, cwr.Subreddit)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 422, err)
|
||||
switch err {
|
||||
case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined:
|
||||
a.errorResponse(w, r, 403, err)
|
||||
default:
|
||||
a.errorResponse(w, r, 422, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -141,6 +145,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
watcher.WatcheeID = sr.ID
|
||||
} else if cwr.Type == "user" {
|
||||
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||
urr, err := ac.UserAbout(ctx, cwr.User)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 500, err)
|
||||
|
@ -235,12 +240,10 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
account := accs[0]
|
||||
found := false
|
||||
for _, acc := range accs {
|
||||
if acc.AccountID == rid {
|
||||
found = true
|
||||
account = acc
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -250,11 +253,14 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||
|
||||
srr, err := ac.SubredditAbout(ctx, lsr)
|
||||
srr, err := a.reddit.SubredditAbout(ctx, lsr)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 422, err)
|
||||
switch err {
|
||||
case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined:
|
||||
a.errorResponse(w, r, 403, err)
|
||||
default:
|
||||
a.errorResponse(w, r, 422, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -188,6 +188,108 @@ func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimit
|
|||
}
|
||||
}
|
||||
|
||||
func (rc *Client) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
||||
bb, _, err := rc.doRequest(ctx, r)
|
||||
|
||||
if err != nil && err != ErrOauthRevoked && r.retry {
|
||||
for _, backoff := range backoffSchedule {
|
||||
done := make(chan struct{})
|
||||
|
||||
time.AfterFunc(backoff, func() {
|
||||
_ = rc.statsd.Incr("reddit.api.retries", r.tags, 0.1)
|
||||
bb, _, err = rc.doRequest(ctx, r)
|
||||
done <- struct{}{}
|
||||
})
|
||||
|
||||
<-done
|
||||
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes {
|
||||
return empty, nil
|
||||
}
|
||||
|
||||
parser := rc.pool.Get()
|
||||
defer rc.pool.Put(parser)
|
||||
|
||||
val, err := parser.ParseBytes(bb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rh(val), nil
|
||||
}
|
||||
|
||||
func (rc *Client) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {
|
||||
url := fmt.Sprintf("https://www.reddit.com/r/%s/%s.json", subreddit, sort)
|
||||
opts = append(rc.defaultOpts, opts...)
|
||||
opts = append(opts, []RequestOption{
|
||||
WithMethod("GET"),
|
||||
WithURL(url),
|
||||
}...)
|
||||
req := NewRequest(opts...)
|
||||
|
||||
lr, err := rc.request(ctx, req, NewListingResponse, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return lr.(*ListingResponse), nil
|
||||
}
|
||||
|
||||
func (rc *Client) SubredditHot(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||
return rc.subredditPosts(ctx, subreddit, "hot", opts...)
|
||||
}
|
||||
|
||||
func (rc *Client) SubredditTop(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||
return rc.subredditPosts(ctx, subreddit, "top", opts...)
|
||||
}
|
||||
|
||||
func (rc *Client) SubredditNew(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||
return rc.subredditPosts(ctx, subreddit, "new", opts...)
|
||||
}
|
||||
|
||||
func (rc *Client) SubredditAbout(ctx context.Context, subreddit string, opts ...RequestOption) (*SubredditResponse, error) {
|
||||
url := fmt.Sprintf("https://www.reddit.com/r/%s/about.json", subreddit)
|
||||
opts = append(rc.defaultOpts, opts...)
|
||||
opts = append(opts, []RequestOption{
|
||||
WithMethod("GET"),
|
||||
WithURL(url),
|
||||
}...)
|
||||
req := NewRequest(opts...)
|
||||
srr, err := rc.request(ctx, req, NewSubredditResponse, nil)
|
||||
|
||||
if err != nil {
|
||||
if err == ErrOauthRevoked {
|
||||
return nil, ErrSubredditIsPrivate
|
||||
} else if serr, ok := err.(ServerError); ok {
|
||||
if serr.StatusCode == 404 {
|
||||
return nil, ErrSubredditNotFound
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sr := srr.(*SubredditResponse)
|
||||
if sr.Quarantined {
|
||||
return nil, ErrSubredditIsQuarantined
|
||||
}
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
||||
if rac.isRateLimited() {
|
||||
return nil, ErrRateLimited
|
||||
|
@ -385,13 +487,25 @@ func (rac *AuthenticatedClient) SubredditAbout(ctx context.Context, subreddit st
|
|||
WithURL(url),
|
||||
}...)
|
||||
req := NewRequest(opts...)
|
||||
sr, err := rac.request(ctx, req, NewSubredditResponse, nil)
|
||||
srr, err := rac.request(ctx, req, NewSubredditResponse, nil)
|
||||
|
||||
if err != nil {
|
||||
if err == ErrOauthRevoked {
|
||||
return nil, ErrSubredditIsPrivate
|
||||
} else if serr, ok := err.(ServerError); ok {
|
||||
if serr.StatusCode == 404 {
|
||||
return nil, ErrSubredditNotFound
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sr.(*SubredditResponse), nil
|
||||
sr := srr.(*SubredditResponse)
|
||||
if sr.Quarantined {
|
||||
return nil, ErrSubredditIsQuarantined
|
||||
}
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {
|
||||
|
|
|
@ -25,4 +25,10 @@ var (
|
|||
ErrRequiresRedditId = errors.New("requires reddit id")
|
||||
// ErrInvalidBasicAuth .
|
||||
ErrInvalidBasicAuth = errors.New("invalid basic auth")
|
||||
// ErrSubredditIsPrivate .
|
||||
ErrSubredditIsPrivate = errors.New("subreddit is private")
|
||||
// ErrSubredditIsQuarantined .
|
||||
ErrSubredditIsQuarantined = errors.New("subreddit is quarantined")
|
||||
// ErrSubredditNotFound .
|
||||
ErrSubredditNotFound = errors.New("subreddit not found")
|
||||
)
|
||||
|
|
|
@ -83,6 +83,7 @@ type Thing struct {
|
|||
URL string `json:"url"`
|
||||
Flair string `json:"flair"`
|
||||
Thumbnail string `json:"thumbnail"`
|
||||
Over18 bool `json:"over_18"`
|
||||
}
|
||||
|
||||
func (t *Thing) FullName() string {
|
||||
|
@ -120,6 +121,7 @@ func NewThing(val *fastjson.Value) *Thing {
|
|||
t.URL = string(data.GetStringBytes("url"))
|
||||
t.Flair = string(data.GetStringBytes("link_flair_text"))
|
||||
t.Thumbnail = string(data.GetStringBytes("thumbnail"))
|
||||
t.Over18 = data.GetBool("over_18")
|
||||
|
||||
return t
|
||||
}
|
||||
|
@ -156,7 +158,8 @@ func NewListingResponse(val *fastjson.Value) interface{} {
|
|||
type SubredditResponse struct {
|
||||
Thing
|
||||
|
||||
Name string
|
||||
Name string
|
||||
Quarantined bool
|
||||
}
|
||||
|
||||
func NewSubredditResponse(val *fastjson.Value) interface{} {
|
||||
|
@ -167,6 +170,7 @@ func NewSubredditResponse(val *fastjson.Value) interface{} {
|
|||
data := val.Get("data")
|
||||
sr.ID = string(data.GetStringBytes("id"))
|
||||
sr.Name = string(data.GetStringBytes("display_name"))
|
||||
sr.Quarantined = data.GetBool("quarantined")
|
||||
|
||||
return sr
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package worker
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -192,13 +191,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
zap.Int("page", page),
|
||||
)
|
||||
|
||||
i := rand.Intn(len(watchers))
|
||||
watcher := watchers[i]
|
||||
|
||||
acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID)
|
||||
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
||||
|
||||
sps, err := rac.SubredditNew(sc,
|
||||
sps, err := sc.reddit.SubredditNew(sc,
|
||||
subreddit.Name,
|
||||
reddit.WithQuery("before", before),
|
||||
reddit.WithQuery("limit", "100"),
|
||||
|
@ -259,12 +252,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
zap.String("subreddit#name", subreddit.NormalizedName()),
|
||||
)
|
||||
{
|
||||
i := rand.Intn(len(watchers))
|
||||
watcher := watchers[i]
|
||||
|
||||
acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID)
|
||||
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
||||
sps, err := rac.SubredditHot(sc,
|
||||
sps, err := sc.reddit.SubredditHot(sc,
|
||||
subreddit.Name,
|
||||
reddit.WithQuery("limit", "100"),
|
||||
)
|
||||
|
@ -452,7 +440,7 @@ func payloadFromPost(post *reddit.Thing) *payload.Payload {
|
|||
MutableContent().
|
||||
Sound("traloop.wav")
|
||||
|
||||
if post.Thumbnail != "" {
|
||||
if post.Thumbnail != "" && !post.Over18 {
|
||||
payload.Custom("thumbnail", post.Thumbnail)
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ package worker
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -167,12 +166,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
|||
return
|
||||
}
|
||||
|
||||
// Grab last month's top posts so we calculate a trending average
|
||||
i := rand.Intn(len(watchers))
|
||||
watcher := watchers[i]
|
||||
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||
|
||||
tps, err := rac.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week"))
|
||||
tps, err := tc.reddit.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week"))
|
||||
if err != nil {
|
||||
tc.logger.Error("failed to fetch weeks's top posts",
|
||||
zap.Error(err),
|
||||
|
@ -219,11 +213,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
|||
zap.Int64("score", medianScore),
|
||||
)
|
||||
|
||||
// Grab hot posts and filter out anything that's > 2 days old
|
||||
i = rand.Intn(len(watchers))
|
||||
watcher = watchers[i]
|
||||
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
||||
hps, err := rac.SubredditHot(tc, subreddit.Name)
|
||||
hps, err := tc.reddit.SubredditHot(tc, subreddit.Name)
|
||||
if err != nil {
|
||||
tc.logger.Error("failed to fetch hot posts",
|
||||
zap.Error(err),
|
||||
|
@ -340,7 +330,7 @@ func payloadFromTrendingPost(post *reddit.Thing) *payload.Payload {
|
|||
MutableContent().
|
||||
Sound("traloop.wav")
|
||||
|
||||
if post.Thumbnail != "" {
|
||||
if post.Thumbnail != "" && !post.Over18 {
|
||||
payload.Custom("thumbnail", post.Thumbnail)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue