From dbcda74ab8764469df874bff469caa411a82ed84 Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Mon, 28 Mar 2022 17:05:01 -0400 Subject: [PATCH 1/4] Changes to schema --- internal/api/accounts.go | 9 +- internal/api/devices.go | 4 +- internal/api/receipt.go | 9 +- internal/api/watcher.go | 23 ++--- internal/cmd/scheduler.go | 104 ++++++++++------------ internal/domain/account.go | 27 ++++-- internal/domain/device.go | 19 ++-- internal/domain/subreddit.go | 7 +- internal/domain/user.go | 7 +- internal/domain/watcher.go | 5 +- internal/reddit/types.go | 40 +++++---- internal/reddit/types_test.go | 4 +- internal/repository/postgres_account.go | 58 +++++++----- internal/repository/postgres_device.go | 39 ++++---- internal/repository/postgres_subreddit.go | 10 +-- internal/repository/postgres_user.go | 15 ++-- internal/worker/notifications.go | 20 +++-- internal/worker/subreddits.go | 8 +- internal/worker/trending.go | 6 +- internal/worker/users.go | 4 +- 20 files changed, 226 insertions(+), 192 deletions(-) diff --git a/internal/api/accounts.go b/internal/api/accounts.go index 5b19018..3c381c3 100644 --- a/internal/api/accounts.go +++ b/internal/api/accounts.go @@ -150,7 +150,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { } // Reset expiration timer - acc.ExpiresAt = time.Now().Unix() + 3540 + acc.TokenExpiresAt = time.Now().Add(1 * time.Hour) acc.RefreshToken = tokens.RefreshToken acc.AccessToken = tokens.AccessToken @@ -175,7 +175,10 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { return } - _ = a.accountRepo.Associate(ctx, &acc, &dev) + if err := a.accountRepo.Associate(ctx, &acc, &dev); err != nil { + a.errorResponse(w, r, 422, err.Error()) + return + } } for _, acc := range accsMap { @@ -212,7 +215,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { } // Reset expiration timer - acct.ExpiresAt = time.Now().Unix() + 3540 + acct.TokenExpiresAt = time.Now().Add(1 * time.Hour) acct.RefreshToken = tokens.RefreshToken acct.AccessToken = tokens.AccessToken diff --git a/internal/api/devices.go b/internal/api/devices.go index 8c5a834..3a963a4 100644 --- a/internal/api/devices.go +++ b/internal/api/devices.go @@ -28,8 +28,8 @@ func (a *api) upsertDeviceHandler(w http.ResponseWriter, r *http.Request) { return } - d.ActiveUntil = time.Now().Unix() + domain.DeviceGracePeriodDuration - d.GracePeriodUntil = d.ActiveUntil + domain.DeviceGracePeriodAfterReceiptExpiry + d.ExpiresAt = time.Now().Add(domain.DeviceReceiptCheckPeriodDuration) + d.GracePeriodExpiresAt = d.ExpiresAt.Add(domain.DeviceGracePeriodAfterReceiptExpiry) if err := a.deviceRepo.CreateOrUpdate(ctx, d); err != nil { a.errorResponse(w, r, 500, err.Error()) diff --git a/internal/api/receipt.go b/internal/api/receipt.go index 98036f4..831e44c 100644 --- a/internal/api/receipt.go +++ b/internal/api/receipt.go @@ -39,6 +39,11 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) { } if iapr.DeleteDevice { + if dev.GracePeriodExpiresAt.After(time.Now()) { + w.WriteHeader(http.StatusOK) + return + } + accs, err := a.accountRepo.GetByAPNSToken(ctx, apns) if err != nil { a.errorResponse(w, r, 500, err.Error()) @@ -51,8 +56,8 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) { _ = a.deviceRepo.Delete(ctx, apns) } else { - dev.ActiveUntil = time.Now().Unix() + domain.DeviceActiveAfterReceitCheckDuration - dev.GracePeriodUntil = dev.ActiveUntil + domain.DeviceGracePeriodAfterReceiptExpiry + dev.ExpiresAt = time.Now().Add(domain.DeviceActiveAfterReceitCheckDuration) + dev.GracePeriodExpiresAt = dev.ExpiresAt.Add(domain.DeviceGracePeriodAfterReceiptExpiry) _ = a.deviceRepo.Update(ctx, &dev) } } diff --git a/internal/api/watcher.go b/internal/api/watcher.go index 7748ce2..8882fb4 100644 --- a/internal/api/watcher.go +++ b/internal/api/watcher.go @@ -6,6 +6,7 @@ import ( "net/http" "strconv" "strings" + "time" validation "github.com/go-ozzo/ozzo-validation/v4" "github.com/gorilla/mux" @@ -230,17 +231,17 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) { } type watcherItem struct { - ID int64 `json:"id"` - CreatedAt float64 `json:"created_at"` - Type string `json:"type"` - Label string `json:"label"` - SourceLabel string `json:"source_label"` - Upvotes *int64 `json:"upvotes,omitempty"` - Keyword string `json:"keyword,omitempty"` - Flair string `json:"flair,omitempty"` - Domain string `json:"domain,omitempty"` - Hits int64 `json:"hits"` - Author string `json:"author,omitempty"` + ID int64 `json:"id"` + CreatedAt time.Time `json:"created_at"` + Type string `json:"type"` + Label string `json:"label"` + SourceLabel string `json:"source_label"` + Upvotes *int64 `json:"upvotes,omitempty"` + Keyword string `json:"keyword,omitempty"` + Flair string `json:"flair,omitempty"` + Domain string `json:"domain,omitempty"` + Hits int64 `json:"hits"` + Author string `json:"author,omitempty"` } func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) { diff --git a/internal/cmd/scheduler.go b/internal/cmd/scheduler.go index 97412c4..debede5 100644 --- a/internal/cmd/scheduler.go +++ b/internal/cmd/scheduler.go @@ -9,27 +9,19 @@ import ( "github.com/DataDog/datadog-go/statsd" "github.com/adjust/rmq/v4" - "github.com/christianselig/apollo-backend/internal/cmdutil" - "github.com/christianselig/apollo-backend/internal/repository" "github.com/go-co-op/gocron" "github.com/go-redis/redis/v8" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + + "github.com/christianselig/apollo-backend/internal/cmdutil" + "github.com/christianselig/apollo-backend/internal/domain" + "github.com/christianselig/apollo-backend/internal/repository" ) -const ( - batchSize = 250 - checkTimeout = 60 // how long until we force a check - - accountEnqueueInterval = 5 // how frequently we want to check (seconds) - subredditEnqueueInterval = 2 * 60 // how frequently we want to check (seconds) - userEnqueueInterval = 2 * 60 // how frequently we want to check (seconds) - stuckAccountEnqueueInterval = 2 * 60 // how frequently we want to check (seconds) - - staleAccountThreshold = 7200 // 2 hours -) +const batchSize = 250 func SchedulerCmd(ctx context.Context) *cobra.Command { cmd := &cobra.Command{ @@ -129,16 +121,16 @@ func evalScript(ctx context.Context, redis *redis.Client) (string, error) { end return retv - `, checkTimeout) + `, int64(domain.NotificationCheckTimeout.Seconds())) return redis.ScriptLoad(ctx, lua).Result() } func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) { - before := time.Now().Unix() - staleAccountThreshold + expiry := time.Now().Add(-domain.StaleTokenThreshold) ar := repository.NewPostgresAccount(pool) - stale, err := ar.PruneStale(ctx, before) + stale, err := ar.PruneStale(ctx, expiry) if err != nil { logger.WithFields(logrus.Fields{ "err": err, @@ -158,16 +150,17 @@ func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Poo if count > 0 { logger.WithFields(logrus.Fields{ - "count": count, + "stale": stale, + "orphaned": orphaned, }).Info("pruned accounts") } } func pruneDevices(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) { - threshold := time.Now().Unix() + now := time.Now() dr := repository.NewPostgresDevice(pool) - count, err := dr.PruneStale(ctx, threshold) + count, err := dr.PruneStale(ctx, now) if err != nil { logger.WithFields(logrus.Fields{ "err": err, @@ -227,6 +220,8 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { now := time.Now() + next := now.Add(domain.NotificationCheckInterval) + ids := []int64{} defer func() { @@ -235,21 +230,20 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) }() - ready := now.Unix() - userEnqueueInterval err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { stmt := ` - WITH userb AS ( + WITH batch AS ( SELECT id FROM users - WHERE last_checked_at < $1 - ORDER BY last_checked_at + WHERE next_check_at < $1 + ORDER BY next_check_at LIMIT 100 ) UPDATE users - SET last_checked_at = $2 - WHERE users.id IN(SELECT id FROM userb) + SET next_check_at = $2 + WHERE users.id IN(SELECT id FROM batch) RETURNING users.id` - rows, err := tx.Query(ctx, stmt, ready, now.Unix()) + rows, err := tx.Query(ctx, stmt, now, next) if err != nil { return err } @@ -275,7 +269,7 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli logger.WithFields(logrus.Fields{ "count": len(ids), - "start": ready, + "start": now, }).Debug("enqueueing user batch") batchIds := make([]string, len(ids)) @@ -292,6 +286,8 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) { now := time.Now() + next := now.Add(domain.SubredditCheckInterval) + ids := []int64{} defer func() { @@ -300,21 +296,20 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) }() - ready := now.Unix() - subredditEnqueueInterval err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { stmt := ` - WITH subreddit AS ( + WITH batch AS ( SELECT id FROM subreddits - WHERE last_checked_at < $1 - ORDER BY last_checked_at + WHERE next_check_at < $1 + ORDER BY next_check_at LIMIT 100 ) UPDATE subreddits - SET last_checked_at = $2 - WHERE subreddits.id IN(SELECT id FROM subreddit) + SET next_check_at = $2 + WHERE subreddits.id IN(SELECT id FROM batch) RETURNING subreddits.id` - rows, err := tx.Query(ctx, stmt, ready, now.Unix()) + rows, err := tx.Query(ctx, stmt, now, next) if err != nil { return err } @@ -340,7 +335,7 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats logger.WithFields(logrus.Fields{ "count": len(ids), - "start": ready, + "start": now, }).Debug("enqueueing subreddit batch") batchIds := make([]string, len(ids)) @@ -361,6 +356,8 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { now := time.Now() + next := now.Add(domain.StuckNotificationCheckInterval) + ids := []int64{} defer func() { @@ -369,22 +366,21 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) }() - ready := now.Unix() - stuckAccountEnqueueInterval err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { stmt := ` - WITH account AS ( + WITH batch AS ( SELECT id FROM accounts WHERE - last_unstuck_at < $1 - ORDER BY last_unstuck_at + next_stuck_notification_check_at < $1 + ORDER BY next_stuck_notification_check_at LIMIT 500 ) UPDATE accounts - SET last_unstuck_at = $2 - WHERE accounts.id IN(SELECT id FROM account) + SET next_stuck_notification_check_at = $2 + WHERE accounts.id IN(SELECT id FROM batch) RETURNING accounts.id` - rows, err := tx.Query(ctx, stmt, ready, now.Unix()) + rows, err := tx.Query(ctx, stmt, now, next) if err != nil { return err } @@ -410,7 +406,7 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st logger.WithFields(logrus.Fields{ "count": len(ids), - "start": ready, + "start": now, }).Debug("enqueueing stuck account batch") batchIds := make([]string, len(ids)) @@ -428,6 +424,8 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) { now := time.Now() + next := now.Add(domain.NotificationCheckInterval) + ids := []int64{} enqueued := 0 skipped := 0 @@ -439,29 +437,21 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd. _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) }() - // Start looking for accounts that were last checked at least 5 seconds ago - // and at most 6 seconds ago. Also look for accounts that haven't been checked - // in over a minute. - ts := now.Unix() - ready := ts - accountEnqueueInterval - expired := ts - checkTimeout - err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { stmt := ` WITH account AS ( SELECT id FROM accounts WHERE - last_enqueued_at < $1 - OR last_checked_at < $2 - ORDER BY last_checked_at + next_notification_check_at < $1 + ORDER BY next_notification_check_at LIMIT 2500 ) UPDATE accounts - SET last_enqueued_at = $3 + SET next_notification_check_at = $2 WHERE accounts.id IN(SELECT id FROM account) RETURNING accounts.id` - rows, err := tx.Query(ctx, stmt, ready, expired, ts) + rows, err := tx.Query(ctx, stmt, now, next) if err != nil { return err } @@ -487,7 +477,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd. logger.WithFields(logrus.Fields{ "count": len(ids), - "start": ready, + "start": now, }).Debug("enqueueing account batch") // Split ids in batches for i := 0; i < len(ids); i += batchSize { @@ -532,7 +522,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd. logger.WithFields(logrus.Fields{ "count": enqueued, "skipped": skipped, - "start": ready, + "start": now, }).Debug("done enqueueing account batch") } diff --git a/internal/domain/account.go b/internal/domain/account.go index 7e9cd0c..5892fe3 100644 --- a/internal/domain/account.go +++ b/internal/domain/account.go @@ -3,25 +3,34 @@ package domain import ( "context" "strings" + "time" validation "github.com/go-ozzo/ozzo-validation/v4" ) +const ( + NotificationCheckInterval = 5 * time.Second // time between notification checks + NotificationCheckTimeout = 60 * time.Second // time before we give up an account check lock + StuckNotificationCheckInterval = 2 * time.Minute // time between stuck notification checks + StaleTokenThreshold = 2 * time.Hour // time an oauth token has to be expired for to be stale +) + // Account represents an account we need to periodically check in the notifications worker. type Account struct { ID int64 // Reddit information - Username string - AccountID string - AccessToken string - RefreshToken string - ExpiresAt int64 + Username string + AccountID string + AccessToken string + RefreshToken string + TokenExpiresAt time.Time // Tracking how far behind we are - LastMessageID string - LastCheckedAt float64 - LastUnstuckAt float64 + LastMessageID string + NextNotificationCheckAt time.Time + NextStuckNotificationCheckAt time.Time + CheckCount int64 } func (acct *Account) NormalizedUsername() string { @@ -49,5 +58,5 @@ type AccountRepository interface { Disassociate(ctx context.Context, acc *Account, dev *Device) error PruneOrphaned(ctx context.Context) (int64, error) - PruneStale(ctx context.Context, before int64) (int64, error) + PruneStale(ctx context.Context, expiry time.Time) (int64, error) } diff --git a/internal/domain/device.go b/internal/domain/device.go index 66dd5d6..a23023b 100644 --- a/internal/domain/device.go +++ b/internal/domain/device.go @@ -2,22 +2,23 @@ package domain import ( "context" + "time" validation "github.com/go-ozzo/ozzo-validation/v4" ) const ( - DeviceGracePeriodDuration = 3600 // 1 hour - DeviceActiveAfterReceitCheckDuration = 3600 * 24 * 30 // ~1 month - DeviceGracePeriodAfterReceiptExpiry = 3600 * 24 * 30 // ~1 month + DeviceReceiptCheckPeriodDuration = 1 * time.Hour + DeviceActiveAfterReceitCheckDuration = 30 * 24 * time.Hour // ~1 month + DeviceGracePeriodAfterReceiptExpiry = 30 * 24 * time.Hour // ~1 month ) type Device struct { - ID int64 - APNSToken string - Sandbox bool - ActiveUntil int64 - GracePeriodUntil int64 + ID int64 + APNSToken string + Sandbox bool + ExpiresAt time.Time + GracePeriodExpiresAt time.Time } func (dev *Device) Validate() error { @@ -40,5 +41,5 @@ type DeviceRepository interface { SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher, global bool) error GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, bool, error) - PruneStale(ctx context.Context, before int64) (int64, error) + PruneStale(ctx context.Context, expiry time.Time) (int64, error) } diff --git a/internal/domain/subreddit.go b/internal/domain/subreddit.go index efb2b2e..524a044 100644 --- a/internal/domain/subreddit.go +++ b/internal/domain/subreddit.go @@ -3,13 +3,16 @@ package domain import ( "context" "strings" + "time" validation "github.com/go-ozzo/ozzo-validation/v4" ) +const SubredditCheckInterval = 2 * time.Minute + type Subreddit struct { - ID int64 - LastCheckedAt float64 + ID int64 + NextCheckAt time.Time // Reddit information SubredditID string diff --git a/internal/domain/user.go b/internal/domain/user.go index d11aa1d..0d1d9c3 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -3,13 +3,16 @@ package domain import ( "context" "strings" + "time" validation "github.com/go-ozzo/ozzo-validation/v4" ) +const UserRefreshInterval = 2 * time.Minute + type User struct { - ID int64 - LastCheckedAt float64 + ID int64 + NextCheckAt time.Time // Reddit information UserID string diff --git a/internal/domain/watcher.go b/internal/domain/watcher.go index a7a7713..e985d17 100644 --- a/internal/domain/watcher.go +++ b/internal/domain/watcher.go @@ -2,6 +2,7 @@ package domain import ( "context" + "time" validation "github.com/go-ozzo/ozzo-validation/v4" ) @@ -29,8 +30,8 @@ func (wt WatcherType) String() string { type Watcher struct { ID int64 - CreatedAt float64 - LastNotifiedAt float64 + CreatedAt time.Time + LastNotifiedAt time.Time Label string DeviceID int64 diff --git a/internal/reddit/types.go b/internal/reddit/types.go index 4a170ee..5efe4d1 100644 --- a/internal/reddit/types.go +++ b/internal/reddit/types.go @@ -3,6 +3,7 @@ package reddit import ( "fmt" "strings" + "time" "github.com/valyala/fastjson" ) @@ -61,24 +62,24 @@ func NewMeResponse(val *fastjson.Value) interface{} { } type Thing struct { - Kind string `json:"kind"` - ID string `json:"id"` - Type string `json:"type"` - Author string `json:"author"` - Subject string `json:"subject"` - Body string `json:"body"` - CreatedAt float64 `json:"created_utc"` - Context string `json:"context"` - ParentID string `json:"parent_id"` - LinkTitle string `json:"link_title"` - Destination string `json:"dest"` - Subreddit string `json:"subreddit"` - SubredditType string `json:"subreddit_type"` - Score int64 `json:"score"` - SelfText string `json:"selftext"` - Title string `json:"title"` - URL string `json:"url"` - Flair string `json:"flair"` + Kind string `json:"kind"` + ID string `json:"id"` + Type string `json:"type"` + Author string `json:"author"` + Subject string `json:"subject"` + Body string `json:"body"` + CreatedAt time.Time `json:"created_utc"` + Context string `json:"context"` + ParentID string `json:"parent_id"` + LinkTitle string `json:"link_title"` + Destination string `json:"dest"` + Subreddit string `json:"subreddit"` + SubredditType string `json:"subreddit_type"` + Score int64 `json:"score"` + SelfText string `json:"selftext"` + Title string `json:"title"` + URL string `json:"url"` + Flair string `json:"flair"` } func (t *Thing) FullName() string { @@ -95,13 +96,14 @@ func NewThing(val *fastjson.Value) *Thing { t.Kind = string(val.GetStringBytes("kind")) data := val.Get("data") + unix := int64(data.GetFloat64("created_utc")) t.ID = string(data.GetStringBytes("id")) t.Type = string(data.GetStringBytes("type")) t.Author = string(data.GetStringBytes("author")) t.Subject = string(data.GetStringBytes("subject")) t.Body = string(data.GetStringBytes("body")) - t.CreatedAt = data.GetFloat64("created_utc") + t.CreatedAt = time.Unix(unix, 0) t.Context = string(data.GetStringBytes("context")) t.ParentID = string(data.GetStringBytes("parent_id")) t.LinkTitle = string(data.GetStringBytes("link_title")) diff --git a/internal/reddit/types_test.go b/internal/reddit/types_test.go index 0c9a566..ea7d0fd 100644 --- a/internal/reddit/types_test.go +++ b/internal/reddit/types_test.go @@ -3,6 +3,7 @@ package reddit import ( "io/ioutil" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/valyala/fastjson" @@ -60,13 +61,14 @@ func TestListingResponseParsing(t *testing.T) { assert.Equal(t, "", l.Before) thing := l.Children[0] + created := time.Time(time.Date(2021, time.July, 14, 13, 56, 35, 0, time.Local)) assert.Equal(t, "t4", thing.Kind) assert.Equal(t, "138z6ke", thing.ID) assert.Equal(t, "unknown", thing.Type) assert.Equal(t, "iamthatis", thing.Author) assert.Equal(t, "how goes it", thing.Subject) assert.Equal(t, "how are you today", thing.Body) - assert.Equal(t, 1626285395.0, thing.CreatedAt) + assert.Equal(t, created, thing.CreatedAt) assert.Equal(t, "hugocat", thing.Destination) assert.Equal(t, "t4_138z6ke", thing.FullName()) diff --git a/internal/repository/postgres_account.go b/internal/repository/postgres_account.go index 798a204..5814164 100644 --- a/internal/repository/postgres_account.go +++ b/internal/repository/postgres_account.go @@ -3,6 +3,7 @@ package repository import ( "context" "fmt" + "time" "github.com/jackc/pgx/v4/pgxpool" @@ -33,10 +34,11 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg &acc.AccountID, &acc.AccessToken, &acc.RefreshToken, - &acc.ExpiresAt, + &acc.TokenExpiresAt, &acc.LastMessageID, - &acc.LastCheckedAt, - &acc.LastUnstuckAt, + &acc.NextNotificationCheckAt, + &acc.NextStuckNotificationCheckAt, + &acc.CheckCount, ); err != nil { return nil, err } @@ -47,7 +49,9 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (domain.Account, error) { query := ` - SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at + SELECT id, username, account_id, access_token, refresh_token, token_expires_at, + last_message_id, next_notification_check_at, next_stuck_notification_check_at, + check_count FROM accounts WHERE id = $1` @@ -64,7 +68,9 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) { query := ` - SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at + SELECT id, username, account_id, access_token, refresh_token, token_expires_at, + last_message_id, next_notification_check_at, next_stuck_notification_check_at, + check_count FROM accounts WHERE account_id = $1` @@ -81,12 +87,13 @@ func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string } func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error { query := ` - INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at) - VALUES ($1, $2, $3, $4, $5, '', 0, 0) + INSERT INTO accounts (username, account_id, access_token, refresh_token, token_expires_at, + last_message_id, next_notification_check_at, next_stuck_notification_check_at) + VALUES ($1, $2, $3, $4, $5, '', NOW(), NOW()) ON CONFLICT(username) DO UPDATE SET access_token = $3, refresh_token = $4, - expires_at = $5 + token_expires_at = $5 RETURNING id` return p.pool.QueryRow( @@ -96,14 +103,15 @@ func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *dom acc.AccountID, acc.AccessToken, acc.RefreshToken, - acc.ExpiresAt, + acc.TokenExpiresAt, ).Scan(&acc.ID) } func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error { query := ` INSERT INTO accounts - (username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at) + (username, account_id, access_token, refresh_token, token_expires_at, + last_message_id, next_notification_check_at, next_stuck_notification_check_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id` @@ -114,10 +122,10 @@ func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Acco acc.AccountID, acc.AccessToken, acc.RefreshToken, - acc.ExpiresAt, + acc.TokenExpiresAt, acc.LastMessageID, - acc.LastCheckedAt, - acc.LastUnstuckAt, + acc.NextNotificationCheckAt, + acc.NextStuckNotificationCheckAt, ).Scan(&acc.ID) } @@ -128,10 +136,11 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco account_id = $3, access_token = $4, refresh_token = $5, - expires_at = $6, + token_expires_at = $6, last_message_id = $7, - last_checked_at = $8, - last_unstuck_at = $9 + next_notification_check_at = $8, + next_stuck_notification_check_at = $9, + check_count = $10 WHERE id = $1` res, err := p.pool.Exec( @@ -142,10 +151,11 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco acc.AccountID, acc.AccessToken, acc.RefreshToken, - acc.ExpiresAt, + acc.TokenExpiresAt, acc.LastMessageID, - acc.LastCheckedAt, - acc.LastUnstuckAt, + acc.NextNotificationCheckAt, + acc.NextStuckNotificationCheckAt, + acc.CheckCount, ) if res.RowsAffected() != 1 { @@ -186,7 +196,9 @@ func (p *postgresAccountRepository) Disassociate(ctx context.Context, acc *domai func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token string) ([]domain.Account, error) { query := ` - SELECT accounts.id, username, accounts.account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at + SELECT accounts.id, username, accounts.account_id, access_token, refresh_token, token_expires_at, + last_message_id, next_notification_check_at, next_stuck_notification_check_at, + check_count FROM accounts INNER JOIN devices_accounts ON accounts.id = devices_accounts.account_id INNER JOIN devices ON devices.id = devices_accounts.device_id @@ -195,12 +207,12 @@ func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token st return p.fetch(ctx, query, token) } -func (p *postgresAccountRepository) PruneStale(ctx context.Context, before int64) (int64, error) { +func (p *postgresAccountRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) { query := ` DELETE FROM accounts - WHERE expires_at < $1` + WHERE token_expires_at < $1` - res, err := p.pool.Exec(ctx, query, before) + res, err := p.pool.Exec(ctx, query, expiry) return res.RowsAffected(), err } diff --git a/internal/repository/postgres_device.go b/internal/repository/postgres_device.go index c4bf709..64310b2 100644 --- a/internal/repository/postgres_device.go +++ b/internal/repository/postgres_device.go @@ -3,6 +3,7 @@ package repository import ( "context" "fmt" + "time" "github.com/jackc/pgx/v4/pgxpool" @@ -31,8 +32,8 @@ func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args &dev.ID, &dev.APNSToken, &dev.Sandbox, - &dev.ActiveUntil, - &dev.GracePeriodUntil, + &dev.ExpiresAt, + &dev.GracePeriodExpiresAt, ); err != nil { return nil, err } @@ -43,7 +44,7 @@ func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domain.Device, error) { query := ` - SELECT id, apns_token, sandbox, active_until, grace_period_until + SELECT id, apns_token, sandbox, expires_at, grace_period_expires_at FROM devices WHERE id = $1` @@ -60,7 +61,7 @@ func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domai func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token string) (domain.Device, error) { query := ` - SELECT id, apns_token, sandbox, active_until, grace_period_until + SELECT id, apns_token, sandbox, expires_at, grace_period_expires_at FROM devices WHERE apns_token = $1` @@ -77,7 +78,7 @@ func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token str func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { query := ` - SELECT devices.id, apns_token, sandbox, active_until, grace_period_until + SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at FROM devices INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id WHERE devices_accounts.account_id = $1` @@ -87,7 +88,7 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { query := ` - SELECT devices.id, apns_token, sandbox, active_until, grace_period_until + SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at FROM devices INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id WHERE devices_accounts.account_id = $1 AND @@ -99,7 +100,7 @@ func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Con func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { query := ` - SELECT devices.id, apns_token, sandbox, active_until, grace_period_until + SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at FROM devices INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id WHERE devices_accounts.account_id = $1 AND @@ -111,10 +112,10 @@ func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.C func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error { query := ` - INSERT INTO devices (apns_token, sandbox, active_until, grace_period_until) + INSERT INTO devices (apns_token, sandbox, expires_at, grace_period_expires_at) VALUES ($1, $2, $3, $4) ON CONFLICT(apns_token) DO - UPDATE SET active_until = $3, grace_period_until = $4 + UPDATE SET expires_at = $3, grace_period_expires_at = $4 RETURNING id` return p.pool.QueryRow( @@ -122,8 +123,8 @@ func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *doma query, dev.APNSToken, dev.Sandbox, - dev.ActiveUntil, - dev.GracePeriodUntil, + &dev.ExpiresAt, + &dev.GracePeriodExpiresAt, ).Scan(&dev.ID) } @@ -134,7 +135,7 @@ func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Devic query := ` INSERT INTO devices - (apns_token, sandbox, active_until, grace_period_until) + (apns_token, sandbox, expires_at, grace_period_expires_at) VALUES ($1, $2, $3, $4) RETURNING id` @@ -143,8 +144,8 @@ func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Devic query, dev.APNSToken, dev.Sandbox, - dev.ActiveUntil, - dev.GracePeriodUntil, + dev.ExpiresAt, + dev.GracePeriodExpiresAt, ).Scan(&dev.ID) } @@ -155,10 +156,10 @@ func (p *postgresDeviceRepository) Update(ctx context.Context, dev *domain.Devic query := ` UPDATE devices - SET active_until = $2, grace_period_until = $3 + SET expires_at = $2, grace_period_expires_at = $3 WHERE id = $1` - res, err := p.pool.Exec(ctx, query, dev.ID, dev.ActiveUntil, dev.GracePeriodUntil) + res, err := p.pool.Exec(ctx, query, dev.ID, dev.ExpiresAt, dev.GracePeriodExpiresAt) if res.RowsAffected() != 1 { return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) @@ -209,10 +210,10 @@ func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domai return inbox, watcher, global, nil } -func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) { - query := `DELETE FROM devices WHERE grace_period_until < $1` +func (p *postgresDeviceRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) { + query := `DELETE FROM devices WHERE grace_period_expires_at < $1` - res, err := p.pool.Exec(ctx, query, before) + res, err := p.pool.Exec(ctx, query, expiry) return res.RowsAffected(), err } diff --git a/internal/repository/postgres_subreddit.go b/internal/repository/postgres_subreddit.go index 663f92e..5336593 100644 --- a/internal/repository/postgres_subreddit.go +++ b/internal/repository/postgres_subreddit.go @@ -30,7 +30,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a &sr.ID, &sr.SubredditID, &sr.Name, - &sr.LastCheckedAt, + &sr.NextCheckAt, ); err != nil { return nil, err } @@ -41,7 +41,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) { query := ` - SELECT id, subreddit_id, name, last_checked_at + SELECT id, subreddit_id, name, next_check_at FROM subreddits WHERE id = $1` @@ -58,7 +58,7 @@ func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (do func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) { query := ` - SELECT id, subreddit_id, name, last_checked_at + SELECT id, subreddit_id, name, next_check_at FROM subreddits WHERE name = $1` @@ -81,8 +81,8 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do } query := ` - INSERT INTO subreddits (subreddit_id, name) - VALUES ($1, $2) + INSERT INTO subreddits (subreddit_id, name, next_check_at) + VALUES ($1, $2, NOW()) ON CONFLICT(subreddit_id) DO NOTHING RETURNING id` diff --git a/internal/repository/postgres_user.go b/internal/repository/postgres_user.go index fdf9eab..0636352 100644 --- a/internal/repository/postgres_user.go +++ b/internal/repository/postgres_user.go @@ -32,7 +32,7 @@ func (p *postgresUserRepository) fetch(ctx context.Context, query string, args . &u.ID, &u.UserID, &u.Name, - &u.LastCheckedAt, + &u.NextCheckAt, ); err != nil { return nil, err } @@ -43,7 +43,7 @@ func (p *postgresUserRepository) fetch(ctx context.Context, query string, args . func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.User, error) { query := ` - SELECT id, user_id, name, last_checked_at + SELECT id, user_id, name, next_check_at FROM users WHERE id = $1` @@ -60,7 +60,7 @@ func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain. func (p *postgresUserRepository) GetByName(ctx context.Context, name string) (domain.User, error) { query := ` - SELECT id, user_id, name, last_checked_at + SELECT id, user_id, name, next_check_at FROM users WHERE name = $1` @@ -83,10 +83,9 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U } query := ` - INSERT INTO users (user_id, name) - VALUES ($1, $2) - ON CONFLICT(user_id) DO - UPDATE SET last_checked_at = $3 + INSERT INTO users (user_id, name, next_check_at) + VALUES ($1, $2, NOW()) + ON CONFLICT(user_id) DO NOTHING RETURNING id` return p.pool.QueryRow( @@ -94,7 +93,7 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U query, u.UserID, u.NormalizedName(), - u.LastCheckedAt, + u.NextCheckAt, ).Scan(&u.ID) } diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index d03d107..62e2a33 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -161,7 +161,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { defer func() { _ = delivery.Ack() }() - now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000 + now := time.Now() account, err := nc.accountRepo.GetByID(ctx, id) if err != nil { @@ -171,20 +171,22 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { return } - previousLastCheckedAt := account.LastCheckedAt - newAccount := (previousLastCheckedAt == 0) - account.LastCheckedAt = now + newAccount := account.CheckCount == 0 + previousNextCheck := account.NextNotificationCheckAt + + account.CheckCount++ + account.NextNotificationCheckAt = time.Now().Add(domain.NotificationCheckInterval) if err = nc.accountRepo.Update(ctx, &account); err != nil { nc.logger.WithFields(logrus.Fields{ "account#username": account.NormalizedUsername(), "err": err, - }).Error("failed to update last_checked_at for account") + }).Error("failed to update next_notification_check_at for account") return } rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) - if account.ExpiresAt < int64(now) { + if account.TokenExpiresAt.Before(now) { nc.logger.WithFields(logrus.Fields{ "account#username": account.NormalizedUsername(), }).Debug("refreshing reddit token") @@ -213,7 +215,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { // Update account account.AccessToken = tokens.AccessToken account.RefreshToken = tokens.RefreshToken - account.ExpiresAt = int64(now + 3540) + account.TokenExpiresAt = now.Add(3600 * time.Second) // Refresh client rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken) @@ -230,8 +232,8 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { // Only update delay on accounts we can actually check, otherwise it skews // the numbers too much. if !newAccount { - latency := now - previousLastCheckedAt - float64(backoff) - _ = nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate) + latency := now.Sub(previousNextCheck) - backoff*time.Second + _ = nc.statsd.Histogram("apollo.queue.delay", float64(latency.Milliseconds()), []string{}, rate) } nc.logger.WithFields(logrus.Fields{ diff --git a/internal/worker/subreddits.go b/internal/worker/subreddits.go index d5747ff..60ff33f 100644 --- a/internal/worker/subreddits.go +++ b/internal/worker/subreddits.go @@ -177,7 +177,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { return } - threshold := float64(time.Now().AddDate(0, 0, -1).UTC().Unix()) + threshold := time.Now().Add(-24 * time.Hour) posts := []*reddit.Thing{} before := "" finished := false @@ -234,7 +234,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { } for _, post := range sps.Children { - if post.CreatedAt < threshold { + if post.CreatedAt.Before(threshold) { finished = true break } @@ -284,7 +284,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { }).Debug("loaded hot posts") for _, post := range sps.Children { - if post.CreatedAt < threshold { + if post.CreatedAt.Before(threshold) { break } if _, ok := seenPosts[post.ID]; !ok { @@ -310,7 +310,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { for _, watcher := range watchers { // Make sure we only alert on posts created after the search - if watcher.CreatedAt > post.CreatedAt { + if watcher.CreatedAt.After(post.CreatedAt) { continue } diff --git a/internal/worker/trending.go b/internal/worker/trending.go index b27a3bd..f8b5d4e 100644 --- a/internal/worker/trending.go +++ b/internal/worker/trending.go @@ -235,14 +235,14 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { }).Debug("loaded hot posts") // Trending only counts for posts less than 2 days old - threshold := float64(time.Now().AddDate(0, 0, -2).UTC().Unix()) + threshold := time.Now().Add(-24 * time.Hour * 2) for _, post := range hps.Children { if post.Score < medianScore { continue } - if post.CreatedAt < threshold { + if post.CreatedAt.Before(threshold) { break } @@ -251,7 +251,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) { notification.Payload = payloadFromTrendingPost(post) for _, watcher := range watchers { - if watcher.CreatedAt > post.CreatedAt { + if watcher.CreatedAt.After(post.CreatedAt) { continue } diff --git a/internal/worker/users.go b/internal/worker/users.go index ce1ffc8..567bf40 100644 --- a/internal/worker/users.go +++ b/internal/worker/users.go @@ -234,11 +234,11 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { for _, watcher := range watchers { // Make sure we only alert on activities created after the search - if watcher.CreatedAt > post.CreatedAt { + if watcher.CreatedAt.After(post.CreatedAt) { continue } - if watcher.LastNotifiedAt > post.CreatedAt { + if watcher.LastNotifiedAt.After(post.CreatedAt) { continue } From 69675d4d5c947f1700c26d7a5ad4e01a21202962 Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Mon, 28 Mar 2022 17:27:07 -0400 Subject: [PATCH 2/4] add schema --- docs/schema.sql | 64 +++++++++++++++++++++++++ internal/repository/postgres_account.go | 14 +++--- internal/repository/postgres_watcher.go | 8 ++-- 3 files changed, 75 insertions(+), 11 deletions(-) create mode 100644 docs/schema.sql diff --git a/docs/schema.sql b/docs/schema.sql new file mode 100644 index 0000000..dffdf11 --- /dev/null +++ b/docs/schema.sql @@ -0,0 +1,64 @@ +CREATE TABLE accounts ( + id SERIAL PRIMARY KEY, + reddit_account_id character varying(32) DEFAULT ''::character varying, + username character varying(20) DEFAULT ''::character varying UNIQUE, + access_token character varying(64) DEFAULT ''::character varying, + refresh_token character varying(64) DEFAULT ''::character varying, + token_expires_at timestamp without time zone, + last_message_id character varying(32) DEFAULT ''::character varying, + next_notification_check_at timestamp without time zone, + next_stuck_notification_check_at timestamp without time zone, + check_count integer DEFAULT 0 +); + +CREATE TABLE devices ( + id SERIAL PRIMARY KEY, + apns_token character varying(100) UNIQUE, + sandbox boolean, + expires_at timestamp without time zone, + grace_period_expires_at timestamp without time zone +); + +CREATE TABLE devices_accounts ( + id SERIAL PRIMARY KEY, + account_id integer REFERENCES accounts(id) ON DELETE CASCADE, + device_id integer REFERENCES devices(id) ON DELETE CASCADE, + watcher_notifiable boolean DEFAULT true, + inbox_notifiable boolean DEFAULT true, + global_mute boolean DEFAULT false +); + +CREATE UNIQUE INDEX devices_accounts_account_id_device_id_idx ON devices_accounts(account_id int4_ops,device_id int4_ops); + +CREATE TABLE subreddits ( + id SERIAL PRIMARY KEY, + subreddit_id character varying(32) DEFAULT ''::character varying UNIQUE, + name character varying(32) DEFAULT ''::character varying, + next_check_at timestamp without time zone +); + +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + user_id character varying(32) DEFAULT ''::character varying UNIQUE, + name character varying(32) DEFAULT ''::character varying, + next_check_at timestamp without time zone +); + +CREATE TABLE watchers ( + id SERIAL PRIMARY KEY, + created_at timestamp without time zone, + last_notified_at timestamp without time zone, + device_id integer REFERENCES devices(id) ON DELETE CASCADE, + account_id integer REFERENCES accounts(id) ON DELETE CASCADE, + watchee_id integer, + upvotes integer DEFAULT 0, + keyword character varying(32) DEFAULT ''::character varying, + flair character varying(32) DEFAULT ''::character varying, + domain character varying(32) DEFAULT ''::character varying, + hits integer DEFAULT 0, + type integer DEFAULT 0, + label character varying(64) DEFAULT ''::character varying, + author character varying(32) DEFAULT ''::character varying, + subreddit character varying(32) DEFAULT ''::character varying +); + diff --git a/internal/repository/postgres_account.go b/internal/repository/postgres_account.go index 5814164..83115b9 100644 --- a/internal/repository/postgres_account.go +++ b/internal/repository/postgres_account.go @@ -49,7 +49,7 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (domain.Account, error) { query := ` - SELECT id, username, account_id, access_token, refresh_token, token_expires_at, + SELECT id, username, reddit_account_id, access_token, refresh_token, token_expires_at, last_message_id, next_notification_check_at, next_stuck_notification_check_at, check_count FROM accounts @@ -68,11 +68,11 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) { query := ` - SELECT id, username, account_id, access_token, refresh_token, token_expires_at, + SELECT id, username, reddit_account_id, access_token, refresh_token, token_expires_at, last_message_id, next_notification_check_at, next_stuck_notification_check_at, check_count FROM accounts - WHERE account_id = $1` + WHERE reddit_account_id = $1` accs, err := p.fetch(ctx, query, id) if err != nil { @@ -87,7 +87,7 @@ func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string } func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error { query := ` - INSERT INTO accounts (username, account_id, access_token, refresh_token, token_expires_at, + INSERT INTO accounts (username, reddit_account_id, access_token, refresh_token, token_expires_at, last_message_id, next_notification_check_at, next_stuck_notification_check_at) VALUES ($1, $2, $3, $4, $5, '', NOW(), NOW()) ON CONFLICT(username) DO @@ -110,7 +110,7 @@ func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *dom func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error { query := ` INSERT INTO accounts - (username, account_id, access_token, refresh_token, token_expires_at, + (username, reddit_account_id, access_token, refresh_token, token_expires_at, last_message_id, next_notification_check_at, next_stuck_notification_check_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id` @@ -133,7 +133,7 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco query := ` UPDATE accounts SET username = $2, - account_id = $3, + reddit_account_id = $3, access_token = $4, refresh_token = $5, token_expires_at = $6, @@ -196,7 +196,7 @@ func (p *postgresAccountRepository) Disassociate(ctx context.Context, acc *domai func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token string) ([]domain.Account, error) { query := ` - SELECT accounts.id, username, accounts.account_id, access_token, refresh_token, token_expires_at, + SELECT accounts.id, username, accounts.reddit_account_id, access_token, refresh_token, token_expires_at, last_message_id, next_notification_check_at, next_stuck_notification_check_at, check_count FROM accounts diff --git a/internal/repository/postgres_watcher.go b/internal/repository/postgres_watcher.go index a9a6f79..25a96f6 100644 --- a/internal/repository/postgres_watcher.go +++ b/internal/repository/postgres_watcher.go @@ -93,7 +93,7 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma devices.apns_token, devices.sandbox, accounts.id, - accounts.account_id, + accounts.reddit_account_id, accounts.access_token, accounts.refresh_token, COALESCE(subreddits.name, '') AS subreddit_label, @@ -138,7 +138,7 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t devices.apns_token, devices.sandbox, accounts.id, - accounts.account_id, + accounts.reddit_account_id, accounts.access_token, accounts.refresh_token, COALESCE(subreddits.name, '') AS subreddit_label, @@ -191,7 +191,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c devices.apns_token, devices.sandbox, accounts.id, - accounts.account_id, + accounts.reddit_account_id, accounts.access_token, accounts.refresh_token, COALESCE(subreddits.name, '') AS subreddit_label, @@ -203,7 +203,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id WHERE devices.apns_token = $1 AND - accounts.account_id = $2` + accounts.reddit_account_id = $2` return p.fetch(ctx, query, apns, rid) } From 7c7e1e5e1c9068d55bc3b6d6fb794bffb47d488d Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Mon, 28 Mar 2022 17:33:01 -0400 Subject: [PATCH 3/4] More proper types --- internal/api/accounts.go | 4 ++-- internal/reddit/types.go | 6 ++++-- internal/reddit/types_test.go | 1 + internal/worker/notifications.go | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/internal/api/accounts.go b/internal/api/accounts.go index 3c381c3..95237cf 100644 --- a/internal/api/accounts.go +++ b/internal/api/accounts.go @@ -150,7 +150,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { } // Reset expiration timer - acc.TokenExpiresAt = time.Now().Add(1 * time.Hour) + acc.TokenExpiresAt = time.Now().Add(tokens.Expiry) acc.RefreshToken = tokens.RefreshToken acc.AccessToken = tokens.AccessToken @@ -215,7 +215,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { } // Reset expiration timer - acct.TokenExpiresAt = time.Now().Add(1 * time.Hour) + acct.TokenExpiresAt = time.Now().Add(tokens.Expiry) acct.RefreshToken = tokens.RefreshToken acct.AccessToken = tokens.AccessToken diff --git a/internal/reddit/types.go b/internal/reddit/types.go index 5efe4d1..f21b182 100644 --- a/internal/reddit/types.go +++ b/internal/reddit/types.go @@ -30,8 +30,9 @@ func NewError(val *fastjson.Value, status int) *Error { } type RefreshTokenResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + Expiry time.Duration `json:"expires_in"` } func NewRefreshTokenResponse(val *fastjson.Value) interface{} { @@ -39,6 +40,7 @@ func NewRefreshTokenResponse(val *fastjson.Value) interface{} { rtr.AccessToken = string(val.GetStringBytes("access_token")) rtr.RefreshToken = string(val.GetStringBytes("refresh_token")) + rtr.Expiry = time.Duration(val.GetInt("expires_in")) * time.Second return rtr } diff --git a/internal/reddit/types_test.go b/internal/reddit/types_test.go index ea7d0fd..e7b2ead 100644 --- a/internal/reddit/types_test.go +++ b/internal/reddit/types_test.go @@ -41,6 +41,7 @@ func TestRefreshTokenResponseParsing(t *testing.T) { assert.Equal(t, "***REMOVED***", rtr.AccessToken) assert.Equal(t, "***REMOVED***", rtr.RefreshToken) + assert.Equal(t, 1*time.Hour, rtr.Expiry) } func TestListingResponseParsing(t *testing.T) { diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index 62e2a33..e1a4838 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -215,7 +215,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { // Update account account.AccessToken = tokens.AccessToken account.RefreshToken = tokens.RefreshToken - account.TokenExpiresAt = now.Add(3600 * time.Second) + account.TokenExpiresAt = now.Add(tokens.Expiry) // Refresh client rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken) From 4fad10bade05e5cec38aa1b81096b1dbdae780d9 Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Sat, 7 May 2022 11:51:56 -0400 Subject: [PATCH 4/4] fix tests --- internal/reddit/types.go | 4 ++-- internal/reddit/types_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/reddit/types.go b/internal/reddit/types.go index 4714f72..1629975 100644 --- a/internal/reddit/types.go +++ b/internal/reddit/types.go @@ -82,7 +82,7 @@ type Thing struct { Title string `json:"title"` URL string `json:"url"` Flair string `json:"flair"` - Thumbnail string `json:"thumbnail"` + Thumbnail string `json:"thumbnail"` } func (t *Thing) FullName() string { @@ -106,7 +106,7 @@ func NewThing(val *fastjson.Value) *Thing { t.Author = string(data.GetStringBytes("author")) t.Subject = string(data.GetStringBytes("subject")) t.Body = string(data.GetStringBytes("body")) - t.CreatedAt = time.Unix(unix, 0) + t.CreatedAt = time.Unix(unix, 0).UTC() t.Context = string(data.GetStringBytes("context")) t.ParentID = string(data.GetStringBytes("parent_id")) t.LinkTitle = string(data.GetStringBytes("link_title")) diff --git a/internal/reddit/types_test.go b/internal/reddit/types_test.go index e7b2ead..f44e76f 100644 --- a/internal/reddit/types_test.go +++ b/internal/reddit/types_test.go @@ -62,7 +62,7 @@ func TestListingResponseParsing(t *testing.T) { assert.Equal(t, "", l.Before) thing := l.Children[0] - created := time.Time(time.Date(2021, time.July, 14, 13, 56, 35, 0, time.Local)) + created := time.Time(time.Date(2021, time.July, 14, 17, 56, 35, 0, time.UTC)) assert.Equal(t, "t4", thing.Kind) assert.Equal(t, "138z6ke", thing.ID) assert.Equal(t, "unknown", thing.Type)