diff --git a/internal/api/accounts.go b/internal/api/accounts.go index 85f1341..5b19018 100644 --- a/internal/api/accounts.go +++ b/internal/api/accounts.go @@ -17,6 +17,7 @@ import ( type accountNotificationsRequest struct { InboxNotifications bool `json:"inbox_notifications"` WatcherNotifications bool `json:"watcher_notifications"` + GlobalMute bool `json:"global_mute"` } func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) { @@ -44,7 +45,7 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request return } - if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.InboxNotifications, anr.WatcherNotifications); err != nil { + if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.InboxNotifications, anr.WatcherNotifications, anr.GlobalMute); err != nil { a.errorResponse(w, r, 500, err.Error()) return } @@ -71,7 +72,7 @@ func (a *api) getNotificationsAccountHandler(w http.ResponseWriter, r *http.Requ return } - inbox, watchers, err := a.deviceRepo.GetNotifiable(ctx, &dev, &acct) + inbox, watchers, global, err := a.deviceRepo.GetNotifiable(ctx, &dev, &acct) if err != nil { a.errorResponse(w, r, 500, err.Error()) return @@ -79,7 +80,7 @@ func (a *api) getNotificationsAccountHandler(w http.ResponseWriter, r *http.Requ w.WriteHeader(http.StatusOK) - an := &accountNotificationsRequest{InboxNotifications: inbox, WatcherNotifications: watchers} + an := &accountNotificationsRequest{InboxNotifications: inbox, WatcherNotifications: watchers, GlobalMute: global} _ = json.NewEncoder(w).Encode(an) } diff --git a/internal/domain/device.go b/internal/domain/device.go index 62c3420..a002ba4 100644 --- a/internal/domain/device.go +++ b/internal/domain/device.go @@ -35,8 +35,8 @@ type DeviceRepository interface { Update(ctx context.Context, dev *Device) error Create(ctx context.Context, dev *Device) error Delete(ctx context.Context, token string) error - SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher bool) error - GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, error) + SetNotifiable(ctx context.Context, dev *Device, acct *Account, inbox, watcher, global bool) error + GetNotifiable(ctx context.Context, dev *Device, acct *Account) (bool, bool, bool, error) PruneStale(ctx context.Context, before int64) (int64, error) } diff --git a/internal/repository/postgres_device.go b/internal/repository/postgres_device.go index 4e2ccdb..2da3a1d 100644 --- a/internal/repository/postgres_device.go +++ b/internal/repository/postgres_device.go @@ -170,15 +170,16 @@ func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) err return err } -func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, inbox, watcher bool) error { +func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account, inbox, watcher, global bool) error { query := ` UPDATE devices_accounts SET inbox_notifiable = $1, - watcher_notifiable = $2 - WHERE device_id = $3 AND account_id = $4` + watcher_notifiable = $2, + global_mute = $3 + WHERE device_id = $4 AND account_id = $5` - res, err := p.pool.Exec(ctx, query, inbox, watcher, dev.ID, acct.ID) + res, err := p.pool.Exec(ctx, query, inbox, watcher, global, dev.ID, acct.ID) if res.RowsAffected() != 1 { return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) @@ -187,26 +188,26 @@ func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domai } -func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account) (bool, bool, error) { +func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domain.Device, acct *domain.Account) (bool, bool, bool, error) { query := ` - SELECT inbox_notifiable, watcher_notifiable + SELECT inbox_notifiable, watcher_notifiable, global_mute FROM devices_accounts WHERE device_id = $1 AND account_id = $2` rows, err := p.pool.Query(ctx, query, dev.ID, acct.ID) if err != nil { - return false, false, err + return false, false, false, err } defer rows.Close() for rows.Next() { - var inbox, watcher bool - if err := rows.Scan(&inbox, &watcher); err != nil { - return false, false, err + var inbox, watcher, global bool + if err := rows.Scan(&inbox, &watcher, &global); err != nil { + return false, false, false, err } - return inbox, watcher, nil + return inbox, watcher, global, nil } - return false, false, domain.ErrNotFound + return false, false, false, domain.ErrNotFound } func (p *postgresDeviceRepository) PruneStale(ctx context.Context, before int64) (int64, error) { diff --git a/internal/repository/postgres_watcher.go b/internal/repository/postgres_watcher.go index a841e77..a9a6f79 100644 --- a/internal/repository/postgres_watcher.go +++ b/internal/repository/postgres_watcher.go @@ -149,7 +149,10 @@ func (p *postgresWatcherRepository) GetByTypeAndWatcheeID(ctx context.Context, t INNER JOIN devices_accounts ON devices.id = devices_accounts.device_id AND accounts.id = devices_accounts.account_id LEFT JOIN subreddits ON watchers.type IN(0,2) AND watchers.watchee_id = subreddits.id LEFT JOIN users ON watchers.type = 1 AND watchers.watchee_id = users.id - WHERE watchers.type = $1 AND watchers.watchee_id = $2 AND devices_accounts.watcher_notifiable = TRUE` + WHERE watchers.type = $1 AND + watchers.watchee_id = $2 AND + devices_accounts.watcher_notifiable = TRUE AND + devices_accounts.global_mute = FALSE` return p.fetch(ctx, query, typ, id) }