Merge pull request #66 from christianselig/chore/schema-changes

Changes to schema
This commit is contained in:
André Medeiros 2022-05-07 11:52:52 -04:00 committed by GitHub
commit a94aa11845
22 changed files with 302 additions and 201 deletions

64
docs/schema.sql Normal file
View file

@ -0,0 +1,64 @@
CREATE TABLE accounts (
id SERIAL PRIMARY KEY,
reddit_account_id character varying(32) DEFAULT ''::character varying,
username character varying(20) DEFAULT ''::character varying UNIQUE,
access_token character varying(64) DEFAULT ''::character varying,
refresh_token character varying(64) DEFAULT ''::character varying,
token_expires_at timestamp without time zone,
last_message_id character varying(32) DEFAULT ''::character varying,
next_notification_check_at timestamp without time zone,
next_stuck_notification_check_at timestamp without time zone,
check_count integer DEFAULT 0
);
CREATE TABLE devices (
id SERIAL PRIMARY KEY,
apns_token character varying(100) UNIQUE,
sandbox boolean,
expires_at timestamp without time zone,
grace_period_expires_at timestamp without time zone
);
CREATE TABLE devices_accounts (
id SERIAL PRIMARY KEY,
account_id integer REFERENCES accounts(id) ON DELETE CASCADE,
device_id integer REFERENCES devices(id) ON DELETE CASCADE,
watcher_notifiable boolean DEFAULT true,
inbox_notifiable boolean DEFAULT true,
global_mute boolean DEFAULT false
);
CREATE UNIQUE INDEX devices_accounts_account_id_device_id_idx ON devices_accounts(account_id int4_ops,device_id int4_ops);
CREATE TABLE subreddits (
id SERIAL PRIMARY KEY,
subreddit_id character varying(32) DEFAULT ''::character varying UNIQUE,
name character varying(32) DEFAULT ''::character varying,
next_check_at timestamp without time zone
);
CREATE TABLE users (
id SERIAL PRIMARY KEY,
user_id character varying(32) DEFAULT ''::character varying UNIQUE,
name character varying(32) DEFAULT ''::character varying,
next_check_at timestamp without time zone
);
CREATE TABLE watchers (
id SERIAL PRIMARY KEY,
created_at timestamp without time zone,
last_notified_at timestamp without time zone,
device_id integer REFERENCES devices(id) ON DELETE CASCADE,
account_id integer REFERENCES accounts(id) ON DELETE CASCADE,
watchee_id integer,
upvotes integer DEFAULT 0,
keyword character varying(32) DEFAULT ''::character varying,
flair character varying(32) DEFAULT ''::character varying,
domain character varying(32) DEFAULT ''::character varying,
hits integer DEFAULT 0,
type integer DEFAULT 0,
label character varying(64) DEFAULT ''::character varying,
author character varying(32) DEFAULT ''::character varying,
subreddit character varying(32) DEFAULT ''::character varying
);

View file

@ -150,7 +150,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
}
// Reset expiration timer
acc.ExpiresAt = time.Now().Unix() + 3540
acc.TokenExpiresAt = time.Now().Add(tokens.Expiry)
acc.RefreshToken = tokens.RefreshToken
acc.AccessToken = tokens.AccessToken
@ -175,7 +175,10 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
return
}
_ = a.accountRepo.Associate(ctx, &acc, &dev)
if err := a.accountRepo.Associate(ctx, &acc, &dev); err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
}
for _, acc := range accsMap {
@ -228,7 +231,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
}
// Reset expiration timer
acct.ExpiresAt = time.Now().Unix() + 3540
acct.TokenExpiresAt = time.Now().Add(tokens.Expiry)
acct.RefreshToken = tokens.RefreshToken
acct.AccessToken = tokens.AccessToken

View file

@ -28,8 +28,8 @@ func (a *api) upsertDeviceHandler(w http.ResponseWriter, r *http.Request) {
return
}
d.ActiveUntil = time.Now().Unix() + domain.DeviceGracePeriodDuration
d.GracePeriodUntil = d.ActiveUntil + domain.DeviceGracePeriodAfterReceiptExpiry
d.ExpiresAt = time.Now().Add(domain.DeviceReceiptCheckPeriodDuration)
d.GracePeriodExpiresAt = d.ExpiresAt.Add(domain.DeviceGracePeriodAfterReceiptExpiry)
if err := a.deviceRepo.CreateOrUpdate(ctx, d); err != nil {
a.errorResponse(w, r, 500, err.Error())

View file

@ -39,6 +39,11 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) {
}
if iapr.DeleteDevice {
if dev.GracePeriodExpiresAt.After(time.Now()) {
w.WriteHeader(http.StatusOK)
return
}
accs, err := a.accountRepo.GetByAPNSToken(ctx, apns)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
@ -51,8 +56,8 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) {
_ = a.deviceRepo.Delete(ctx, apns)
} else {
dev.ActiveUntil = time.Now().Unix() + domain.DeviceActiveAfterReceitCheckDuration
dev.GracePeriodUntil = dev.ActiveUntil + domain.DeviceGracePeriodAfterReceiptExpiry
dev.ExpiresAt = time.Now().Add(domain.DeviceActiveAfterReceitCheckDuration)
dev.GracePeriodExpiresAt = dev.ExpiresAt.Add(domain.DeviceGracePeriodAfterReceiptExpiry)
_ = a.deviceRepo.Update(ctx, &dev)
}
}

View file

@ -6,6 +6,7 @@ import (
"net/http"
"strconv"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/gorilla/mux"
@ -231,7 +232,7 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
type watcherItem struct {
ID int64 `json:"id"`
CreatedAt float64 `json:"created_at"`
CreatedAt time.Time `json:"created_at"`
Type string `json:"type"`
Label string `json:"label"`
SourceLabel string `json:"source_label"`

View file

@ -9,27 +9,19 @@ import (
"github.com/DataDog/datadog-go/statsd"
"github.com/adjust/rmq/v4"
"github.com/christianselig/apollo-backend/internal/cmdutil"
"github.com/christianselig/apollo-backend/internal/repository"
"github.com/go-co-op/gocron"
"github.com/go-redis/redis/v8"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/christianselig/apollo-backend/internal/cmdutil"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/repository"
)
const (
batchSize = 250
checkTimeout = 60 // how long until we force a check
accountEnqueueInterval = 5 // how frequently we want to check (seconds)
subredditEnqueueInterval = 2 * 60 // how frequently we want to check (seconds)
userEnqueueInterval = 2 * 60 // how frequently we want to check (seconds)
stuckAccountEnqueueInterval = 2 * 60 // how frequently we want to check (seconds)
staleAccountThreshold = 7200 // 2 hours
)
const batchSize = 250
func SchedulerCmd(ctx context.Context) *cobra.Command {
cmd := &cobra.Command{
@ -129,16 +121,16 @@ func evalScript(ctx context.Context, redis *redis.Client) (string, error) {
end
return retv
`, checkTimeout)
`, int64(domain.NotificationCheckTimeout.Seconds()))
return redis.ScriptLoad(ctx, lua).Result()
}
func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) {
before := time.Now().Unix() - staleAccountThreshold
expiry := time.Now().Add(-domain.StaleTokenThreshold)
ar := repository.NewPostgresAccount(pool)
stale, err := ar.PruneStale(ctx, before)
stale, err := ar.PruneStale(ctx, expiry)
if err != nil {
logger.WithFields(logrus.Fields{
"err": err,
@ -158,16 +150,17 @@ func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Poo
if count > 0 {
logger.WithFields(logrus.Fields{
"count": count,
"stale": stale,
"orphaned": orphaned,
}).Info("pruned accounts")
}
}
func pruneDevices(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) {
threshold := time.Now().Unix()
now := time.Now()
dr := repository.NewPostgresDevice(pool)
count, err := dr.PruneStale(ctx, threshold)
count, err := dr.PruneStale(ctx, now)
if err != nil {
logger.WithFields(logrus.Fields{
"err": err,
@ -227,6 +220,8 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie
func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
now := time.Now()
next := now.Add(domain.NotificationCheckInterval)
ids := []int64{}
defer func() {
@ -235,21 +230,20 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
}()
ready := now.Unix() - userEnqueueInterval
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := `
WITH userb AS (
WITH batch AS (
SELECT id
FROM users
WHERE last_checked_at < $1
ORDER BY last_checked_at
WHERE next_check_at < $1
ORDER BY next_check_at
LIMIT 100
)
UPDATE users
SET last_checked_at = $2
WHERE users.id IN(SELECT id FROM userb)
SET next_check_at = $2
WHERE users.id IN(SELECT id FROM batch)
RETURNING users.id`
rows, err := tx.Query(ctx, stmt, ready, now.Unix())
rows, err := tx.Query(ctx, stmt, now, next)
if err != nil {
return err
}
@ -275,7 +269,7 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
logger.WithFields(logrus.Fields{
"count": len(ids),
"start": ready,
"start": now,
}).Debug("enqueueing user batch")
batchIds := make([]string, len(ids))
@ -292,6 +286,8 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) {
now := time.Now()
next := now.Add(domain.SubredditCheckInterval)
ids := []int64{}
defer func() {
@ -300,21 +296,20 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
}()
ready := now.Unix() - subredditEnqueueInterval
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := `
WITH subreddit AS (
WITH batch AS (
SELECT id
FROM subreddits
WHERE last_checked_at < $1
ORDER BY last_checked_at
WHERE next_check_at < $1
ORDER BY next_check_at
LIMIT 100
)
UPDATE subreddits
SET last_checked_at = $2
WHERE subreddits.id IN(SELECT id FROM subreddit)
SET next_check_at = $2
WHERE subreddits.id IN(SELECT id FROM batch)
RETURNING subreddits.id`
rows, err := tx.Query(ctx, stmt, ready, now.Unix())
rows, err := tx.Query(ctx, stmt, now, next)
if err != nil {
return err
}
@ -340,7 +335,7 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
logger.WithFields(logrus.Fields{
"count": len(ids),
"start": ready,
"start": now,
}).Debug("enqueueing subreddit batch")
batchIds := make([]string, len(ids))
@ -361,6 +356,8 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
now := time.Now()
next := now.Add(domain.StuckNotificationCheckInterval)
ids := []int64{}
defer func() {
@ -369,22 +366,21 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
}()
ready := now.Unix() - stuckAccountEnqueueInterval
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := `
WITH account AS (
WITH batch AS (
SELECT id
FROM accounts
WHERE
last_unstuck_at < $1
ORDER BY last_unstuck_at
next_stuck_notification_check_at < $1
ORDER BY next_stuck_notification_check_at
LIMIT 500
)
UPDATE accounts
SET last_unstuck_at = $2
WHERE accounts.id IN(SELECT id FROM account)
SET next_stuck_notification_check_at = $2
WHERE accounts.id IN(SELECT id FROM batch)
RETURNING accounts.id`
rows, err := tx.Query(ctx, stmt, ready, now.Unix())
rows, err := tx.Query(ctx, stmt, now, next)
if err != nil {
return err
}
@ -410,7 +406,7 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st
logger.WithFields(logrus.Fields{
"count": len(ids),
"start": ready,
"start": now,
}).Debug("enqueueing stuck account batch")
batchIds := make([]string, len(ids))
@ -428,6 +424,8 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st
func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) {
now := time.Now()
next := now.Add(domain.NotificationCheckInterval)
ids := []int64{}
enqueued := 0
skipped := 0
@ -439,29 +437,21 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
}()
// Start looking for accounts that were last checked at least 5 seconds ago
// and at most 6 seconds ago. Also look for accounts that haven't been checked
// in over a minute.
ts := now.Unix()
ready := ts - accountEnqueueInterval
expired := ts - checkTimeout
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := `
WITH account AS (
SELECT id
FROM accounts
WHERE
last_enqueued_at < $1
OR last_checked_at < $2
ORDER BY last_checked_at
next_notification_check_at < $1
ORDER BY next_notification_check_at
LIMIT 2500
)
UPDATE accounts
SET last_enqueued_at = $3
SET next_notification_check_at = $2
WHERE accounts.id IN(SELECT id FROM account)
RETURNING accounts.id`
rows, err := tx.Query(ctx, stmt, ready, expired, ts)
rows, err := tx.Query(ctx, stmt, now, next)
if err != nil {
return err
}
@ -487,7 +477,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
logger.WithFields(logrus.Fields{
"count": len(ids),
"start": ready,
"start": now,
}).Debug("enqueueing account batch")
// Split ids in batches
for i := 0; i < len(ids); i += batchSize {
@ -532,7 +522,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
logger.WithFields(logrus.Fields{
"count": enqueued,
"skipped": skipped,
"start": ready,
"start": now,
}).Debug("done enqueueing account batch")
}

View file

@ -3,10 +3,18 @@ package domain
import (
"context"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
)
const (
NotificationCheckInterval = 5 * time.Second // time between notification checks
NotificationCheckTimeout = 60 * time.Second // time before we give up an account check lock
StuckNotificationCheckInterval = 2 * time.Minute // time between stuck notification checks
StaleTokenThreshold = 2 * time.Hour // time an oauth token has to be expired for to be stale
)
// Account represents an account we need to periodically check in the notifications worker.
type Account struct {
ID int64
@ -16,12 +24,13 @@ type Account struct {
AccountID string
AccessToken string
RefreshToken string
ExpiresAt int64
TokenExpiresAt time.Time
// Tracking how far behind we are
LastMessageID string
LastCheckedAt float64
LastUnstuckAt float64
NextNotificationCheckAt time.Time
NextStuckNotificationCheckAt time.Time
CheckCount int64
}
func (acct *Account) NormalizedUsername() string {
@ -49,5 +58,5 @@ type AccountRepository interface {
Disassociate(ctx context.Context, acc *Account, dev *Device) error
PruneOrphaned(ctx context.Context) (int64, error)
PruneStale(ctx context.Context, before int64) (int64, error)
PruneStale(ctx context.Context, expiry time.Time) (int64, error)
}

View file

@ -2,22 +2,23 @@ package domain
import (
"context"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
)
const (
DeviceGracePeriodDuration = 3600 // 1 hour
DeviceActiveAfterReceitCheckDuration = 3600 * 24 * 30 // ~1 month
DeviceGracePeriodAfterReceiptExpiry = 3600 * 24 * 30 // ~1 month
DeviceReceiptCheckPeriodDuration = 1 * time.Hour
DeviceActiveAfterReceitCheckDuration = 30 * 24 * time.Hour // ~1 month
DeviceGracePeriodAfterReceiptExpiry = 30 * 24 * time.Hour // ~1 month
)
type Device struct {
ID int64
APNSToken string
Sandbox bool
ActiveUntil int64
GracePeriodUntil int64
ExpiresAt time.Time
GracePeriodExpiresAt time.Time
}
func (dev *Device) Validate() error {
@ -40,5 +41,5 @@ type DeviceRepository interface {
SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher, global bool) error
GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, bool, error)
PruneStale(ctx context.Context, before int64) (int64, error)
PruneStale(ctx context.Context, expiry time.Time) (int64, error)
}

View file

@ -5,13 +5,16 @@ import (
"errors"
"regexp"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
)
const SubredditCheckInterval = 2 * time.Minute
type Subreddit struct {
ID int64
LastCheckedAt float64
NextCheckAt time.Time
// Reddit information
SubredditID string

View file

@ -3,13 +3,16 @@ package domain
import (
"context"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
)
const UserRefreshInterval = 2 * time.Minute
type User struct {
ID int64
LastCheckedAt float64
NextCheckAt time.Time
// Reddit information
UserID string

View file

@ -2,6 +2,7 @@ package domain
import (
"context"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
)
@ -29,8 +30,8 @@ func (wt WatcherType) String() string {
type Watcher struct {
ID int64
CreatedAt float64
LastNotifiedAt float64
CreatedAt time.Time
LastNotifiedAt time.Time
Label string
DeviceID int64

View file

@ -3,6 +3,7 @@ package reddit
import (
"fmt"
"strings"
"time"
"github.com/valyala/fastjson"
)
@ -31,6 +32,7 @@ func NewError(val *fastjson.Value, status int) *Error {
type RefreshTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
Expiry time.Duration `json:"expires_in"`
}
func NewRefreshTokenResponse(val *fastjson.Value) interface{} {
@ -38,6 +40,7 @@ func NewRefreshTokenResponse(val *fastjson.Value) interface{} {
rtr.AccessToken = string(val.GetStringBytes("access_token"))
rtr.RefreshToken = string(val.GetStringBytes("refresh_token"))
rtr.Expiry = time.Duration(val.GetInt("expires_in")) * time.Second
return rtr
}
@ -67,7 +70,7 @@ type Thing struct {
Author string `json:"author"`
Subject string `json:"subject"`
Body string `json:"body"`
CreatedAt float64 `json:"created_utc"`
CreatedAt time.Time `json:"created_utc"`
Context string `json:"context"`
ParentID string `json:"parent_id"`
LinkTitle string `json:"link_title"`
@ -96,13 +99,14 @@ func NewThing(val *fastjson.Value) *Thing {
t.Kind = string(val.GetStringBytes("kind"))
data := val.Get("data")
unix := int64(data.GetFloat64("created_utc"))
t.ID = string(data.GetStringBytes("id"))
t.Type = string(data.GetStringBytes("type"))
t.Author = string(data.GetStringBytes("author"))
t.Subject = string(data.GetStringBytes("subject"))
t.Body = string(data.GetStringBytes("body"))
t.CreatedAt = data.GetFloat64("created_utc")
t.CreatedAt = time.Unix(unix, 0).UTC()
t.Context = string(data.GetStringBytes("context"))
t.ParentID = string(data.GetStringBytes("parent_id"))
t.LinkTitle = string(data.GetStringBytes("link_title"))

View file

@ -3,6 +3,7 @@ package reddit
import (
"io/ioutil"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/valyala/fastjson"
@ -40,6 +41,7 @@ func TestRefreshTokenResponseParsing(t *testing.T) {
assert.Equal(t, "***REMOVED***", rtr.AccessToken)
assert.Equal(t, "***REMOVED***", rtr.RefreshToken)
assert.Equal(t, 1*time.Hour, rtr.Expiry)
}
func TestListingResponseParsing(t *testing.T) {
@ -60,13 +62,14 @@ func TestListingResponseParsing(t *testing.T) {
assert.Equal(t, "", l.Before)
thing := l.Children[0]
created := time.Time(time.Date(2021, time.July, 14, 17, 56, 35, 0, time.UTC))
assert.Equal(t, "t4", thing.Kind)
assert.Equal(t, "138z6ke", thing.ID)
assert.Equal(t, "unknown", thing.Type)
assert.Equal(t, "iamthatis", thing.Author)
assert.Equal(t, "how goes it", thing.Subject)
assert.Equal(t, "how are you today", thing.Body)
assert.Equal(t, 1626285395.0, thing.CreatedAt)
assert.Equal(t, created, thing.CreatedAt)
assert.Equal(t, "hugocat", thing.Destination)
assert.Equal(t, "t4_138z6ke", thing.FullName())

View file

@ -3,6 +3,7 @@ package repository
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v4/pgxpool"
@ -33,10 +34,11 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg
&acc.AccountID,
&acc.AccessToken,
&acc.RefreshToken,
&acc.ExpiresAt,
&acc.TokenExpiresAt,
&acc.LastMessageID,
&acc.LastCheckedAt,
&acc.LastUnstuckAt,
&acc.NextNotificationCheckAt,
&acc.NextStuckNotificationCheckAt,
&acc.CheckCount,
); err != nil {
return nil, err
}
@ -47,7 +49,9 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg
func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (domain.Account, error) {
query := `
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at
SELECT id, username, reddit_account_id, access_token, refresh_token, token_expires_at,
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
check_count
FROM accounts
WHERE id = $1`
@ -64,9 +68,11 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma
func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) {
query := `
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at
SELECT id, username, reddit_account_id, access_token, refresh_token, token_expires_at,
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
check_count
FROM accounts
WHERE account_id = $1`
WHERE reddit_account_id = $1`
accs, err := p.fetch(ctx, query, id)
if err != nil {
@ -81,12 +87,13 @@ func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string
}
func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error {
query := `
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at)
VALUES ($1, $2, $3, $4, $5, '', 0, 0)
INSERT INTO accounts (username, reddit_account_id, access_token, refresh_token, token_expires_at,
last_message_id, next_notification_check_at, next_stuck_notification_check_at)
VALUES ($1, $2, $3, $4, $5, '', NOW(), NOW())
ON CONFLICT(username) DO
UPDATE SET access_token = $3,
refresh_token = $4,
expires_at = $5
token_expires_at = $5
RETURNING id`
return p.pool.QueryRow(
@ -96,14 +103,15 @@ func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *dom
acc.AccountID,
acc.AccessToken,
acc.RefreshToken,
acc.ExpiresAt,
acc.TokenExpiresAt,
).Scan(&acc.ID)
}
func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error {
query := `
INSERT INTO accounts
(username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at)
(username, reddit_account_id, access_token, refresh_token, token_expires_at,
last_message_id, next_notification_check_at, next_stuck_notification_check_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id`
@ -114,10 +122,10 @@ func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Acco
acc.AccountID,
acc.AccessToken,
acc.RefreshToken,
acc.ExpiresAt,
acc.TokenExpiresAt,
acc.LastMessageID,
acc.LastCheckedAt,
acc.LastUnstuckAt,
acc.NextNotificationCheckAt,
acc.NextStuckNotificationCheckAt,
).Scan(&acc.ID)
}
@ -125,13 +133,14 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
query := `
UPDATE accounts
SET username = $2,
account_id = $3,
reddit_account_id = $3,
access_token = $4,
refresh_token = $5,
expires_at = $6,
token_expires_at = $6,
last_message_id = $7,
last_checked_at = $8,
last_unstuck_at = $9
next_notification_check_at = $8,
next_stuck_notification_check_at = $9,
check_count = $10
WHERE id = $1`
res, err := p.pool.Exec(
@ -142,10 +151,11 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
acc.AccountID,
acc.AccessToken,
acc.RefreshToken,
acc.ExpiresAt,
acc.TokenExpiresAt,
acc.LastMessageID,
acc.LastCheckedAt,
acc.LastUnstuckAt,
acc.NextNotificationCheckAt,
acc.NextStuckNotificationCheckAt,
acc.CheckCount,
)
if res.RowsAffected() != 1 {
@ -186,7 +196,9 @@ func (p *postgresAccountRepository) Disassociate(ctx context.Context, acc *domai
func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token string) ([]domain.Account, error) {
query := `
SELECT accounts.id, username, accounts.account_id, access_token, refresh_token, accounts.expires_at, last_message_id, last_checked_at, last_unstuck_at
SELECT accounts.id, username, accounts.reddit_account_id, access_token, refresh_token, token_expires_at,
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
check_count
FROM accounts
INNER JOIN devices_accounts ON accounts.id = devices_accounts.account_id
INNER JOIN devices ON devices.id = devices_accounts.device_id
@ -195,12 +207,12 @@ func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token st
return p.fetch(ctx, query, token)
}
func (p *postgresAccountRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
func (p *postgresAccountRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) {
query := `
DELETE FROM accounts
WHERE expires_at < $1`
WHERE token_expires_at < $1`
res, err := p.pool.Exec(ctx, query, before)
res, err := p.pool.Exec(ctx, query, expiry)
return res.RowsAffected(), err
}

View file

@ -3,6 +3,7 @@ package repository
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v4/pgxpool"
@ -31,8 +32,8 @@ func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args
&dev.ID,
&dev.APNSToken,
&dev.Sandbox,
&dev.ActiveUntil,
&dev.GracePeriodUntil,
&dev.ExpiresAt,
&dev.GracePeriodExpiresAt,
); err != nil {
return nil, err
}
@ -43,7 +44,7 @@ func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args
func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domain.Device, error) {
query := `
SELECT id, apns_token, sandbox, active_until, grace_period_until
SELECT id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices
WHERE id = $1`
@ -60,7 +61,7 @@ func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domai
func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token string) (domain.Device, error) {
query := `
SELECT id, apns_token, sandbox, active_until, grace_period_until
SELECT id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices
WHERE apns_token = $1`
@ -77,7 +78,7 @@ func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token str
func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := `
SELECT devices.id, apns_token, sandbox, active_until, grace_period_until
SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1`
@ -87,7 +88,7 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := `
SELECT devices.id, apns_token, sandbox, active_until, grace_period_until
SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1 AND
@ -99,7 +100,7 @@ func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Con
func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := `
SELECT devices.id, apns_token, sandbox, active_until, grace_period_until
SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1 AND
@ -111,10 +112,10 @@ func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.C
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
query := `
INSERT INTO devices (apns_token, sandbox, active_until, grace_period_until)
INSERT INTO devices (apns_token, sandbox, expires_at, grace_period_expires_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT(apns_token) DO
UPDATE SET active_until = $3, grace_period_until = $4
UPDATE SET expires_at = $3, grace_period_expires_at = $4
RETURNING id`
return p.pool.QueryRow(
@ -122,8 +123,8 @@ func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *doma
query,
dev.APNSToken,
dev.Sandbox,
dev.ActiveUntil,
dev.GracePeriodUntil,
&dev.ExpiresAt,
&dev.GracePeriodExpiresAt,
).Scan(&dev.ID)
}
@ -134,7 +135,7 @@ func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Devic
query := `
INSERT INTO devices
(apns_token, sandbox, active_until, grace_period_until)
(apns_token, sandbox, expires_at, grace_period_expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id`
@ -143,8 +144,8 @@ func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Devic
query,
dev.APNSToken,
dev.Sandbox,
dev.ActiveUntil,
dev.GracePeriodUntil,
dev.ExpiresAt,
dev.GracePeriodExpiresAt,
).Scan(&dev.ID)
}
@ -155,10 +156,10 @@ func (p *postgresDeviceRepository) Update(ctx context.Context, dev *domain.Devic
query := `
UPDATE devices
SET active_until = $2, grace_period_until = $3
SET expires_at = $2, grace_period_expires_at = $3
WHERE id = $1`
res, err := p.pool.Exec(ctx, query, dev.ID, dev.ActiveUntil, dev.GracePeriodUntil)
res, err := p.pool.Exec(ctx, query, dev.ID, dev.ExpiresAt, dev.GracePeriodExpiresAt)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -209,10 +210,10 @@ func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domai
return inbox, watcher, global, nil
}
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
query := `DELETE FROM devices WHERE grace_period_until < $1`
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) {
query := `DELETE FROM devices WHERE grace_period_expires_at < $1`
res, err := p.pool.Exec(ctx, query, before)
res, err := p.pool.Exec(ctx, query, expiry)
return res.RowsAffected(), err
}

View file

@ -30,7 +30,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
&sr.ID,
&sr.SubredditID,
&sr.Name,
&sr.LastCheckedAt,
&sr.NextCheckAt,
); err != nil {
return nil, err
}
@ -41,7 +41,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) {
query := `
SELECT id, subreddit_id, name, last_checked_at
SELECT id, subreddit_id, name, next_check_at
FROM subreddits
WHERE id = $1`
@ -58,7 +58,7 @@ func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (do
func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) {
query := `
SELECT id, subreddit_id, name, last_checked_at
SELECT id, subreddit_id, name, next_check_at
FROM subreddits
WHERE name = $1`
@ -81,8 +81,8 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
}
query := `
INSERT INTO subreddits (subreddit_id, name)
VALUES ($1, $2)
INSERT INTO subreddits (subreddit_id, name, next_check_at)
VALUES ($1, $2, NOW())
ON CONFLICT(subreddit_id) DO NOTHING
RETURNING id`

View file

@ -32,7 +32,7 @@ func (p *postgresUserRepository) fetch(ctx context.Context, query string, args .
&u.ID,
&u.UserID,
&u.Name,
&u.LastCheckedAt,
&u.NextCheckAt,
); err != nil {
return nil, err
}
@ -43,7 +43,7 @@ func (p *postgresUserRepository) fetch(ctx context.Context, query string, args .
func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.User, error) {
query := `
SELECT id, user_id, name, last_checked_at
SELECT id, user_id, name, next_check_at
FROM users
WHERE id = $1`
@ -60,7 +60,7 @@ func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.
func (p *postgresUserRepository) GetByName(ctx context.Context, name string) (domain.User, error) {
query := `
SELECT id, user_id, name, last_checked_at
SELECT id, user_id, name, next_check_at
FROM users
WHERE name = $1`
@ -83,10 +83,9 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U
}
query := `
INSERT INTO users (user_id, name)
VALUES ($1, $2)
ON CONFLICT(user_id) DO
UPDATE SET last_checked_at = $3
INSERT INTO users (user_id, name, next_check_at)
VALUES ($1, $2, NOW())
ON CONFLICT(user_id) DO NOTHING
RETURNING id`
return p.pool.QueryRow(
@ -94,7 +93,7 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U
query,
u.UserID,
u.NormalizedName(),
u.LastCheckedAt,
u.NextCheckAt,
).Scan(&u.ID)
}

View file

@ -93,7 +93,7 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.account_id,
accounts.reddit_account_id,
accounts.access_token,
accounts.refresh_token,
COALESCE(subreddits.name, '') AS subreddit_label,
@ -138,7 +138,7 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.account_id,
accounts.reddit_account_id,
accounts.access_token,
accounts.refresh_token,
COALESCE(subreddits.name, '') AS subreddit_label,
@ -191,7 +191,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
devices.apns_token,
devices.sandbox,
accounts.id,
accounts.account_id,
accounts.reddit_account_id,
accounts.access_token,
accounts.refresh_token,
COALESCE(subreddits.name, '') AS subreddit_label,
@ -203,7 +203,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id
WHERE
devices.apns_token = $1 AND
accounts.account_id = $2`
accounts.reddit_account_id = $2`
return p.fetch(ctx, query, apns, rid)
}

View file

@ -166,7 +166,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }()
now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000
now := time.Now()
account, err := nc.accountRepo.GetByID(ctx, id)
if err != nil {
@ -177,20 +177,22 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return
}
previousLastCheckedAt := account.LastCheckedAt
newAccount := (previousLastCheckedAt == 0)
account.LastCheckedAt = now
newAccount := account.CheckCount == 0
previousNextCheck := account.NextNotificationCheckAt
account.CheckCount++
account.NextNotificationCheckAt = time.Now().Add(domain.NotificationCheckInterval)
if err = nc.accountRepo.Update(ctx, &account); err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
"err": err,
}).Error("failed to update last_checked_at for account")
}).Error("failed to update next_notification_check_at for account")
return
}
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
if account.ExpiresAt < int64(now) {
if account.TokenExpiresAt.Before(now) {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
}).Debug("refreshing reddit token")
@ -219,7 +221,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
// Update account
account.AccessToken = tokens.AccessToken
account.RefreshToken = tokens.RefreshToken
account.ExpiresAt = int64(now + 3540)
account.TokenExpiresAt = now.Add(tokens.Expiry)
// Refresh client
rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
@ -236,8 +238,8 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
// Only update delay on accounts we can actually check, otherwise it skews
// the numbers too much.
if !newAccount {
latency := now - previousLastCheckedAt - float64(backoff)
_ = nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate)
latency := now.Sub(previousNextCheck) - backoff*time.Second
_ = nc.statsd.Histogram("apollo.queue.delay", float64(latency.Milliseconds()), []string{}, rate)
}
nc.logger.WithFields(logrus.Fields{

View file

@ -180,7 +180,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
return
}
threshold := float64(time.Now().AddDate(0, 0, -1).UTC().Unix())
threshold := time.Now().Add(-24 * time.Hour)
posts := []*reddit.Thing{}
before := ""
finished := false
@ -237,7 +237,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
}
for _, post := range sps.Children {
if post.CreatedAt < threshold {
if post.CreatedAt.Before(threshold) {
finished = true
break
}
@ -287,7 +287,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
}).Debug("loaded hot posts")
for _, post := range sps.Children {
if post.CreatedAt < threshold {
if post.CreatedAt.Before(threshold) {
break
}
if _, ok := seenPosts[post.ID]; !ok {
@ -313,7 +313,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
for _, watcher := range watchers {
// Make sure we only alert on posts created after the search
if watcher.CreatedAt > post.CreatedAt {
if watcher.CreatedAt.After(post.CreatedAt) {
continue
}

View file

@ -235,14 +235,14 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
}).Debug("loaded hot posts")
// Trending only counts for posts less than 2 days old
threshold := float64(time.Now().AddDate(0, 0, -2).UTC().Unix())
threshold := time.Now().Add(-24 * time.Hour * 2)
for _, post := range hps.Children {
if post.Score < medianScore {
continue
}
if post.CreatedAt < threshold {
if post.CreatedAt.Before(threshold) {
break
}
@ -251,7 +251,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
notification.Payload = payloadFromTrendingPost(post)
for _, watcher := range watchers {
if watcher.CreatedAt > post.CreatedAt {
if watcher.CreatedAt.After(post.CreatedAt) {
continue
}

View file

@ -234,11 +234,11 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
for _, watcher := range watchers {
// Make sure we only alert on activities created after the search
if watcher.CreatedAt > post.CreatedAt {
if watcher.CreatedAt.After(post.CreatedAt) {
continue
}
if watcher.LastNotifiedAt > post.CreatedAt {
if watcher.LastNotifiedAt.After(post.CreatedAt) {
continue
}