mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-10 22:17:44 +00:00
Merge pull request #28 from christianselig/feature/follow-users
This commit is contained in:
commit
38e596b27e
18 changed files with 787 additions and 55 deletions
5
go.sum
5
go.sum
|
@ -64,7 +64,6 @@ github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInq
|
||||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||||
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
|
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
|
||||||
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
|
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
|
||||||
github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
|
|
||||||
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
||||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
@ -122,8 +121,6 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V
|
||||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||||
github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8=
|
github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8=
|
||||||
github.com/go-redis/redis/v8 v8.3.2/go.mod h1:jszGxBCez8QA1HWSmQxJO9Y82kNibbUmeYhKWrBejTU=
|
github.com/go-redis/redis/v8 v8.3.2/go.mod h1:jszGxBCez8QA1HWSmQxJO9Y82kNibbUmeYhKWrBejTU=
|
||||||
github.com/go-redis/redis/v8 v8.11.3 h1:GCjoYp8c+yQTJfc0n69iwSiHjvuAdruxl7elnZCxgt8=
|
|
||||||
github.com/go-redis/redis/v8 v8.11.3/go.mod h1:xNJ9xDG09FsIPwh3bWdk+0oDWHbtF9rPN0F/oD9XeKc=
|
|
||||||
github.com/go-redis/redis/v8 v8.11.4 h1:kHoYkfZP6+pe04aFTnhDH6GDROa5yJdHJVNxV3F46Tg=
|
github.com/go-redis/redis/v8 v8.11.4 h1:kHoYkfZP6+pe04aFTnhDH6GDROa5yJdHJVNxV3F46Tg=
|
||||||
github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w=
|
github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w=
|
||||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
|
@ -349,8 +346,6 @@ github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vv
|
||||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||||
github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc=
|
github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc=
|
||||||
github.com/onsi/gomega v1.15.0 h1:WjP/FQ/sk43MRmnEcT+MlDw2TFvkrXlprrPST/IudjU=
|
|
||||||
github.com/onsi/gomega v1.15.0/go.mod h1:cIuvLEne0aoVhAgh/O6ac0Op8WWw9H6eYCriF+tEHG0=
|
|
||||||
github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c=
|
github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c=
|
||||||
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
||||||
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
||||||
|
|
|
@ -28,6 +28,7 @@ type api struct {
|
||||||
deviceRepo domain.DeviceRepository
|
deviceRepo domain.DeviceRepository
|
||||||
subredditRepo domain.SubredditRepository
|
subredditRepo domain.SubredditRepository
|
||||||
watcherRepo domain.WatcherRepository
|
watcherRepo domain.WatcherRepository
|
||||||
|
userRepo domain.UserRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api {
|
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api {
|
||||||
|
@ -56,6 +57,7 @@ func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, p
|
||||||
deviceRepo := repository.NewPostgresDevice(pool)
|
deviceRepo := repository.NewPostgresDevice(pool)
|
||||||
subredditRepo := repository.NewPostgresSubreddit(pool)
|
subredditRepo := repository.NewPostgresSubreddit(pool)
|
||||||
watcherRepo := repository.NewPostgresWatcher(pool)
|
watcherRepo := repository.NewPostgresWatcher(pool)
|
||||||
|
userRepo := repository.NewPostgresUser(pool)
|
||||||
|
|
||||||
return &api{
|
return &api{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
@ -67,6 +69,7 @@ func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, p
|
||||||
deviceRepo: deviceRepo,
|
deviceRepo: deviceRepo,
|
||||||
subredditRepo: subredditRepo,
|
subredditRepo: subredditRepo,
|
||||||
watcherRepo: watcherRepo,
|
watcherRepo: watcherRepo,
|
||||||
|
userRepo: userRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/christianselig/apollo-backend/internal/domain"
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
@ -18,6 +19,8 @@ type watcherCriteria struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type createWatcherRequest struct {
|
type createWatcherRequest struct {
|
||||||
|
Type string
|
||||||
|
User string
|
||||||
Subreddit string
|
Subreddit string
|
||||||
Criteria watcherCriteria
|
Criteria watcherCriteria
|
||||||
}
|
}
|
||||||
|
@ -69,6 +72,17 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
||||||
|
|
||||||
|
watcher := domain.Watcher{
|
||||||
|
DeviceID: dev.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Upvotes: cwr.Criteria.Upvotes,
|
||||||
|
Keyword: strings.ToLower(cwr.Criteria.Keyword),
|
||||||
|
Flair: strings.ToLower(cwr.Criteria.Flair),
|
||||||
|
Domain: strings.ToLower(cwr.Criteria.Domain),
|
||||||
|
}
|
||||||
|
|
||||||
|
if cwr.Type == "subreddit" {
|
||||||
srr, err := ac.SubredditAbout(cwr.Subreddit)
|
srr, err := ac.SubredditAbout(cwr.Subreddit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.errorResponse(w, r, 422, err.Error())
|
a.errorResponse(w, r, 422, err.Error())
|
||||||
|
@ -88,14 +102,33 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
watcher := domain.Watcher{
|
watcher.Type = domain.SubredditWatcher
|
||||||
SubredditID: sr.ID,
|
watcher.WatcheeID = sr.ID
|
||||||
DeviceID: dev.ID,
|
} else if cwr.Type == "user" {
|
||||||
AccountID: account.ID,
|
urr, err := ac.UserAbout(cwr.User)
|
||||||
Upvotes: cwr.Criteria.Upvotes,
|
if err != nil {
|
||||||
Keyword: cwr.Criteria.Keyword,
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
Flair: cwr.Criteria.Flair,
|
return
|
||||||
Domain: cwr.Criteria.Domain,
|
}
|
||||||
|
|
||||||
|
if !urr.AcceptFollowers {
|
||||||
|
a.errorResponse(w, r, 422, "no followers accepted")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u := domain.User{UserID: urr.ID, Name: urr.Name}
|
||||||
|
err = a.userRepo.CreateOrUpdate(ctx, &u)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
watcher.Type = domain.UserWatcher
|
||||||
|
watcher.WatcheeID = u.ID
|
||||||
|
} else {
|
||||||
|
a.errorResponse(w, r, 422, "unknown watcher type")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := a.watcherRepo.Create(ctx, &watcher); err != nil {
|
if err := a.watcherRepo.Create(ctx, &watcher); err != nil {
|
||||||
|
|
|
@ -24,7 +24,8 @@ const (
|
||||||
checkTimeout = 60 // how long until we force a check
|
checkTimeout = 60 // how long until we force a check
|
||||||
|
|
||||||
accountEnqueueTimeout = 5 // how frequently we want to check (seconds)
|
accountEnqueueTimeout = 5 // how frequently we want to check (seconds)
|
||||||
subredditEnqueueTimeout = 5 * 60 // how frequently we want to check (seconds)
|
subredditEnqueueTimeout = 2 * 60 // how frequently we want to check (seconds)
|
||||||
|
userEnqueueTimeout = 2 * 60 // how frequently we want to check (seconds)
|
||||||
|
|
||||||
staleAccountThreshold = 7200 // 2 hours
|
staleAccountThreshold = 7200 // 2 hours
|
||||||
)
|
)
|
||||||
|
@ -76,9 +77,15 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userQueue, err := queue.OpenQueue("users")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
s := gocron.NewScheduler(time.UTC)
|
s := gocron.NewScheduler(time.UTC)
|
||||||
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
|
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
|
||||||
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) })
|
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) })
|
||||||
|
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) })
|
||||||
_, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) })
|
_, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) })
|
||||||
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) })
|
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) })
|
||||||
_, _ = s.Every(1).Minute().Do(func() { pruneAccounts(ctx, logger, db) })
|
_, _ = s.Every(1).Minute().Do(func() { pruneAccounts(ctx, logger, db) })
|
||||||
|
@ -202,6 +209,70 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
|
||||||
|
now := time.Now()
|
||||||
|
ready := now.Unix() - userEnqueueTimeout
|
||||||
|
|
||||||
|
ids := []int64{}
|
||||||
|
|
||||||
|
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
||||||
|
stmt := `
|
||||||
|
WITH userb AS (
|
||||||
|
SELECT id
|
||||||
|
FROM users
|
||||||
|
WHERE last_checked_at < $1
|
||||||
|
ORDER BY last_checked_at
|
||||||
|
LIMIT 100
|
||||||
|
)
|
||||||
|
UPDATE users
|
||||||
|
SET last_checked_at = $2
|
||||||
|
WHERE users.id IN(SELECT id FROM userb)
|
||||||
|
RETURNING users.id`
|
||||||
|
rows, err := tx.Query(ctx, stmt, ready, now.Unix())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var id int64
|
||||||
|
_ = rows.Scan(&id)
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch batch of users")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"count": len(ids),
|
||||||
|
"start": ready,
|
||||||
|
}).Debug("enqueueing user batch")
|
||||||
|
|
||||||
|
batchIds := make([]string, len(ids))
|
||||||
|
for i, id := range ids {
|
||||||
|
batchIds[i] = strconv.FormatInt(id, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = queue.Publish(batchIds...); err != nil {
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to enqueue user")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = statsd.Histogram("apollo.queue.users.enqueued", float64(len(ids)), []string{}, 1)
|
||||||
|
_ = statsd.Histogram("apollo.queue.users.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
|
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
ready := now.Unix() - subredditEnqueueTimeout
|
ready := now.Unix() - subredditEnqueueTimeout
|
||||||
|
|
|
@ -15,6 +15,7 @@ var (
|
||||||
queues = map[string]worker.NewWorkerFn{
|
queues = map[string]worker.NewWorkerFn{
|
||||||
"notifications": worker.NewNotificationsWorker,
|
"notifications": worker.NewNotificationsWorker,
|
||||||
"subreddits": worker.NewSubredditsWorker,
|
"subreddits": worker.NewSubredditsWorker,
|
||||||
|
"users": worker.NewUsersWorker,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,10 @@ func NewRedisClient(ctx context.Context) (*redis.Client, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatabasePool(ctx context.Context, maxConns int) (*pgxpool.Pool, error) {
|
func NewDatabasePool(ctx context.Context, maxConns int) (*pgxpool.Pool, error) {
|
||||||
|
if maxConns == 0 {
|
||||||
|
maxConns = 1
|
||||||
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf(
|
url := fmt.Sprintf(
|
||||||
"%s?pool_max_conns=%d&pool_min_conns=%d",
|
"%s?pool_max_conns=%d&pool_min_conns=%d",
|
||||||
os.Getenv("DATABASE_CONNECTION_POOL_URL"),
|
os.Getenv("DATABASE_CONNECTION_POOL_URL"),
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
type Subreddit struct {
|
type Subreddit struct {
|
||||||
ID int64
|
ID int64
|
||||||
|
LastCheckedAt float64
|
||||||
|
|
||||||
// Reddit information
|
// Reddit information
|
||||||
SubredditID string
|
SubredditID string
|
||||||
|
|
27
internal/domain/user.go
Normal file
27
internal/domain/user.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int64
|
||||||
|
LastCheckedAt float64
|
||||||
|
|
||||||
|
// Reddit information
|
||||||
|
UserID string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) NormalizedName() string {
|
||||||
|
return strings.ToLower(u.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserRepository interface {
|
||||||
|
GetByID(context.Context, int64) (User, error)
|
||||||
|
GetByName(context.Context, string) (User, error)
|
||||||
|
|
||||||
|
CreateOrUpdate(context.Context, *User) error
|
||||||
|
Delete(context.Context, int64) error
|
||||||
|
}
|
|
@ -2,13 +2,22 @@ package domain
|
||||||
|
|
||||||
import "context"
|
import "context"
|
||||||
|
|
||||||
|
type WatcherType int64
|
||||||
|
|
||||||
|
const (
|
||||||
|
SubredditWatcher WatcherType = iota
|
||||||
|
UserWatcher
|
||||||
|
)
|
||||||
|
|
||||||
type Watcher struct {
|
type Watcher struct {
|
||||||
ID int64
|
ID int64
|
||||||
CreatedAt float64
|
CreatedAt float64
|
||||||
|
LastNotifiedAt float64
|
||||||
|
|
||||||
DeviceID int64
|
DeviceID int64
|
||||||
AccountID int64
|
AccountID int64
|
||||||
SubredditID int64
|
Type WatcherType
|
||||||
|
WatcheeID int64
|
||||||
|
|
||||||
Upvotes int64
|
Upvotes int64
|
||||||
Keyword string
|
Keyword string
|
||||||
|
@ -20,10 +29,12 @@ type Watcher struct {
|
||||||
type WatcherRepository interface {
|
type WatcherRepository interface {
|
||||||
GetByID(ctx context.Context, id int64) (Watcher, error)
|
GetByID(ctx context.Context, id int64) (Watcher, error)
|
||||||
GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error)
|
GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error)
|
||||||
|
GetByUserID(ctx context.Context, id int64) ([]Watcher, error)
|
||||||
GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error)
|
GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error)
|
||||||
|
|
||||||
Create(ctx context.Context, watcher *Watcher) error
|
Create(ctx context.Context, watcher *Watcher) error
|
||||||
Update(ctx context.Context, watcher *Watcher) error
|
Update(ctx context.Context, watcher *Watcher) error
|
||||||
IncrementHits(ctx context.Context, id int64) error
|
IncrementHits(ctx context.Context, id int64) error
|
||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
DeleteByTypeAndWatcheeID(context.Context, WatcherType, int64) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -171,6 +171,41 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rac *AuthenticatedClient) UserPosts(user string, opts ...RequestOption) (*ListingResponse, error) {
|
||||||
|
url := fmt.Sprintf("https://oauth.reddit.com/u/%s/submitted.json", user)
|
||||||
|
opts = append([]RequestOption{
|
||||||
|
WithMethod("GET"),
|
||||||
|
WithToken(rac.accessToken),
|
||||||
|
WithURL(url),
|
||||||
|
}, opts...)
|
||||||
|
req := NewRequest(opts...)
|
||||||
|
|
||||||
|
lr, err := rac.request(req, NewListingResponse, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return lr.(*ListingResponse), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *AuthenticatedClient) UserAbout(user string, opts ...RequestOption) (*UserResponse, error) {
|
||||||
|
url := fmt.Sprintf("https://oauth.reddit.com/u/%s/about.json", user)
|
||||||
|
opts = append([]RequestOption{
|
||||||
|
WithMethod("GET"),
|
||||||
|
WithToken(rac.accessToken),
|
||||||
|
WithURL(url),
|
||||||
|
}, opts...)
|
||||||
|
req := NewRequest(opts...)
|
||||||
|
ur, err := rac.request(req, NewUserResponse, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ur.(*UserResponse), nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...RequestOption) (*SubredditResponse, error) {
|
func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...RequestOption) (*SubredditResponse, error) {
|
||||||
url := fmt.Sprintf("https://oauth.reddit.com/r/%s/about.json", subreddit)
|
url := fmt.Sprintf("https://oauth.reddit.com/r/%s/about.json", subreddit)
|
||||||
opts = append([]RequestOption{
|
opts = append([]RequestOption{
|
||||||
|
|
75
internal/reddit/testdata/user_about.json
vendored
Normal file
75
internal/reddit/testdata/user_about.json
vendored
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
{
|
||||||
|
"kind": "t2",
|
||||||
|
"data": {
|
||||||
|
"is_employee": false,
|
||||||
|
"is_friend": false,
|
||||||
|
"subreddit": {
|
||||||
|
"default_set": true,
|
||||||
|
"user_is_contributor": false,
|
||||||
|
"banner_img": "",
|
||||||
|
"restrict_posting": true,
|
||||||
|
"user_is_banned": false,
|
||||||
|
"free_form_reports": true,
|
||||||
|
"community_icon": null,
|
||||||
|
"show_media": true,
|
||||||
|
"icon_color": "#FFD635",
|
||||||
|
"user_is_muted": false,
|
||||||
|
"display_name": "u_changelog",
|
||||||
|
"header_img": null,
|
||||||
|
"title": "",
|
||||||
|
"previous_names": [],
|
||||||
|
"over_18": false,
|
||||||
|
"icon_size": [
|
||||||
|
256,
|
||||||
|
256
|
||||||
|
],
|
||||||
|
"primary_color": "",
|
||||||
|
"icon_img": "https://www.redditstatic.com/avatars/defaults/v2/avatar_default_2.png",
|
||||||
|
"description": "",
|
||||||
|
"submit_link_label": "",
|
||||||
|
"header_size": null,
|
||||||
|
"restrict_commenting": false,
|
||||||
|
"subscribers": 0,
|
||||||
|
"submit_text_label": "",
|
||||||
|
"is_default_icon": true,
|
||||||
|
"link_flair_position": "",
|
||||||
|
"display_name_prefixed": "u/changelog",
|
||||||
|
"key_color": "",
|
||||||
|
"name": "t5_c30sw",
|
||||||
|
"is_default_banner": true,
|
||||||
|
"url": "/user/changelog/",
|
||||||
|
"quarantine": false,
|
||||||
|
"banner_size": null,
|
||||||
|
"user_is_moderator": false,
|
||||||
|
"accept_followers": true,
|
||||||
|
"public_description": "",
|
||||||
|
"link_flair_enabled": false,
|
||||||
|
"disable_contributor_requests": false,
|
||||||
|
"subreddit_type": "user",
|
||||||
|
"user_is_subscriber": false
|
||||||
|
},
|
||||||
|
"snoovatar_size": null,
|
||||||
|
"awardee_karma": 23,
|
||||||
|
"id": "1ia22",
|
||||||
|
"verified": true,
|
||||||
|
"is_gold": false,
|
||||||
|
"is_mod": true,
|
||||||
|
"awarder_karma": 0,
|
||||||
|
"has_verified_email": true,
|
||||||
|
"icon_img": "https://www.redditstatic.com/avatars/defaults/v2/avatar_default_2.png",
|
||||||
|
"hide_from_robots": true,
|
||||||
|
"link_karma": 2676,
|
||||||
|
"pref_show_snoovatar": false,
|
||||||
|
"is_blocked": false,
|
||||||
|
"total_karma": 4132,
|
||||||
|
"accept_chats": false,
|
||||||
|
"name": "changelog",
|
||||||
|
"created": 1176721866.0,
|
||||||
|
"created_utc": 1176721866.0,
|
||||||
|
"snoovatar_img": "",
|
||||||
|
"comment_karma": 1433,
|
||||||
|
"accept_followers": true,
|
||||||
|
"has_subscribed": true,
|
||||||
|
"accept_pms": true
|
||||||
|
}
|
||||||
|
}
|
|
@ -160,4 +160,23 @@ func NewSubredditResponse(val *fastjson.Value) interface{} {
|
||||||
return sr
|
return sr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UserResponse struct {
|
||||||
|
Thing
|
||||||
|
|
||||||
|
AcceptFollowers bool
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserResponse(val *fastjson.Value) interface{} {
|
||||||
|
ur := &UserResponse{}
|
||||||
|
ur.Kind = string(val.GetStringBytes("kind"))
|
||||||
|
|
||||||
|
data := val.Get("data")
|
||||||
|
ur.ID = string(data.GetStringBytes("id"))
|
||||||
|
ur.Name = string(data.GetStringBytes("name"))
|
||||||
|
ur.AcceptFollowers = data.GetBool("accept_followers")
|
||||||
|
|
||||||
|
return ur
|
||||||
|
}
|
||||||
|
|
||||||
var EmptyListingResponse = &ListingResponse{}
|
var EmptyListingResponse = &ListingResponse{}
|
||||||
|
|
|
@ -111,3 +111,20 @@ func TestSubredditResponseParsing(t *testing.T) {
|
||||||
assert.Equal(t, "2vq0w", s.ID)
|
assert.Equal(t, "2vq0w", s.ID)
|
||||||
assert.Equal(t, "DestinyTheGame", s.Name)
|
assert.Equal(t, "DestinyTheGame", s.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUserResponseParsing(t *testing.T) {
|
||||||
|
bb, err := ioutil.ReadFile("testdata/user_about.json")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
val, err := parser.ParseBytes(bb)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ret := NewUserResponse(val)
|
||||||
|
u := ret.(*UserResponse)
|
||||||
|
assert.NotNil(t, u)
|
||||||
|
|
||||||
|
assert.Equal(t, "t2", u.Kind)
|
||||||
|
assert.Equal(t, "1ia22", u.ID)
|
||||||
|
assert.Equal(t, "changelog", u.Name)
|
||||||
|
assert.Equal(t, true, u.AcceptFollowers)
|
||||||
|
}
|
||||||
|
|
|
@ -30,6 +30,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
|
||||||
&sr.ID,
|
&sr.ID,
|
||||||
&sr.SubredditID,
|
&sr.SubredditID,
|
||||||
&sr.Name,
|
&sr.Name,
|
||||||
|
&sr.LastCheckedAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -40,7 +41,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
|
||||||
|
|
||||||
func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) {
|
func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, subreddit_id, name
|
SELECT id, subreddit_id, name, last_checked_at
|
||||||
FROM subreddits
|
FROM subreddits
|
||||||
WHERE id = $1`
|
WHERE id = $1`
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (do
|
||||||
|
|
||||||
func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) {
|
func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, subreddit_id, name
|
SELECT id, subreddit_id, name, last_checked_at
|
||||||
FROM subreddits
|
FROM subreddits
|
||||||
WHERE name = $1`
|
WHERE name = $1`
|
||||||
|
|
||||||
|
@ -78,7 +79,8 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO subreddits (subreddit_id, name)
|
INSERT INTO subreddits (subreddit_id, name)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
ON CONFLICT(subreddit_id) DO NOTHING
|
ON CONFLICT(subreddit_id) DO
|
||||||
|
UPDATE SET last_checked_at = $3
|
||||||
RETURNING id`
|
RETURNING id`
|
||||||
|
|
||||||
return p.pool.QueryRow(
|
return p.pool.QueryRow(
|
||||||
|
@ -86,5 +88,6 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
|
||||||
query,
|
query,
|
||||||
sr.SubredditID,
|
sr.SubredditID,
|
||||||
sr.NormalizedName(),
|
sr.NormalizedName(),
|
||||||
|
sr.LastCheckedAt,
|
||||||
).Scan(&sr.ID)
|
).Scan(&sr.ID)
|
||||||
}
|
}
|
||||||
|
|
104
internal/repository/postgres_user.go
Normal file
104
internal/repository/postgres_user.go
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
type postgresUserRepository struct {
|
||||||
|
pool *pgxpool.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresUser(pool *pgxpool.Pool) domain.UserRepository {
|
||||||
|
return &postgresUserRepository{pool: pool}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresUserRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.User, error) {
|
||||||
|
rows, err := p.pool.Query(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var uu []domain.User
|
||||||
|
for rows.Next() {
|
||||||
|
var u domain.User
|
||||||
|
if err := rows.Scan(
|
||||||
|
&u.ID,
|
||||||
|
&u.UserID,
|
||||||
|
&u.Name,
|
||||||
|
&u.LastCheckedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
uu = append(uu, u)
|
||||||
|
}
|
||||||
|
return uu, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.User, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, user_id, name, last_checked_at
|
||||||
|
FROM users
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
srs, err := p.fetch(ctx, query, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return domain.User{}, err
|
||||||
|
}
|
||||||
|
if len(srs) == 0 {
|
||||||
|
return domain.User{}, domain.ErrNotFound
|
||||||
|
}
|
||||||
|
return srs[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresUserRepository) GetByName(ctx context.Context, name string) (domain.User, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, user_id, name, last_checked_at
|
||||||
|
FROM users
|
||||||
|
WHERE name = $1`
|
||||||
|
|
||||||
|
name = strings.ToLower(name)
|
||||||
|
|
||||||
|
srs, err := p.fetch(ctx, query, name)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return domain.User{}, err
|
||||||
|
}
|
||||||
|
if len(srs) == 0 {
|
||||||
|
return domain.User{}, domain.ErrNotFound
|
||||||
|
}
|
||||||
|
return srs[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.User) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO users (user_id, name)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
ON CONFLICT(user_id) DO
|
||||||
|
UPDATE SET last_checked_at = $3
|
||||||
|
RETURNING id`
|
||||||
|
|
||||||
|
return p.pool.QueryRow(
|
||||||
|
ctx,
|
||||||
|
query,
|
||||||
|
u.UserID,
|
||||||
|
u.NormalizedName(),
|
||||||
|
u.LastCheckedAt,
|
||||||
|
).Scan(&u.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresUserRepository) Delete(ctx context.Context, id int64) error {
|
||||||
|
query := `DELETE FROM users WHERE id = $1`
|
||||||
|
res, err := p.pool.Exec(ctx, query, id)
|
||||||
|
|
||||||
|
if res.RowsAffected() != 1 {
|
||||||
|
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
|
@ -31,9 +31,11 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
|
||||||
if err := rows.Scan(
|
if err := rows.Scan(
|
||||||
&watcher.ID,
|
&watcher.ID,
|
||||||
&watcher.CreatedAt,
|
&watcher.CreatedAt,
|
||||||
|
&watcher.LastNotifiedAt,
|
||||||
&watcher.DeviceID,
|
&watcher.DeviceID,
|
||||||
&watcher.AccountID,
|
&watcher.AccountID,
|
||||||
&watcher.SubredditID,
|
&watcher.Type,
|
||||||
|
&watcher.WatcheeID,
|
||||||
&watcher.Upvotes,
|
&watcher.Upvotes,
|
||||||
&watcher.Keyword,
|
&watcher.Keyword,
|
||||||
&watcher.Flair,
|
&watcher.Flair,
|
||||||
|
@ -49,7 +51,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
|
||||||
|
|
||||||
func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) {
|
func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain, hits
|
SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits
|
||||||
FROM watchers
|
FROM watchers
|
||||||
WHERE id = $1`
|
WHERE id = $1`
|
||||||
|
|
||||||
|
@ -64,13 +66,21 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
|
||||||
return watchers[0], nil
|
return watchers[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
|
func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) ([]domain.Watcher, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain, hits
|
SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits
|
||||||
FROM watchers
|
FROM watchers
|
||||||
WHERE subreddit_id = $1`
|
WHERE type = $1 AND watchee_id = $2`
|
||||||
|
|
||||||
return p.fetch(ctx, query, id)
|
return p.fetch(ctx, query, typ, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
|
||||||
|
return p.GetByTypeAndWatcheeID(ctx, domain.SubredditWatcher, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) GetByUserID(ctx context.Context, id int64) ([]domain.Watcher, error) {
|
||||||
|
return p.GetByTypeAndWatcheeID(ctx, domain.UserWatcher, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]domain.Watcher, error) {
|
func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]domain.Watcher, error) {
|
||||||
|
@ -78,9 +88,11 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
|
||||||
SELECT
|
SELECT
|
||||||
watchers.id,
|
watchers.id,
|
||||||
watchers.created_at,
|
watchers.created_at,
|
||||||
|
watchers.last_notified_at
|
||||||
watchers.device_id,
|
watchers.device_id,
|
||||||
watchers.account_id,
|
watchers.account_id,
|
||||||
watchers.subreddit_id,
|
watchers.type,
|
||||||
|
watchers.watchee_id,
|
||||||
watchers.upvotes,
|
watchers.upvotes,
|
||||||
watchers.keyword,
|
watchers.keyword,
|
||||||
watchers.flair,
|
watchers.flair,
|
||||||
|
@ -101,8 +113,8 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO watchers
|
INSERT INTO watchers
|
||||||
(created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain)
|
(created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
RETURNING id`
|
RETURNING id`
|
||||||
|
|
||||||
return p.pool.QueryRow(
|
return p.pool.QueryRow(
|
||||||
|
@ -111,7 +123,8 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
|
||||||
now,
|
now,
|
||||||
watcher.DeviceID,
|
watcher.DeviceID,
|
||||||
watcher.AccountID,
|
watcher.AccountID,
|
||||||
watcher.SubredditID,
|
watcher.Type,
|
||||||
|
watcher.WatcheeID,
|
||||||
watcher.Upvotes,
|
watcher.Upvotes,
|
||||||
watcher.Keyword,
|
watcher.Keyword,
|
||||||
watcher.Flair,
|
watcher.Flair,
|
||||||
|
@ -145,8 +158,9 @@ func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *postgresWatcherRepository) IncrementHits(ctx context.Context, id int64) error {
|
func (p *postgresWatcherRepository) IncrementHits(ctx context.Context, id int64) error {
|
||||||
query := `UPDATE watchers SET hits = hits + 1 WHERE id = $1`
|
now := time.Now().Unix()
|
||||||
res, err := p.pool.Exec(ctx, query, id)
|
query := `UPDATE watchers SET hits = hits + 1, last_notified_at = $2 WHERE id = $1`
|
||||||
|
res, err := p.pool.Exec(ctx, query, id, now)
|
||||||
|
|
||||||
if res.RowsAffected() != 1 {
|
if res.RowsAffected() != 1 {
|
||||||
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||||
|
@ -163,3 +177,13 @@ func (p *postgresWatcherRepository) Delete(ctx context.Context, id int64) error
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *postgresWatcherRepository) DeleteByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) error {
|
||||||
|
query := `DELETE FROM watchers WHERE type = $1 AND watchee_id = $2`
|
||||||
|
res, err := p.pool.Exec(ctx, query, typ, id)
|
||||||
|
|
||||||
|
if res.RowsAffected() == 0 {
|
||||||
|
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -170,7 +170,6 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
if len(watchers) == 0 {
|
if len(watchers) == 0 {
|
||||||
sc.logger.WithFields(logrus.Fields{
|
sc.logger.WithFields(logrus.Fields{
|
||||||
"subreddit#id": subreddit.ID,
|
"subreddit#id": subreddit.ID,
|
||||||
"err": err,
|
|
||||||
}).Info("no watchers for subreddit, skipping")
|
}).Info("no watchers for subreddit, skipping")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -299,6 +298,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
"count": len(posts),
|
"count": len(posts),
|
||||||
}).Debug("checking posts for hits")
|
}).Debug("checking posts for hits")
|
||||||
for _, post := range posts {
|
for _, post := range posts {
|
||||||
|
lowcaseTitle := strings.ToLower(post.Title)
|
||||||
|
lowcaseFlair := strings.ToLower(post.Flair)
|
||||||
|
lowcaseDomain := strings.ToLower(post.URL)
|
||||||
|
|
||||||
ids := []int64{}
|
ids := []int64{}
|
||||||
|
|
||||||
for _, watcher := range watchers {
|
for _, watcher := range watchers {
|
||||||
|
@ -313,15 +316,15 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
matched = false
|
matched = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if watcher.Keyword != "" && !strings.Contains(post.Title, watcher.Keyword) {
|
if watcher.Keyword != "" && !strings.Contains(lowcaseTitle, watcher.Keyword) {
|
||||||
matched = false
|
matched = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if watcher.Flair != "" && !strings.Contains(post.Flair, watcher.Flair) {
|
if watcher.Flair != "" && !strings.Contains(lowcaseFlair, watcher.Flair) {
|
||||||
matched = false
|
matched = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if watcher.Domain != "" && !strings.Contains(post.URL, watcher.Domain) {
|
if watcher.Domain != "" && !strings.Contains(lowcaseDomain, watcher.Domain) {
|
||||||
matched = false
|
matched = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -329,8 +332,6 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = sc.watcherRepo.IncrementHits(ctx, watcher.ID)
|
|
||||||
|
|
||||||
lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID)
|
lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID)
|
||||||
notified, _ := sc.redis.Get(ctx, lockKey).Bool()
|
notified, _ := sc.redis.Get(ctx, lockKey).Bool()
|
||||||
|
|
||||||
|
@ -345,6 +346,15 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := sc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
|
||||||
|
sc.logger.WithFields(logrus.Fields{
|
||||||
|
"subreddit#id": subreddit.ID,
|
||||||
|
"watcher#id": watcher.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("could not increment hits")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
sc.logger.WithFields(logrus.Fields{
|
sc.logger.WithFields(logrus.Fields{
|
||||||
"subreddit#id": subreddit.ID,
|
"subreddit#id": subreddit.ID,
|
||||||
"subreddit#name": subreddit.Name,
|
"subreddit#name": subreddit.Name,
|
||||||
|
|
299
internal/worker/users.go
Normal file
299
internal/worker/users.go
Normal file
|
@ -0,0 +1,299 @@
|
||||||
|
package worker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
|
"github.com/adjust/rmq/v4"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
|
"github.com/sideshow/apns2"
|
||||||
|
"github.com/sideshow/apns2/payload"
|
||||||
|
"github.com/sideshow/apns2/token"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
|
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||||
|
"github.com/christianselig/apollo-backend/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
type usersWorker struct {
|
||||||
|
logger *logrus.Logger
|
||||||
|
statsd *statsd.Client
|
||||||
|
db *pgxpool.Pool
|
||||||
|
redis *redis.Client
|
||||||
|
queue rmq.Connection
|
||||||
|
reddit *reddit.Client
|
||||||
|
apns *token.Token
|
||||||
|
|
||||||
|
consumers int
|
||||||
|
|
||||||
|
accountRepo domain.AccountRepository
|
||||||
|
deviceRepo domain.DeviceRepository
|
||||||
|
userRepo domain.UserRepository
|
||||||
|
watcherRepo domain.WatcherRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
|
||||||
|
reddit := reddit.NewClient(
|
||||||
|
os.Getenv("REDDIT_CLIENT_ID"),
|
||||||
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
||||||
|
statsd,
|
||||||
|
consumers,
|
||||||
|
)
|
||||||
|
|
||||||
|
var apns *token.Token
|
||||||
|
{
|
||||||
|
authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
apns = &token.Token{
|
||||||
|
AuthKey: authKey,
|
||||||
|
KeyID: os.Getenv("APPLE_KEY_ID"),
|
||||||
|
TeamID: os.Getenv("APPLE_TEAM_ID"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &usersWorker{
|
||||||
|
logger,
|
||||||
|
statsd,
|
||||||
|
db,
|
||||||
|
redis,
|
||||||
|
queue,
|
||||||
|
reddit,
|
||||||
|
apns,
|
||||||
|
consumers,
|
||||||
|
|
||||||
|
repository.NewPostgresAccount(db),
|
||||||
|
repository.NewPostgresDevice(db),
|
||||||
|
repository.NewPostgresUser(db),
|
||||||
|
repository.NewPostgresWatcher(db),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uw *usersWorker) Start() error {
|
||||||
|
queue, err := uw.queue.OpenQueue("users")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
uw.logger.WithFields(logrus.Fields{
|
||||||
|
"numConsumers": uw.consumers,
|
||||||
|
}).Info("starting up users worker")
|
||||||
|
|
||||||
|
prefetchLimit := int64(uw.consumers * 2)
|
||||||
|
|
||||||
|
if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
host, _ := os.Hostname()
|
||||||
|
|
||||||
|
for i := 0; i < uw.consumers; i++ {
|
||||||
|
name := fmt.Sprintf("consumer %s-%d", host, i)
|
||||||
|
|
||||||
|
consumer := NewUsersConsumer(uw, i)
|
||||||
|
if _, err := queue.AddConsumer(name, consumer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uw *usersWorker) Stop() {
|
||||||
|
<-uw.queue.StopAllConsuming() // wait for all Consume() calls to finish
|
||||||
|
}
|
||||||
|
|
||||||
|
type usersConsumer struct {
|
||||||
|
*usersWorker
|
||||||
|
tag int
|
||||||
|
|
||||||
|
apnsSandbox *apns2.Client
|
||||||
|
apnsProduction *apns2.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUsersConsumer(uw *usersWorker, tag int) *usersConsumer {
|
||||||
|
return &usersConsumer{
|
||||||
|
uw,
|
||||||
|
tag,
|
||||||
|
apns2.NewTokenClient(uw.apns),
|
||||||
|
apns2.NewTokenClient(uw.apns).Production(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": delivery.Payload(),
|
||||||
|
}).Debug("starting job")
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": delivery.Payload(),
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to parse user ID")
|
||||||
|
|
||||||
|
_ = delivery.Reject()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = delivery.Ack() }()
|
||||||
|
|
||||||
|
user, err := uc.userRepo.GetByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch user from database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
watchers, err := uc.watcherRepo.GetByUserID(ctx, user.ID)
|
||||||
|
if err != nil {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch watchers from database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(watchers) == 0 {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
}).Info("no watchers for user, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load 25 newest posts
|
||||||
|
i := rand.Intn(len(watchers))
|
||||||
|
watcher := watchers[i]
|
||||||
|
|
||||||
|
acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID)
|
||||||
|
rac := uc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
|
||||||
|
|
||||||
|
ru, err := rac.UserAbout(user.Name)
|
||||||
|
if err != nil {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch user details")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ru.AcceptFollowers {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
}).Info("user disabled followers, removing")
|
||||||
|
|
||||||
|
if err := uc.watcherRepo.DeleteByTypeAndWatcheeID(ctx, domain.UserWatcher, user.ID); err != nil {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to delete watchers for user who does not allow followers")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := uc.userRepo.Delete(ctx, user.ID); err != nil {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to delete user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
posts, err := rac.UserPosts(user.Name)
|
||||||
|
if err != nil {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("failed to fetch user activity")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, post := range posts.Children {
|
||||||
|
notification := &apns2.Notification{}
|
||||||
|
notification.Topic = "com.christianselig.Apollo"
|
||||||
|
notification.Payload = payloadFromUserPost(post)
|
||||||
|
|
||||||
|
for _, watcher := range watchers {
|
||||||
|
// Make sure we only alert on activities created after the search
|
||||||
|
if watcher.CreatedAt > post.CreatedAt {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if watcher.LastNotifiedAt > post.CreatedAt {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := uc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
"watcher#id": watcher.ID,
|
||||||
|
"err": err,
|
||||||
|
}).Error("could not increment hits")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID)
|
||||||
|
notification.DeviceToken = device.APNSToken
|
||||||
|
|
||||||
|
client := uc.apnsProduction
|
||||||
|
if device.Sandbox {
|
||||||
|
client = uc.apnsSandbox
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := client.Push(notification)
|
||||||
|
if err != nil {
|
||||||
|
_ = uc.statsd.Incr("apns.notification.errors", []string{}, 1)
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
"device#id": device.ID,
|
||||||
|
"err": err,
|
||||||
|
"status": res.StatusCode,
|
||||||
|
"reason": res.Reason,
|
||||||
|
}).Error("failed to send notification")
|
||||||
|
} else {
|
||||||
|
_ = uc.statsd.Incr("apns.notification.sent", []string{}, 1)
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
"device#id": device.ID,
|
||||||
|
"device#token": device.APNSToken,
|
||||||
|
}).Info("sent notification")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uc.logger.WithFields(logrus.Fields{
|
||||||
|
"user#id": user.ID,
|
||||||
|
"user#name": user.Name,
|
||||||
|
}).Debug("finishing job")
|
||||||
|
}
|
||||||
|
|
||||||
|
func payloadFromUserPost(post *reddit.Thing) *payload.Payload {
|
||||||
|
title := fmt.Sprintf("👨\u200d🚀 User post! (u/%s)", post.Author)
|
||||||
|
|
||||||
|
payload := payload.
|
||||||
|
NewPayload().
|
||||||
|
AlertTitle(title).
|
||||||
|
AlertBody(post.Title).
|
||||||
|
AlertSummaryArg(post.Author).
|
||||||
|
Category("user-watch").
|
||||||
|
Custom("post_title", post.Title).
|
||||||
|
Custom("post_id", post.ID).
|
||||||
|
Custom("subreddit", post.Subreddit).
|
||||||
|
Custom("author", post.Author).
|
||||||
|
Custom("post_age", post.CreatedAt).
|
||||||
|
MutableContent().
|
||||||
|
Sound("traloop.wav")
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
Loading…
Reference in a new issue