diff --git a/internal/api/accounts.go b/internal/api/accounts.go index 84499d8..7dbcd04 100644 --- a/internal/api/accounts.go +++ b/internal/api/accounts.go @@ -13,6 +13,43 @@ import ( "github.com/christianselig/apollo-backend/internal/domain" ) +type accountNotificationsRequest struct { + Enabled bool +} + +func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) { + anr := &accountNotificationsRequest{} + if err := json.NewDecoder(r.Body).Decode(anr); err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + vars := mux.Vars(r) + apns := vars["apns"] + rid := vars["redditID"] + + ctx := context.Background() + + dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) + if err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + acct, err := a.accountRepo.GetByRedditID(ctx, rid) + if err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.Enabled); err != nil { + a.errorResponse(w, r, 500, err.Error()) + return + } + + w.WriteHeader(http.StatusOK) +} + func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) apns := vars["apns"] diff --git a/internal/api/api.go b/internal/api/api.go index f7a7ec3..c505fc9 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -92,6 +92,7 @@ func (a *api) Routes() *mux.Router { r.HandleFunc("/v1/device/{apns}/account", a.upsertAccountHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE") + r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE") diff --git a/internal/domain/device.go b/internal/domain/device.go index d7f2e11..f37cd3d 100644 --- a/internal/domain/device.go +++ b/internal/domain/device.go @@ -27,12 +27,14 @@ func (dev *Device) Validate() error { type DeviceRepository interface { GetByID(ctx context.Context, id int64) (Device, error) GetByAPNSToken(ctx context.Context, token string) (Device, error) + GetNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error) GetByAccountID(ctx context.Context, id int64) ([]Device, error) CreateOrUpdate(ctx context.Context, dev *Device) error Update(ctx context.Context, dev *Device) error Create(ctx context.Context, dev *Device) error Delete(ctx context.Context, token string) error + SetNotifiable(ctx context.Context, dev *Device, acct *Account, notifiable bool) error PruneStale(ctx context.Context, before int64) (int64, error) } diff --git a/internal/repository/postgres_account.go b/internal/repository/postgres_account.go index 9db52e6..5321328 100644 --- a/internal/repository/postgres_account.go +++ b/internal/repository/postgres_account.go @@ -64,7 +64,7 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) { query := ` - SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at + SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at FROM accounts WHERE account_id = $1` diff --git a/internal/repository/postgres_device.go b/internal/repository/postgres_device.go index 5f3252e..99f0d07 100644 --- a/internal/repository/postgres_device.go +++ b/internal/repository/postgres_device.go @@ -84,6 +84,16 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64) return p.fetch(ctx, query, id) } +func (p *postgresDeviceRepository) GetNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) { + query := ` + SELECT devices.id, apns_token, sandbox, active_until + FROM devices + INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id + WHERE devices_accounts.account_id = $1 AND devices_accounts.notifiable = TRUE` + + return p.fetch(ctx, query, id) +} + func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error { query := ` INSERT INTO devices (apns_token, sandbox, active_until) @@ -150,6 +160,21 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err return err } +func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, notifiable bool) error { + query := ` + UPDATE devices_accounts + SET notifiable = $1 + WHERE device_id = $2 AND account_id = $3` + + res, err := p.pool.Exec(ctx, query, notifiable, dev.ID, acct.ID) + + if res.RowsAffected() != 1 { + return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) + } + return err + +} + func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) { query := `DELETE FROM devices WHERE active_until < $1` diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index 79e2b18..d5e7c3b 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -204,8 +204,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { "account#username": account.NormalizedUsername(), "err": err, }).Error("failed to remove revoked account") - return } + + return } // Update account @@ -303,7 +304,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { return } - devices, err := nc.deviceRepo.GetByAccountID(ctx, account.ID) + devices, err := nc.deviceRepo.GetNotifiableByAccountID(ctx, account.ID) if err != nil { nc.logger.WithFields(logrus.Fields{ "account#username": account.NormalizedUsername(),