diff --git a/internal/cmd/scheduler.go b/internal/cmd/scheduler.go index 1ac55fc..5600450 100644 --- a/internal/cmd/scheduler.go +++ b/internal/cmd/scheduler.go @@ -14,7 +14,6 @@ import ( "github.com/adjust/rmq/v4" "github.com/go-co-op/gocron" "github.com/go-redis/redis/v8" - "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/spf13/cobra" "go.uber.org/zap" @@ -219,36 +218,29 @@ func enqueueUsers(ctx context.Context, logger *zap.Logger, statsd *statsd.Client _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) }() - err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { - stmt := ` - UPDATE users - SET next_check_at = $2 - WHERE id IN ( - SELECT id - FROM users - WHERE next_check_at < $1 - ORDER BY next_check_at - FOR UPDATE SKIP LOCKED - LIMIT 100 - ) - RETURNING users.id` - rows, err := tx.Query(ctx, stmt, now, next) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var id int64 - _ = rows.Scan(&id) - ids = append(ids, id) - } - return nil - }) - + stmt := ` + UPDATE users + SET next_check_at = $2 + WHERE id IN ( + SELECT id + FROM users + WHERE next_check_at < $1 + ORDER BY next_check_at + FOR UPDATE SKIP LOCKED + LIMIT 100 + ) + RETURNING users.id` + rows, err := pool.Query(ctx, stmt, now, next) if err != nil { logger.Error("failed to fetch batch of users", zap.Error(err)) return } + for rows.Next() { + var id int64 + _ = rows.Scan(&id) + ids = append(ids, id) + } + rows.Close() if len(ids) == 0 { return @@ -278,8 +270,7 @@ func enqueueSubreddits(ctx context.Context, logger *zap.Logger, statsd *statsd.C _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) }() - err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { - stmt := ` + stmt := ` UPDATE subreddits SET next_check_at = $2 WHERE subreddits.id IN( @@ -291,23 +282,17 @@ func enqueueSubreddits(ctx context.Context, logger *zap.Logger, statsd *statsd.C LIMIT 100 ) RETURNING subreddits.id` - rows, err := tx.Query(ctx, stmt, now, next) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var id int64 - _ = rows.Scan(&id) - ids = append(ids, id) - } - return nil - }) - + rows, err := pool.Query(ctx, stmt, now, next) if err != nil { logger.Error("failed to fetch batch of subreddits", zap.Error(err)) return } + for rows.Next() { + var id int64 + _ = rows.Scan(&id) + ids = append(ids, id) + } + rows.Close() if len(ids) == 0 { return @@ -340,8 +325,7 @@ func enqueueStuckAccounts(ctx context.Context, logger *zap.Logger, statsd *stats _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) }() - err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { - stmt := ` + stmt := ` UPDATE accounts SET next_stuck_notification_check_at = $2 WHERE accounts.id IN( @@ -353,24 +337,19 @@ func enqueueStuckAccounts(ctx context.Context, logger *zap.Logger, statsd *stats LIMIT 500 ) RETURNING accounts.id` - rows, err := tx.Query(ctx, stmt, now, next) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var id int64 - _ = rows.Scan(&id) - ids = append(ids, id) - } - return nil - }) - + rows, err := pool.Query(ctx, stmt, now, next) if err != nil { logger.Error("failed to fetch accounts", zap.Error(err)) return } + for rows.Next() { + var id int64 + _ = rows.Scan(&id) + ids = append(ids, id) + } + rows.Close() + if len(ids) == 0 { return } @@ -403,8 +382,7 @@ func enqueueAccounts(ctx context.Context, logger *zap.Logger, statsd *statsd.Cli _ = statsd.Histogram("apollo.queue.runtime", float64(time.Since(now).Milliseconds()), tags, 1) }() - err := pool.BeginFunc(ctx, func(tx pgx.Tx) error { - stmt := fmt.Sprintf(` + stmt := fmt.Sprintf(` UPDATE accounts SET next_notification_check_at = $2 WHERE accounts.id IN( @@ -416,22 +394,16 @@ func enqueueAccounts(ctx context.Context, logger *zap.Logger, statsd *statsd.Cli LIMIT %d ) RETURNING accounts.reddit_account_id`, maxNotificationChecks) - rows, err := tx.Query(ctx, stmt, now, next) - if err != nil { - return err - } - defer rows.Close() - for i := 0; rows.Next(); i++ { - _ = rows.Scan(&ids[i]) - idslen = i - } - return nil - }) - + rows, err := pool.Query(ctx, stmt, now, next) if err != nil { logger.Error("failed to fetch batch of accounts", zap.Error(err)) return } + for i := 0; rows.Next(); i++ { + _ = rows.Scan(&ids[i]) + idslen = i + } + rows.Close() if idslen == 0 { return diff --git a/internal/reddit/client.go b/internal/reddit/client.go index 917b174..57376da 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -356,47 +356,56 @@ func (rac *AuthenticatedClient) logRequest() error { return nil } - return rac.client.redis.HIncrBy(context.Background(), "reddit:requests", rac.redditId, 1).Err() + return nil + // return rac.client.redis.HIncrBy(context.Background(), "reddit:requests", rac.redditId, 1).Err() } func (rac *AuthenticatedClient) isRateLimited() bool { - if rac.redditId == SkipRateLimiting { - return false - } + return false - key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) - _, err := rac.client.redis.Get(context.Background(), key).Result() - return err != redis.Nil + /* + if rac.redditId == SkipRateLimiting { + return false + } + + key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) + _, err := rac.client.redis.Get(context.Background(), key).Result() + return err != redis.Nil + */ } func (rac *AuthenticatedClient) markRateLimited(rli *RateLimitingInfo) error { - if rac.redditId == SkipRateLimiting { - return ErrRequiresRedditId - } + return nil - if !rli.Present { - return nil - } - - if rli.Remaining > RequestRemainingBuffer { - return nil - } - - _ = rac.client.statsd.Incr("reddit.api.ratelimit", nil, 1.0) - - key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) - duration := time.Duration(rli.Reset) * time.Second - info := fmt.Sprintf("%+v", *rli) - - if rli.Used > 2000 { - _, err := rac.client.redis.HSet(context.Background(), "reddit:ratelimited:crazy", rac.redditId, info).Result() - if err != nil { - return err + /* + if rac.redditId == SkipRateLimiting { + return ErrRequiresRedditId } - } - _, err := rac.client.redis.SetEX(context.Background(), key, info, duration).Result() - return err + if !rli.Present { + return nil + } + + if rli.Remaining > RequestRemainingBuffer { + return nil + } + + _ = rac.client.statsd.Incr("reddit.api.ratelimit", nil, 1.0) + + key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) + duration := time.Duration(rli.Reset) * time.Second + info := fmt.Sprintf("%+v", *rli) + + if rli.Used > 2000 { + _, err := rac.client.redis.HSet(context.Background(), "reddit:ratelimited:crazy", rac.redditId, info).Result() + if err != nil { + return err + } + } + + _, err := rac.client.redis.SetEX(context.Background(), key, info, duration).Result() + return err + */ } func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...RequestOption) (*RefreshTokenResponse, error) {