Changes to schema

This commit is contained in:
Andre Medeiros 2022-03-28 17:05:01 -04:00
parent 59e435eb2d
commit dbcda74ab8
20 changed files with 226 additions and 192 deletions

View file

@ -150,7 +150,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
} }
// Reset expiration timer // Reset expiration timer
acc.ExpiresAt = time.Now().Unix() + 3540 acc.TokenExpiresAt = time.Now().Add(1 * time.Hour)
acc.RefreshToken = tokens.RefreshToken acc.RefreshToken = tokens.RefreshToken
acc.AccessToken = tokens.AccessToken acc.AccessToken = tokens.AccessToken
@ -175,7 +175,10 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
_ = a.accountRepo.Associate(ctx, &acc, &dev) if err := a.accountRepo.Associate(ctx, &acc, &dev); err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
} }
for _, acc := range accsMap { for _, acc := range accsMap {
@ -212,7 +215,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
} }
// Reset expiration timer // Reset expiration timer
acct.ExpiresAt = time.Now().Unix() + 3540 acct.TokenExpiresAt = time.Now().Add(1 * time.Hour)
acct.RefreshToken = tokens.RefreshToken acct.RefreshToken = tokens.RefreshToken
acct.AccessToken = tokens.AccessToken acct.AccessToken = tokens.AccessToken

View file

@ -28,8 +28,8 @@ func (a *api) upsertDeviceHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
d.ActiveUntil = time.Now().Unix() + domain.DeviceGracePeriodDuration d.ExpiresAt = time.Now().Add(domain.DeviceReceiptCheckPeriodDuration)
d.GracePeriodUntil = d.ActiveUntil + domain.DeviceGracePeriodAfterReceiptExpiry d.GracePeriodExpiresAt = d.ExpiresAt.Add(domain.DeviceGracePeriodAfterReceiptExpiry)
if err := a.deviceRepo.CreateOrUpdate(ctx, d); err != nil { if err := a.deviceRepo.CreateOrUpdate(ctx, d); err != nil {
a.errorResponse(w, r, 500, err.Error()) a.errorResponse(w, r, 500, err.Error())

View file

@ -39,6 +39,11 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) {
} }
if iapr.DeleteDevice { if iapr.DeleteDevice {
if dev.GracePeriodExpiresAt.After(time.Now()) {
w.WriteHeader(http.StatusOK)
return
}
accs, err := a.accountRepo.GetByAPNSToken(ctx, apns) accs, err := a.accountRepo.GetByAPNSToken(ctx, apns)
if err != nil { if err != nil {
a.errorResponse(w, r, 500, err.Error()) a.errorResponse(w, r, 500, err.Error())
@ -51,8 +56,8 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) {
_ = a.deviceRepo.Delete(ctx, apns) _ = a.deviceRepo.Delete(ctx, apns)
} else { } else {
dev.ActiveUntil = time.Now().Unix() + domain.DeviceActiveAfterReceitCheckDuration dev.ExpiresAt = time.Now().Add(domain.DeviceActiveAfterReceitCheckDuration)
dev.GracePeriodUntil = dev.ActiveUntil + domain.DeviceGracePeriodAfterReceiptExpiry dev.GracePeriodExpiresAt = dev.ExpiresAt.Add(domain.DeviceGracePeriodAfterReceiptExpiry)
_ = a.deviceRepo.Update(ctx, &dev) _ = a.deviceRepo.Update(ctx, &dev)
} }
} }

View file

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4" validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -230,17 +231,17 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
} }
type watcherItem struct { type watcherItem struct {
ID int64 `json:"id"` ID int64 `json:"id"`
CreatedAt float64 `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Type string `json:"type"` Type string `json:"type"`
Label string `json:"label"` Label string `json:"label"`
SourceLabel string `json:"source_label"` SourceLabel string `json:"source_label"`
Upvotes *int64 `json:"upvotes,omitempty"` Upvotes *int64 `json:"upvotes,omitempty"`
Keyword string `json:"keyword,omitempty"` Keyword string `json:"keyword,omitempty"`
Flair string `json:"flair,omitempty"` Flair string `json:"flair,omitempty"`
Domain string `json:"domain,omitempty"` Domain string `json:"domain,omitempty"`
Hits int64 `json:"hits"` Hits int64 `json:"hits"`
Author string `json:"author,omitempty"` Author string `json:"author,omitempty"`
} }
func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) { func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {

View file

@ -9,27 +9,19 @@ import (
"github.com/DataDog/datadog-go/statsd" "github.com/DataDog/datadog-go/statsd"
"github.com/adjust/rmq/v4" "github.com/adjust/rmq/v4"
"github.com/christianselig/apollo-backend/internal/cmdutil"
"github.com/christianselig/apollo-backend/internal/repository"
"github.com/go-co-op/gocron" "github.com/go-co-op/gocron"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/christianselig/apollo-backend/internal/cmdutil"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/repository"
) )
const ( const batchSize = 250
batchSize = 250
checkTimeout = 60 // how long until we force a check
accountEnqueueInterval = 5 // how frequently we want to check (seconds)
subredditEnqueueInterval = 2 * 60 // how frequently we want to check (seconds)
userEnqueueInterval = 2 * 60 // how frequently we want to check (seconds)
stuckAccountEnqueueInterval = 2 * 60 // how frequently we want to check (seconds)
staleAccountThreshold = 7200 // 2 hours
)
func SchedulerCmd(ctx context.Context) *cobra.Command { func SchedulerCmd(ctx context.Context) *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
@ -129,16 +121,16 @@ func evalScript(ctx context.Context, redis *redis.Client) (string, error) {
end end
return retv return retv
`, checkTimeout) `, int64(domain.NotificationCheckTimeout.Seconds()))
return redis.ScriptLoad(ctx, lua).Result() return redis.ScriptLoad(ctx, lua).Result()
} }
func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) { func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) {
before := time.Now().Unix() - staleAccountThreshold expiry := time.Now().Add(-domain.StaleTokenThreshold)
ar := repository.NewPostgresAccount(pool) ar := repository.NewPostgresAccount(pool)
stale, err := ar.PruneStale(ctx, before) stale, err := ar.PruneStale(ctx, expiry)
if err != nil { if err != nil {
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"err": err, "err": err,
@ -158,16 +150,17 @@ func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Poo
if count > 0 { if count > 0 {
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"count": count, "stale": stale,
"orphaned": orphaned,
}).Info("pruned accounts") }).Info("pruned accounts")
} }
} }
func pruneDevices(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) { func pruneDevices(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) {
threshold := time.Now().Unix() now := time.Now()
dr := repository.NewPostgresDevice(pool) dr := repository.NewPostgresDevice(pool)
count, err := dr.PruneStale(ctx, threshold) count, err := dr.PruneStale(ctx, now)
if err != nil { if err != nil {
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"err": err, "err": err,
@ -227,6 +220,8 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie
func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
now := time.Now() now := time.Now()
next := now.Add(domain.NotificationCheckInterval)
ids := []int64{} ids := []int64{}
defer func() { defer func() {
@ -235,21 +230,20 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
}() }()
ready := now.Unix() - userEnqueueInterval
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := ` stmt := `
WITH userb AS ( WITH batch AS (
SELECT id SELECT id
FROM users FROM users
WHERE last_checked_at < $1 WHERE next_check_at < $1
ORDER BY last_checked_at ORDER BY next_check_at
LIMIT 100 LIMIT 100
) )
UPDATE users UPDATE users
SET last_checked_at = $2 SET next_check_at = $2
WHERE users.id IN(SELECT id FROM userb) WHERE users.id IN(SELECT id FROM batch)
RETURNING users.id` RETURNING users.id`
rows, err := tx.Query(ctx, stmt, ready, now.Unix()) rows, err := tx.Query(ctx, stmt, now, next)
if err != nil { if err != nil {
return err return err
} }
@ -275,7 +269,7 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"count": len(ids), "count": len(ids),
"start": ready, "start": now,
}).Debug("enqueueing user batch") }).Debug("enqueueing user batch")
batchIds := make([]string, len(ids)) batchIds := make([]string, len(ids))
@ -292,6 +286,8 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) { func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) {
now := time.Now() now := time.Now()
next := now.Add(domain.SubredditCheckInterval)
ids := []int64{} ids := []int64{}
defer func() { defer func() {
@ -300,21 +296,20 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
}() }()
ready := now.Unix() - subredditEnqueueInterval
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := ` stmt := `
WITH subreddit AS ( WITH batch AS (
SELECT id SELECT id
FROM subreddits FROM subreddits
WHERE last_checked_at < $1 WHERE next_check_at < $1
ORDER BY last_checked_at ORDER BY next_check_at
LIMIT 100 LIMIT 100
) )
UPDATE subreddits UPDATE subreddits
SET last_checked_at = $2 SET next_check_at = $2
WHERE subreddits.id IN(SELECT id FROM subreddit) WHERE subreddits.id IN(SELECT id FROM batch)
RETURNING subreddits.id` RETURNING subreddits.id`
rows, err := tx.Query(ctx, stmt, ready, now.Unix()) rows, err := tx.Query(ctx, stmt, now, next)
if err != nil { if err != nil {
return err return err
} }
@ -340,7 +335,7 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"count": len(ids), "count": len(ids),
"start": ready, "start": now,
}).Debug("enqueueing subreddit batch") }).Debug("enqueueing subreddit batch")
batchIds := make([]string, len(ids)) batchIds := make([]string, len(ids))
@ -361,6 +356,8 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) { func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
now := time.Now() now := time.Now()
next := now.Add(domain.StuckNotificationCheckInterval)
ids := []int64{} ids := []int64{}
defer func() { defer func() {
@ -369,22 +366,21 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
}() }()
ready := now.Unix() - stuckAccountEnqueueInterval
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := ` stmt := `
WITH account AS ( WITH batch AS (
SELECT id SELECT id
FROM accounts FROM accounts
WHERE WHERE
last_unstuck_at < $1 next_stuck_notification_check_at < $1
ORDER BY last_unstuck_at ORDER BY next_stuck_notification_check_at
LIMIT 500 LIMIT 500
) )
UPDATE accounts UPDATE accounts
SET last_unstuck_at = $2 SET next_stuck_notification_check_at = $2
WHERE accounts.id IN(SELECT id FROM account) WHERE accounts.id IN(SELECT id FROM batch)
RETURNING accounts.id` RETURNING accounts.id`
rows, err := tx.Query(ctx, stmt, ready, now.Unix()) rows, err := tx.Query(ctx, stmt, now, next)
if err != nil { if err != nil {
return err return err
} }
@ -410,7 +406,7 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"count": len(ids), "count": len(ids),
"start": ready, "start": now,
}).Debug("enqueueing stuck account batch") }).Debug("enqueueing stuck account batch")
batchIds := make([]string, len(ids)) batchIds := make([]string, len(ids))
@ -428,6 +424,8 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st
func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) { func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) {
now := time.Now() now := time.Now()
next := now.Add(domain.NotificationCheckInterval)
ids := []int64{} ids := []int64{}
enqueued := 0 enqueued := 0
skipped := 0 skipped := 0
@ -439,29 +437,21 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
}() }()
// Start looking for accounts that were last checked at least 5 seconds ago
// and at most 6 seconds ago. Also look for accounts that haven't been checked
// in over a minute.
ts := now.Unix()
ready := ts - accountEnqueueInterval
expired := ts - checkTimeout
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := ` stmt := `
WITH account AS ( WITH account AS (
SELECT id SELECT id
FROM accounts FROM accounts
WHERE WHERE
last_enqueued_at < $1 next_notification_check_at < $1
OR last_checked_at < $2 ORDER BY next_notification_check_at
ORDER BY last_checked_at
LIMIT 2500 LIMIT 2500
) )
UPDATE accounts UPDATE accounts
SET last_enqueued_at = $3 SET next_notification_check_at = $2
WHERE accounts.id IN(SELECT id FROM account) WHERE accounts.id IN(SELECT id FROM account)
RETURNING accounts.id` RETURNING accounts.id`
rows, err := tx.Query(ctx, stmt, ready, expired, ts) rows, err := tx.Query(ctx, stmt, now, next)
if err != nil { if err != nil {
return err return err
} }
@ -487,7 +477,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"count": len(ids), "count": len(ids),
"start": ready, "start": now,
}).Debug("enqueueing account batch") }).Debug("enqueueing account batch")
// Split ids in batches // Split ids in batches
for i := 0; i < len(ids); i += batchSize { for i := 0; i < len(ids); i += batchSize {
@ -532,7 +522,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"count": enqueued, "count": enqueued,
"skipped": skipped, "skipped": skipped,
"start": ready, "start": now,
}).Debug("done enqueueing account batch") }).Debug("done enqueueing account batch")
} }

View file

@ -3,25 +3,34 @@ package domain
import ( import (
"context" "context"
"strings" "strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4" validation "github.com/go-ozzo/ozzo-validation/v4"
) )
const (
NotificationCheckInterval = 5 * time.Second // time between notification checks
NotificationCheckTimeout = 60 * time.Second // time before we give up an account check lock
StuckNotificationCheckInterval = 2 * time.Minute // time between stuck notification checks
StaleTokenThreshold = 2 * time.Hour // time an oauth token has to be expired for to be stale
)
// Account represents an account we need to periodically check in the notifications worker. // Account represents an account we need to periodically check in the notifications worker.
type Account struct { type Account struct {
ID int64 ID int64
// Reddit information // Reddit information
Username string Username string
AccountID string AccountID string
AccessToken string AccessToken string
RefreshToken string RefreshToken string
ExpiresAt int64 TokenExpiresAt time.Time
// Tracking how far behind we are // Tracking how far behind we are
LastMessageID string LastMessageID string
LastCheckedAt float64 NextNotificationCheckAt time.Time
LastUnstuckAt float64 NextStuckNotificationCheckAt time.Time
CheckCount int64
} }
func (acct *Account) NormalizedUsername() string { func (acct *Account) NormalizedUsername() string {
@ -49,5 +58,5 @@ type AccountRepository interface {
Disassociate(ctx context.Context, acc *Account, dev *Device) error Disassociate(ctx context.Context, acc *Account, dev *Device) error
PruneOrphaned(ctx context.Context) (int64, error) PruneOrphaned(ctx context.Context) (int64, error)
PruneStale(ctx context.Context, before int64) (int64, error) PruneStale(ctx context.Context, expiry time.Time) (int64, error)
} }

View file

@ -2,22 +2,23 @@ package domain
import ( import (
"context" "context"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4" validation "github.com/go-ozzo/ozzo-validation/v4"
) )
const ( const (
DeviceGracePeriodDuration = 3600 // 1 hour DeviceReceiptCheckPeriodDuration = 1 * time.Hour
DeviceActiveAfterReceitCheckDuration = 3600 * 24 * 30 // ~1 month DeviceActiveAfterReceitCheckDuration = 30 * 24 * time.Hour // ~1 month
DeviceGracePeriodAfterReceiptExpiry = 3600 * 24 * 30 // ~1 month DeviceGracePeriodAfterReceiptExpiry = 30 * 24 * time.Hour // ~1 month
) )
type Device struct { type Device struct {
ID int64 ID int64
APNSToken string APNSToken string
Sandbox bool Sandbox bool
ActiveUntil int64 ExpiresAt time.Time
GracePeriodUntil int64 GracePeriodExpiresAt time.Time
} }
func (dev *Device) Validate() error { func (dev *Device) Validate() error {
@ -40,5 +41,5 @@ type DeviceRepository interface {
SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher, global bool) error SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher, global bool) error
GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, bool, error) GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, bool, error)
PruneStale(ctx context.Context, before int64) (int64, error) PruneStale(ctx context.Context, expiry time.Time) (int64, error)
} }

View file

@ -3,13 +3,16 @@ package domain
import ( import (
"context" "context"
"strings" "strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4" validation "github.com/go-ozzo/ozzo-validation/v4"
) )
const SubredditCheckInterval = 2 * time.Minute
type Subreddit struct { type Subreddit struct {
ID int64 ID int64
LastCheckedAt float64 NextCheckAt time.Time
// Reddit information // Reddit information
SubredditID string SubredditID string

View file

@ -3,13 +3,16 @@ package domain
import ( import (
"context" "context"
"strings" "strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4" validation "github.com/go-ozzo/ozzo-validation/v4"
) )
const UserRefreshInterval = 2 * time.Minute
type User struct { type User struct {
ID int64 ID int64
LastCheckedAt float64 NextCheckAt time.Time
// Reddit information // Reddit information
UserID string UserID string

View file

@ -2,6 +2,7 @@ package domain
import ( import (
"context" "context"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4" validation "github.com/go-ozzo/ozzo-validation/v4"
) )
@ -29,8 +30,8 @@ func (wt WatcherType) String() string {
type Watcher struct { type Watcher struct {
ID int64 ID int64
CreatedAt float64 CreatedAt time.Time
LastNotifiedAt float64 LastNotifiedAt time.Time
Label string Label string
DeviceID int64 DeviceID int64

View file

@ -3,6 +3,7 @@ package reddit
import ( import (
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/valyala/fastjson" "github.com/valyala/fastjson"
) )
@ -61,24 +62,24 @@ func NewMeResponse(val *fastjson.Value) interface{} {
} }
type Thing struct { type Thing struct {
Kind string `json:"kind"` Kind string `json:"kind"`
ID string `json:"id"` ID string `json:"id"`
Type string `json:"type"` Type string `json:"type"`
Author string `json:"author"` Author string `json:"author"`
Subject string `json:"subject"` Subject string `json:"subject"`
Body string `json:"body"` Body string `json:"body"`
CreatedAt float64 `json:"created_utc"` CreatedAt time.Time `json:"created_utc"`
Context string `json:"context"` Context string `json:"context"`
ParentID string `json:"parent_id"` ParentID string `json:"parent_id"`
LinkTitle string `json:"link_title"` LinkTitle string `json:"link_title"`
Destination string `json:"dest"` Destination string `json:"dest"`
Subreddit string `json:"subreddit"` Subreddit string `json:"subreddit"`
SubredditType string `json:"subreddit_type"` SubredditType string `json:"subreddit_type"`
Score int64 `json:"score"` Score int64 `json:"score"`
SelfText string `json:"selftext"` SelfText string `json:"selftext"`
Title string `json:"title"` Title string `json:"title"`
URL string `json:"url"` URL string `json:"url"`
Flair string `json:"flair"` Flair string `json:"flair"`
} }
func (t *Thing) FullName() string { func (t *Thing) FullName() string {
@ -95,13 +96,14 @@ func NewThing(val *fastjson.Value) *Thing {
t.Kind = string(val.GetStringBytes("kind")) t.Kind = string(val.GetStringBytes("kind"))
data := val.Get("data") data := val.Get("data")
unix := int64(data.GetFloat64("created_utc"))
t.ID = string(data.GetStringBytes("id")) t.ID = string(data.GetStringBytes("id"))
t.Type = string(data.GetStringBytes("type")) t.Type = string(data.GetStringBytes("type"))
t.Author = string(data.GetStringBytes("author")) t.Author = string(data.GetStringBytes("author"))
t.Subject = string(data.GetStringBytes("subject")) t.Subject = string(data.GetStringBytes("subject"))
t.Body = string(data.GetStringBytes("body")) t.Body = string(data.GetStringBytes("body"))
t.CreatedAt = data.GetFloat64("created_utc") t.CreatedAt = time.Unix(unix, 0)
t.Context = string(data.GetStringBytes("context")) t.Context = string(data.GetStringBytes("context"))
t.ParentID = string(data.GetStringBytes("parent_id")) t.ParentID = string(data.GetStringBytes("parent_id"))
t.LinkTitle = string(data.GetStringBytes("link_title")) t.LinkTitle = string(data.GetStringBytes("link_title"))

View file

@ -3,6 +3,7 @@ package reddit
import ( import (
"io/ioutil" "io/ioutil"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/valyala/fastjson" "github.com/valyala/fastjson"
@ -60,13 +61,14 @@ func TestListingResponseParsing(t *testing.T) {
assert.Equal(t, "", l.Before) assert.Equal(t, "", l.Before)
thing := l.Children[0] thing := l.Children[0]
created := time.Time(time.Date(2021, time.July, 14, 13, 56, 35, 0, time.Local))
assert.Equal(t, "t4", thing.Kind) assert.Equal(t, "t4", thing.Kind)
assert.Equal(t, "138z6ke", thing.ID) assert.Equal(t, "138z6ke", thing.ID)
assert.Equal(t, "unknown", thing.Type) assert.Equal(t, "unknown", thing.Type)
assert.Equal(t, "iamthatis", thing.Author) assert.Equal(t, "iamthatis", thing.Author)
assert.Equal(t, "how goes it", thing.Subject) assert.Equal(t, "how goes it", thing.Subject)
assert.Equal(t, "how are you today", thing.Body) assert.Equal(t, "how are you today", thing.Body)
assert.Equal(t, 1626285395.0, thing.CreatedAt) assert.Equal(t, created, thing.CreatedAt)
assert.Equal(t, "hugocat", thing.Destination) assert.Equal(t, "hugocat", thing.Destination)
assert.Equal(t, "t4_138z6ke", thing.FullName()) assert.Equal(t, "t4_138z6ke", thing.FullName())

View file

@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
@ -33,10 +34,11 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg
&acc.AccountID, &acc.AccountID,
&acc.AccessToken, &acc.AccessToken,
&acc.RefreshToken, &acc.RefreshToken,
&acc.ExpiresAt, &acc.TokenExpiresAt,
&acc.LastMessageID, &acc.LastMessageID,
&acc.LastCheckedAt, &acc.NextNotificationCheckAt,
&acc.LastUnstuckAt, &acc.NextStuckNotificationCheckAt,
&acc.CheckCount,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -47,7 +49,9 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg
func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (domain.Account, error) { func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (domain.Account, error) {
query := ` query := `
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at SELECT id, username, account_id, access_token, refresh_token, token_expires_at,
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
check_count
FROM accounts FROM accounts
WHERE id = $1` WHERE id = $1`
@ -64,7 +68,9 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma
func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) { func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) {
query := ` query := `
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at SELECT id, username, account_id, access_token, refresh_token, token_expires_at,
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
check_count
FROM accounts FROM accounts
WHERE account_id = $1` WHERE account_id = $1`
@ -81,12 +87,13 @@ func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string
} }
func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error { func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error {
query := ` query := `
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at) INSERT INTO accounts (username, account_id, access_token, refresh_token, token_expires_at,
VALUES ($1, $2, $3, $4, $5, '', 0, 0) last_message_id, next_notification_check_at, next_stuck_notification_check_at)
VALUES ($1, $2, $3, $4, $5, '', NOW(), NOW())
ON CONFLICT(username) DO ON CONFLICT(username) DO
UPDATE SET access_token = $3, UPDATE SET access_token = $3,
refresh_token = $4, refresh_token = $4,
expires_at = $5 token_expires_at = $5
RETURNING id` RETURNING id`
return p.pool.QueryRow( return p.pool.QueryRow(
@ -96,14 +103,15 @@ func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *dom
acc.AccountID, acc.AccountID,
acc.AccessToken, acc.AccessToken,
acc.RefreshToken, acc.RefreshToken,
acc.ExpiresAt, acc.TokenExpiresAt,
).Scan(&acc.ID) ).Scan(&acc.ID)
} }
func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error { func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error {
query := ` query := `
INSERT INTO accounts INSERT INTO accounts
(username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at) (username, account_id, access_token, refresh_token, token_expires_at,
last_message_id, next_notification_check_at, next_stuck_notification_check_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id` RETURNING id`
@ -114,10 +122,10 @@ func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Acco
acc.AccountID, acc.AccountID,
acc.AccessToken, acc.AccessToken,
acc.RefreshToken, acc.RefreshToken,
acc.ExpiresAt, acc.TokenExpiresAt,
acc.LastMessageID, acc.LastMessageID,
acc.LastCheckedAt, acc.NextNotificationCheckAt,
acc.LastUnstuckAt, acc.NextStuckNotificationCheckAt,
).Scan(&acc.ID) ).Scan(&acc.ID)
} }
@ -128,10 +136,11 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
account_id = $3, account_id = $3,
access_token = $4, access_token = $4,
refresh_token = $5, refresh_token = $5,
expires_at = $6, token_expires_at = $6,
last_message_id = $7, last_message_id = $7,
last_checked_at = $8, next_notification_check_at = $8,
last_unstuck_at = $9 next_stuck_notification_check_at = $9,
check_count = $10
WHERE id = $1` WHERE id = $1`
res, err := p.pool.Exec( res, err := p.pool.Exec(
@ -142,10 +151,11 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
acc.AccountID, acc.AccountID,
acc.AccessToken, acc.AccessToken,
acc.RefreshToken, acc.RefreshToken,
acc.ExpiresAt, acc.TokenExpiresAt,
acc.LastMessageID, acc.LastMessageID,
acc.LastCheckedAt, acc.NextNotificationCheckAt,
acc.LastUnstuckAt, acc.NextStuckNotificationCheckAt,
acc.CheckCount,
) )
if res.RowsAffected() != 1 { if res.RowsAffected() != 1 {
@ -186,7 +196,9 @@ func (p *postgresAccountRepository) Disassociate(ctx context.Context, acc *domai
func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token string) ([]domain.Account, error) { func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token string) ([]domain.Account, error) {
query := ` query := `
SELECT accounts.id, username, accounts.account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at SELECT accounts.id, username, accounts.account_id, access_token, refresh_token, token_expires_at,
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
check_count
FROM accounts FROM accounts
INNER JOIN devices_accounts ON accounts.id = devices_accounts.account_id INNER JOIN devices_accounts ON accounts.id = devices_accounts.account_id
INNER JOIN devices ON devices.id = devices_accounts.device_id INNER JOIN devices ON devices.id = devices_accounts.device_id
@ -195,12 +207,12 @@ func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token st
return p.fetch(ctx, query, token) return p.fetch(ctx, query, token)
} }
func (p *postgresAccountRepository) PruneStale(ctx context.Context, before int64) (int64, error) { func (p *postgresAccountRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) {
query := ` query := `
DELETE FROM accounts DELETE FROM accounts
WHERE expires_at < $1` WHERE token_expires_at < $1`
res, err := p.pool.Exec(ctx, query, before) res, err := p.pool.Exec(ctx, query, expiry)
return res.RowsAffected(), err return res.RowsAffected(), err
} }

View file

@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
@ -31,8 +32,8 @@ func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args
&dev.ID, &dev.ID,
&dev.APNSToken, &dev.APNSToken,
&dev.Sandbox, &dev.Sandbox,
&dev.ActiveUntil, &dev.ExpiresAt,
&dev.GracePeriodUntil, &dev.GracePeriodExpiresAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -43,7 +44,7 @@ func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args
func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domain.Device, error) { func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domain.Device, error) {
query := ` query := `
SELECT id, apns_token, sandbox, active_until, grace_period_until SELECT id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices FROM devices
WHERE id = $1` WHERE id = $1`
@ -60,7 +61,7 @@ func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domai
func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token string) (domain.Device, error) { func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token string) (domain.Device, error) {
query := ` query := `
SELECT id, apns_token, sandbox, active_until, grace_period_until SELECT id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices FROM devices
WHERE apns_token = $1` WHERE apns_token = $1`
@ -77,7 +78,7 @@ func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token str
func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := ` query := `
SELECT devices.id, apns_token, sandbox, active_until, grace_period_until SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1` WHERE devices_accounts.account_id = $1`
@ -87,7 +88,7 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := ` query := `
SELECT devices.id, apns_token, sandbox, active_until, grace_period_until SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1 AND WHERE devices_accounts.account_id = $1 AND
@ -99,7 +100,7 @@ func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Con
func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := ` query := `
SELECT devices.id, apns_token, sandbox, active_until, grace_period_until SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1 AND WHERE devices_accounts.account_id = $1 AND
@ -111,10 +112,10 @@ func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.C
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error { func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
query := ` query := `
INSERT INTO devices (apns_token, sandbox, active_until, grace_period_until) INSERT INTO devices (apns_token, sandbox, expires_at, grace_period_expires_at)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
ON CONFLICT(apns_token) DO ON CONFLICT(apns_token) DO
UPDATE SET active_until = $3, grace_period_until = $4 UPDATE SET expires_at = $3, grace_period_expires_at = $4
RETURNING id` RETURNING id`
return p.pool.QueryRow( return p.pool.QueryRow(
@ -122,8 +123,8 @@ func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *doma
query, query,
dev.APNSToken, dev.APNSToken,
dev.Sandbox, dev.Sandbox,
dev.ActiveUntil, &dev.ExpiresAt,
dev.GracePeriodUntil, &dev.GracePeriodExpiresAt,
).Scan(&dev.ID) ).Scan(&dev.ID)
} }
@ -134,7 +135,7 @@ func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Devic
query := ` query := `
INSERT INTO devices INSERT INTO devices
(apns_token, sandbox, active_until, grace_period_until) (apns_token, sandbox, expires_at, grace_period_expires_at)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
RETURNING id` RETURNING id`
@ -143,8 +144,8 @@ func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Devic
query, query,
dev.APNSToken, dev.APNSToken,
dev.Sandbox, dev.Sandbox,
dev.ActiveUntil, dev.ExpiresAt,
dev.GracePeriodUntil, dev.GracePeriodExpiresAt,
).Scan(&dev.ID) ).Scan(&dev.ID)
} }
@ -155,10 +156,10 @@ func (p *postgresDeviceRepository) Update(ctx context.Context, dev *domain.Devic
query := ` query := `
UPDATE devices UPDATE devices
SET active_until = $2, grace_period_until = $3 SET expires_at = $2, grace_period_expires_at = $3
WHERE id = $1` WHERE id = $1`
res, err := p.pool.Exec(ctx, query, dev.ID, dev.ActiveUntil, dev.GracePeriodUntil) res, err := p.pool.Exec(ctx, query, dev.ID, dev.ExpiresAt, dev.GracePeriodExpiresAt)
if res.RowsAffected() != 1 { if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -209,10 +210,10 @@ func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domai
return inbox, watcher, global, nil return inbox, watcher, global, nil
} }
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) { func (p *postgresDeviceRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) {
query := `DELETE FROM devices WHERE grace_period_until < $1` query := `DELETE FROM devices WHERE grace_period_expires_at < $1`
res, err := p.pool.Exec(ctx, query, before) res, err := p.pool.Exec(ctx, query, expiry)
return res.RowsAffected(), err return res.RowsAffected(), err
} }

View file

@ -30,7 +30,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
&sr.ID, &sr.ID,
&sr.SubredditID, &sr.SubredditID,
&sr.Name, &sr.Name,
&sr.LastCheckedAt, &sr.NextCheckAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -41,7 +41,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) { func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) {
query := ` query := `
SELECT id, subreddit_id, name, last_checked_at SELECT id, subreddit_id, name, next_check_at
FROM subreddits FROM subreddits
WHERE id = $1` WHERE id = $1`
@ -58,7 +58,7 @@ func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (do
func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) { func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) {
query := ` query := `
SELECT id, subreddit_id, name, last_checked_at SELECT id, subreddit_id, name, next_check_at
FROM subreddits FROM subreddits
WHERE name = $1` WHERE name = $1`
@ -81,8 +81,8 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
} }
query := ` query := `
INSERT INTO subreddits (subreddit_id, name) INSERT INTO subreddits (subreddit_id, name, next_check_at)
VALUES ($1, $2) VALUES ($1, $2, NOW())
ON CONFLICT(subreddit_id) DO NOTHING ON CONFLICT(subreddit_id) DO NOTHING
RETURNING id` RETURNING id`

View file

@ -32,7 +32,7 @@ func (p *postgresUserRepository) fetch(ctx context.Context, query string, args .
&u.ID, &u.ID,
&u.UserID, &u.UserID,
&u.Name, &u.Name,
&u.LastCheckedAt, &u.NextCheckAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -43,7 +43,7 @@ func (p *postgresUserRepository) fetch(ctx context.Context, query string, args .
func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.User, error) { func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.User, error) {
query := ` query := `
SELECT id, user_id, name, last_checked_at SELECT id, user_id, name, next_check_at
FROM users FROM users
WHERE id = $1` WHERE id = $1`
@ -60,7 +60,7 @@ func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.
func (p *postgresUserRepository) GetByName(ctx context.Context, name string) (domain.User, error) { func (p *postgresUserRepository) GetByName(ctx context.Context, name string) (domain.User, error) {
query := ` query := `
SELECT id, user_id, name, last_checked_at SELECT id, user_id, name, next_check_at
FROM users FROM users
WHERE name = $1` WHERE name = $1`
@ -83,10 +83,9 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U
} }
query := ` query := `
INSERT INTO users (user_id, name) INSERT INTO users (user_id, name, next_check_at)
VALUES ($1, $2) VALUES ($1, $2, NOW())
ON CONFLICT(user_id) DO ON CONFLICT(user_id) DO NOTHING
UPDATE SET last_checked_at = $3
RETURNING id` RETURNING id`
return p.pool.QueryRow( return p.pool.QueryRow(
@ -94,7 +93,7 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U
query, query,
u.UserID, u.UserID,
u.NormalizedName(), u.NormalizedName(),
u.LastCheckedAt, u.NextCheckAt,
).Scan(&u.ID) ).Scan(&u.ID)
} }

View file

@ -161,7 +161,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }() defer func() { _ = delivery.Ack() }()
now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000 now := time.Now()
account, err := nc.accountRepo.GetByID(ctx, id) account, err := nc.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
@ -171,20 +171,22 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return return
} }
previousLastCheckedAt := account.LastCheckedAt newAccount := account.CheckCount == 0
newAccount := (previousLastCheckedAt == 0) previousNextCheck := account.NextNotificationCheckAt
account.LastCheckedAt = now
account.CheckCount++
account.NextNotificationCheckAt = time.Now().Add(domain.NotificationCheckInterval)
if err = nc.accountRepo.Update(ctx, &account); err != nil { if err = nc.accountRepo.Update(ctx, &account); err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(), "account#username": account.NormalizedUsername(),
"err": err, "err": err,
}).Error("failed to update last_checked_at for account") }).Error("failed to update next_notification_check_at for account")
return return
} }
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
if account.ExpiresAt < int64(now) { if account.TokenExpiresAt.Before(now) {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(), "account#username": account.NormalizedUsername(),
}).Debug("refreshing reddit token") }).Debug("refreshing reddit token")
@ -213,7 +215,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
// Update account // Update account
account.AccessToken = tokens.AccessToken account.AccessToken = tokens.AccessToken
account.RefreshToken = tokens.RefreshToken account.RefreshToken = tokens.RefreshToken
account.ExpiresAt = int64(now + 3540) account.TokenExpiresAt = now.Add(3600 * time.Second)
// Refresh client // Refresh client
rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken) rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
@ -230,8 +232,8 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
// Only update delay on accounts we can actually check, otherwise it skews // Only update delay on accounts we can actually check, otherwise it skews
// the numbers too much. // the numbers too much.
if !newAccount { if !newAccount {
latency := now - previousLastCheckedAt - float64(backoff) latency := now.Sub(previousNextCheck) - backoff*time.Second
_ = nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate) _ = nc.statsd.Histogram("apollo.queue.delay", float64(latency.Milliseconds()), []string{}, rate)
} }
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{

View file

@ -177,7 +177,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
return return
} }
threshold := float64(time.Now().AddDate(0, 0, -1).UTC().Unix()) threshold := time.Now().Add(-24 * time.Hour)
posts := []*reddit.Thing{} posts := []*reddit.Thing{}
before := "" before := ""
finished := false finished := false
@ -234,7 +234,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
} }
for _, post := range sps.Children { for _, post := range sps.Children {
if post.CreatedAt < threshold { if post.CreatedAt.Before(threshold) {
finished = true finished = true
break break
} }
@ -284,7 +284,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
}).Debug("loaded hot posts") }).Debug("loaded hot posts")
for _, post := range sps.Children { for _, post := range sps.Children {
if post.CreatedAt < threshold { if post.CreatedAt.Before(threshold) {
break break
} }
if _, ok := seenPosts[post.ID]; !ok { if _, ok := seenPosts[post.ID]; !ok {
@ -310,7 +310,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
for _, watcher := range watchers { for _, watcher := range watchers {
// Make sure we only alert on posts created after the search // Make sure we only alert on posts created after the search
if watcher.CreatedAt > post.CreatedAt { if watcher.CreatedAt.After(post.CreatedAt) {
continue continue
} }

View file

@ -235,14 +235,14 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
}).Debug("loaded hot posts") }).Debug("loaded hot posts")
// Trending only counts for posts less than 2 days old // Trending only counts for posts less than 2 days old
threshold := float64(time.Now().AddDate(0, 0, -2).UTC().Unix()) threshold := time.Now().Add(-24 * time.Hour * 2)
for _, post := range hps.Children { for _, post := range hps.Children {
if post.Score < medianScore { if post.Score < medianScore {
continue continue
} }
if post.CreatedAt < threshold { if post.CreatedAt.Before(threshold) {
break break
} }
@ -251,7 +251,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
notification.Payload = payloadFromTrendingPost(post) notification.Payload = payloadFromTrendingPost(post)
for _, watcher := range watchers { for _, watcher := range watchers {
if watcher.CreatedAt > post.CreatedAt { if watcher.CreatedAt.After(post.CreatedAt) {
continue continue
} }

View file

@ -234,11 +234,11 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
for _, watcher := range watchers { for _, watcher := range watchers {
// Make sure we only alert on activities created after the search // Make sure we only alert on activities created after the search
if watcher.CreatedAt > post.CreatedAt { if watcher.CreatedAt.After(post.CreatedAt) {
continue continue
} }
if watcher.LastNotifiedAt > post.CreatedAt { if watcher.LastNotifiedAt.After(post.CreatedAt) {
continue continue
} }