mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-10 22:17:44 +00:00
Merge pull request #25 from christianselig/feature/subreddit-watchers
Subreddit notifications
This commit is contained in:
commit
b9e6950cb1
19 changed files with 14017 additions and 18 deletions
|
@ -26,6 +26,8 @@ type api struct {
|
||||||
|
|
||||||
accountRepo domain.AccountRepository
|
accountRepo domain.AccountRepository
|
||||||
deviceRepo domain.DeviceRepository
|
deviceRepo domain.DeviceRepository
|
||||||
|
subredditRepo domain.SubredditRepository
|
||||||
|
watcherRepo domain.WatcherRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api {
|
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api {
|
||||||
|
@ -52,14 +54,19 @@ func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, p
|
||||||
|
|
||||||
accountRepo := repository.NewPostgresAccount(pool)
|
accountRepo := repository.NewPostgresAccount(pool)
|
||||||
deviceRepo := repository.NewPostgresDevice(pool)
|
deviceRepo := repository.NewPostgresDevice(pool)
|
||||||
|
subredditRepo := repository.NewPostgresSubreddit(pool)
|
||||||
|
watcherRepo := repository.NewPostgresWatcher(pool)
|
||||||
|
|
||||||
return &api{
|
return &api{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
statsd: statsd,
|
statsd: statsd,
|
||||||
reddit: reddit,
|
reddit: reddit,
|
||||||
apns: apns,
|
apns: apns,
|
||||||
|
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
deviceRepo: deviceRepo,
|
deviceRepo: deviceRepo,
|
||||||
|
subredditRepo: subredditRepo,
|
||||||
|
watcherRepo: watcherRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,6 +90,10 @@ func (a *api) Routes() *mux.Router {
|
||||||
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
|
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
|
||||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
|
||||||
|
|
||||||
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
|
||||||
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watchers", a.listWatchersHandler).Methods("GET")
|
||||||
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
|
||||||
|
|
||||||
r.HandleFunc("/v1/receipt", a.checkReceiptHandler).Methods("POST")
|
r.HandleFunc("/v1/receipt", a.checkReceiptHandler).Methods("POST")
|
||||||
r.HandleFunc("/v1/receipt/{apns}", a.checkReceiptHandler).Methods("POST")
|
r.HandleFunc("/v1/receipt/{apns}", a.checkReceiptHandler).Methods("POST")
|
||||||
|
|
||||||
|
|
175
internal/api/watcher.go
Normal file
175
internal/api/watcher.go
Normal file
|
@ -0,0 +1,175 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
type watcherCriteria struct {
|
||||||
|
Upvotes int64
|
||||||
|
Keyword string
|
||||||
|
Flair string
|
||||||
|
Domain string
|
||||||
|
}
|
||||||
|
|
||||||
|
type createWatcherRequest struct {
|
||||||
|
Subreddit string
|
||||||
|
Criteria watcherCriteria
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
apns := vars["apns"]
|
||||||
|
redditID := vars["redditID"]
|
||||||
|
|
||||||
|
cwr := &createWatcherRequest{
|
||||||
|
Criteria: watcherCriteria{
|
||||||
|
Upvotes: 0,
|
||||||
|
Keyword: "",
|
||||||
|
Flair: "",
|
||||||
|
Domain: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(cwr); err != nil {
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accs, err := a.accountRepo.GetByAPNSToken(ctx, apns)
|
||||||
|
if err != nil || len(accs) == 0 {
|
||||||
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
account := accs[0]
|
||||||
|
found := false
|
||||||
|
for _, acc := range accs {
|
||||||
|
if acc.AccountID == redditID {
|
||||||
|
found = true
|
||||||
|
account = acc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
a.errorResponse(w, r, 422, "yeah nice try")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
||||||
|
srr, err := ac.SubredditAbout(cwr.Subreddit)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sr, err := a.subredditRepo.GetByName(ctx, cwr.Subreddit)
|
||||||
|
if err != nil {
|
||||||
|
switch err {
|
||||||
|
case domain.ErrNotFound:
|
||||||
|
// Might be that we don't know about that subreddit yet
|
||||||
|
sr = domain.Subreddit{SubredditID: srr.ID, Name: srr.Name}
|
||||||
|
_ = a.subredditRepo.CreateOrUpdate(ctx, &sr)
|
||||||
|
default:
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
watcher := domain.Watcher{
|
||||||
|
SubredditID: sr.ID,
|
||||||
|
DeviceID: dev.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Upvotes: cwr.Criteria.Upvotes,
|
||||||
|
Keyword: cwr.Criteria.Keyword,
|
||||||
|
Flair: cwr.Criteria.Flair,
|
||||||
|
Domain: cwr.Criteria.Domain,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := a.watcherRepo.Create(ctx, &watcher); err != nil {
|
||||||
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
id, err := strconv.ParseInt(vars["watcherID"], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apns := vars["apns"]
|
||||||
|
|
||||||
|
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
watcher, err := a.watcherRepo.GetByID(ctx, id)
|
||||||
|
if err != nil || watcher.DeviceID != dev.ID {
|
||||||
|
a.errorResponse(w, r, 422, "nice try")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = a.watcherRepo.Delete(ctx, id)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
type watcherItem struct {
|
||||||
|
ID int64
|
||||||
|
Upvotes int64
|
||||||
|
Keyword string
|
||||||
|
Flair string
|
||||||
|
Domain string
|
||||||
|
Hits int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
apns := vars["apns"]
|
||||||
|
redditID := vars["redditID"]
|
||||||
|
|
||||||
|
watchers, err := a.watcherRepo.GetByDeviceAPNSTokenAndAccountRedditID(ctx, apns, redditID)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 400, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wis := make([]watcherItem, len(watchers))
|
||||||
|
for i, watcher := range watchers {
|
||||||
|
wi := watcherItem{
|
||||||
|
ID: watcher.ID,
|
||||||
|
Upvotes: watcher.Upvotes,
|
||||||
|
Keyword: watcher.Keyword,
|
||||||
|
Flair: watcher.Flair,
|
||||||
|
Domain: watcher.Domain,
|
||||||
|
Hits: watcher.Hits,
|
||||||
|
}
|
||||||
|
|
||||||
|
wis[i] = wi
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(wis)
|
||||||
|
}
|
|
@ -22,7 +22,9 @@ import (
|
||||||
const (
|
const (
|
||||||
batchSize = 250
|
batchSize = 250
|
||||||
checkTimeout = 60 // how long until we force a check
|
checkTimeout = 60 // how long until we force a check
|
||||||
enqueueTimeout = 5 // how long until we try to re-enqueue
|
|
||||||
|
accountEnqueueTimeout = 5 // how frequently we want to check (seconds)
|
||||||
|
subredditEnqueueTimeout = 5 * 60 // how frequently we want to check (seconds)
|
||||||
|
|
||||||
staleAccountThreshold = 7200 // 2 hours
|
staleAccountThreshold = 7200 // 2 hours
|
||||||
staleDeviceThreshold = 604800 // 1 week
|
staleDeviceThreshold = 604800 // 1 week
|
||||||
|
@ -70,8 +72,14 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
subredditQueue, err := queue.OpenQueue("subreddits")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
s := gocron.NewScheduler(time.UTC)
|
s := gocron.NewScheduler(time.UTC)
|
||||||
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
|
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
|
||||||
|
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) })
|
||||||
_, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) })
|
_, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) })
|
||||||
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) })
|
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) })
|
||||||
_, _ = s.Every(1).Minute().Do(func() { pruneAccounts(ctx, logger, db) })
|
_, _ = s.Every(1).Minute().Do(func() { pruneAccounts(ctx, logger, db) })
|
||||||
|
@ -195,6 +203,70 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
|
||||||
|
now := time.Now()
|
||||||
|
ready := now.Unix() - subredditEnqueueTimeout
|
||||||
|
|
||||||
|
ids := []int64{}
|
||||||
|
|
||||||
|
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
||||||
|
stmt := `
|
||||||
|
WITH subreddit AS (
|
||||||
|
SELECT id
|
||||||
|
FROM subreddits
|
||||||
|
WHERE last_checked_at < $1
|
||||||
|
ORDER BY last_checked_at
|
||||||
|
LIMIT 100
|
||||||
|
)
|
||||||
|
UPDATE subreddits
|
||||||
|
SET last_checked_at = $2
|
||||||
|
WHERE subreddits.id IN(SELECT id FROM subreddit)
|
||||||
|
RETURNING subreddits.id`
|
||||||
|
rows, err := tx.Query(ctx, stmt, ready, now.Unix())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var id int64
|
||||||
|
_ = rows.Scan(&id)
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch batch of subreddits")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"count": len(ids),
|
||||||
|
"start": ready,
|
||||||
|
}).Debug("enqueueing subreddit batch")
|
||||||
|
|
||||||
|
batchIds := make([]string, len(ids))
|
||||||
|
for i, id := range ids {
|
||||||
|
batchIds[i] = strconv.FormatInt(id, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = queue.Publish(batchIds...); err != nil {
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to enqueue subreddit")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = statsd.Histogram("apollo.queue.subreddits.enqueued", float64(len(ids)), []string{}, 1)
|
||||||
|
_ = statsd.Histogram("apollo.queue.subreddits.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) {
|
func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
|
@ -204,7 +276,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
|
||||||
// and at most 6 seconds ago. Also look for accounts that haven't been checked
|
// and at most 6 seconds ago. Also look for accounts that haven't been checked
|
||||||
// in over a minute.
|
// in over a minute.
|
||||||
ts := start.Unix()
|
ts := start.Unix()
|
||||||
ready := ts - enqueueTimeout
|
ready := ts - accountEnqueueTimeout
|
||||||
expired := ts - checkTimeout
|
expired := ts - checkTimeout
|
||||||
|
|
||||||
ids := []int64{}
|
ids := []int64{}
|
||||||
|
@ -292,9 +364,9 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = statsd.Histogram("apollo.queue.enqueued", float64(enqueued), []string{}, 1)
|
_ = statsd.Histogram("apollo.queue.notifications.enqueued", float64(enqueued), []string{}, 1)
|
||||||
_ = statsd.Histogram("apollo.queue.skipped", float64(skipped), []string{}, 1)
|
_ = statsd.Histogram("apollo.queue.notifications.skipped", float64(skipped), []string{}, 1)
|
||||||
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(start).Milliseconds()), []string{}, 1)
|
_ = statsd.Histogram("apollo.queue.notifications.runtime", float64(time.Since(start).Milliseconds()), []string{}, 1)
|
||||||
|
|
||||||
logger.WithFields(logrus.Fields{
|
logger.WithFields(logrus.Fields{
|
||||||
"count": enqueued,
|
"count": enqueued,
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
var (
|
var (
|
||||||
queues = map[string]worker.NewWorkerFn{
|
queues = map[string]worker.NewWorkerFn{
|
||||||
"notifications": worker.NewNotificationsWorker,
|
"notifications": worker.NewNotificationsWorker,
|
||||||
|
"subreddits": worker.NewSubredditsWorker,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ type Device struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeviceRepository interface {
|
type DeviceRepository interface {
|
||||||
|
GetByID(ctx context.Context, id int64) (Device, error)
|
||||||
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
||||||
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
|
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||||
|
|
||||||
|
|
25
internal/domain/subreddit.go
Normal file
25
internal/domain/subreddit.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Subreddit struct {
|
||||||
|
ID int64
|
||||||
|
|
||||||
|
// Reddit information
|
||||||
|
SubredditID string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sr *Subreddit) NormalizedName() string {
|
||||||
|
return strings.ToLower(sr.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
type SubredditRepository interface {
|
||||||
|
GetByID(ctx context.Context, id int64) (Subreddit, error)
|
||||||
|
GetByName(ctx context.Context, name string) (Subreddit, error)
|
||||||
|
|
||||||
|
CreateOrUpdate(ctx context.Context, sr *Subreddit) error
|
||||||
|
}
|
29
internal/domain/watcher.go
Normal file
29
internal/domain/watcher.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type Watcher struct {
|
||||||
|
ID int64
|
||||||
|
CreatedAt float64
|
||||||
|
|
||||||
|
DeviceID int64
|
||||||
|
AccountID int64
|
||||||
|
SubredditID int64
|
||||||
|
|
||||||
|
Upvotes int64
|
||||||
|
Keyword string
|
||||||
|
Flair string
|
||||||
|
Domain string
|
||||||
|
Hits int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type WatcherRepository interface {
|
||||||
|
GetByID(ctx context.Context, id int64) (Watcher, error)
|
||||||
|
GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error)
|
||||||
|
GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error)
|
||||||
|
|
||||||
|
Create(ctx context.Context, watcher *Watcher) error
|
||||||
|
Update(ctx context.Context, watcher *Watcher) error
|
||||||
|
IncrementHits(ctx context.Context, id int64) error
|
||||||
|
Delete(ctx context.Context, id int64) error
|
||||||
|
}
|
|
@ -276,7 +276,7 @@ func (iapr *IAPResponse) handleAppleResponse() {
|
||||||
// For sandbox environment, be more lenient (just ensure bundle ID is accurate) because otherwise you'll break
|
// For sandbox environment, be more lenient (just ensure bundle ID is accurate) because otherwise you'll break
|
||||||
// things for TestFlight users (see: https://twitter.com/ChristianSelig/status/1414990459861098496)
|
// things for TestFlight users (see: https://twitter.com/ChristianSelig/status/1414990459861098496)
|
||||||
// TODO(andremedeiros): let this through for now
|
// TODO(andremedeiros): let this through for now
|
||||||
if iapr.Environment == Sandbox && false {
|
if iapr.Environment == Sandbox && true {
|
||||||
ultraProduct := VerificationProduct{Name: "ultra", Status: "SANDBOX", SubscriptionType: "SANDBOX"}
|
ultraProduct := VerificationProduct{Name: "ultra", Status: "SANDBOX", SubscriptionType: "SANDBOX"}
|
||||||
proProduct := VerificationProduct{Name: "pro", Status: "SANDBOX"}
|
proProduct := VerificationProduct{Name: "pro", Status: "SANDBOX"}
|
||||||
communityIconsProduct := VerificationProduct{Name: "community_icons", Status: "SANDBOX"}
|
communityIconsProduct := VerificationProduct{Name: "community_icons", Status: "SANDBOX"}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package reddit
|
package reddit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
|
@ -179,6 +180,48 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...RequestOption) (*SubredditResponse, error) {
|
||||||
|
url := fmt.Sprintf("https://oauth.reddit.com/r/%s/about.json", subreddit)
|
||||||
|
opts = append([]RequestOption{
|
||||||
|
WithMethod("GET"),
|
||||||
|
WithToken(rac.accessToken),
|
||||||
|
WithURL(url),
|
||||||
|
}, opts...)
|
||||||
|
req := NewRequest(opts...)
|
||||||
|
sr, err := rac.request(req, NewSubredditResponse, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sr.(*SubredditResponse), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *AuthenticatedClient) subredditPosts(subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {
|
||||||
|
url := fmt.Sprintf("https://oauth.reddit.com/r/%s/%s.json", subreddit, sort)
|
||||||
|
opts = append([]RequestOption{
|
||||||
|
WithMethod("GET"),
|
||||||
|
WithToken(rac.accessToken),
|
||||||
|
WithURL(url),
|
||||||
|
}, opts...)
|
||||||
|
req := NewRequest(opts...)
|
||||||
|
|
||||||
|
lr, err := rac.request(req, NewListingResponse, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return lr.(*ListingResponse), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *AuthenticatedClient) SubredditHot(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||||
|
return rac.subredditPosts(subreddit, "hot", opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *AuthenticatedClient) SubredditNew(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
|
||||||
|
return rac.subredditPosts(subreddit, "new", opts...)
|
||||||
|
}
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) {
|
func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) {
|
||||||
opts = append([]RequestOption{
|
opts = append([]RequestOption{
|
||||||
WithTags([]string{"url:/api/v1/message/inbox"}),
|
WithTags([]string{"url:/api/v1/message/inbox"}),
|
||||||
|
|
|
@ -87,6 +87,10 @@ func WithBody(key, val string) RequestOption {
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithQuery(key, val string) RequestOption {
|
func WithQuery(key, val string) RequestOption {
|
||||||
|
if val == "" {
|
||||||
|
return func(req *Request) {}
|
||||||
|
}
|
||||||
|
|
||||||
return func(req *Request) {
|
return func(req *Request) {
|
||||||
req.query.Set(key, val)
|
req.query.Set(key, val)
|
||||||
}
|
}
|
||||||
|
|
110
internal/reddit/testdata/subreddit_about.json
vendored
Normal file
110
internal/reddit/testdata/subreddit_about.json
vendored
Normal file
File diff suppressed because one or more lines are too long
12766
internal/reddit/testdata/subreddit_new.json
vendored
Normal file
12766
internal/reddit/testdata/subreddit_new.json
vendored
Normal file
File diff suppressed because one or more lines are too long
|
@ -73,6 +73,11 @@ type Thing struct {
|
||||||
LinkTitle string `json:"link_title"`
|
LinkTitle string `json:"link_title"`
|
||||||
Destination string `json:"dest"`
|
Destination string `json:"dest"`
|
||||||
Subreddit string `json:"subreddit"`
|
Subreddit string `json:"subreddit"`
|
||||||
|
Score int64 `json:"score"`
|
||||||
|
SelfText string `json:"selftext"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Flair string `json:"flair"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Thing) FullName() string {
|
func (t *Thing) FullName() string {
|
||||||
|
@ -98,6 +103,12 @@ func NewThing(val *fastjson.Value) *Thing {
|
||||||
t.Destination = string(data.GetStringBytes("dest"))
|
t.Destination = string(data.GetStringBytes("dest"))
|
||||||
t.Subreddit = string(data.GetStringBytes("subreddit"))
|
t.Subreddit = string(data.GetStringBytes("subreddit"))
|
||||||
|
|
||||||
|
t.Score = data.GetInt64("score")
|
||||||
|
t.Title = string(data.GetStringBytes("title"))
|
||||||
|
t.SelfText = string(data.GetStringBytes("selftext"))
|
||||||
|
t.URL = string(data.GetStringBytes("url"))
|
||||||
|
t.Flair = string(data.GetStringBytes("link_flair_text"))
|
||||||
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,4 +142,22 @@ func NewListingResponse(val *fastjson.Value) interface{} {
|
||||||
return lr
|
return lr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SubredditResponse struct {
|
||||||
|
Thing
|
||||||
|
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSubredditResponse(val *fastjson.Value) interface{} {
|
||||||
|
sr := &SubredditResponse{}
|
||||||
|
|
||||||
|
sr.Kind = string(val.GetStringBytes("kind"))
|
||||||
|
|
||||||
|
data := val.Get("data")
|
||||||
|
sr.ID = string(data.GetStringBytes("id"))
|
||||||
|
sr.Name = string(data.GetStringBytes("display_name"))
|
||||||
|
|
||||||
|
return sr
|
||||||
|
}
|
||||||
|
|
||||||
var EmptyListingResponse = &ListingResponse{}
|
var EmptyListingResponse = &ListingResponse{}
|
||||||
|
|
|
@ -43,6 +43,7 @@ func TestRefreshTokenResponseParsing(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListingResponseParsing(t *testing.T) {
|
func TestListingResponseParsing(t *testing.T) {
|
||||||
|
// Message list
|
||||||
bb, err := ioutil.ReadFile("testdata/message_inbox.json")
|
bb, err := ioutil.ReadFile("testdata/message_inbox.json")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -74,4 +75,39 @@ func TestListingResponseParsing(t *testing.T) {
|
||||||
assert.Equal(t, "t1_h46tec3", thing.ParentID)
|
assert.Equal(t, "t1_h46tec3", thing.ParentID)
|
||||||
assert.Equal(t, "hello i am a cat", thing.LinkTitle)
|
assert.Equal(t, "hello i am a cat", thing.LinkTitle)
|
||||||
assert.Equal(t, "calicosummer", thing.Subreddit)
|
assert.Equal(t, "calicosummer", thing.Subreddit)
|
||||||
|
|
||||||
|
// Post list
|
||||||
|
bb, err = ioutil.ReadFile("testdata/subreddit_new.json")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
val, err = parser.ParseBytes(bb)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ret = NewListingResponse(val)
|
||||||
|
l = ret.(*ListingResponse)
|
||||||
|
assert.NotNil(t, l)
|
||||||
|
|
||||||
|
assert.Equal(t, 100, l.Count)
|
||||||
|
|
||||||
|
thing = l.Children[1]
|
||||||
|
assert.Equal(t, "Riven boss", thing.Title)
|
||||||
|
assert.Equal(t, "Question", thing.Flair)
|
||||||
|
assert.Contains(t, thing.SelfText, "never done riven")
|
||||||
|
assert.Equal(t, int64(1), thing.Score)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubredditResponseParsing(t *testing.T) {
|
||||||
|
bb, err := ioutil.ReadFile("testdata/subreddit_about.json")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
val, err := parser.ParseBytes(bb)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ret := NewSubredditResponse(val)
|
||||||
|
s := ret.(*SubredditResponse)
|
||||||
|
assert.NotNil(t, s)
|
||||||
|
|
||||||
|
assert.Equal(t, "t5", s.Kind)
|
||||||
|
assert.Equal(t, "2vq0w", s.ID)
|
||||||
|
assert.Equal(t, "DestinyTheGame", s.Name)
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,6 +40,23 @@ func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args
|
||||||
return devs, nil
|
return devs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domain.Device, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, apns_token, sandbox, active_until
|
||||||
|
FROM devices
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
devs, err := p.fetch(ctx, query, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return domain.Device{}, err
|
||||||
|
}
|
||||||
|
if len(devs) == 0 {
|
||||||
|
return domain.Device{}, domain.ErrNotFound
|
||||||
|
}
|
||||||
|
return devs[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token string) (domain.Device, error) {
|
func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token string) (domain.Device, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, apns_token, sandbox, active_until
|
SELECT id, apns_token, sandbox, active_until
|
||||||
|
|
90
internal/repository/postgres_subreddit.go
Normal file
90
internal/repository/postgres_subreddit.go
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
type postgresSubredditRepository struct {
|
||||||
|
pool *pgxpool.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresSubreddit(pool *pgxpool.Pool) domain.SubredditRepository {
|
||||||
|
return &postgresSubredditRepository{pool: pool}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Subreddit, error) {
|
||||||
|
rows, err := p.pool.Query(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var srs []domain.Subreddit
|
||||||
|
for rows.Next() {
|
||||||
|
var sr domain.Subreddit
|
||||||
|
if err := rows.Scan(
|
||||||
|
&sr.ID,
|
||||||
|
&sr.SubredditID,
|
||||||
|
&sr.Name,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
srs = append(srs, sr)
|
||||||
|
}
|
||||||
|
return srs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, subreddit_id, name
|
||||||
|
FROM subreddits
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
srs, err := p.fetch(ctx, query, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return domain.Subreddit{}, err
|
||||||
|
}
|
||||||
|
if len(srs) == 0 {
|
||||||
|
return domain.Subreddit{}, domain.ErrNotFound
|
||||||
|
}
|
||||||
|
return srs[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, subreddit_id, name
|
||||||
|
FROM subreddits
|
||||||
|
WHERE name = $1`
|
||||||
|
|
||||||
|
name = strings.ToLower(name)
|
||||||
|
|
||||||
|
srs, err := p.fetch(ctx, query, name)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return domain.Subreddit{}, err
|
||||||
|
}
|
||||||
|
if len(srs) == 0 {
|
||||||
|
return domain.Subreddit{}, domain.ErrNotFound
|
||||||
|
}
|
||||||
|
return srs[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *domain.Subreddit) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO subreddits (subreddit_id, name)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
ON CONFLICT(subreddit_id) DO NOTHING
|
||||||
|
RETURNING id`
|
||||||
|
|
||||||
|
return p.pool.QueryRow(
|
||||||
|
ctx,
|
||||||
|
query,
|
||||||
|
sr.SubredditID,
|
||||||
|
sr.NormalizedName(),
|
||||||
|
).Scan(&sr.ID)
|
||||||
|
}
|
165
internal/repository/postgres_watcher.go
Normal file
165
internal/repository/postgres_watcher.go
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
|
|
||||||
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type postgresWatcherRepository struct {
|
||||||
|
pool *pgxpool.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresWatcher(pool *pgxpool.Pool) domain.WatcherRepository {
|
||||||
|
return &postgresWatcherRepository{pool: pool}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Watcher, error) {
|
||||||
|
rows, err := p.pool.Query(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var watchers []domain.Watcher
|
||||||
|
for rows.Next() {
|
||||||
|
var watcher domain.Watcher
|
||||||
|
if err := rows.Scan(
|
||||||
|
&watcher.ID,
|
||||||
|
&watcher.CreatedAt,
|
||||||
|
&watcher.DeviceID,
|
||||||
|
&watcher.AccountID,
|
||||||
|
&watcher.SubredditID,
|
||||||
|
&watcher.Upvotes,
|
||||||
|
&watcher.Keyword,
|
||||||
|
&watcher.Flair,
|
||||||
|
&watcher.Domain,
|
||||||
|
&watcher.Hits,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
watchers = append(watchers, watcher)
|
||||||
|
}
|
||||||
|
return watchers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain, hits
|
||||||
|
FROM watchers
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
watchers, err := p.fetch(ctx, query, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return domain.Watcher{}, err
|
||||||
|
}
|
||||||
|
if len(watchers) == 0 {
|
||||||
|
return domain.Watcher{}, domain.ErrNotFound
|
||||||
|
}
|
||||||
|
return watchers[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain, hits
|
||||||
|
FROM watchers
|
||||||
|
WHERE subreddit_id = $1`
|
||||||
|
|
||||||
|
return p.fetch(ctx, query, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]domain.Watcher, error) {
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
watchers.id,
|
||||||
|
watchers.created_at,
|
||||||
|
watchers.device_id,
|
||||||
|
watchers.account_id,
|
||||||
|
watchers.subreddit_id,
|
||||||
|
watchers.upvotes,
|
||||||
|
watchers.keyword,
|
||||||
|
watchers.flair,
|
||||||
|
watchers.domain,
|
||||||
|
watchers.hits
|
||||||
|
FROM watchers
|
||||||
|
INNER JOIN accounts ON watchers.account_id = accounts.id
|
||||||
|
INNER JOIN devices ON watchers.device_id = devices.id
|
||||||
|
WHERE
|
||||||
|
devices.apns_token = $1 AND
|
||||||
|
accounts.account_id = $2`
|
||||||
|
|
||||||
|
return p.fetch(ctx, query, apns, rid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.Watcher) error {
|
||||||
|
now := float64(time.Now().UTC().Unix())
|
||||||
|
|
||||||
|
query := `
|
||||||
|
INSERT INTO watchers
|
||||||
|
(created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
RETURNING id`
|
||||||
|
|
||||||
|
return p.pool.QueryRow(
|
||||||
|
ctx,
|
||||||
|
query,
|
||||||
|
now,
|
||||||
|
watcher.DeviceID,
|
||||||
|
watcher.AccountID,
|
||||||
|
watcher.SubredditID,
|
||||||
|
watcher.Upvotes,
|
||||||
|
watcher.Keyword,
|
||||||
|
watcher.Flair,
|
||||||
|
watcher.Domain,
|
||||||
|
).Scan(&watcher.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.Watcher) error {
|
||||||
|
query := `
|
||||||
|
UPDATE watchers
|
||||||
|
SET upvotes = $2,
|
||||||
|
keyword = $3,
|
||||||
|
flair = $4,
|
||||||
|
domain = $5,
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
res, err := p.pool.Exec(
|
||||||
|
ctx,
|
||||||
|
query,
|
||||||
|
watcher.ID,
|
||||||
|
watcher.Upvotes,
|
||||||
|
watcher.Keyword,
|
||||||
|
watcher.Flair,
|
||||||
|
watcher.Domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
if res.RowsAffected() != 1 {
|
||||||
|
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) IncrementHits(ctx context.Context, id int64) error {
|
||||||
|
query := `UPDATE watchers SET hits = hits + 1 WHERE id = $1`
|
||||||
|
res, err := p.pool.Exec(ctx, query, id)
|
||||||
|
|
||||||
|
if res.RowsAffected() != 1 {
|
||||||
|
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) Delete(ctx context.Context, id int64) error {
|
||||||
|
query := `DELETE FROM watchers WHERE id = $1`
|
||||||
|
res, err := p.pool.Exec(ctx, query, id)
|
||||||
|
|
||||||
|
if res.RowsAffected() != 1 {
|
||||||
|
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
|
@ -165,7 +165,6 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
account, err := nc.accountRepo.GetByID(ctx, id)
|
account, err := nc.accountRepo.GetByID(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"account#username": account.NormalizedUsername(),
|
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to fetch account from database")
|
}).Error("failed to fetch account from database")
|
||||||
return
|
return
|
||||||
|
|
426
internal/worker/subreddits.go
Normal file
426
internal/worker/subreddits.go
Normal file
|
@ -0,0 +1,426 @@
|
||||||
|
package worker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
|
"github.com/adjust/rmq/v4"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
|
"github.com/sideshow/apns2"
|
||||||
|
"github.com/sideshow/apns2/payload"
|
||||||
|
"github.com/sideshow/apns2/token"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
|
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||||
|
"github.com/christianselig/apollo-backend/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
type subredditsWorker struct {
|
||||||
|
logger *logrus.Logger
|
||||||
|
statsd *statsd.Client
|
||||||
|
db *pgxpool.Pool
|
||||||
|
redis *redis.Client
|
||||||
|
queue rmq.Connection
|
||||||
|
reddit *reddit.Client
|
||||||
|
apns *token.Token
|
||||||
|
|
||||||
|
consumers int
|
||||||
|
|
||||||
|
accountRepo domain.AccountRepository
|
||||||
|
deviceRepo domain.DeviceRepository
|
||||||
|
subredditRepo domain.SubredditRepository
|
||||||
|
watcherRepo domain.WatcherRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
|
||||||
|
reddit := reddit.NewClient(
|
||||||
|
os.Getenv("REDDIT_CLIENT_ID"),
|
||||||
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||||
|
statsd,
|
||||||
|
consumers,
|
||||||
|
)
|
||||||
|
|
||||||
|
var apns *token.Token
|
||||||
|
{
|
||||||
|
authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
apns = &token.Token{
|
||||||
|
AuthKey: authKey,
|
||||||
|
KeyID: os.Getenv("APPLE_KEY_ID"),
|
||||||
|
TeamID: os.Getenv("APPLE_TEAM_ID"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &subredditsWorker{
|
||||||
|
logger,
|
||||||
|
statsd,
|
||||||
|
db,
|
||||||
|
redis,
|
||||||
|
queue,
|
||||||
|
reddit,
|
||||||
|
apns,
|
||||||
|
consumers,
|
||||||
|
|
||||||
|
repository.NewPostgresAccount(db),
|
||||||
|
repository.NewPostgresDevice(db),
|
||||||
|
repository.NewPostgresSubreddit(db),
|
||||||
|
repository.NewPostgresWatcher(db),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sw *subredditsWorker) Start() error {
|
||||||
|
queue, err := sw.queue.OpenQueue("subreddits")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sw.logger.WithFields(logrus.Fields{
|
||||||
|
"numConsumers": sw.consumers,
|
||||||
|
}).Info("starting up subreddits worker")
|
||||||
|
|
||||||
|
prefetchLimit := int64(sw.consumers * 2)
|
||||||
|
|
||||||
|
if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
host, _ := os.Hostname()
|
||||||
|
|
||||||
|
for i := 0; i < sw.consumers; i++ {
|
||||||
|
name := fmt.Sprintf("consumer %s-%d", host, i)
|
||||||
|
|
||||||
|
consumer := NewSubredditsConsumer(sw, i)
|
||||||
|
if _, err := queue.AddConsumer(name, consumer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sw *subredditsWorker) Stop() {
|
||||||
|
<-sw.queue.StopAllConsuming() // wait for all Consume() calls to finish
|
||||||
|
}
|
||||||
|
|
||||||
|
type subredditsConsumer struct {
|
||||||
|
*subredditsWorker
|
||||||
|
tag int
|
||||||
|
|
||||||
|
apnsSandbox *apns2.Client
|
||||||
|
apnsProduction *apns2.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSubredditsConsumer(sw *subredditsWorker, tag int) *subredditsConsumer {
|
||||||
|
return &subredditsConsumer{
|
||||||
|
sw,
|
||||||
|
tag,
|
||||||
|
apns2.NewTokenClient(sw.apns),
|
||||||
|
apns2.NewTokenClient(sw.apns).Production(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": delivery.Payload(),
|
||||||
|
}).Debug("starting job")
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": delivery.Payload(),
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to parse subreddit ID")
|
||||||
|
|
||||||
|
_ = delivery.Reject()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = delivery.Ack() }()
|
||||||
|
|
||||||
|
subreddit, err := sc.subredditRepo.GetByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch subreddit from database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
watchers, err := sc.watcherRepo.GetBySubredditID(ctx, subreddit.ID)
|
||||||
|
if err != nil {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch watchers from database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(watchers) == 0 {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Info("no watchers for subreddit, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
threshold := float64(time.Now().AddDate(0, 0, -1).UTC().Unix())
|
||||||
|
posts := []*reddit.Thing{}
|
||||||
|
before := ""
|
||||||
|
finished := false
|
||||||
|
seenPosts := map[string]bool{}
|
||||||
|
|
||||||
|
// Load 500 newest posts
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
}).Debug("loading up to 500 new posts")
|
||||||
|
|
||||||
|
for page := 0; page < 5; page++ {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
"page": page,
|
||||||
|
}).Debug("loading new posts")
|
||||||
|
|
||||||
|
i := rand.Intn(len(watchers))
|
||||||
|
watcher := watchers[i]
|
||||||
|
|
||||||
|
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
|
||||||
|
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
||||||
|
|
||||||
|
sps, err := rac.SubredditNew(
|
||||||
|
subreddit.Name,
|
||||||
|
reddit.WithQuery("before", before),
|
||||||
|
reddit.WithQuery("limit", "100"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch new posts")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
"count": sps.Count,
|
||||||
|
"page": page,
|
||||||
|
}).Debug("loaded new posts for page")
|
||||||
|
|
||||||
|
// If it's empty, we're done
|
||||||
|
if sps.Count == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we don't have 100 posts, we're going to be done
|
||||||
|
if sps.Count < 100 {
|
||||||
|
finished = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, post := range sps.Children {
|
||||||
|
if post.CreatedAt < threshold {
|
||||||
|
finished = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := seenPosts[post.ID]; !ok {
|
||||||
|
posts = append(posts, post)
|
||||||
|
seenPosts[post.ID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if finished {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
"page": page,
|
||||||
|
}).Debug("reached date threshold")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load hot posts
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
}).Debug("loading hot posts")
|
||||||
|
{
|
||||||
|
i := rand.Intn(len(watchers))
|
||||||
|
watcher := watchers[i]
|
||||||
|
|
||||||
|
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
|
||||||
|
rac := sc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
||||||
|
sps, err := rac.SubredditHot(
|
||||||
|
subreddit.Name,
|
||||||
|
reddit.WithQuery("limit", "100"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch hot posts")
|
||||||
|
} else {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
"count": sps.Count,
|
||||||
|
}).Debug("loaded hot posts")
|
||||||
|
|
||||||
|
for _, post := range sps.Children {
|
||||||
|
if post.CreatedAt < threshold {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if _, ok := seenPosts[post.ID]; !ok {
|
||||||
|
posts = append(posts, post)
|
||||||
|
seenPosts[post.ID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
"count": len(posts),
|
||||||
|
}).Debug("checking posts for hits")
|
||||||
|
for _, post := range posts {
|
||||||
|
ids := []int64{}
|
||||||
|
|
||||||
|
for _, watcher := range watchers {
|
||||||
|
// Make sure we only alert on posts created after the search
|
||||||
|
if watcher.CreatedAt > post.CreatedAt {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
matched := true
|
||||||
|
|
||||||
|
if watcher.Upvotes > 0 && post.Score < watcher.Upvotes {
|
||||||
|
matched = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if watcher.Keyword != "" && !strings.Contains(post.Title, watcher.Keyword) {
|
||||||
|
matched = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if watcher.Flair != "" && !strings.Contains(post.Flair, watcher.Flair) {
|
||||||
|
matched = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if watcher.Domain != "" && !strings.Contains(post.URL, watcher.Domain) {
|
||||||
|
matched = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = sc.watcherRepo.IncrementHits(ctx, watcher.ID)
|
||||||
|
|
||||||
|
lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID)
|
||||||
|
notified, _ := sc.redis.Get(ctx, lockKey).Bool()
|
||||||
|
|
||||||
|
if notified {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
"watcher#id": watcher.ID,
|
||||||
|
"post#id": post.ID,
|
||||||
|
}).Debug("already notified, skipping")
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
"watcher#id": watcher.ID,
|
||||||
|
"post#id": post.ID,
|
||||||
|
}).Debug("got a hit")
|
||||||
|
|
||||||
|
sc.redis.SetEX(ctx, lockKey, true, 24*time.Hour)
|
||||||
|
ids = append(ids, watcher.DeviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ids) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
"post#id": post.ID,
|
||||||
|
"count": len(ids),
|
||||||
|
}).Debug("got hits for post")
|
||||||
|
|
||||||
|
notification := &apns2.Notification{}
|
||||||
|
notification.Topic = "com.christianselig.Apollo"
|
||||||
|
notification.Payload = payloadFromPost(post)
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
device, _ := sc.deviceRepo.GetByID(ctx, id)
|
||||||
|
notification.DeviceToken = device.APNSToken
|
||||||
|
|
||||||
|
client := sc.apnsProduction
|
||||||
|
if device.Sandbox {
|
||||||
|
client = sc.apnsSandbox
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := client.Push(notification)
|
||||||
|
if err != nil {
|
||||||
|
_ = sc.statsd.Incr("apns.notification.errors", []string{}, 1)
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"device#id": device.ID,
|
||||||
|
"err": err,
|
||||||
|
"status": res.StatusCode,
|
||||||
|
"reason": res.Reason,
|
||||||
|
}).Error("failed to send notification")
|
||||||
|
} else {
|
||||||
|
_ = sc.statsd.Incr("apns.notification.sent", []string{}, 1)
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"device#id": device.ID,
|
||||||
|
"device#token": device.APNSToken,
|
||||||
|
}).Info("sent notification")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"subreddit#name": subreddit.Name,
|
||||||
|
}).Debug("finishing job")
|
||||||
|
}
|
||||||
|
|
||||||
|
func payloadFromPost(post *reddit.Thing) *payload.Payload {
|
||||||
|
payload := payload.
|
||||||
|
NewPayload().
|
||||||
|
AlertTitle(post.Title).
|
||||||
|
AlertSubtitle(fmt.Sprintf("in r/%s", post.Subreddit)).
|
||||||
|
AlertSummaryArg(post.Subreddit).
|
||||||
|
Category("post-watch").
|
||||||
|
Custom("post_title", post.Title).
|
||||||
|
Custom("post_id", post.ID).
|
||||||
|
Custom("subreddit", post.Subreddit).
|
||||||
|
Custom("author", post.Author).
|
||||||
|
Custom("post_age", post.CreatedAt).
|
||||||
|
MutableContent().
|
||||||
|
Sound("traloop.wav")
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
Loading…
Reference in a new issue