diff --git a/internal/domain/user.go b/internal/domain/user.go index c8b6f8f..44dabc0 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -23,4 +23,5 @@ type UserRepository interface { GetByName(context.Context, string) (User, error) CreateOrUpdate(context.Context, *User) error + Delete(context.Context, int64) error } diff --git a/internal/domain/watcher.go b/internal/domain/watcher.go index 99f20b7..0655e4f 100644 --- a/internal/domain/watcher.go +++ b/internal/domain/watcher.go @@ -10,8 +10,9 @@ const ( ) type Watcher struct { - ID int64 - CreatedAt float64 + ID int64 + CreatedAt float64 + LastNotifiedAt float64 DeviceID int64 AccountID int64 @@ -35,4 +36,5 @@ type WatcherRepository interface { Update(ctx context.Context, watcher *Watcher) error IncrementHits(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error + DeleteByTypeAndWatcheeID(context.Context, WatcherType, int64) error } diff --git a/internal/repository/postgres_user.go b/internal/repository/postgres_user.go index 283f940..5f038ad 100644 --- a/internal/repository/postgres_user.go +++ b/internal/repository/postgres_user.go @@ -2,6 +2,7 @@ package repository import ( "context" + "fmt" "strings" "github.com/christianselig/apollo-backend/internal/domain" @@ -91,3 +92,13 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U u.LastCheckedAt, ).Scan(&u.ID) } + +func (p *postgresUserRepository) Delete(ctx context.Context, id int64) error { + query := `DELETE FROM users WHERE id = $1` + res, err := p.pool.Exec(ctx, query, id) + + if res.RowsAffected() != 1 { + return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) + } + return err +} diff --git a/internal/repository/postgres_watcher.go b/internal/repository/postgres_watcher.go index 92d27e1..0dc47c4 100644 --- a/internal/repository/postgres_watcher.go +++ b/internal/repository/postgres_watcher.go @@ -31,6 +31,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg if err := rows.Scan( &watcher.ID, &watcher.CreatedAt, + &watcher.LastNotifiedAt, &watcher.DeviceID, &watcher.AccountID, &watcher.Type, @@ -50,7 +51,7 @@ func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, arg func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (domain.Watcher, error) { query := ` - SELECT id, created_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits + SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits FROM watchers WHERE id = $1` @@ -67,7 +68,7 @@ func (p *postgresWatcherRepository) GetByID(ctx context.Context, id int64) (doma func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) ([]domain.Watcher, error) { query := ` - SELECT id, created_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits + SELECT id, created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain, hits FROM watchers WHERE type = $1 AND watchee_id = $2` @@ -87,6 +88,7 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c SELECT watchers.id, watchers.created_at, + watchers.last_notified_at watchers.device_id, watchers.account_id, watchers.type, @@ -111,8 +113,8 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain. query := ` INSERT INTO watchers - (created_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + (created_at, last_notified_at, device_id, account_id, type, watchee_id, upvotes, keyword, flair, domain) + VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id` return p.pool.QueryRow( @@ -156,8 +158,9 @@ func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain. } func (p *postgresWatcherRepository) IncrementHits(ctx context.Context, id int64) error { - query := `UPDATE watchers SET hits = hits + 1 WHERE id = $1` - res, err := p.pool.Exec(ctx, query, id) + now := time.Now().Unix() + query := `UPDATE watchers SET hits = hits + 1, last_notified_at = $2 WHERE id = $1` + res, err := p.pool.Exec(ctx, query, id, now) if res.RowsAffected() != 1 { return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) @@ -174,3 +177,13 @@ func (p *postgresWatcherRepository) Delete(ctx context.Context, id int64) error } return err } + +func (p *postgresWatcherRepository) DeleteByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) error { + query := `DELETE FROM watchers WHERE type = $1 AND watchee_id = $2` + res, err := p.pool.Exec(ctx, query, typ, id) + + if res.RowsAffected() == 0 { + return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) + } + return err +} diff --git a/internal/worker/subreddits.go b/internal/worker/subreddits.go index 42488c9..4ab9cbd 100644 --- a/internal/worker/subreddits.go +++ b/internal/worker/subreddits.go @@ -332,8 +332,6 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { continue } - _ = sc.watcherRepo.IncrementHits(ctx, watcher.ID) - lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID) notified, _ := sc.redis.Get(ctx, lockKey).Bool() @@ -348,6 +346,15 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) { continue } + if err := sc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil { + sc.logger.WithFields(logrus.Fields{ + "subreddit#id": subreddit.ID, + "watcher#id": watcher.ID, + "err": err, + }).Error("could not increment hits") + return + } + sc.logger.WithFields(logrus.Fields{ "subreddit#id": subreddit.ID, "subreddit#name": subreddit.Name, diff --git a/internal/worker/users.go b/internal/worker/users.go index eda8f3a..ddc10eb 100644 --- a/internal/worker/users.go +++ b/internal/worker/users.go @@ -168,7 +168,6 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { if len(watchers) == 0 { uc.logger.WithFields(logrus.Fields{ "user#id": user.ID, - "err": err, }).Info("no watchers for user, skipping") return } @@ -180,6 +179,37 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID) rac := uc.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken) + ru, err := rac.UserAbout(user.Name) + if err != nil { + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "err": err, + }).Error("failed to fetch user details") + return + } + + if !ru.AcceptFollowers { + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + }).Info("user disabled followers, removing") + + if err := uc.watcherRepo.DeleteByTypeAndWatcheeID(ctx, domain.UserWatcher, user.ID); err != nil { + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "err": err, + }).Error("failed to delete watchers for user who does not allow followers") + return + } + + if err := uc.userRepo.Delete(ctx, user.ID); err != nil { + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "err": err, + }).Error("failed to delete user") + return + } + } + posts, err := rac.UserPosts(user.Name) if err != nil { uc.logger.WithFields(logrus.Fields{ @@ -190,10 +220,6 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { } for _, post := range posts.Children { - if post.CreatedAt < user.LastCheckedAt { - break - } - notification := &apns2.Notification{} notification.Topic = "com.christianselig.Apollo" notification.Payload = payloadFromUserPost(post) @@ -203,6 +229,20 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) { if watcher.CreatedAt > post.CreatedAt { continue } + + if watcher.LastNotifiedAt > post.CreatedAt { + continue + } + + if err := uc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil { + uc.logger.WithFields(logrus.Fields{ + "user#id": user.ID, + "watcher#id": watcher.ID, + "err": err, + }).Error("could not increment hits") + return + } + device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID) notification.DeviceToken = device.APNSToken