enable/disable notifications per account/device pair

This commit is contained in:
Andre Medeiros 2021-10-17 12:41:12 -04:00
parent 70d73eab4c
commit 815b577bf5
6 changed files with 69 additions and 3 deletions

View file

@ -13,6 +13,43 @@ import (
"github.com/christianselig/apollo-backend/internal/domain" "github.com/christianselig/apollo-backend/internal/domain"
) )
type accountNotificationsRequest struct {
Enabled bool
}
func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
anr := &accountNotificationsRequest{}
if err := json.NewDecoder(r.Body).Decode(anr); err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
vars := mux.Vars(r)
apns := vars["apns"]
rid := vars["redditID"]
ctx := context.Background()
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.Enabled); err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
w.WriteHeader(http.StatusOK)
}
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) { func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
apns := vars["apns"] apns := vars["apns"]

View file

@ -92,6 +92,7 @@ func (a *api) Routes() *mux.Router {
r.HandleFunc("/v1/device/{apns}/account", a.upsertAccountHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/account", a.upsertAccountHandler).Methods("POST")
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE") r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE") r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")

View file

@ -27,12 +27,14 @@ func (dev *Device) Validate() error {
type DeviceRepository interface { type DeviceRepository interface {
GetByID(ctx context.Context, id int64) (Device, error) GetByID(ctx context.Context, id int64) (Device, error)
GetByAPNSToken(ctx context.Context, token string) (Device, error) GetByAPNSToken(ctx context.Context, token string) (Device, error)
GetNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
GetByAccountID(ctx context.Context, id int64) ([]Device, error) GetByAccountID(ctx context.Context, id int64) ([]Device, error)
CreateOrUpdate(ctx context.Context, dev *Device) error CreateOrUpdate(ctx context.Context, dev *Device) error
Update(ctx context.Context, dev *Device) error Update(ctx context.Context, dev *Device) error
Create(ctx context.Context, dev *Device) error Create(ctx context.Context, dev *Device) error
Delete(ctx context.Context, token string) error Delete(ctx context.Context, token string) error
SetNotifiable(ctx context.Context, dev *Device, acct *Account, notifiable bool) error
PruneStale(ctx context.Context, before int64) (int64, error) PruneStale(ctx context.Context, before int64) (int64, error)
} }

View file

@ -64,7 +64,7 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma
func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) { func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) {
query := ` query := `
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at
FROM accounts FROM accounts
WHERE account_id = $1` WHERE account_id = $1`

View file

@ -84,6 +84,16 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
return p.fetch(ctx, query, id) return p.fetch(ctx, query, id)
} }
func (p *postgresDeviceRepository) GetNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
query := `
SELECT devices.id, apns_token, sandbox, active_until
FROM devices
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
WHERE devices_accounts.account_id = $1 AND devices_accounts.notifiable = TRUE`
return p.fetch(ctx, query, id)
}
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error { func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
query := ` query := `
INSERT INTO devices (apns_token, sandbox, active_until) INSERT INTO devices (apns_token, sandbox, active_until)
@ -150,6 +160,21 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err
return err return err
} }
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, notifiable bool) error {
query := `
UPDATE devices_accounts
SET notifiable = $1
WHERE device_id = $2 AND account_id = $3`
res, err := p.pool.Exec(ctx, query, notifiable, dev.ID, acct.ID)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
}
return err
}
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) { func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
query := `DELETE FROM devices WHERE active_until < $1` query := `DELETE FROM devices WHERE active_until < $1`

View file

@ -204,8 +204,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
"account#username": account.NormalizedUsername(), "account#username": account.NormalizedUsername(),
"err": err, "err": err,
}).Error("failed to remove revoked account") }).Error("failed to remove revoked account")
return
} }
return
} }
// Update account // Update account
@ -303,7 +304,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return return
} }
devices, err := nc.deviceRepo.GetByAccountID(ctx, account.ID) devices, err := nc.deviceRepo.GetNotifiableByAccountID(ctx, account.ID)
if err != nil { if err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(), "account#username": account.NormalizedUsername(),