mirror of
https://github.com/christianselig/apollo-backend
synced 2024-12-22 22:35:27 +00:00
enable/disable notifications per account/device pair
This commit is contained in:
parent
70d73eab4c
commit
815b577bf5
6 changed files with 69 additions and 3 deletions
|
@ -13,6 +13,43 @@ import (
|
|||
"github.com/christianselig/apollo-backend/internal/domain"
|
||||
)
|
||||
|
||||
type accountNotificationsRequest struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
anr := &accountNotificationsRequest{}
|
||||
if err := json.NewDecoder(r.Body).Decode(anr); err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
apns := vars["apns"]
|
||||
rid := vars["redditID"]
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
|
||||
if err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.Enabled); err != nil {
|
||||
a.errorResponse(w, r, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
apns := vars["apns"]
|
||||
|
|
|
@ -92,6 +92,7 @@ func (a *api) Routes() *mux.Router {
|
|||
r.HandleFunc("/v1/device/{apns}/account", a.upsertAccountHandler).Methods("POST")
|
||||
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH")
|
||||
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
|
||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
|
||||
|
|
|
@ -27,12 +27,14 @@ func (dev *Device) Validate() error {
|
|||
type DeviceRepository interface {
|
||||
GetByID(ctx context.Context, id int64) (Device, error)
|
||||
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
||||
GetNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||
|
||||
CreateOrUpdate(ctx context.Context, dev *Device) error
|
||||
Update(ctx context.Context, dev *Device) error
|
||||
Create(ctx context.Context, dev *Device) error
|
||||
Delete(ctx context.Context, token string) error
|
||||
SetNotifiable(ctx context.Context, dev *Device, acct *Account, notifiable bool) error
|
||||
|
||||
PruneStale(ctx context.Context, before int64) (int64, error)
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma
|
|||
|
||||
func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at
|
||||
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at
|
||||
FROM accounts
|
||||
WHERE account_id = $1`
|
||||
|
||||
|
|
|
@ -84,6 +84,16 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
|
|||
return p.fetch(ctx, query, id)
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) GetNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||
query := `
|
||||
SELECT devices.id, apns_token, sandbox, active_until
|
||||
FROM devices
|
||||
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||
WHERE devices_accounts.account_id = $1 AND devices_accounts.notifiable = TRUE`
|
||||
|
||||
return p.fetch(ctx, query, id)
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
|
||||
query := `
|
||||
INSERT INTO devices (apns_token, sandbox, active_until)
|
||||
|
@ -150,6 +160,21 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err
|
|||
return err
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, notifiable bool) error {
|
||||
query := `
|
||||
UPDATE devices_accounts
|
||||
SET notifiable = $1
|
||||
WHERE device_id = $2 AND account_id = $3`
|
||||
|
||||
res, err := p.pool.Exec(ctx, query, notifiable, dev.ID, acct.ID)
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
|
||||
query := `DELETE FROM devices WHERE active_until < $1`
|
||||
|
||||
|
|
|
@ -204,8 +204,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
"account#username": account.NormalizedUsername(),
|
||||
"err": err,
|
||||
}).Error("failed to remove revoked account")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Update account
|
||||
|
@ -303,7 +304,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
return
|
||||
}
|
||||
|
||||
devices, err := nc.deviceRepo.GetByAccountID(ctx, account.ID)
|
||||
devices, err := nc.deviceRepo.GetNotifiableByAccountID(ctx, account.ID)
|
||||
if err != nil {
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
"account#username": account.NormalizedUsername(),
|
||||
|
|
Loading…
Reference in a new issue