mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-26 05:37:42 +00:00
enable/disable notifications per account/device pair
This commit is contained in:
parent
70d73eab4c
commit
815b577bf5
6 changed files with 69 additions and 3 deletions
|
@ -13,6 +13,43 @@ import (
|
||||||
"github.com/christianselig/apollo-backend/internal/domain"
|
"github.com/christianselig/apollo-backend/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type accountNotificationsRequest struct {
|
||||||
|
Enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
anr := &accountNotificationsRequest{}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(anr); err != nil {
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
apns := vars["apns"]
|
||||||
|
rid := vars["redditID"]
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
|
||||||
|
if err != nil {
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.Enabled); err != nil {
|
||||||
|
a.errorResponse(w, r, 500, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
apns := vars["apns"]
|
apns := vars["apns"]
|
||||||
|
|
|
@ -92,6 +92,7 @@ func (a *api) Routes() *mux.Router {
|
||||||
r.HandleFunc("/v1/device/{apns}/account", a.upsertAccountHandler).Methods("POST")
|
r.HandleFunc("/v1/device/{apns}/account", a.upsertAccountHandler).Methods("POST")
|
||||||
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
|
r.HandleFunc("/v1/device/{apns}/accounts", a.upsertAccountsHandler).Methods("POST")
|
||||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}", a.disassociateAccountHandler).Methods("DELETE")
|
||||||
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/notifications", a.notificationsAccountHandler).Methods("PATCH")
|
||||||
|
|
||||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher", a.createWatcherHandler).Methods("POST")
|
||||||
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
|
r.HandleFunc("/v1/device/{apns}/account/{redditID}/watcher/{watcherID}", a.deleteWatcherHandler).Methods("DELETE")
|
||||||
|
|
|
@ -27,12 +27,14 @@ func (dev *Device) Validate() error {
|
||||||
type DeviceRepository interface {
|
type DeviceRepository interface {
|
||||||
GetByID(ctx context.Context, id int64) (Device, error)
|
GetByID(ctx context.Context, id int64) (Device, error)
|
||||||
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
||||||
|
GetNotifiableByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||||
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
|
GetByAccountID(ctx context.Context, id int64) ([]Device, error)
|
||||||
|
|
||||||
CreateOrUpdate(ctx context.Context, dev *Device) error
|
CreateOrUpdate(ctx context.Context, dev *Device) error
|
||||||
Update(ctx context.Context, dev *Device) error
|
Update(ctx context.Context, dev *Device) error
|
||||||
Create(ctx context.Context, dev *Device) error
|
Create(ctx context.Context, dev *Device) error
|
||||||
Delete(ctx context.Context, token string) error
|
Delete(ctx context.Context, token string) error
|
||||||
|
SetNotifiable(ctx context.Context, dev *Device, acct *Account, notifiable bool) error
|
||||||
|
|
||||||
PruneStale(ctx context.Context, before int64) (int64, error)
|
PruneStale(ctx context.Context, before int64) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma
|
||||||
|
|
||||||
func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) {
|
func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at
|
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at, last_unstuck_at
|
||||||
FROM accounts
|
FROM accounts
|
||||||
WHERE account_id = $1`
|
WHERE account_id = $1`
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,16 @@ func (p *postgresDeviceRepository) GetByAccountID(ctx context.Context, id int64)
|
||||||
return p.fetch(ctx, query, id)
|
return p.fetch(ctx, query, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *postgresDeviceRepository) GetNotifiableByAccountID(ctx context.Context, id int64) ([]domain.Device, error) {
|
||||||
|
query := `
|
||||||
|
SELECT devices.id, apns_token, sandbox, active_until
|
||||||
|
FROM devices
|
||||||
|
INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
||||||
|
WHERE devices_accounts.account_id = $1 AND devices_accounts.notifiable = TRUE`
|
||||||
|
|
||||||
|
return p.fetch(ctx, query, id)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
|
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO devices (apns_token, sandbox, active_until)
|
INSERT INTO devices (apns_token, sandbox, active_until)
|
||||||
|
@ -150,6 +160,21 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, notifiable bool) error {
|
||||||
|
query := `
|
||||||
|
UPDATE devices_accounts
|
||||||
|
SET notifiable = $1
|
||||||
|
WHERE device_id = $2 AND account_id = $3`
|
||||||
|
|
||||||
|
res, err := p.pool.Exec(ctx, query, notifiable, dev.ID, acct.ID)
|
||||||
|
|
||||||
|
if res.RowsAffected() != 1 {
|
||||||
|
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
|
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) {
|
||||||
query := `DELETE FROM devices WHERE active_until < $1`
|
query := `DELETE FROM devices WHERE active_until < $1`
|
||||||
|
|
||||||
|
|
|
@ -204,8 +204,9 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
"account#username": account.NormalizedUsername(),
|
"account#username": account.NormalizedUsername(),
|
||||||
"err": err,
|
"err": err,
|
||||||
}).Error("failed to remove revoked account")
|
}).Error("failed to remove revoked account")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update account
|
// Update account
|
||||||
|
@ -303,7 +304,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
devices, err := nc.deviceRepo.GetByAccountID(ctx, account.ID)
|
devices, err := nc.deviceRepo.GetNotifiableByAccountID(ctx, account.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nc.logger.WithFields(logrus.Fields{
|
nc.logger.WithFields(logrus.Fields{
|
||||||
"account#username": account.NormalizedUsername(),
|
"account#username": account.NormalizedUsername(),
|
||||||
|
|
Loading…
Reference in a new issue