mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-13 23:47:44 +00:00
omg fix watchers and subreddit things (#79)
* omg fix watchers and subreddit things * don't send thumbnails for posts that are over 18
This commit is contained in:
parent
7e9cf6e78a
commit
8dc4ac350e
6 changed files with 149 additions and 41 deletions
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/christianselig/apollo-backend/internal/domain"
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
|
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||||
)
|
)
|
||||||
|
|
||||||
type watcherCriteria struct {
|
type watcherCriteria struct {
|
||||||
|
@ -97,8 +98,6 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
|
||||||
|
|
||||||
watcher := domain.Watcher{
|
watcher := domain.Watcher{
|
||||||
Label: cwr.Label,
|
Label: cwr.Label,
|
||||||
DeviceID: dev.ID,
|
DeviceID: dev.ID,
|
||||||
|
@ -112,9 +111,14 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cwr.Type == "subreddit" || cwr.Type == "trending" {
|
if cwr.Type == "subreddit" || cwr.Type == "trending" {
|
||||||
srr, err := ac.SubredditAbout(ctx, cwr.Subreddit)
|
srr, err := a.reddit.SubredditAbout(ctx, cwr.Subreddit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.errorResponse(w, r, 422, err)
|
switch err {
|
||||||
|
case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined:
|
||||||
|
a.errorResponse(w, r, 403, err)
|
||||||
|
default:
|
||||||
|
a.errorResponse(w, r, 422, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,6 +145,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
watcher.WatcheeID = sr.ID
|
watcher.WatcheeID = sr.ID
|
||||||
} else if cwr.Type == "user" {
|
} else if cwr.Type == "user" {
|
||||||
|
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||||
urr, err := ac.UserAbout(ctx, cwr.User)
|
urr, err := ac.UserAbout(ctx, cwr.User)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.errorResponse(w, r, 500, err)
|
a.errorResponse(w, r, 500, err)
|
||||||
|
@ -235,12 +240,10 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account := accs[0]
|
|
||||||
found := false
|
found := false
|
||||||
for _, acc := range accs {
|
for _, acc := range accs {
|
||||||
if acc.AccountID == rid {
|
if acc.AccountID == rid {
|
||||||
found = true
|
found = true
|
||||||
account = acc
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,11 +253,14 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
srr, err := a.reddit.SubredditAbout(ctx, lsr)
|
||||||
|
|
||||||
srr, err := ac.SubredditAbout(ctx, lsr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.errorResponse(w, r, 422, err)
|
switch err {
|
||||||
|
case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined:
|
||||||
|
a.errorResponse(w, r, 403, err)
|
||||||
|
default:
|
||||||
|
a.errorResponse(w, r, 422, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -188,6 +188,108 @@ func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rc *Client) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
||||||
|
bb, _, err := rc.doRequest(ctx, r)
|
||||||
|
|
||||||
|
if err != nil && err != ErrOauthRevoked && r.retry {
|
||||||
|
for _, backoff := range backoffSchedule {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
time.AfterFunc(backoff, func() {
|
||||||
|
_ = rc.statsd.Incr("reddit.api.retries", r.tags, 0.1)
|
||||||
|
bb, _, err = rc.doRequest(ctx, r)
|
||||||
|
done <- struct{}{}
|
||||||
|
})
|
||||||
|
|
||||||
|
<-done
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
|
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
||||||
|
return nil, ErrTimeout
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes {
|
||||||
|
return empty, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parser := rc.pool.Get()
|
||||||
|
defer rc.pool.Put(parser)
|
||||||
|
|
||||||
|
val, err := parser.ParseBytes(bb)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return rh(val), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *Client) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {
|
||||||
|
url := fmt.Sprintf("https://www.reddit.com/r/%s/%s.json", subreddit, sort)
|
||||||
|
opts = append(rc.defaultOpts, opts...)
|
||||||
|
opts = append(opts, []RequestOption{
|
||||||
|
WithMethod("GET"),
|
||||||
|
WithURL(url),
|
||||||
|
}...)
|
||||||
|
req := NewRequest(opts...)
|
||||||
|
|
||||||
|
lr, err := rc.request(ctx, req, NewListingResponse, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return lr.(*ListingResponse), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *Client) SubredditHot(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||||
|
return rc.subredditPosts(ctx, subreddit, "hot", opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *Client) SubredditTop(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||||
|
return rc.subredditPosts(ctx, subreddit, "top", opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *Client) SubredditNew(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||||
|
return rc.subredditPosts(ctx, subreddit, "new", opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *Client) SubredditAbout(ctx context.Context, subreddit string, opts ...RequestOption) (*SubredditResponse, error) {
|
||||||
|
url := fmt.Sprintf("https://www.reddit.com/r/%s/about.json", subreddit)
|
||||||
|
opts = append(rc.defaultOpts, opts...)
|
||||||
|
opts = append(opts, []RequestOption{
|
||||||
|
WithMethod("GET"),
|
||||||
|
WithURL(url),
|
||||||
|
}...)
|
||||||
|
req := NewRequest(opts...)
|
||||||
|
srr, err := rc.request(ctx, req, NewSubredditResponse, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == ErrOauthRevoked {
|
||||||
|
return nil, ErrSubredditIsPrivate
|
||||||
|
} else if serr, ok := err.(ServerError); ok {
|
||||||
|
if serr.StatusCode == 404 {
|
||||||
|
return nil, ErrSubredditNotFound
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sr := srr.(*SubredditResponse)
|
||||||
|
if sr.Quarantined {
|
||||||
|
return nil, ErrSubredditIsQuarantined
|
||||||
|
}
|
||||||
|
|
||||||
|
return sr, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
||||||
if rac.isRateLimited() {
|
if rac.isRateLimited() {
|
||||||
return nil, ErrRateLimited
|
return nil, ErrRateLimited
|
||||||
|
@ -385,13 +487,25 @@ func (rac *AuthenticatedClient) SubredditAbout(ctx context.Context, subreddit st
|
||||||
WithURL(url),
|
WithURL(url),
|
||||||
}...)
|
}...)
|
||||||
req := NewRequest(opts...)
|
req := NewRequest(opts...)
|
||||||
sr, err := rac.request(ctx, req, NewSubredditResponse, nil)
|
srr, err := rac.request(ctx, req, NewSubredditResponse, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == ErrOauthRevoked {
|
||||||
|
return nil, ErrSubredditIsPrivate
|
||||||
|
} else if serr, ok := err.(ServerError); ok {
|
||||||
|
if serr.StatusCode == 404 {
|
||||||
|
return nil, ErrSubredditNotFound
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return sr.(*SubredditResponse), nil
|
sr := srr.(*SubredditResponse)
|
||||||
|
if sr.Quarantined {
|
||||||
|
return nil, ErrSubredditIsQuarantined
|
||||||
|
}
|
||||||
|
|
||||||
|
return sr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {
|
func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {
|
||||||
|
|
|
@ -25,4 +25,10 @@ var (
|
||||||
ErrRequiresRedditId = errors.New("requires reddit id")
|
ErrRequiresRedditId = errors.New("requires reddit id")
|
||||||
// ErrInvalidBasicAuth .
|
// ErrInvalidBasicAuth .
|
||||||
ErrInvalidBasicAuth = errors.New("invalid basic auth")
|
ErrInvalidBasicAuth = errors.New("invalid basic auth")
|
||||||
|
// ErrSubredditIsPrivate .
|
||||||
|
ErrSubredditIsPrivate = errors.New("subreddit is private")
|
||||||
|
// ErrSubredditIsQuarantined .
|
||||||
|
ErrSubredditIsQuarantined = errors.New("subreddit is quarantined")
|
||||||
|
// ErrSubredditNotFound .
|
||||||
|
ErrSubredditNotFound = errors.New("subreddit not found")
|
||||||
)
|
)
|
||||||
|
|
|
@ -83,6 +83,7 @@ type Thing struct {
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
Flair string `json:"flair"`
|
Flair string `json:"flair"`
|
||||||
Thumbnail string `json:"thumbnail"`
|
Thumbnail string `json:"thumbnail"`
|
||||||
|
Over18 bool `json:"over_18"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Thing) FullName() string {
|
func (t *Thing) FullName() string {
|
||||||
|
@ -120,6 +121,7 @@ func NewThing(val *fastjson.Value) *Thing {
|
||||||
t.URL = string(data.GetStringBytes("url"))
|
t.URL = string(data.GetStringBytes("url"))
|
||||||
t.Flair = string(data.GetStringBytes("link_flair_text"))
|
t.Flair = string(data.GetStringBytes("link_flair_text"))
|
||||||
t.Thumbnail = string(data.GetStringBytes("thumbnail"))
|
t.Thumbnail = string(data.GetStringBytes("thumbnail"))
|
||||||
|
t.Over18 = data.GetBool("over_18")
|
||||||
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
@ -156,7 +158,8 @@ func NewListingResponse(val *fastjson.Value) interface{} {
|
||||||
type SubredditResponse struct {
|
type SubredditResponse struct {
|
||||||
Thing
|
Thing
|
||||||
|
|
||||||
Name string
|
Name string
|
||||||
|
Quarantined bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSubredditResponse(val *fastjson.Value) interface{} {
|
func NewSubredditResponse(val *fastjson.Value) interface{} {
|
||||||
|
@ -167,6 +170,7 @@ func NewSubredditResponse(val *fastjson.Value) interface{} {
|
||||||
data := val.Get("data")
|
data := val.Get("data")
|
||||||
sr.ID = string(data.GetStringBytes("id"))
|
sr.ID = string(data.GetStringBytes("id"))
|
||||||
sr.Name = string(data.GetStringBytes("display_name"))
|
sr.Name = string(data.GetStringBytes("display_name"))
|
||||||
|
sr.Quarantined = data.GetBool("quarantined")
|
||||||
|
|
||||||
return sr
|
return sr
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package worker
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -192,13 +191,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
zap.Int("page", page),
|
zap.Int("page", page),
|
||||||
)
|
)
|
||||||
|
|
||||||
i := rand.Intn(len(watchers))
|
sps, err := sc.reddit.SubredditNew(sc,
|
||||||
watcher := watchers[i]
|
|
||||||
|
|
||||||
acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID)
|
|
||||||
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
|
||||||
|
|
||||||
sps, err := rac.SubredditNew(sc,
|
|
||||||
subreddit.Name,
|
subreddit.Name,
|
||||||
reddit.WithQuery("before", before),
|
reddit.WithQuery("before", before),
|
||||||
reddit.WithQuery("limit", "100"),
|
reddit.WithQuery("limit", "100"),
|
||||||
|
@ -259,12 +252,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
zap.String("subreddit#name", subreddit.NormalizedName()),
|
zap.String("subreddit#name", subreddit.NormalizedName()),
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
i := rand.Intn(len(watchers))
|
sps, err := sc.reddit.SubredditHot(sc,
|
||||||
watcher := watchers[i]
|
|
||||||
|
|
||||||
acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID)
|
|
||||||
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
|
|
||||||
sps, err := rac.SubredditHot(sc,
|
|
||||||
subreddit.Name,
|
subreddit.Name,
|
||||||
reddit.WithQuery("limit", "100"),
|
reddit.WithQuery("limit", "100"),
|
||||||
)
|
)
|
||||||
|
@ -452,7 +440,7 @@ func payloadFromPost(post *reddit.Thing) *payload.Payload {
|
||||||
MutableContent().
|
MutableContent().
|
||||||
Sound("traloop.wav")
|
Sound("traloop.wav")
|
||||||
|
|
||||||
if post.Thumbnail != "" {
|
if post.Thumbnail != "" && !post.Over18 {
|
||||||
payload.Custom("thumbnail", post.Thumbnail)
|
payload.Custom("thumbnail", post.Thumbnail)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ package worker
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
@ -167,12 +166,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grab last month's top posts so we calculate a trending average
|
tps, err := tc.reddit.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week"))
|
||||||
i := rand.Intn(len(watchers))
|
|
||||||
watcher := watchers[i]
|
|
||||||
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
|
||||||
|
|
||||||
tps, err := rac.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tc.logger.Error("failed to fetch weeks's top posts",
|
tc.logger.Error("failed to fetch weeks's top posts",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
|
@ -219,11 +213,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
||||||
zap.Int64("score", medianScore),
|
zap.Int64("score", medianScore),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Grab hot posts and filter out anything that's > 2 days old
|
hps, err := tc.reddit.SubredditHot(tc, subreddit.Name)
|
||||||
i = rand.Intn(len(watchers))
|
|
||||||
watcher = watchers[i]
|
|
||||||
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
|
|
||||||
hps, err := rac.SubredditHot(tc, subreddit.Name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tc.logger.Error("failed to fetch hot posts",
|
tc.logger.Error("failed to fetch hot posts",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
|
@ -340,7 +330,7 @@ func payloadFromTrendingPost(post *reddit.Thing) *payload.Payload {
|
||||||
MutableContent().
|
MutableContent().
|
||||||
Sound("traloop.wav")
|
Sound("traloop.wav")
|
||||||
|
|
||||||
if post.Thumbnail != "" {
|
if post.Thumbnail != "" && !post.Over18 {
|
||||||
payload.Custom("thumbnail", post.Thumbnail)
|
payload.Custom("thumbnail", post.Thumbnail)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue