Merge pull request #28 from christianselig/feature/follow-users

This commit is contained in:
André Medeiros 2021-10-09 11:51:50 -04:00 committed by GitHub
commit 38e596b27e
18 changed files with 787 additions and 55 deletions

5
go.sum
View file

@ -64,7 +64,6 @@ github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInq
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@ -122,8 +121,6 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8=
github.com/go-redis/redis/v8 v8.3.2/go.mod h1:jszGxBCez8QA1HWSmQxJO9Y82kNibbUmeYhKWrBejTU= github.com/go-redis/redis/v8 v8.3.2/go.mod h1:jszGxBCez8QA1HWSmQxJO9Y82kNibbUmeYhKWrBejTU=
github.com/go-redis/redis/v8 v8.11.3 h1:GCjoYp8c+yQTJfc0n69iwSiHjvuAdruxl7elnZCxgt8=
github.com/go-redis/redis/v8 v8.11.3/go.mod h1:xNJ9xDG09FsIPwh3bWdk+0oDWHbtF9rPN0F/oD9XeKc=
github.com/go-redis/redis/v8 v8.11.4 h1:kHoYkfZP6+pe04aFTnhDH6GDROa5yJdHJVNxV3F46Tg= github.com/go-redis/redis/v8 v8.11.4 h1:kHoYkfZP6+pe04aFTnhDH6GDROa5yJdHJVNxV3F46Tg=
github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w= github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
@ -349,8 +346,6 @@ github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vv
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc=
github.com/onsi/gomega v1.15.0 h1:WjP/FQ/sk43MRmnEcT+MlDw2TFvkrXlprrPST/IudjU=
github.com/onsi/gomega v1.15.0/go.mod h1:cIuvLEne0aoVhAgh/O6ac0Op8WWw9H6eYCriF+tEHG0=
github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c= github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c=
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=

View file

@ -28,6 +28,7 @@ type api struct {
deviceRepo domain.DeviceRepository deviceRepo domain.DeviceRepository
subredditRepo domain.SubredditRepository subredditRepo domain.SubredditRepository
watcherRepo domain.WatcherRepository watcherRepo domain.WatcherRepository
userRepo domain.UserRepository
} }
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api { func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api {
@ -56,6 +57,7 @@ func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, p
deviceRepo := repository.NewPostgresDevice(pool) deviceRepo := repository.NewPostgresDevice(pool)
subredditRepo := repository.NewPostgresSubreddit(pool) subredditRepo := repository.NewPostgresSubreddit(pool)
watcherRepo := repository.NewPostgresWatcher(pool) watcherRepo := repository.NewPostgresWatcher(pool)
userRepo := repository.NewPostgresUser(pool)
return &api{ return &api{
logger: logger, logger: logger,
@ -67,6 +69,7 @@ func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, p
deviceRepo: deviceRepo, deviceRepo: deviceRepo,
subredditRepo: subredditRepo, subredditRepo: subredditRepo,
watcherRepo: watcherRepo, watcherRepo: watcherRepo,
userRepo: userRepo,
} }
} }

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"github.com/christianselig/apollo-backend/internal/domain" "github.com/christianselig/apollo-backend/internal/domain"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -18,6 +19,8 @@ type watcherCriteria struct {
} }
type createWatcherRequest struct { type createWatcherRequest struct {
Type string
User string
Subreddit string Subreddit string
Criteria watcherCriteria Criteria watcherCriteria
} }
@ -69,6 +72,17 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
} }
ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
watcher := domain.Watcher{
DeviceID: dev.ID,
AccountID: account.ID,
Upvotes: cwr.Criteria.Upvotes,
Keyword: strings.ToLower(cwr.Criteria.Keyword),
Flair: strings.ToLower(cwr.Criteria.Flair),
Domain: strings.ToLower(cwr.Criteria.Domain),
}
if cwr.Type == "subreddit" {
srr, err := ac.SubredditAbout(cwr.Subreddit) srr, err := ac.SubredditAbout(cwr.Subreddit)
if err != nil { if err != nil {
a.errorResponse(w, r, 422, err.Error()) a.errorResponse(w, r, 422, err.Error())
@ -88,14 +102,33 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
watcher := domain.Watcher{ watcher.Type = domain.SubredditWatcher
SubredditID: sr.ID, watcher.WatcheeID = sr.ID
DeviceID: dev.ID, } else if cwr.Type == "user" {
AccountID: account.ID, urr, err := ac.UserAbout(cwr.User)
Upvotes: cwr.Criteria.Upvotes, if err != nil {
Keyword: cwr.Criteria.Keyword, a.errorResponse(w, r, 500, err.Error())
Flair: cwr.Criteria.Flair, return
Domain: cwr.Criteria.Domain, }
if !urr.AcceptFollowers {
a.errorResponse(w, r, 422, "no followers accepted")
return
}
u := domain.User{UserID: urr.ID, Name: urr.Name}
err = a.userRepo.CreateOrUpdate(ctx, &u)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
watcher.Type = domain.UserWatcher
watcher.WatcheeID = u.ID
} else {
a.errorResponse(w, r, 422, "unknown watcher type")
return
} }
if err := a.watcherRepo.Create(ctx, &watcher); err != nil { if err := a.watcherRepo.Create(ctx, &watcher); err != nil {

View file

@ -24,7 +24,8 @@ const (
checkTimeout = 60 // how long until we force a check checkTimeout = 60 // how long until we force a check
accountEnqueueTimeout = 5 // how frequently we want to check (seconds) accountEnqueueTimeout = 5 // how frequently we want to check (seconds)
subredditEnqueueTimeout = 5 * 60 // how frequently we want to check (seconds) subredditEnqueueTimeout = 2 * 60 // how frequently we want to check (seconds)
userEnqueueTimeout = 2 * 60 // how frequently we want to check (seconds)
staleAccountThreshold = 7200 // 2 hours staleAccountThreshold = 7200 // 2 hours
) )
@ -76,9 +77,15 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
return err return err
} }
userQueue, err := queue.OpenQueue("users")
if err != nil {
return err
}
s := gocron.NewScheduler(time.UTC) s := gocron.NewScheduler(time.UTC)
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) }) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) }) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) })
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) })
_, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) }) _, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) })
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) }) _, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) })
_, _ = s.Every(1).Minute().Do(func() { pruneAccounts(ctx, logger, db) }) _, _ = s.Every(1).Minute().Do(func() { pruneAccounts(ctx, logger, db) })
@ -202,6 +209,70 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie
} }
} }
func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
now := time.Now()
ready := now.Unix() - userEnqueueTimeout
ids := []int64{}
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := `
WITH userb AS (
SELECT id
FROM users
WHERE last_checked_at < $1
ORDER BY last_checked_at
LIMIT 100
)
UPDATE users
SET last_checked_at = $2
WHERE users.id IN(SELECT id FROM userb)
RETURNING users.id`
rows, err := tx.Query(ctx, stmt, ready, now.Unix())
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var id int64
_ = rows.Scan(&id)
ids = append(ids, id)
}
return nil
})
if err != nil {
logger.WithFields(logrus.Fields{
"err": err,
}).Error("failed to fetch batch of users")
return
}
if len(ids) == 0 {
return
}
logger.WithFields(logrus.Fields{
"count": len(ids),
"start": ready,
}).Debug("enqueueing user batch")
batchIds := make([]string, len(ids))
for i, id := range ids {
batchIds[i] = strconv.FormatInt(id, 10)
}
if err = queue.Publish(batchIds...); err != nil {
logger.WithFields(logrus.Fields{
"err": err,
}).Error("failed to enqueue user")
}
_ = statsd.Histogram("apollo.queue.users.enqueued", float64(len(ids)), []string{}, 1)
_ = statsd.Histogram("apollo.queue.users.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1)
}
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
now := time.Now() now := time.Now()
ready := now.Unix() - subredditEnqueueTimeout ready := now.Unix() - subredditEnqueueTimeout

View file

@ -15,6 +15,7 @@ var (
queues = map[string]worker.NewWorkerFn{ queues = map[string]worker.NewWorkerFn{
"notifications": worker.NewNotificationsWorker, "notifications": worker.NewNotificationsWorker,
"subreddits": worker.NewSubredditsWorker, "subreddits": worker.NewSubredditsWorker,
"users": worker.NewUsersWorker,
} }
) )

View file

@ -51,6 +51,10 @@ func NewRedisClient(ctx context.Context) (*redis.Client, error) {
} }
func NewDatabasePool(ctx context.Context, maxConns int) (*pgxpool.Pool, error) { func NewDatabasePool(ctx context.Context, maxConns int) (*pgxpool.Pool, error) {
if maxConns == 0 {
maxConns = 1
}
url := fmt.Sprintf( url := fmt.Sprintf(
"%s?pool_max_conns=%d&pool_min_conns=%d", "%s?pool_max_conns=%d&pool_min_conns=%d",
os.Getenv("DATABASE_CONNECTION_POOL_URL"), os.Getenv("DATABASE_CONNECTION_POOL_URL"),

View file

@ -7,6 +7,7 @@ import (
type Subreddit struct { type Subreddit struct {
ID int64 ID int64
LastCheckedAt float64
// Reddit information // Reddit information
SubredditID string SubredditID string

27
internal/domain/user.go Normal file
View file

@ -0,0 +1,27 @@
package domain
import (
"context"
"strings"
)
type User struct {
ID int64
LastCheckedAt float64
// Reddit information
UserID string
Name string
}
func (u *User) NormalizedName() string {
return strings.ToLower(u.Name)
}
type UserRepository interface {
GetByID(context.Context, int64) (User, error)
GetByName(context.Context, string) (User, error)
CreateOrUpdate(context.Context, *User) error
Delete(context.Context, int64) error
}

View file

@ -2,13 +2,22 @@ package domain
import "context" import "context"
type WatcherType int64
const (
SubredditWatcher WatcherType = iota
UserWatcher
)
type Watcher struct { type Watcher struct {
ID int64 ID int64
CreatedAt float64 CreatedAt float64
LastNotifiedAt float64
DeviceID int64 DeviceID int64
AccountID int64 AccountID int64
SubredditID int64 Type WatcherType
WatcheeID int64
Upvotes int64 Upvotes int64
Keyword string Keyword string
@ -20,10 +29,12 @@ type Watcher struct {
type WatcherRepository interface { type WatcherRepository interface {
GetByID(ctx context.Context, id int64) (Watcher, error) GetByID(ctx context.Context, id int64) (Watcher, error)
GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error) GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error)
GetByUserID(ctx context.Context, id int64) ([]Watcher, error)
GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error)
Create(ctx context.Context, watcher *Watcher) error Create(ctx context.Context, watcher *Watcher) error
Update(ctx context.Context, watcher *Watcher) error Update(ctx context.Context, watcher *Watcher) error
IncrementHits(ctx context.Context, id int64) error IncrementHits(ctx context.Context, id int64) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
DeleteByTypeAndWatcheeID(context.Context, WatcherType, int64) error
} }

View file

@ -171,6 +171,41 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
return ret, nil return ret, nil
} }
func (rac *AuthenticatedClient) UserPosts(user string, opts ...RequestOption) (*ListingResponse, error) {
url := fmt.Sprintf("https://oauth.reddit.com/u/%s/submitted.json", user)
opts = append([]RequestOption{
WithMethod("GET"),
WithToken(rac.accessToken),
WithURL(url),
}, opts...)
req := NewRequest(opts...)
lr, err := rac.request(req, NewListingResponse, nil)
if err != nil {
return nil, err
}
return lr.(*ListingResponse), nil
}
func (rac *AuthenticatedClient) UserAbout(user string, opts ...RequestOption) (*UserResponse, error) {
url := fmt.Sprintf("https://oauth.reddit.com/u/%s/about.json", user)
opts = append([]RequestOption{
WithMethod("GET"),
WithToken(rac.accessToken),
WithURL(url),
}, opts...)
req := NewRequest(opts...)
ur, err := rac.request(req, NewUserResponse, nil)
if err != nil {
return nil, err
}
return ur.(*UserResponse), nil
}
func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...RequestOption) (*SubredditResponse, error) { func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...RequestOption) (*SubredditResponse, error) {
url := fmt.Sprintf("https://oauth.reddit.com/r/%s/about.json", subreddit) url := fmt.Sprintf("https://oauth.reddit.com/r/%s/about.json", subreddit)
opts = append([]RequestOption{ opts = append([]RequestOption{

View file

@ -0,0 +1,75 @@
{
"kind": "t2",
"data": {
"is_employee": false,
"is_friend": false,
"subreddit": {
"default_set": true,
"user_is_contributor": false,
"banner_img": "",
"restrict_posting": true,
"user_is_banned": false,
"free_form_reports": true,
"community_icon": null,
"show_media": true,
"icon_color": "#FFD635",
"user_is_muted": false,
"display_name": "u_changelog",
"header_img": null,
"title": "",
"previous_names": [],
"over_18": false,
"icon_size": [
256,
256
],
"primary_color": "",
"icon_img": "https://www.redditstatic.com/avatars/defaults/v2/avatar_default_2.png",
"description": "",
"submit_link_label": "",
"header_size": null,
"restrict_commenting": false,
"subscribers": 0,
"submit_text_label": "",
"is_default_icon": true,
"link_flair_position": "",
"display_name_prefixed": "u/changelog",
"key_color": "",
"name": "t5_c30sw",
"is_default_banner": true,
"url": "/user/changelog/",
"quarantine": false,
"banner_size": null,
"user_is_moderator": false,
"accept_followers": true,
"public_description": "",
"link_flair_enabled": false,
"disable_contributor_requests": false,
"subreddit_type": "user",
"user_is_subscriber": false
},
"snoovatar_size": null,
"awardee_karma": 23,
"id": "1ia22",
"verified": true,
"is_gold": false,
"is_mod": true,
"awarder_karma": 0,
"has_verified_email": true,
"icon_img": "https://www.redditstatic.com/avatars/defaults/v2/avatar_default_2.png",
"hide_from_robots": true,
"link_karma": 2676,
"pref_show_snoovatar": false,
"is_blocked": false,
"total_karma": 4132,
"accept_chats": false,
"name": "changelog",
"created": 1176721866.0,
"created_utc": 1176721866.0,
"snoovatar_img": "",
"comment_karma": 1433,
"accept_followers": true,
"has_subscribed": true,
"accept_pms": true
}
}

View file

@ -160,4 +160,23 @@ func NewSubredditResponse(val *fastjson.Value) interface{} {
return sr return sr
} }
type UserResponse struct {
Thing
AcceptFollowers bool
Name string
}
func NewUserResponse(val *fastjson.Value) interface{} {
ur := &UserResponse{}
ur.Kind = string(val.GetStringBytes("kind"))
data := val.Get("data")
ur.ID = string(data.GetStringBytes("id"))
ur.Name = string(data.GetStringBytes("name"))
ur.AcceptFollowers = data.GetBool("accept_followers")
return ur
}
var EmptyListingResponse = &ListingResponse{} var EmptyListingResponse = &ListingResponse{}

View file

@ -111,3 +111,20 @@ func TestSubredditResponseParsing(t *testing.T) {
assert.Equal(t, "2vq0w", s.ID) assert.Equal(t, "2vq0w", s.ID)
assert.Equal(t, "DestinyTheGame", s.Name) assert.Equal(t, "DestinyTheGame", s.Name)
} }
func TestUserResponseParsing(t *testing.T) {
bb, err := ioutil.ReadFile("testdata/user_about.json")
assert.NoError(t, err)
val, err := parser.ParseBytes(bb)
assert.NoError(t, err)
ret := NewUserResponse(val)
u := ret.(*UserResponse)
assert.NotNil(t, u)
assert.Equal(t, "t2", u.Kind)
assert.Equal(t, "1ia22", u.ID)
assert.Equal(t, "changelog", u.Name)
assert.Equal(t, true, u.AcceptFollowers)
}

View file

@ -30,6 +30,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
&sr.ID, &sr.ID,
&sr.SubredditID, &sr.SubredditID,
&sr.Name, &sr.Name,
&sr.LastCheckedAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -40,7 +41,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) { func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) {
query := ` query := `
SELECT id, subreddit_id, name SELECT id, subreddit_id, name, last_checked_at
FROM subreddits FROM subreddits
WHERE id = $1` WHERE id = $1`
@ -57,7 +58,7 @@ func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (do
func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) { func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) {
query := ` query := `
SELECT id, subreddit_id, name SELECT id, subreddit_id, name, last_checked_at
FROM subreddits FROM subreddits
WHERE name = $1` WHERE name = $1`
@ -78,7 +79,8 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
query := ` query := `
INSERT INTO subreddits (subreddit_id, name) INSERT INTO subreddits (subreddit_id, name)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT(subreddit_id) DO NOTHING ON CONFLICT(subreddit_id) DO
UPDATE SET last_checked_at = $3
RETURNING id` RETURNING id`
return p.pool.QueryRow( return p.pool.QueryRow(
@ -86,5 +88,6 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
query, query,
sr.SubredditID, sr.SubredditID,
sr.NormalizedName(), sr.NormalizedName(),
sr.LastCheckedAt,
).Scan(&sr.ID) ).Scan(&sr.ID)
} }

View file

@ -0,0 +1,104 @@
package repository
import (
"context"
"fmt"
"strings"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/jackc/pgx/v4/pgxpool"
)
type postgresUserRepository struct {
pool *pgxpool.Pool
}
func NewPostgresUser(pool *pgxpool.Pool) domain.UserRepository {
return &postgresUserRepository{pool: pool}
}
func (p *postgresUserRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.User, error) {
rows, err := p.pool.Query(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var uu []domain.User
for rows.Next() {
var u domain.User
if err := rows.Scan(
&u.ID,
&u.UserID,
&u.Name,
&u.LastCheckedAt,
); err != nil {
return nil, err
}
uu = append(uu, u)
}
return uu, nil
}
func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.User, error) {
query := `
SELECT id, user_id, name, last_checked_at
FROM users
WHERE id = $1`
srs, err := p.fetch(ctx, query, id)
if err != nil {
return domain.User{}, err
}
if len(srs) == 0 {
return domain.User{}, domain.ErrNotFound
}
return srs[0], nil
}
func (p *postgresUserRepository) GetByName(ctx context.Context, name string) (domain.User, error) {
query := `
SELECT id, user_id, name, last_checked_at
FROM users
WHERE name = $1`
name = strings.ToLower(name)
srs, err := p.fetch(ctx, query, name)
if err != nil {
return domain.User{}, err
}
if len(srs) == 0 {
return domain.User{}, domain.ErrNotFound
}
return srs[0], nil
}
func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.User) error {
query := `
INSERT INTO users (user_id, name)
VALUES ($1, $2)
ON CONFLICT(user_id) DO
UPDATE SET last_checked_at = $3
RETURNING id`
return p.pool.QueryRow(
ctx,
query,
u.UserID,
u.NormalizedName(),
u.LastCheckedAt,
).Scan(&u.ID)
}
func (p *postgresUserRepository) Delete(ctx context.Context, id int64) error {
query := `DELETE FROM users WHERE id = $1`
res, err := p.pool.Exec(ctx, query, id)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
}
return err
}

View file

@ -31,9 +31,11 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
if err := rows.Scan( if err := rows.Scan(
&watcher.ID, &watcher.ID,
&watcher.CreatedAt, &watcher.CreatedAt,
&watcher.LastNotifiedAt,
&watcher.DeviceID, &watcher.DeviceID,
&watcher.AccountID, &watcher.AccountID,
&watcher.SubredditID, &watcher.Type,
&watcher.WatcheeID,
&watcher.Upvotes, &watcher.Upvotes,
&watcher.Keyword, &watcher.Keyword,
&watcher.Flair, &watcher.Flair,
@ -49,7 +51,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg
func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) { func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) {
query := ` query := `
SELECT id, created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain, hits SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits
FROM watchers FROM watchers
WHERE id = $1` WHERE id = $1`
@ -64,13 +66,21 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
return watchers[0], nil return watchers[0], nil
} }
func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) { func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) ([]domain.Watcher, error) {
query := ` query := `
SELECT id, created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain, hits SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits
FROM watchers FROM watchers
WHERE subreddit_id = $1` WHERE type = $1 AND watchee_id = $2`
return p.fetch(ctx, query, id) return p.fetch(ctx, query, typ, id)
}
func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) {
return p.GetByTypeAndWatcheeID(ctx, domain.SubredditWatcher, id)
}
func (p *postgresWatcherRepository) GetByUserID(ctx context.Context, id int64) ([]domain.Watcher, error) {
return p.GetByTypeAndWatcheeID(ctx, domain.UserWatcher, id)
} }
func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]domain.Watcher, error) { func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]domain.Watcher, error) {
@ -78,9 +88,11 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
SELECT SELECT
watchers.id, watchers.id,
watchers.created_at, watchers.created_at,
watchers.last_notified_at
watchers.device_id, watchers.device_id,
watchers.account_id, watchers.account_id,
watchers.subreddit_id, watchers.type,
watchers.watchee_id,
watchers.upvotes, watchers.upvotes,
watchers.keyword, watchers.keyword,
watchers.flair, watchers.flair,
@ -101,8 +113,8 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
query := ` query := `
INSERT INTO watchers INSERT INTO watchers
(created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain) (created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id` RETURNING id`
return p.pool.QueryRow( return p.pool.QueryRow(
@ -111,7 +123,8 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
now, now,
watcher.DeviceID, watcher.DeviceID,
watcher.AccountID, watcher.AccountID,
watcher.SubredditID, watcher.Type,
watcher.WatcheeID,
watcher.Upvotes, watcher.Upvotes,
watcher.Keyword, watcher.Keyword,
watcher.Flair, watcher.Flair,
@ -145,8 +158,9 @@ func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.
} }
func (p *postgresWatcherRepository) IncrementHits(ctx context.Context, id int64) error { func (p *postgresWatcherRepository) IncrementHits(ctx context.Context, id int64) error {
query := `UPDATE watchers SET hits = hits + 1 WHERE id = $1` now := time.Now().Unix()
res, err := p.pool.Exec(ctx, query, id) query := `UPDATE watchers SET hits = hits + 1, last_notified_at = $2 WHERE id = $1`
res, err := p.pool.Exec(ctx, query, id, now)
if res.RowsAffected() != 1 { if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -163,3 +177,13 @@ func (p *postgresWatcherRepository) Delete(ctx context.Context, id int64) error
} }
return err return err
} }
func (p *postgresWatcherRepository) DeleteByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) error {
query := `DELETE FROM watchers WHERE type = $1 AND watchee_id = $2`
res, err := p.pool.Exec(ctx, query, typ, id)
if res.RowsAffected() == 0 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
}
return err
}

View file

@ -170,7 +170,6 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
if len(watchers) == 0 { if len(watchers) == 0 {
sc.logger.WithFields(logrus.Fields{ sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID, "subreddit#id": subreddit.ID,
"err": err,
}).Info("no watchers for subreddit, skipping") }).Info("no watchers for subreddit, skipping")
return return
} }
@ -299,6 +298,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
"count": len(posts), "count": len(posts),
}).Debug("checking posts for hits") }).Debug("checking posts for hits")
for _, post := range posts { for _, post := range posts {
lowcaseTitle := strings.ToLower(post.Title)
lowcaseFlair := strings.ToLower(post.Flair)
lowcaseDomain := strings.ToLower(post.URL)
ids := []int64{} ids := []int64{}
for _, watcher := range watchers { for _, watcher := range watchers {
@ -313,15 +316,15 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
matched = false matched = false
} }
if watcher.Keyword != "" && !strings.Contains(post.Title, watcher.Keyword) { if watcher.Keyword != "" && !strings.Contains(lowcaseTitle, watcher.Keyword) {
matched = false matched = false
} }
if watcher.Flair != "" && !strings.Contains(post.Flair, watcher.Flair) { if watcher.Flair != "" && !strings.Contains(lowcaseFlair, watcher.Flair) {
matched = false matched = false
} }
if watcher.Domain != "" && !strings.Contains(post.URL, watcher.Domain) { if watcher.Domain != "" && !strings.Contains(lowcaseDomain, watcher.Domain) {
matched = false matched = false
} }
@ -329,8 +332,6 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
continue continue
} }
_ = sc.watcherRepo.IncrementHits(ctx, watcher.ID)
lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID) lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID)
notified, _ := sc.redis.Get(ctx, lockKey).Bool() notified, _ := sc.redis.Get(ctx, lockKey).Bool()
@ -345,6 +346,15 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
continue continue
} }
if err := sc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"watcher#id": watcher.ID,
"err": err,
}).Error("could not increment hits")
return
}
sc.logger.WithFields(logrus.Fields{ sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID, "subreddit#id": subreddit.ID,
"subreddit#name": subreddit.Name, "subreddit#name": subreddit.Name,

299
internal/worker/users.go Normal file
View file

@ -0,0 +1,299 @@
package worker
import (
"context"
"fmt"
"math/rand"
"os"
"strconv"
"github.com/DataDog/datadog-go/statsd"
"github.com/adjust/rmq/v4"
"github.com/go-redis/redis/v8"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/sideshow/apns2"
"github.com/sideshow/apns2/payload"
"github.com/sideshow/apns2/token"
"github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit"
"github.com/christianselig/apollo-backend/internal/repository"
)
type usersWorker struct {
logger *logrus.Logger
statsd *statsd.Client
db *pgxpool.Pool
redis *redis.Client
queue rmq.Connection
reddit *reddit.Client
apns *token.Token
consumers int
accountRepo domain.AccountRepository
deviceRepo domain.DeviceRepository
userRepo domain.UserRepository
watcherRepo domain.WatcherRepository
}
func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
statsd,
consumers,
)
var apns *token.Token
{
authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH"))
if err != nil {
panic(err)
}
apns = &token.Token{
AuthKey: authKey,
KeyID: os.Getenv("APPLE_KEY_ID"),
TeamID: os.Getenv("APPLE_TEAM_ID"),
}
}
return &usersWorker{
logger,
statsd,
db,
redis,
queue,
reddit,
apns,
consumers,
repository.NewPostgresAccount(db),
repository.NewPostgresDevice(db),
repository.NewPostgresUser(db),
repository.NewPostgresWatcher(db),
}
}
func (uw *usersWorker) Start() error {
queue, err := uw.queue.OpenQueue("users")
if err != nil {
return err
}
uw.logger.WithFields(logrus.Fields{
"numConsumers": uw.consumers,
}).Info("starting up users worker")
prefetchLimit := int64(uw.consumers * 2)
if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil {
return err
}
host, _ := os.Hostname()
for i := 0; i < uw.consumers; i++ {
name := fmt.Sprintf("consumer %s-%d", host, i)
consumer := NewUsersConsumer(uw, i)
if _, err := queue.AddConsumer(name, consumer); err != nil {
return err
}
}
return nil
}
func (uw *usersWorker) Stop() {
<-uw.queue.StopAllConsuming() // wait for all Consume() calls to finish
}
type usersConsumer struct {
*usersWorker
tag int
apnsSandbox *apns2.Client
apnsProduction *apns2.Client
}
func NewUsersConsumer(uw *usersWorker, tag int) *usersConsumer {
return &usersConsumer{
uw,
tag,
apns2.NewTokenClient(uw.apns),
apns2.NewTokenClient(uw.apns).Production(),
}
}
func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
ctx := context.Background()
uc.logger.WithFields(logrus.Fields{
"user#id": delivery.Payload(),
}).Debug("starting job")
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
if err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": delivery.Payload(),
"err": err,
}).Error("failed to parse user ID")
_ = delivery.Reject()
return
}
defer func() { _ = delivery.Ack() }()
user, err := uc.userRepo.GetByID(ctx, id)
if err != nil {
uc.logger.WithFields(logrus.Fields{
"err": err,
}).Error("failed to fetch user from database")
return
}
watchers, err := uc.watcherRepo.GetByUserID(ctx, user.ID)
if err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"err": err,
}).Error("failed to fetch watchers from database")
return
}
if len(watchers) == 0 {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
}).Info("no watchers for user, skipping")
return
}
// Load 25 newest posts
i := rand.Intn(len(watchers))
watcher := watchers[i]
acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID)
rac := uc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
ru, err := rac.UserAbout(user.Name)
if err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"err": err,
}).Error("failed to fetch user details")
return
}
if !ru.AcceptFollowers {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
}).Info("user disabled followers, removing")
if err := uc.watcherRepo.DeleteByTypeAndWatcheeID(ctx, domain.UserWatcher, user.ID); err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"err": err,
}).Error("failed to delete watchers for user who does not allow followers")
return
}
if err := uc.userRepo.Delete(ctx, user.ID); err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"err": err,
}).Error("failed to delete user")
return
}
}
posts, err := rac.UserPosts(user.Name)
if err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"err": err,
}).Error("failed to fetch user activity")
return
}
for _, post := range posts.Children {
notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo"
notification.Payload = payloadFromUserPost(post)
for _, watcher := range watchers {
// Make sure we only alert on activities created after the search
if watcher.CreatedAt > post.CreatedAt {
continue
}
if watcher.LastNotifiedAt > post.CreatedAt {
continue
}
if err := uc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"watcher#id": watcher.ID,
"err": err,
}).Error("could not increment hits")
return
}
device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID)
notification.DeviceToken = device.APNSToken
client := uc.apnsProduction
if device.Sandbox {
client = uc.apnsSandbox
}
res, err := client.Push(notification)
if err != nil {
_ = uc.statsd.Incr("apns.notification.errors", []string{}, 1)
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"device#id": device.ID,
"err": err,
"status": res.StatusCode,
"reason": res.Reason,
}).Error("failed to send notification")
} else {
_ = uc.statsd.Incr("apns.notification.sent", []string{}, 1)
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"device#id": device.ID,
"device#token": device.APNSToken,
}).Info("sent notification")
}
}
}
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"user#name": user.Name,
}).Debug("finishing job")
}
func payloadFromUserPost(post *reddit.Thing) *payload.Payload {
title := fmt.Sprintf("👨\u200d🚀 User post! (u/%s)", post.Author)
payload := payload.
NewPayload().
AlertTitle(title).
AlertBody(post.Title).
AlertSummaryArg(post.Author).
Category("user-watch").
Custom("post_title", post.Title).
Custom("post_id", post.ID).
Custom("subreddit", post.Subreddit).
Custom("author", post.Author).
Custom("post_age", post.CreatedAt).
MutableContent().
Sound("traloop.wav")
return payload
}