diff --git a/go.sum b/go.sum index 81569a4..9556f7d 100644 --- a/go.sum +++ b/go.sum @@ -64,7 +64,6 @@ github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInq github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= -github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -122,8 +121,6 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= github.com/go-redis/redis/v8 v8.3.2/go.mod h1:jszGxBCez8QA1HWSmQxJO9Y82kNibbUmeYhKWrBejTU= -github.com/go-redis/redis/v8 v8.11.3 h1:GCjoYp8c+yQTJfc0n69iwSiHjvuAdruxl7elnZCxgt8= -github.com/go-redis/redis/v8 v8.11.3/go.mod h1:xNJ9xDG09FsIPwh3bWdk+0oDWHbtF9rPN0F/oD9XeKc= github.com/go-redis/redis/v8 v8.11.4 h1:kHoYkfZP6+pe04aFTnhDH6GDROa5yJdHJVNxV3F46Tg= github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -349,8 +346,6 @@ github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vv github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= -github.com/onsi/gomega v1.15.0 h1:WjP/FQ/sk43MRmnEcT+MlDw2TFvkrXlprrPST/IudjU= -github.com/onsi/gomega v1.15.0/go.mod h1:cIuvLEne0aoVhAgh/O6ac0Op8WWw9H6eYCriF+tEHG0= github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c= github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= diff --git a/internal/api/api.go b/internal/api/api.go index df8f564..26125d9 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -28,6 +28,7 @@ type api struct { deviceRepo domain.DeviceRepository subredditRepo domain.SubredditRepository watcherRepo domain.WatcherRepository + userRepo domain.UserRepository } func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api { @@ -56,6 +57,7 @@ func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, p deviceRepo := repository.NewPostgresDevice(pool) subredditRepo := repository.NewPostgresSubreddit(pool) watcherRepo := repository.NewPostgresWatcher(pool) + userRepo := repository.NewPostgresUser(pool) return &api{ logger: logger, @@ -67,6 +69,7 @@ func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, p deviceRepo: deviceRepo, subredditRepo: subredditRepo, watcherRepo: watcherRepo, + userRepo: userRepo, } } diff --git a/internal/api/watcher.go b/internal/api/watcher.go index 77a8528..46f8c08 100644 --- a/internal/api/watcher.go +++ b/internal/api/watcher.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "strconv" + "strings" "github.com/christianselig/apollo-backend/internal/domain" "github.com/gorilla/mux" @@ -18,6 +19,8 @@ type watcherCriteria struct { } type createWatcherRequest struct { + Type string + User string Subreddit string Criteria watcherCriteria } @@ -69,33 +72,63 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { } ac := a.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken) - srr, err := ac.SubredditAbout(cwr.Subreddit) - if err != nil { - a.errorResponse(w, r, 422, err.Error()) - return + + watcher := domain.Watcher{ + DeviceID: dev.ID, + AccountID: account.ID, + Upvotes: cwr.Criteria.Upvotes, + Keyword: strings.ToLower(cwr.Criteria.Keyword), + Flair: strings.ToLower(cwr.Criteria.Flair), + Domain: strings.ToLower(cwr.Criteria.Domain), } - sr, err := a.subredditRepo.GetByName(ctx, cwr.Subreddit) - if err != nil { - switch err { - case domain.ErrNotFound: - // Might be that we don't know about that subreddit yet - sr = domain.Subreddit{SubredditID: srr.ID, Name: srr.Name} - _ = a.subredditRepo.CreateOrUpdate(ctx, &sr) - default: + if cwr.Type == "subreddit" { + srr, err := ac.SubredditAbout(cwr.Subreddit) + if err != nil { + a.errorResponse(w, r, 422, err.Error()) + return + } + + sr, err := a.subredditRepo.GetByName(ctx, cwr.Subreddit) + if err != nil { + switch err { + case domain.ErrNotFound: + // Might be that we don't know about that subreddit yet + sr = domain.Subreddit{SubredditID: srr.ID, Name: srr.Name} + _ = a.subredditRepo.CreateOrUpdate(ctx, &sr) + default: + a.errorResponse(w, r, 500, err.Error()) + return + } + } + + watcher.Type = domain.SubredditWatcher + watcher.WatcheeID = sr.ID + } else if cwr.Type == "user" { + urr, err := ac.UserAbout(cwr.User) + if err != nil { a.errorResponse(w, r, 500, err.Error()) return } - } - watcher := domain.Watcher{ - SubredditID: sr.ID, - DeviceID: dev.ID, - AccountID: account.ID, - Upvotes: cwr.Criteria.Upvotes, - Keyword: cwr.Criteria.Keyword, - Flair: cwr.Criteria.Flair, - Domain: cwr.Criteria.Domain, + if !urr.AcceptFollowers { + a.errorResponse(w, r, 422, "no followers accepted") + return + } + + u := domain.User{UserID: urr.ID, Name: urr.Name} + err = a.userRepo.CreateOrUpdate(ctx, &u) + + if err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + watcher.Type = domain.UserWatcher + watcher.WatcheeID = u.ID + } else { + a.errorResponse(w, r, 422, "unknown watcher type") + return } if err := a.watcherRepo.Create(ctx, &watcher); err != nil { diff --git a/internal/cmd/scheduler.go b/internal/cmd/scheduler.go index 482a29b..7836aa4 100644 --- a/internal/cmd/scheduler.go +++ b/internal/cmd/scheduler.go @@ -24,7 +24,8 @@ const ( checkTimeout = 60 // how long until we force a check accountEnqueueTimeout = 5 // how frequently we want to check (seconds) - subredditEnqueueTimeout = 5 * 60 // how frequently we want to check (seconds) + subredditEnqueueTimeout = 2 * 60 // how frequently we want to check (seconds) + userEnqueueTimeout = 2 * 60 // how frequently we want to check (seconds) staleAccountThreshold = 7200 // 2 hours ) @@ -76,9 +77,15 @@ func SchedulerCmd(ctx context.Context) *cobra.Command { return err } + userQueue, err := queue.OpenQueue("users") + if err != nil { + return err + } + s := gocron.NewScheduler(time.UTC) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) }) _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, subredditQueue) }) + _, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) }) _, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) }) _, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) }) _, _ = s.Every(1).Minute().Do(func() { pruneAccounts(ctx, logger, db) }) @@ -202,6 +209,70 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie } } +func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { + now := time.Now() + ready := now.Unix() - userEnqueueTimeout + + ids := []int64{} + + err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { + stmt := ` + WITH userb AS ( + SELECT id + FROM users + WHERE last_checked_at < $1 + ORDER BY last_checked_at + LIMIT 100 + ) + UPDATE users + SET last_checked_at = $2 + WHERE users.id IN(SELECT id FROM userb) + RETURNING users.id` + rows, err := tx.Query(ctx, stmt, ready, now.Unix()) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var id int64 + _ = rows.Scan(&id) + ids = append(ids, id) + } + return nil + }) + + if err != nil { + logger.WithFields(logrus.Fields{ + "err": err, + }).Error("failed to fetch batch of users") + return + } + + if len(ids) == 0 { + return + } + + logger.WithFields(logrus.Fields{ + "count": len(ids), + "start": ready, + }).Debug("enqueueing user batch") + + batchIds := make([]string, len(ids)) + for i, id := range ids { + batchIds[i] = strconv.FormatInt(id, 10) + } + + if err = queue.Publish(batchIds...); err != nil { + logger.WithFields(logrus.Fields{ + "err": err, + }).Error("failed to enqueue user") + } + + _ = statsd.Histogram("apollo.queue.users.enqueued", float64(len(ids)), []string{}, 1) + _ = statsd.Histogram("apollo.queue.users.runtime", float64(time.Since(now).Milliseconds()), []string{}, 1) + +} + func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { now := time.Now() ready := now.Unix() - subredditEnqueueTimeout diff --git a/internal/cmd/worker.go b/internal/cmd/worker.go index 7c502f7..dd6f36c 100644 --- a/internal/cmd/worker.go +++ b/internal/cmd/worker.go @@ -15,6 +15,7 @@ var ( queues = map[string]worker.NewWorkerFn{ "notifications": worker.NewNotificationsWorker, "subreddits": worker.NewSubredditsWorker, + "users": worker.NewUsersWorker, } ) diff --git a/internal/cmdutil/cmdutil.go b/internal/cmdutil/cmdutil.go index e0e1cc3..613f5ac 100644 --- a/internal/cmdutil/cmdutil.go +++ b/internal/cmdutil/cmdutil.go @@ -51,6 +51,10 @@ func NewRedisClient(ctx context.Context) (*redis.Client, error) { } func NewDatabasePool(ctx context.Context, maxConns int) (*pgxpool.Pool, error) { + if maxConns == 0 { + maxConns = 1 + } + url := fmt.Sprintf( "%s?pool_max_conns=%d&pool_min_conns=%d", os.Getenv("DATABASE_CONNECTION_POOL_URL"), diff --git a/internal/domain/subreddit.go b/internal/domain/subreddit.go index 22ce7e9..4563f22 100644 --- a/internal/domain/subreddit.go +++ b/internal/domain/subreddit.go @@ -6,7 +6,8 @@ import ( ) type Subreddit struct { - ID int64 + ID int64 + LastCheckedAt float64 // Reddit information SubredditID string diff --git a/internal/domain/user.go b/internal/domain/user.go new file mode 100644 index 0000000..c8b6f8f --- /dev/null +++ b/internal/domain/user.go @@ -0,0 +1,26 @@ +package domain + +import ( + "context" + "strings" +) + +type User struct { + ID int64 + LastCheckedAt float64 + + // Reddit information + UserID string + Name string +} + +func (u *User) NormalizedName() string { + return strings.ToLower(u.Name) +} + +type UserRepository interface { + GetByID(context.Context, int64) (User, error) + GetByName(context.Context, string) (User, error) + + CreateOrUpdate(context.Context, *User) error +} diff --git a/internal/domain/watcher.go b/internal/domain/watcher.go index 3a77404..99f20b7 100644 --- a/internal/domain/watcher.go +++ b/internal/domain/watcher.go @@ -2,13 +2,21 @@ package domain import "context" +type WatcherType int64 + +const ( + SubredditWatcher WatcherType = iota + UserWatcher +) + type Watcher struct { ID int64 CreatedAt float64 - DeviceID int64 - AccountID int64 - SubredditID int64 + DeviceID int64 + AccountID int64 + Type WatcherType + WatcheeID int64 Upvotes int64 Keyword string @@ -20,6 +28,7 @@ type Watcher struct { type WatcherRepository interface { GetByID(ctx context.Context, id int64) (Watcher, error) GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error) + GetByUserID(ctx context.Context, id int64) ([]Watcher, error) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]Watcher, error) Create(ctx context.Context, watcher *Watcher) error diff --git a/internal/reddit/client.go b/internal/reddit/client.go index 37adb9d..73de11b 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -171,6 +171,41 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { return ret, nil } +func (rac *AuthenticatedClient) UserPosts(user string, opts ...RequestOption) (*ListingResponse, error) { + url := fmt.Sprintf("https://oauth.reddit.com/u/%s/submitted.json", user) + opts = append([]RequestOption{ + WithMethod("GET"), + WithToken(rac.accessToken), + WithURL(url), + }, opts...) + req := NewRequest(opts...) + + lr, err := rac.request(req, NewListingResponse, nil) + if err != nil { + return nil, err + } + + return lr.(*ListingResponse), nil +} + +func (rac *AuthenticatedClient) UserAbout(user string, opts ...RequestOption) (*UserResponse, error) { + url := fmt.Sprintf("https://oauth.reddit.com/u/%s/about.json", user) + opts = append([]RequestOption{ + WithMethod("GET"), + WithToken(rac.accessToken), + WithURL(url), + }, opts...) + req := NewRequest(opts...) + ur, err := rac.request(req, NewUserResponse, nil) + + if err != nil { + return nil, err + } + + return ur.(*UserResponse), nil + +} + func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...RequestOption) (*SubredditResponse, error) { url := fmt.Sprintf("https://oauth.reddit.com/r/%s/about.json", subreddit) opts = append([]RequestOption{ diff --git a/internal/reddit/testdata/user_about.json b/internal/reddit/testdata/user_about.json new file mode 100644 index 0000000..8cc5c56 --- /dev/null +++ b/internal/reddit/testdata/user_about.json @@ -0,0 +1,75 @@ +{ + "kind": "t2", + "data": { + "is_employee": false, + "is_friend": false, + "subreddit": { + "default_set": true, + "user_is_contributor": false, + "banner_img": "", + "restrict_posting": true, + "user_is_banned": false, + "free_form_reports": true, + "community_icon": null, + "show_media": true, + "icon_color": "#FFD635", + "user_is_muted": false, + "display_name": "u_changelog", + "header_img": null, + "title": "", + "previous_names": [], + "over_18": false, + "icon_size": [ + 256, + 256 + ], + "primary_color": "", + "icon_img": "https://www.redditstatic.com/avatars/defaults/v2/avatar_default_2.png", + "description": "", + "submit_link_label": "", + "header_size": null, + "restrict_commenting": false, + "subscribers": 0, + "submit_text_label": "", + "is_default_icon": true, + "link_flair_position": "", + "display_name_prefixed": "u/changelog", + "key_color": "", + "name": "t5_c30sw", + "is_default_banner": true, + "url": "/user/changelog/", + "quarantine": false, + "banner_size": null, + "user_is_moderator": false, + "accept_followers": true, + "public_description": "", + "link_flair_enabled": false, + "disable_contributor_requests": false, + "subreddit_type": "user", + "user_is_subscriber": false + }, + "snoovatar_size": null, + "awardee_karma": 23, + "id": "1ia22", + "verified": true, + "is_gold": false, + "is_mod": true, + "awarder_karma": 0, + "has_verified_email": true, + "icon_img": "https://www.redditstatic.com/avatars/defaults/v2/avatar_default_2.png", + "hide_from_robots": true, + "link_karma": 2676, + "pref_show_snoovatar": false, + "is_blocked": false, + "total_karma": 4132, + "accept_chats": false, + "name": "changelog", + "created": 1176721866.0, + "created_utc": 1176721866.0, + "snoovatar_img": "", + "comment_karma": 1433, + "accept_followers": true, + "has_subscribed": true, + "accept_pms": true + } +} diff --git a/internal/reddit/types.go b/internal/reddit/types.go index 6770dbe..582fd62 100644 --- a/internal/reddit/types.go +++ b/internal/reddit/types.go @@ -160,4 +160,23 @@ func NewSubredditResponse(val *fastjson.Value) interface{} { return sr } +type UserResponse struct { + Thing + + AcceptFollowers bool + Name string +} + +func NewUserResponse(val *fastjson.Value) interface{} { + ur := &UserResponse{} + ur.Kind = string(val.GetStringBytes("kind")) + + data := val.Get("data") + ur.ID = string(data.GetStringBytes("id")) + ur.Name = string(data.GetStringBytes("name")) + ur.AcceptFollowers = data.GetBool("accept_followers") + + return ur +} + var EmptyListingResponse = &ListingResponse{} diff --git a/internal/reddit/types_test.go b/internal/reddit/types_test.go index cd62b9a..7bc8884 100644 --- a/internal/reddit/types_test.go +++ b/internal/reddit/types_test.go @@ -111,3 +111,20 @@ func TestSubredditResponseParsing(t *testing.T) { assert.Equal(t, "2vq0w", s.ID) assert.Equal(t, "DestinyTheGame", s.Name) } + +func TestUserResponseParsing(t *testing.T) { + bb, err := ioutil.ReadFile("testdata/user_about.json") + assert.NoError(t, err) + + val, err := parser.ParseBytes(bb) + assert.NoError(t, err) + + ret := NewUserResponse(val) + u := ret.(*UserResponse) + assert.NotNil(t, u) + + assert.Equal(t, "t2", u.Kind) + assert.Equal(t, "1ia22", u.ID) + assert.Equal(t, "changelog", u.Name) + assert.Equal(t, true, u.AcceptFollowers) +} diff --git a/internal/repository/postgres_subreddit.go b/internal/repository/postgres_subreddit.go index 5754060..7ed2cf2 100644 --- a/internal/repository/postgres_subreddit.go +++ b/internal/repository/postgres_subreddit.go @@ -30,6 +30,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a &sr.ID, &sr.SubredditID, &sr.Name, + &sr.LastCheckedAt, ); err != nil { return nil, err } @@ -40,7 +41,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) { query := ` - SELECT id, subreddit_id, name + SELECT id, subreddit_id, name, last_checked_at FROM subreddits WHERE id = $1` @@ -57,7 +58,7 @@ func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (do func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) { query := ` - SELECT id, subreddit_id, name + SELECT id, subreddit_id, name, last_checked_at FROM subreddits WHERE name = $1` @@ -78,7 +79,8 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do query := ` INSERT INTO subreddits (subreddit_id, name) VALUES ($1, $2) - ON CONFLICT(subreddit_id) DO NOTHING + ON CONFLICT(subreddit_id) DO + UPDATE SET last_checked_at = $3 RETURNING id` return p.pool.QueryRow( @@ -86,5 +88,6 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do query, sr.SubredditID, sr.NormalizedName(), + sr.LastCheckedAt, ).Scan(&sr.ID) } diff --git a/internal/repository/postgres_user.go b/internal/repository/postgres_user.go new file mode 100644 index 0000000..283f940 --- /dev/null +++ b/internal/repository/postgres_user.go @@ -0,0 +1,93 @@ +package repository + +import ( + "context" + "strings" + + "github.com/christianselig/apollo-backend/internal/domain" + "github.com/jackc/pgx/v4/pgxpool" +) + +type postgresUserRepository struct { + pool *pgxpool.Pool +} + +func NewPostgresUser(pool *pgxpool.Pool) domain.UserRepository { + return &postgresUserRepository{pool: pool} +} + +func (p *postgresUserRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.User, error) { + rows, err := p.pool.Query(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var uu []domain.User + for rows.Next() { + var u domain.User + if err := rows.Scan( + &u.ID, + &u.UserID, + &u.Name, + &u.LastCheckedAt, + ); err != nil { + return nil, err + } + uu = append(uu, u) + } + return uu, nil +} + +func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.User, error) { + query := ` + SELECT id, user_id, name, last_checked_at + FROM users + WHERE id = $1` + + srs, err := p.fetch(ctx, query, id) + + if err != nil { + return domain.User{}, err + } + if len(srs) == 0 { + return domain.User{}, domain.ErrNotFound + } + return srs[0], nil +} + +func (p *postgresUserRepository) GetByName(ctx context.Context, name string) (domain.User, error) { + query := ` + SELECT id, user_id, name, last_checked_at + FROM users + WHERE name = $1` + + name = strings.ToLower(name) + + srs, err := p.fetch(ctx, query, name) + + if err != nil { + return domain.User{}, err + } + if len(srs) == 0 { + return domain.User{}, domain.ErrNotFound + } + return srs[0], nil +} + +func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.User) error { + query := ` + INSERT INTO users (user_id, name) + VALUES ($1, $2) + ON CONFLICT(user_id) DO + UPDATE SET last_checked_at = $3 + RETURNING id` + + return p.pool.QueryRow( + ctx, + query, + u.UserID, + u.NormalizedName(), + u.LastCheckedAt, + ).Scan(&u.ID) +} diff --git a/internal/repository/postgres_watcher.go b/internal/repository/postgres_watcher.go index 6de0191..92d27e1 100644 --- a/internal/repository/postgres_watcher.go +++ b/internal/repository/postgres_watcher.go @@ -33,7 +33,8 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg &watcher.CreatedAt, &watcher.DeviceID, &watcher.AccountID, - &watcher.SubredditID, + &watcher.Type, + &watcher.WatcheeID, &watcher.Upvotes, &watcher.Keyword, &watcher.Flair, @@ -49,7 +50,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) { query := ` - SELECT id, created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain, hits + SELECT id, created_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits FROM watchers WHERE id = $1` @@ -64,13 +65,21 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma return watchers[0], nil } -func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) { +func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) ([]domain.Watcher, error) { query := ` - SELECT id, created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain, hits + SELECT id, created_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits FROM watchers - WHERE subreddit_id = $1` + WHERE type = $1 AND watchee_id = $2` - return p.fetch(ctx, query, id) + return p.fetch(ctx, query, typ, id) +} + +func (p *postgresWatcherRepository) GetBySubredditID(ctx context.Context, id int64) ([]domain.Watcher, error) { + return p.GetByTypeAndWatcheeID(ctx, domain.SubredditWatcher, id) +} + +func (p *postgresWatcherRepository) GetByUserID(ctx context.Context, id int64) ([]domain.Watcher, error) { + return p.GetByTypeAndWatcheeID(ctx, domain.UserWatcher, id) } func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx context.Context, apns string, rid string) ([]domain.Watcher, error) { @@ -80,7 +89,8 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c watchers.created_at, watchers.device_id, watchers.account_id, - watchers.subreddit_id, + watchers.type, + watchers.watchee_id, watchers.upvotes, watchers.keyword, watchers.flair, @@ -101,7 +111,7 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain. query := ` INSERT INTO watchers - (created_at, device_id, account_id, subreddit_id, upvotes, keyword, flair, domain) + (created_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id` @@ -111,7 +121,8 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain. now, watcher.DeviceID, watcher.AccountID, - watcher.SubredditID, + watcher.Type, + watcher.WatcheeID, watcher.Upvotes, watcher.Keyword, watcher.Flair, diff --git a/internal/worker/subreddits.go b/internal/worker/subreddits.go index 65b788e..42488c9 100644 --- a/internal/worker/subreddits.go +++ b/internal/worker/subreddits.go @@ -170,7 +170,6 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { if len(watchers) == 0 { sc.logger.WithFields(logrus.Fields{ "subreddit#id": subreddit.ID, - "err": err, }).Info("no watchers for subreddit, skipping") return } @@ -299,6 +298,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { "count": len(posts), }).Debug("checking posts for hits") for _, post := range posts { + lowcaseTitle := strings.ToLower(post.Title) + lowcaseFlair := strings.ToLower(post.Flair) + lowcaseDomain := strings.ToLower(post.URL) + ids := []int64{} for _, watcher := range watchers { @@ -313,15 +316,15 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { matched = false } - if watcher.Keyword != "" && !strings.Contains(post.Title, watcher.Keyword) { + if watcher.Keyword != "" && !strings.Contains(lowcaseTitle, watcher.Keyword) { matched = false } - if watcher.Flair != "" && !strings.Contains(post.Flair, watcher.Flair) { + if watcher.Flair != "" && !strings.Contains(lowcaseFlair, watcher.Flair) { matched = false } - if watcher.Domain != "" && !strings.Contains(post.URL, watcher.Domain) { + if watcher.Domain != "" && !strings.Contains(lowcaseDomain, watcher.Domain) { matched = false } diff --git a/internal/worker/users.go b/internal/worker/users.go new file mode 100644 index 0000000..eda8f3a --- /dev/null +++ b/internal/worker/users.go @@ -0,0 +1,259 @@ +package worker + +import ( + "context" + "fmt" + "math/rand" + "os" + "strconv" + + "github.com/DataDog/datadog-go/statsd" + "github.com/adjust/rmq/v4" + "github.com/go-redis/redis/v8" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/sideshow/apns2" + "github.com/sideshow/apns2/payload" + "github.com/sideshow/apns2/token" + "github.com/sirupsen/logrus" + + "github.com/christianselig/apollo-backend/internal/domain" + "github.com/christianselig/apollo-backend/internal/reddit" + "github.com/christianselig/apollo-backend/internal/repository" +) + +type usersWorker struct { + logger *logrus.Logger + statsd *statsd.Client + db *pgxpool.Pool + redis *redis.Client + queue rmq.Connection + reddit *reddit.Client + apns *token.Token + + consumers int + + accountRepo domain.AccountRepository + deviceRepo domain.DeviceRepository + userRepo domain.UserRepository + watcherRepo domain.WatcherRepository +} + +func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker { + reddit := reddit.NewClient( + os.Getenv("REDDIT_CLIENT_ID"), + os.Getenv("REDDIT_CLIENT_SECRET"), + statsd, + consumers, + ) + + var apns *token.Token + { + authKey, err := token.AuthKeyFromFile(os.Getenv("APPLE_KEY_PATH")) + if err != nil { + panic(err) + } + + apns = &token.Token{ + AuthKey: authKey, + KeyID: os.Getenv("APPLE_KEY_ID"), + TeamID: os.Getenv("APPLE_TEAM_ID"), + } + } + + return &usersWorker{ + logger, + statsd, + db, + redis, + queue, + reddit, + apns, + consumers, + + repository.NewPostgresAccount(db), + repository.NewPostgresDevice(db), + repository.NewPostgresUser(db), + repository.NewPostgresWatcher(db), + } +} + +func (uw *usersWorker) Start() error { + queue, err := uw.queue.OpenQueue("users") + if err != nil { + return err + } + + uw.logger.WithFields(logrus.Fields{ + "numConsumers": uw.consumers, + }).Info("starting up users worker") + + prefetchLimit := int64(uw.consumers * 2) + + if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil { + return err + } + + host, _ := os.Hostname() + + for i := 0; i < uw.consumers; i++ { + name := fmt.Sprintf("consumer %s-%d", host, i) + + consumer := NewUsersConsumer(uw, i) + if _, err := queue.AddConsumer(name, consumer); err != nil { + return err + } + } + + return nil +} + +func (uw *usersWorker) Stop() { + <-uw.queue.StopAllConsuming() // wait for all Consume() calls to finish +} + +type usersConsumer struct { + *usersWorker + tag int + + apnsSandbox *apns2.Client + apnsProduction *apns2.Client +} + +func NewUsersConsumer(uw *usersWorker, tag int) *usersConsumer { + return &usersConsumer{ + uw, + tag, + apns2.NewTokenClient(uw.apns), + apns2.NewTokenClient(uw.apns).Production(), + } +} + +func (uc *usersConsumer) Consume(delivery rmq.Delivery) { + ctx := context.Background() + + uc.logger.WithFields(logrus.Fields{ + "user#id": delivery.Payload(), + }).Debug("starting job") + + id, err := strconv.ParseInt(delivery.Payload(), 10, 64) + if err != nil { + uc.logger.WithFields(logrus.Fields{ + "user#id": delivery.Payload(), + "err": err, + }).Error("failed to parse user ID") + + _ = delivery.Reject() + return + } + + defer func() { _ = delivery.Ack() }() + + user, err := uc.userRepo.GetByID(ctx, id) + if err != nil { + uc.logger.WithFields(logrus.Fields{ + "err": err, + }).Error("failed to fetch user from database") + return + } + + watchers, err := uc.watcherRepo.GetByUserID(ctx, user.ID) + if err != nil { + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "err": err, + }).Error("failed to fetch watchers from database") + return + } + + if len(watchers) == 0 { + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "err": err, + }).Info("no watchers for user, skipping") + return + } + + // Load 25 newest posts + i := rand.Intn(len(watchers)) + watcher := watchers[i] + + acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID) + rac := uc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) + + posts, err := rac.UserPosts(user.Name) + if err != nil { + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "err": err, + }).Error("failed to fetch user activity") + return + } + + for _, post := range posts.Children { + if post.CreatedAt < user.LastCheckedAt { + break + } + + notification := &apns2.Notification{} + notification.Topic = "com.christianselig.Apollo" + notification.Payload = payloadFromUserPost(post) + + for _, watcher := range watchers { + // Make sure we only alert on activities created after the search + if watcher.CreatedAt > post.CreatedAt { + continue + } + device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID) + notification.DeviceToken = device.APNSToken + + client := uc.apnsProduction + if device.Sandbox { + client = uc.apnsSandbox + } + + res, err := client.Push(notification) + if err != nil { + _ = uc.statsd.Incr("apns.notification.errors", []string{}, 1) + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "device#id": device.ID, + "err": err, + "status": res.StatusCode, + "reason": res.Reason, + }).Error("failed to send notification") + } else { + _ = uc.statsd.Incr("apns.notification.sent", []string{}, 1) + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "device#id": device.ID, + "device#token": device.APNSToken, + }).Info("sent notification") + } + } + } + + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "user#name": user.Name, + }).Debug("finishing job") +} + +func payloadFromUserPost(post *reddit.Thing) *payload.Payload { + title := fmt.Sprintf("👨\u200d🚀 User post! (u/%s)", post.Author) + + payload := payload. + NewPayload(). + AlertTitle(title). + AlertBody(post.Title). + AlertSummaryArg(post.Author). + Category("user-watch"). + Custom("post_title", post.Title). + Custom("post_id", post.ID). + Custom("subreddit", post.Subreddit). + Custom("author", post.Author). + Custom("post_age", post.CreatedAt). + MutableContent(). + Sound("traloop.wav") + + return payload +}