mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-13 07:27:43 +00:00
Merge pull request #66 from christianselig/chore/schema-changes
Changes to schema
This commit is contained in:
commit
a94aa11845
22 changed files with 302 additions and 201 deletions
64
docs/schema.sql
Normal file
64
docs/schema.sql
Normal file
|
@ -0,0 +1,64 @@
|
|||
CREATE TABLE accounts (
|
||||
id SERIAL PRIMARY KEY,
|
||||
reddit_account_id character varying(32) DEFAULT ''::character varying,
|
||||
username character varying(20) DEFAULT ''::character varying UNIQUE,
|
||||
access_token character varying(64) DEFAULT ''::character varying,
|
||||
refresh_token character varying(64) DEFAULT ''::character varying,
|
||||
token_expires_at timestamp without time zone,
|
||||
last_message_id character varying(32) DEFAULT ''::character varying,
|
||||
next_notification_check_at timestamp without time zone,
|
||||
next_stuck_notification_check_at timestamp without time zone,
|
||||
check_count integer DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE devices (
|
||||
id SERIAL PRIMARY KEY,
|
||||
apns_token character varying(100) UNIQUE,
|
||||
sandbox boolean,
|
||||
expires_at timestamp without time zone,
|
||||
grace_period_expires_at timestamp without time zone
|
||||
);
|
||||
|
||||
CREATE TABLE devices_accounts (
|
||||
id SERIAL PRIMARY KEY,
|
||||
account_id integer REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
device_id integer REFERENCES devices(id) ON DELETE CASCADE,
|
||||
watcher_notifiable boolean DEFAULT true,
|
||||
inbox_notifiable boolean DEFAULT true,
|
||||
global_mute boolean DEFAULT false
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX devices_accounts_account_id_device_id_idx ON devices_accounts(account_id int4_ops,device_id int4_ops);
|
||||
|
||||
CREATE TABLE subreddits (
|
||||
id SERIAL PRIMARY KEY,
|
||||
subreddit_id character varying(32) DEFAULT ''::character varying UNIQUE,
|
||||
name character varying(32) DEFAULT ''::character varying,
|
||||
next_check_at timestamp without time zone
|
||||
);
|
||||
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id character varying(32) DEFAULT ''::character varying UNIQUE,
|
||||
name character varying(32) DEFAULT ''::character varying,
|
||||
next_check_at timestamp without time zone
|
||||
);
|
||||
|
||||
CREATE TABLE watchers (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at timestamp without time zone,
|
||||
last_notified_at timestamp without time zone,
|
||||
device_id integer REFERENCES devices(id) ON DELETE CASCADE,
|
||||
account_id integer REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
watchee_id integer,
|
||||
upvotes integer DEFAULT 0,
|
||||
keyword character varying(32) DEFAULT ''::character varying,
|
||||
flair character varying(32) DEFAULT ''::character varying,
|
||||
domain character varying(32) DEFAULT ''::character varying,
|
||||
hits integer DEFAULT 0,
|
||||
type integer DEFAULT 0,
|
||||
label character varying(64) DEFAULT ''::character varying,
|
||||
author character varying(32) DEFAULT ''::character varying,
|
||||
subreddit character varying(32) DEFAULT ''::character varying
|
||||
);
|
||||
|
|
@ -150,7 +150,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Reset expiration timer
|
||||
acc.ExpiresAt = time.Now().Unix() + 3540
|
||||
acc.TokenExpiresAt = time.Now().Add(tokens.Expiry)
|
||||
acc.RefreshToken = tokens.RefreshToken
|
||||
acc.AccessToken = tokens.AccessToken
|
||||
|
||||
|
@ -175,7 +175,10 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
_ = a.accountRepo.Associate(ctx, &acc, &dev)
|
||||
if err := a.accountRepo.Associate(ctx, &acc, &dev); err != nil {
|
||||
a.errorResponse(w, r, 422, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, acc := range accsMap {
|
||||
|
@ -228,7 +231,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Reset expiration timer
|
||||
acct.ExpiresAt = time.Now().Unix() + 3540
|
||||
acct.TokenExpiresAt = time.Now().Add(tokens.Expiry)
|
||||
acct.RefreshToken = tokens.RefreshToken
|
||||
acct.AccessToken = tokens.AccessToken
|
||||
|
||||
|
|
|
@ -28,8 +28,8 @@ func (a *api) upsertDeviceHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
d.ActiveUntil = time.Now().Unix() + domain.DeviceGracePeriodDuration
|
||||
d.GracePeriodUntil = d.ActiveUntil + domain.DeviceGracePeriodAfterReceiptExpiry
|
||||
d.ExpiresAt = time.Now().Add(domain.DeviceReceiptCheckPeriodDuration)
|
||||
d.GracePeriodExpiresAt = d.ExpiresAt.Add(domain.DeviceGracePeriodAfterReceiptExpiry)
|
||||
|
||||
if err := a.deviceRepo.CreateOrUpdate(ctx, d); err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
|
|
|
@ -39,6 +39,11 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if iapr.DeleteDevice {
|
||||
if dev.GracePeriodExpiresAt.After(time.Now()) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
accs, err := a.accountRepo.GetByAPNSToken(ctx, apns)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
|
@ -51,8 +56,8 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
_ = a.deviceRepo.Delete(ctx, apns)
|
||||
} else {
|
||||
dev.ActiveUntil = time.Now().Unix() + domain.DeviceActiveAfterReceitCheckDuration
|
||||
dev.GracePeriodUntil = dev.ActiveUntil + domain.DeviceGracePeriodAfterReceiptExpiry
|
||||
dev.ExpiresAt = time.Now().Add(domain.DeviceActiveAfterReceitCheckDuration)
|
||||
dev.GracePeriodExpiresAt = dev.ExpiresAt.Add(domain.DeviceGracePeriodAfterReceiptExpiry)
|
||||
_ = a.deviceRepo.Update(ctx, &dev)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/gorilla/mux"
|
||||
|
@ -230,17 +231,17 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
type watcherItem struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt float64 `json:"created_at"`
|
||||
Type string `json:"type"`
|
||||
Label string `json:"label"`
|
||||
SourceLabel string `json:"source_label"`
|
||||
Upvotes *int64 `json:"upvotes,omitempty"`
|
||||
Keyword string `json:"keyword,omitempty"`
|
||||
Flair string `json:"flair,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Hits int64 `json:"hits"`
|
||||
Author string `json:"author,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Type string `json:"type"`
|
||||
Label string `json:"label"`
|
||||
SourceLabel string `json:"source_label"`
|
||||
Upvotes *int64 `json:"upvotes,omitempty"`
|
||||
Keyword string `json:"keyword,omitempty"`
|
||||
Flair string `json:"flair,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Hits int64 `json:"hits"`
|
||||
Author string `json:"author,omitempty"`
|
||||
}
|
||||
|
||||
func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -9,27 +9,19 @@ import (
|
|||
|
||||
"github.com/DataDog/datadog-go/statsd"
|
||||
"github.com/adjust/rmq/v4"
|
||||
"github.com/christianselig/apollo-backend/internal/cmdutil"
|
||||
"github.com/christianselig/apollo-backend/internal/repository"
|
||||
"github.com/go-co-op/gocron"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/christianselig/apollo-backend/internal/cmdutil"
|
||||
"github.com/christianselig/apollo-backend/internal/domain"
|
||||
"github.com/christianselig/apollo-backend/internal/repository"
|
||||
)
|
||||
|
||||
const (
|
||||
batchSize = 250
|
||||
checkTimeout = 60 // how long until we force a check
|
||||
|
||||
accountEnqueueInterval = 5 // how frequently we want to check (seconds)
|
||||
subredditEnqueueInterval = 2 * 60 // how frequently we want to check (seconds)
|
||||
userEnqueueInterval = 2 * 60 // how frequently we want to check (seconds)
|
||||
stuckAccountEnqueueInterval = 2 * 60 // how frequently we want to check (seconds)
|
||||
|
||||
staleAccountThreshold = 7200 // 2 hours
|
||||
)
|
||||
const batchSize = 250
|
||||
|
||||
func SchedulerCmd(ctx context.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
|
@ -129,16 +121,16 @@ func evalScript(ctx context.Context, redis *redis.Client) (string, error) {
|
|||
end
|
||||
|
||||
return retv
|
||||
`, checkTimeout)
|
||||
`, int64(domain.NotificationCheckTimeout.Seconds()))
|
||||
|
||||
return redis.ScriptLoad(ctx, lua).Result()
|
||||
}
|
||||
|
||||
func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) {
|
||||
before := time.Now().Unix() - staleAccountThreshold
|
||||
expiry := time.Now().Add(-domain.StaleTokenThreshold)
|
||||
ar := repository.NewPostgresAccount(pool)
|
||||
|
||||
stale, err := ar.PruneStale(ctx, before)
|
||||
stale, err := ar.PruneStale(ctx, expiry)
|
||||
if err != nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"err": err,
|
||||
|
@ -158,16 +150,17 @@ func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Poo
|
|||
|
||||
if count > 0 {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"count": count,
|
||||
"stale": stale,
|
||||
"orphaned": orphaned,
|
||||
}).Info("pruned accounts")
|
||||
}
|
||||
}
|
||||
|
||||
func pruneDevices(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool) {
|
||||
threshold := time.Now().Unix()
|
||||
now := time.Now()
|
||||
dr := repository.NewPostgresDevice(pool)
|
||||
|
||||
count, err := dr.PruneStale(ctx, threshold)
|
||||
count, err := dr.PruneStale(ctx, now)
|
||||
if err != nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"err": err,
|
||||
|
@ -227,6 +220,8 @@ func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Clie
|
|||
|
||||
func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
|
||||
now := time.Now()
|
||||
next := now.Add(domain.NotificationCheckInterval)
|
||||
|
||||
ids := []int64{}
|
||||
|
||||
defer func() {
|
||||
|
@ -235,21 +230,20 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
|
|||
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
|
||||
}()
|
||||
|
||||
ready := now.Unix() - userEnqueueInterval
|
||||
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
||||
stmt := `
|
||||
WITH userb AS (
|
||||
WITH batch AS (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE last_checked_at < $1
|
||||
ORDER BY last_checked_at
|
||||
WHERE next_check_at < $1
|
||||
ORDER BY next_check_at
|
||||
LIMIT 100
|
||||
)
|
||||
UPDATE users
|
||||
SET last_checked_at = $2
|
||||
WHERE users.id IN(SELECT id FROM userb)
|
||||
SET next_check_at = $2
|
||||
WHERE users.id IN(SELECT id FROM batch)
|
||||
RETURNING users.id`
|
||||
rows, err := tx.Query(ctx, stmt, ready, now.Unix())
|
||||
rows, err := tx.Query(ctx, stmt, now, next)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -275,7 +269,7 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
|
|||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"count": len(ids),
|
||||
"start": ready,
|
||||
"start": now,
|
||||
}).Debug("enqueueing user batch")
|
||||
|
||||
batchIds := make([]string, len(ids))
|
||||
|
@ -292,6 +286,8 @@ func enqueueUsers(ctx context.Context, logger *logrus.Logger, statsd *statsd.Cli
|
|||
|
||||
func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queues []rmq.Queue) {
|
||||
now := time.Now()
|
||||
next := now.Add(domain.SubredditCheckInterval)
|
||||
|
||||
ids := []int64{}
|
||||
|
||||
defer func() {
|
||||
|
@ -300,21 +296,20 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
|
|||
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
|
||||
}()
|
||||
|
||||
ready := now.Unix() - subredditEnqueueInterval
|
||||
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
||||
stmt := `
|
||||
WITH subreddit AS (
|
||||
WITH batch AS (
|
||||
SELECT id
|
||||
FROM subreddits
|
||||
WHERE last_checked_at < $1
|
||||
ORDER BY last_checked_at
|
||||
WHERE next_check_at < $1
|
||||
ORDER BY next_check_at
|
||||
LIMIT 100
|
||||
)
|
||||
UPDATE subreddits
|
||||
SET last_checked_at = $2
|
||||
WHERE subreddits.id IN(SELECT id FROM subreddit)
|
||||
SET next_check_at = $2
|
||||
WHERE subreddits.id IN(SELECT id FROM batch)
|
||||
RETURNING subreddits.id`
|
||||
rows, err := tx.Query(ctx, stmt, ready, now.Unix())
|
||||
rows, err := tx.Query(ctx, stmt, now, next)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -340,7 +335,7 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
|
|||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"count": len(ids),
|
||||
"start": ready,
|
||||
"start": now,
|
||||
}).Debug("enqueueing subreddit batch")
|
||||
|
||||
batchIds := make([]string, len(ids))
|
||||
|
@ -361,6 +356,8 @@ func enqueueSubreddits(ctx context.Context, logger *logrus.Logger, statsd *stats
|
|||
|
||||
func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, queue rmq.Queue) {
|
||||
now := time.Now()
|
||||
next := now.Add(domain.StuckNotificationCheckInterval)
|
||||
|
||||
ids := []int64{}
|
||||
|
||||
defer func() {
|
||||
|
@ -369,22 +366,21 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st
|
|||
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
|
||||
}()
|
||||
|
||||
ready := now.Unix() - stuckAccountEnqueueInterval
|
||||
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
||||
stmt := `
|
||||
WITH account AS (
|
||||
WITH batch AS (
|
||||
SELECT id
|
||||
FROM accounts
|
||||
WHERE
|
||||
last_unstuck_at < $1
|
||||
ORDER BY last_unstuck_at
|
||||
next_stuck_notification_check_at < $1
|
||||
ORDER BY next_stuck_notification_check_at
|
||||
LIMIT 500
|
||||
)
|
||||
UPDATE accounts
|
||||
SET last_unstuck_at = $2
|
||||
WHERE accounts.id IN(SELECT id FROM account)
|
||||
SET next_stuck_notification_check_at = $2
|
||||
WHERE accounts.id IN(SELECT id FROM batch)
|
||||
RETURNING accounts.id`
|
||||
rows, err := tx.Query(ctx, stmt, ready, now.Unix())
|
||||
rows, err := tx.Query(ctx, stmt, now, next)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -410,7 +406,7 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st
|
|||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"count": len(ids),
|
||||
"start": ready,
|
||||
"start": now,
|
||||
}).Debug("enqueueing stuck account batch")
|
||||
|
||||
batchIds := make([]string, len(ids))
|
||||
|
@ -428,6 +424,8 @@ func enqueueStuckAccounts(ctx context.Context, logger *logrus.Logger, statsd *st
|
|||
|
||||
func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client, luaSha string, queue rmq.Queue) {
|
||||
now := time.Now()
|
||||
next := now.Add(domain.NotificationCheckInterval)
|
||||
|
||||
ids := []int64{}
|
||||
enqueued := 0
|
||||
skipped := 0
|
||||
|
@ -439,29 +437,21 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
|
|||
_ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1)
|
||||
}()
|
||||
|
||||
// Start looking for accounts that were last checked at least 5 seconds ago
|
||||
// and at most 6 seconds ago. Also look for accounts that haven't been checked
|
||||
// in over a minute.
|
||||
ts := now.Unix()
|
||||
ready := ts - accountEnqueueInterval
|
||||
expired := ts - checkTimeout
|
||||
|
||||
err := pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
||||
stmt := `
|
||||
WITH account AS (
|
||||
SELECT id
|
||||
FROM accounts
|
||||
WHERE
|
||||
last_enqueued_at < $1
|
||||
OR last_checked_at < $2
|
||||
ORDER BY last_checked_at
|
||||
next_notification_check_at < $1
|
||||
ORDER BY next_notification_check_at
|
||||
LIMIT 2500
|
||||
)
|
||||
UPDATE accounts
|
||||
SET last_enqueued_at = $3
|
||||
SET next_notification_check_at = $2
|
||||
WHERE accounts.id IN(SELECT id FROM account)
|
||||
RETURNING accounts.id`
|
||||
rows, err := tx.Query(ctx, stmt, ready, expired, ts)
|
||||
rows, err := tx.Query(ctx, stmt, now, next)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -487,7 +477,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
|
|||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"count": len(ids),
|
||||
"start": ready,
|
||||
"start": now,
|
||||
}).Debug("enqueueing account batch")
|
||||
// Split ids in batches
|
||||
for i := 0; i < len(ids); i += batchSize {
|
||||
|
@ -532,7 +522,7 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
|
|||
logger.WithFields(logrus.Fields{
|
||||
"count": enqueued,
|
||||
"skipped": skipped,
|
||||
"start": ready,
|
||||
"start": now,
|
||||
}).Debug("done enqueueing account batch")
|
||||
}
|
||||
|
||||
|
|
|
@ -3,25 +3,34 @@ package domain
|
|||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
NotificationCheckInterval = 5 * time.Second // time between notification checks
|
||||
NotificationCheckTimeout = 60 * time.Second // time before we give up an account check lock
|
||||
StuckNotificationCheckInterval = 2 * time.Minute // time between stuck notification checks
|
||||
StaleTokenThreshold = 2 * time.Hour // time an oauth token has to be expired for to be stale
|
||||
)
|
||||
|
||||
// Account represents an account we need to periodically check in the notifications worker.
|
||||
type Account struct {
|
||||
ID int64
|
||||
|
||||
// Reddit information
|
||||
Username string
|
||||
AccountID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt int64
|
||||
Username string
|
||||
AccountID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
TokenExpiresAt time.Time
|
||||
|
||||
// Tracking how far behind we are
|
||||
LastMessageID string
|
||||
LastCheckedAt float64
|
||||
LastUnstuckAt float64
|
||||
LastMessageID string
|
||||
NextNotificationCheckAt time.Time
|
||||
NextStuckNotificationCheckAt time.Time
|
||||
CheckCount int64
|
||||
}
|
||||
|
||||
func (acct *Account) NormalizedUsername() string {
|
||||
|
@ -49,5 +58,5 @@ type AccountRepository interface {
|
|||
Disassociate(ctx context.Context, acc *Account, dev *Device) error
|
||||
|
||||
PruneOrphaned(ctx context.Context) (int64, error)
|
||||
PruneStale(ctx context.Context, before int64) (int64, error)
|
||||
PruneStale(ctx context.Context, expiry time.Time) (int64, error)
|
||||
}
|
||||
|
|
|
@ -2,22 +2,23 @@ package domain
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
DeviceGracePeriodDuration = 3600 // 1 hour
|
||||
DeviceActiveAfterReceitCheckDuration = 3600 * 24 * 30 // ~1 month
|
||||
DeviceGracePeriodAfterReceiptExpiry = 3600 * 24 * 30 // ~1 month
|
||||
DeviceReceiptCheckPeriodDuration = 1 * time.Hour
|
||||
DeviceActiveAfterReceitCheckDuration = 30 * 24 * time.Hour // ~1 month
|
||||
DeviceGracePeriodAfterReceiptExpiry = 30 * 24 * time.Hour // ~1 month
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
ID int64
|
||||
APNSToken string
|
||||
Sandbox bool
|
||||
ActiveUntil int64
|
||||
GracePeriodUntil int64
|
||||
ID int64
|
||||
APNSToken string
|
||||
Sandbox bool
|
||||
ExpiresAt time.Time
|
||||
GracePeriodExpiresAt time.Time
|
||||
}
|
||||
|
||||
func (dev *Device) Validate() error {
|
||||
|
@ -40,5 +41,5 @@ type DeviceRepository interface {
|
|||
SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher, global bool) error
|
||||
GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, bool, error)
|
||||
|
||||
PruneStale(ctx context.Context, before int64) (int64, error)
|
||||
PruneStale(ctx context.Context, expiry time.Time) (int64, error)
|
||||
}
|
||||
|
|
|
@ -5,13 +5,16 @@ import (
|
|||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
)
|
||||
|
||||
const SubredditCheckInterval = 2 * time.Minute
|
||||
|
||||
type Subreddit struct {
|
||||
ID int64
|
||||
LastCheckedAt float64
|
||||
ID int64
|
||||
NextCheckAt time.Time
|
||||
|
||||
// Reddit information
|
||||
SubredditID string
|
||||
|
|
|
@ -3,13 +3,16 @@ package domain
|
|||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
)
|
||||
|
||||
const UserRefreshInterval = 2 * time.Minute
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
LastCheckedAt float64
|
||||
ID int64
|
||||
NextCheckAt time.Time
|
||||
|
||||
// Reddit information
|
||||
UserID string
|
||||
|
|
|
@ -2,6 +2,7 @@ package domain
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
)
|
||||
|
@ -29,8 +30,8 @@ func (wt WatcherType) String() string {
|
|||
|
||||
type Watcher struct {
|
||||
ID int64
|
||||
CreatedAt float64
|
||||
LastNotifiedAt float64
|
||||
CreatedAt time.Time
|
||||
LastNotifiedAt time.Time
|
||||
Label string
|
||||
|
||||
DeviceID int64
|
||||
|
|
|
@ -3,6 +3,7 @@ package reddit
|
|||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fastjson"
|
||||
)
|
||||
|
@ -29,8 +30,9 @@ func NewError(val *fastjson.Value, status int) *Error {
|
|||
}
|
||||
|
||||
type RefreshTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Expiry time.Duration `json:"expires_in"`
|
||||
}
|
||||
|
||||
func NewRefreshTokenResponse(val *fastjson.Value) interface{} {
|
||||
|
@ -38,6 +40,7 @@ func NewRefreshTokenResponse(val *fastjson.Value) interface{} {
|
|||
|
||||
rtr.AccessToken = string(val.GetStringBytes("access_token"))
|
||||
rtr.RefreshToken = string(val.GetStringBytes("refresh_token"))
|
||||
rtr.Expiry = time.Duration(val.GetInt("expires_in")) * time.Second
|
||||
|
||||
return rtr
|
||||
}
|
||||
|
@ -61,25 +64,25 @@ func NewMeResponse(val *fastjson.Value) interface{} {
|
|||
}
|
||||
|
||||
type Thing struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Author string `json:"author"`
|
||||
Subject string `json:"subject"`
|
||||
Body string `json:"body"`
|
||||
CreatedAt float64 `json:"created_utc"`
|
||||
Context string `json:"context"`
|
||||
ParentID string `json:"parent_id"`
|
||||
LinkTitle string `json:"link_title"`
|
||||
Destination string `json:"dest"`
|
||||
Subreddit string `json:"subreddit"`
|
||||
SubredditType string `json:"subreddit_type"`
|
||||
Score int64 `json:"score"`
|
||||
SelfText string `json:"selftext"`
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Flair string `json:"flair"`
|
||||
Thumbnail string `json:"thumbnail"`
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Author string `json:"author"`
|
||||
Subject string `json:"subject"`
|
||||
Body string `json:"body"`
|
||||
CreatedAt time.Time `json:"created_utc"`
|
||||
Context string `json:"context"`
|
||||
ParentID string `json:"parent_id"`
|
||||
LinkTitle string `json:"link_title"`
|
||||
Destination string `json:"dest"`
|
||||
Subreddit string `json:"subreddit"`
|
||||
SubredditType string `json:"subreddit_type"`
|
||||
Score int64 `json:"score"`
|
||||
SelfText string `json:"selftext"`
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Flair string `json:"flair"`
|
||||
Thumbnail string `json:"thumbnail"`
|
||||
}
|
||||
|
||||
func (t *Thing) FullName() string {
|
||||
|
@ -96,13 +99,14 @@ func NewThing(val *fastjson.Value) *Thing {
|
|||
t.Kind = string(val.GetStringBytes("kind"))
|
||||
|
||||
data := val.Get("data")
|
||||
unix := int64(data.GetFloat64("created_utc"))
|
||||
|
||||
t.ID = string(data.GetStringBytes("id"))
|
||||
t.Type = string(data.GetStringBytes("type"))
|
||||
t.Author = string(data.GetStringBytes("author"))
|
||||
t.Subject = string(data.GetStringBytes("subject"))
|
||||
t.Body = string(data.GetStringBytes("body"))
|
||||
t.CreatedAt = data.GetFloat64("created_utc")
|
||||
t.CreatedAt = time.Unix(unix, 0).UTC()
|
||||
t.Context = string(data.GetStringBytes("context"))
|
||||
t.ParentID = string(data.GetStringBytes("parent_id"))
|
||||
t.LinkTitle = string(data.GetStringBytes("link_title"))
|
||||
|
|
|
@ -3,6 +3,7 @@ package reddit
|
|||
import (
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/valyala/fastjson"
|
||||
|
@ -40,6 +41,7 @@ func TestRefreshTokenResponseParsing(t *testing.T) {
|
|||
|
||||
assert.Equal(t, "***REMOVED***", rtr.AccessToken)
|
||||
assert.Equal(t, "***REMOVED***", rtr.RefreshToken)
|
||||
assert.Equal(t, 1*time.Hour, rtr.Expiry)
|
||||
}
|
||||
|
||||
func TestListingResponseParsing(t *testing.T) {
|
||||
|
@ -60,13 +62,14 @@ func TestListingResponseParsing(t *testing.T) {
|
|||
assert.Equal(t, "", l.Before)
|
||||
|
||||
thing := l.Children[0]
|
||||
created := time.Time(time.Date(2021, time.July, 14, 17, 56, 35, 0, time.UTC))
|
||||
assert.Equal(t, "t4", thing.Kind)
|
||||
assert.Equal(t, "138z6ke", thing.ID)
|
||||
assert.Equal(t, "unknown", thing.Type)
|
||||
assert.Equal(t, "iamthatis", thing.Author)
|
||||
assert.Equal(t, "how goes it", thing.Subject)
|
||||
assert.Equal(t, "how are you today", thing.Body)
|
||||
assert.Equal(t, 1626285395.0, thing.CreatedAt)
|
||||
assert.Equal(t, created, thing.CreatedAt)
|
||||
assert.Equal(t, "hugocat", thing.Destination)
|
||||
assert.Equal(t, "t4_138z6ke", thing.FullName())
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package repository
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
|
||||
|
@ -33,10 +34,11 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg
|
|||
&acc.AccountID,
|
||||
&acc.AccessToken,
|
||||
&acc.RefreshToken,
|
||||
&acc.ExpiresAt,
|
||||
&acc.TokenExpiresAt,
|
||||
&acc.LastMessageID,
|
||||
&acc.LastCheckedAt,
|
||||
&acc.LastUnstuckAt,
|
||||
&acc.NextNotificationCheckAt,
|
||||
&acc.NextStuckNotificationCheckAt,
|
||||
&acc.CheckCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -47,7 +49,9 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg
|
|||
|
||||
func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at
|
||||
SELECT id, username, reddit_account_id, access_token, refresh_token, token_expires_at,
|
||||
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
|
||||
check_count
|
||||
FROM accounts
|
||||
WHERE id = $1`
|
||||
|
||||
|
@ -64,9 +68,11 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma
|
|||
|
||||
func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at
|
||||
SELECT id, username, reddit_account_id, access_token, refresh_token, token_expires_at,
|
||||
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
|
||||
check_count
|
||||
FROM accounts
|
||||
WHERE account_id = $1`
|
||||
WHERE reddit_account_id = $1`
|
||||
|
||||
accs, err := p.fetch(ctx, query, id)
|
||||
if err != nil {
|
||||
|
@ -81,12 +87,13 @@ func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string
|
|||
}
|
||||
func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error {
|
||||
query := `
|
||||
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at)
|
||||
VALUES ($1, $2, $3, $4, $5, '', 0, 0)
|
||||
INSERT INTO accounts (username, reddit_account_id, access_token, refresh_token, token_expires_at,
|
||||
last_message_id, next_notification_check_at, next_stuck_notification_check_at)
|
||||
VALUES ($1, $2, $3, $4, $5, '', NOW(), NOW())
|
||||
ON CONFLICT(username) DO
|
||||
UPDATE SET access_token = $3,
|
||||
refresh_token = $4,
|
||||
expires_at = $5
|
||||
token_expires_at = $5
|
||||
RETURNING id`
|
||||
|
||||
return p.pool.QueryRow(
|
||||
|
@ -96,14 +103,15 @@ func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *dom
|
|||
acc.AccountID,
|
||||
acc.AccessToken,
|
||||
acc.RefreshToken,
|
||||
acc.ExpiresAt,
|
||||
acc.TokenExpiresAt,
|
||||
).Scan(&acc.ID)
|
||||
}
|
||||
|
||||
func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error {
|
||||
query := `
|
||||
INSERT INTO accounts
|
||||
(username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at)
|
||||
(username, reddit_account_id, access_token, refresh_token, token_expires_at,
|
||||
last_message_id, next_notification_check_at, next_stuck_notification_check_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id`
|
||||
|
||||
|
@ -114,10 +122,10 @@ func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Acco
|
|||
acc.AccountID,
|
||||
acc.AccessToken,
|
||||
acc.RefreshToken,
|
||||
acc.ExpiresAt,
|
||||
acc.TokenExpiresAt,
|
||||
acc.LastMessageID,
|
||||
acc.LastCheckedAt,
|
||||
acc.LastUnstuckAt,
|
||||
acc.NextNotificationCheckAt,
|
||||
acc.NextStuckNotificationCheckAt,
|
||||
).Scan(&acc.ID)
|
||||
}
|
||||
|
||||
|
@ -125,13 +133,14 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
|
|||
query := `
|
||||
UPDATE accounts
|
||||
SET username = $2,
|
||||
account_id = $3,
|
||||
reddit_account_id = $3,
|
||||
access_token = $4,
|
||||
refresh_token = $5,
|
||||
expires_at = $6,
|
||||
token_expires_at = $6,
|
||||
last_message_id = $7,
|
||||
last_checked_at = $8,
|
||||
last_unstuck_at = $9
|
||||
next_notification_check_at = $8,
|
||||
next_stuck_notification_check_at = $9,
|
||||
check_count = $10
|
||||
WHERE id = $1`
|
||||
|
||||
res, err := p.pool.Exec(
|
||||
|
@ -142,10 +151,11 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
|
|||
acc.AccountID,
|
||||
acc.AccessToken,
|
||||
acc.RefreshToken,
|
||||
acc.ExpiresAt,
|
||||
acc.TokenExpiresAt,
|
||||
acc.LastMessageID,
|
||||
acc.LastCheckedAt,
|
||||
acc.LastUnstuckAt,
|
||||
acc.NextNotificationCheckAt,
|
||||
acc.NextStuckNotificationCheckAt,
|
||||
acc.CheckCount,
|
||||
)
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
|
@ -186,7 +196,9 @@ func (p *postgresAccountRepository) Disassociate(ctx context.Context, acc *domai
|
|||
|
||||
func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token string) ([]domain.Account, error) {
|
||||
query := `
|
||||
SELECT accounts.id, username, accounts.account_id, access_token, refresh_token, accounts.expires_at, last_message_id, last_checked_at, last_unstuck_at
|
||||
SELECT accounts.id, username, accounts.reddit_account_id, access_token, refresh_token, token_expires_at,
|
||||
last_message_id, next_notification_check_at, next_stuck_notification_check_at,
|
||||
check_count
|
||||
FROM accounts
|
||||
INNER JOIN devices_accounts ON accounts.id = devices_accounts.account_id
|
||||
INNER JOIN devices ON devices.id = devices_accounts.device_id
|
||||
|
@ -195,12 +207,12 @@ func (p *postgresAccountRepository) GetByAPNSToken(ctx context.Context, token st
|
|||
return p.fetch(ctx, query, token)
|
||||
}
|
||||
|
||||
func (p *postgresAccountRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
|
||||
func (p *postgresAccountRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) {
|
||||
query := `
|
||||
DELETE FROM accounts
|
||||
WHERE expires_at < $1`
|
||||
WHERE token_expires_at < $1`
|
||||
|
||||
res, err := p.pool.Exec(ctx, query, before)
|
||||
res, err := p.pool.Exec(ctx, query, expiry)
|
||||
|
||||
return res.RowsAffected(), err
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package repository
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
|
||||
|
@ -31,8 +32,8 @@ func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args
|
|||
&dev.ID,
|
||||
&dev.APNSToken,
|
||||
&dev.Sandbox,
|
||||
&dev.ActiveUntil,
|
||||
&dev.GracePeriodUntil,
|
||||
&dev.ExpiresAt,
|
||||
&dev.GracePeriodExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -43,7 +44,7 @@ func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args
|
|||
|
||||
func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domain.Device, error) {
|
||||
query := `
|
||||
SELECT id, apns_token, sandbox, active_until, grace_period_until
|
||||
SELECT id, apns_token, sandbox, expires_at, grace_period_expires_at
|
||||
FROM devices
|
||||
WHERE id = $1`
|
||||
|
||||
|
@ -60,7 +61,7 @@ func (p *postgresDeviceRepository) GetByID(ctx context.Context, id int64) (domai
|
|||
|
||||
func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token string) (domain.Device, error) {
|
||||
query := `
|
||||
SELECT id, apns_token, sandbox, active_until, grace_period_until
|
||||
SELECT id, apns_token, sandbox, expires_at, grace_period_expires_at
|
||||
FROM devices
|
||||
WHERE apns_token = $1`
|
||||
|
||||
|
@ -77,7 +78,7 @@ func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token str
|
|||
|
||||
func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||
query := `
|
||||
SELECT devices.id, apns_token, sandbox, active_until, grace_period_until
|
||||
SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at
|
||||
FROM devices
|
||||
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||
WHERE devices_accounts.account_id = $1`
|
||||
|
@ -87,7 +88,7 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
|
|||
|
||||
func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||
query := `
|
||||
SELECT devices.id, apns_token, sandbox, active_until, grace_period_until
|
||||
SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at
|
||||
FROM devices
|
||||
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||
WHERE devices_accounts.account_id = $1 AND
|
||||
|
@ -99,7 +100,7 @@ func (p *postgresDeviceRepository) GetInboxNotifiableByAccountID(ctx context.Con
|
|||
|
||||
func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||
query := `
|
||||
SELECT devices.id, apns_token, sandbox, active_until, grace_period_until
|
||||
SELECT devices.id, apns_token, sandbox, expires_at, grace_period_expires_at
|
||||
FROM devices
|
||||
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||
WHERE devices_accounts.account_id = $1 AND
|
||||
|
@ -111,10 +112,10 @@ func (p *postgresDeviceRepository) GetWatcherNotifiableByAccountID(ctx context.C
|
|||
|
||||
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
|
||||
query := `
|
||||
INSERT INTO devices (apns_token, sandbox, active_until, grace_period_until)
|
||||
INSERT INTO devices (apns_token, sandbox, expires_at, grace_period_expires_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT(apns_token) DO
|
||||
UPDATE SET active_until = $3, grace_period_until = $4
|
||||
UPDATE SET expires_at = $3, grace_period_expires_at = $4
|
||||
RETURNING id`
|
||||
|
||||
return p.pool.QueryRow(
|
||||
|
@ -122,8 +123,8 @@ func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *doma
|
|||
query,
|
||||
dev.APNSToken,
|
||||
dev.Sandbox,
|
||||
dev.ActiveUntil,
|
||||
dev.GracePeriodUntil,
|
||||
&dev.ExpiresAt,
|
||||
&dev.GracePeriodExpiresAt,
|
||||
).Scan(&dev.ID)
|
||||
}
|
||||
|
||||
|
@ -134,7 +135,7 @@ func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Devic
|
|||
|
||||
query := `
|
||||
INSERT INTO devices
|
||||
(apns_token, sandbox, active_until, grace_period_until)
|
||||
(apns_token, sandbox, expires_at, grace_period_expires_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id`
|
||||
|
||||
|
@ -143,8 +144,8 @@ func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Devic
|
|||
query,
|
||||
dev.APNSToken,
|
||||
dev.Sandbox,
|
||||
dev.ActiveUntil,
|
||||
dev.GracePeriodUntil,
|
||||
dev.ExpiresAt,
|
||||
dev.GracePeriodExpiresAt,
|
||||
).Scan(&dev.ID)
|
||||
}
|
||||
|
||||
|
@ -155,10 +156,10 @@ func (p *postgresDeviceRepository) Update(ctx context.Context, dev *domain.Devic
|
|||
|
||||
query := `
|
||||
UPDATE devices
|
||||
SET active_until = $2, grace_period_until = $3
|
||||
SET expires_at = $2, grace_period_expires_at = $3
|
||||
WHERE id = $1`
|
||||
|
||||
res, err := p.pool.Exec(ctx, query, dev.ID, dev.ActiveUntil, dev.GracePeriodUntil)
|
||||
res, err := p.pool.Exec(ctx, query, dev.ID, dev.ExpiresAt, dev.GracePeriodExpiresAt)
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||
|
@ -209,10 +210,10 @@ func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domai
|
|||
return inbox, watcher, global, nil
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
|
||||
query := `DELETE FROM devices WHERE grace_period_until < $1`
|
||||
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) {
|
||||
query := `DELETE FROM devices WHERE grace_period_expires_at < $1`
|
||||
|
||||
res, err := p.pool.Exec(ctx, query, before)
|
||||
res, err := p.pool.Exec(ctx, query, expiry)
|
||||
|
||||
return res.RowsAffected(), err
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
|
|||
&sr.ID,
|
||||
&sr.SubredditID,
|
||||
&sr.Name,
|
||||
&sr.LastCheckedAt,
|
||||
&sr.NextCheckAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, a
|
|||
|
||||
func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (domain.Subreddit, error) {
|
||||
query := `
|
||||
SELECT id, subreddit_id, name, last_checked_at
|
||||
SELECT id, subreddit_id, name, next_check_at
|
||||
FROM subreddits
|
||||
WHERE id = $1`
|
||||
|
||||
|
@ -58,7 +58,7 @@ func (p *postgresSubredditRepository) GetByID(ctx context.Context, id int64) (do
|
|||
|
||||
func (p *postgresSubredditRepository) GetByName(ctx context.Context, name string) (domain.Subreddit, error) {
|
||||
query := `
|
||||
SELECT id, subreddit_id, name, last_checked_at
|
||||
SELECT id, subreddit_id, name, next_check_at
|
||||
FROM subreddits
|
||||
WHERE name = $1`
|
||||
|
||||
|
@ -81,8 +81,8 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
|
|||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO subreddits (subreddit_id, name)
|
||||
VALUES ($1, $2)
|
||||
INSERT INTO subreddits (subreddit_id, name, next_check_at)
|
||||
VALUES ($1, $2, NOW())
|
||||
ON CONFLICT(subreddit_id) DO NOTHING
|
||||
RETURNING id`
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ func (p *postgresUserRepository) fetch(ctx context.Context, query string, args .
|
|||
&u.ID,
|
||||
&u.UserID,
|
||||
&u.Name,
|
||||
&u.LastCheckedAt,
|
||||
&u.NextCheckAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ func (p *postgresUserRepository) fetch(ctx context.Context, query string, args .
|
|||
|
||||
func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.User, error) {
|
||||
query := `
|
||||
SELECT id, user_id, name, last_checked_at
|
||||
SELECT id, user_id, name, next_check_at
|
||||
FROM users
|
||||
WHERE id = $1`
|
||||
|
||||
|
@ -60,7 +60,7 @@ func (p *postgresUserRepository) GetByID(ctx context.Context, id int64) (domain.
|
|||
|
||||
func (p *postgresUserRepository) GetByName(ctx context.Context, name string) (domain.User, error) {
|
||||
query := `
|
||||
SELECT id, user_id, name, last_checked_at
|
||||
SELECT id, user_id, name, next_check_at
|
||||
FROM users
|
||||
WHERE name = $1`
|
||||
|
||||
|
@ -83,10 +83,9 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U
|
|||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO users (user_id, name)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT(user_id) DO
|
||||
UPDATE SET last_checked_at = $3
|
||||
INSERT INTO users (user_id, name, next_check_at)
|
||||
VALUES ($1, $2, NOW())
|
||||
ON CONFLICT(user_id) DO NOTHING
|
||||
RETURNING id`
|
||||
|
||||
return p.pool.QueryRow(
|
||||
|
@ -94,7 +93,7 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U
|
|||
query,
|
||||
u.UserID,
|
||||
u.NormalizedName(),
|
||||
u.LastCheckedAt,
|
||||
u.NextCheckAt,
|
||||
).Scan(&u.ID)
|
||||
}
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma
|
|||
devices.apns_token,
|
||||
devices.sandbox,
|
||||
accounts.id,
|
||||
accounts.account_id,
|
||||
accounts.reddit_account_id,
|
||||
accounts.access_token,
|
||||
accounts.refresh_token,
|
||||
COALESCE(subreddits.name, '') AS subreddit_label,
|
||||
|
@ -138,7 +138,7 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t
|
|||
devices.apns_token,
|
||||
devices.sandbox,
|
||||
accounts.id,
|
||||
accounts.account_id,
|
||||
accounts.reddit_account_id,
|
||||
accounts.access_token,
|
||||
accounts.refresh_token,
|
||||
COALESCE(subreddits.name, '') AS subreddit_label,
|
||||
|
@ -191,7 +191,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
|
|||
devices.apns_token,
|
||||
devices.sandbox,
|
||||
accounts.id,
|
||||
accounts.account_id,
|
||||
accounts.reddit_account_id,
|
||||
accounts.access_token,
|
||||
accounts.refresh_token,
|
||||
COALESCE(subreddits.name, '') AS subreddit_label,
|
||||
|
@ -203,7 +203,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
|
|||
LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id
|
||||
WHERE
|
||||
devices.apns_token = $1 AND
|
||||
accounts.account_id = $2`
|
||||
accounts.reddit_account_id = $2`
|
||||
|
||||
return p.fetch(ctx, query, apns, rid)
|
||||
}
|
||||
|
|
|
@ -166,7 +166,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
|
||||
defer func() { _ = delivery.Ack() }()
|
||||
|
||||
now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000
|
||||
now := time.Now()
|
||||
|
||||
account, err := nc.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
|
@ -177,20 +177,22 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
return
|
||||
}
|
||||
|
||||
previousLastCheckedAt := account.LastCheckedAt
|
||||
newAccount := (previousLastCheckedAt == 0)
|
||||
account.LastCheckedAt = now
|
||||
newAccount := account.CheckCount == 0
|
||||
previousNextCheck := account.NextNotificationCheckAt
|
||||
|
||||
account.CheckCount++
|
||||
account.NextNotificationCheckAt = time.Now().Add(domain.NotificationCheckInterval)
|
||||
|
||||
if err = nc.accountRepo.Update(ctx, &account); err != nil {
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
"account#username": account.NormalizedUsername(),
|
||||
"err": err,
|
||||
}).Error("failed to update last_checked_at for account")
|
||||
}).Error("failed to update next_notification_check_at for account")
|
||||
return
|
||||
}
|
||||
|
||||
rac := nc.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
|
||||
if account.ExpiresAt < int64(now) {
|
||||
if account.TokenExpiresAt.Before(now) {
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
"account#username": account.NormalizedUsername(),
|
||||
}).Debug("refreshing reddit token")
|
||||
|
@ -219,7 +221,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
// Update account
|
||||
account.AccessToken = tokens.AccessToken
|
||||
account.RefreshToken = tokens.RefreshToken
|
||||
account.ExpiresAt = int64(now + 3540)
|
||||
account.TokenExpiresAt = now.Add(tokens.Expiry)
|
||||
|
||||
// Refresh client
|
||||
rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
|
||||
|
@ -236,8 +238,8 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
// Only update delay on accounts we can actually check, otherwise it skews
|
||||
// the numbers too much.
|
||||
if !newAccount {
|
||||
latency := now - previousLastCheckedAt - float64(backoff)
|
||||
_ = nc.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate)
|
||||
latency := now.Sub(previousNextCheck) - backoff*time.Second
|
||||
_ = nc.statsd.Histogram("apollo.queue.delay", float64(latency.Milliseconds()), []string{}, rate)
|
||||
}
|
||||
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
|
|
|
@ -180,7 +180,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
return
|
||||
}
|
||||
|
||||
threshold := float64(time.Now().AddDate(0, 0, -1).UTC().Unix())
|
||||
threshold := time.Now().Add(-24 * time.Hour)
|
||||
posts := []*reddit.Thing{}
|
||||
before := ""
|
||||
finished := false
|
||||
|
@ -237,7 +237,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
}
|
||||
|
||||
for _, post := range sps.Children {
|
||||
if post.CreatedAt < threshold {
|
||||
if post.CreatedAt.Before(threshold) {
|
||||
finished = true
|
||||
break
|
||||
}
|
||||
|
@ -287,7 +287,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
}).Debug("loaded hot posts")
|
||||
|
||||
for _, post := range sps.Children {
|
||||
if post.CreatedAt < threshold {
|
||||
if post.CreatedAt.Before(threshold) {
|
||||
break
|
||||
}
|
||||
if _, ok := seenPosts[post.ID]; !ok {
|
||||
|
@ -313,7 +313,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
|
|||
|
||||
for _, watcher := range watchers {
|
||||
// Make sure we only alert on posts created after the search
|
||||
if watcher.CreatedAt > post.CreatedAt {
|
||||
if watcher.CreatedAt.After(post.CreatedAt) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -235,14 +235,14 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
|||
}).Debug("loaded hot posts")
|
||||
|
||||
// Trending only counts for posts less than 2 days old
|
||||
threshold := float64(time.Now().AddDate(0, 0, -2).UTC().Unix())
|
||||
threshold := time.Now().Add(-24 * time.Hour * 2)
|
||||
|
||||
for _, post := range hps.Children {
|
||||
if post.Score < medianScore {
|
||||
continue
|
||||
}
|
||||
|
||||
if post.CreatedAt < threshold {
|
||||
if post.CreatedAt.Before(threshold) {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -251,7 +251,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
|
|||
notification.Payload = payloadFromTrendingPost(post)
|
||||
|
||||
for _, watcher := range watchers {
|
||||
if watcher.CreatedAt > post.CreatedAt {
|
||||
if watcher.CreatedAt.After(post.CreatedAt) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -234,11 +234,11 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
|
|||
|
||||
for _, watcher := range watchers {
|
||||
// Make sure we only alert on activities created after the search
|
||||
if watcher.CreatedAt > post.CreatedAt {
|
||||
if watcher.CreatedAt.After(post.CreatedAt) {
|
||||
continue
|
||||
}
|
||||
|
||||
if watcher.LastNotifiedAt > post.CreatedAt {
|
||||
if watcher.LastNotifiedAt.After(post.CreatedAt) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue