apollo-backend/internal/api/accounts.go

288 lines
7 KiB
Go
Raw Normal View History

package api
import (
2022-11-01 02:33:11 +00:00
"context"
"encoding/json"
2021-08-08 18:19:47 +00:00
"fmt"
"net/http"
"time"
2021-08-08 18:19:47 +00:00
"github.com/gorilla/mux"
2022-05-23 18:17:25 +00:00
"go.uber.org/zap"
2021-08-14 17:56:03 +00:00
"github.com/christianselig/apollo-backend/internal/domain"
2022-03-12 17:50:05 +00:00
"github.com/christianselig/apollo-backend/internal/reddit"
)
type accountNotificationsRequest struct {
2022-03-12 17:50:05 +00:00
InboxNotifications bool `json:"inbox_notifications"`
WatcherNotifications bool `json:"watcher_notifications"`
GlobalMute bool `json:"global_mute"`
}
func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
2022-11-01 02:33:11 +00:00
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
2022-05-07 19:04:35 +00:00
anr := &accountNotificationsRequest{}
if err := json.NewDecoder(r.Body).Decode(anr); err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
return
}
vars := mux.Vars(r)
apns := vars["apns"]
rid := vars["redditID"]
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
return
}
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
if err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
return
}
if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.InboxNotifications, anr.WatcherNotifications, anr.GlobalMute); err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
return
}
w.WriteHeader(http.StatusOK)
}
2022-03-12 17:50:05 +00:00
func (a *api) getNotificationsAccountHandler(w http.ResponseWriter, r *http.Request) {
2022-11-01 02:33:11 +00:00
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
2022-05-07 19:04:35 +00:00
2022-03-12 17:50:05 +00:00
vars := mux.Vars(r)
apns := vars["apns"]
rid := vars["redditID"]
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
2022-03-12 17:50:05 +00:00
return
}
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
if err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
2022-03-12 17:50:05 +00:00
return
}
inbox, watchers, global, err := a.deviceRepo.GetNotifiable(ctx, &dev, &acct)
2022-03-12 17:50:05 +00:00
if err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
2022-03-12 17:50:05 +00:00
return
}
w.WriteHeader(http.StatusOK)
an := &accountNotificationsRequest{InboxNotifications: inbox, WatcherNotifications: watchers, GlobalMute: global}
2022-03-12 17:50:05 +00:00
_ = json.NewEncoder(w).Encode(an)
}
2021-08-08 18:19:47 +00:00
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
2022-11-01 02:33:11 +00:00
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
2022-05-07 19:04:35 +00:00
2021-08-08 18:19:47 +00:00
vars := mux.Vars(r)
apns := vars["apns"]
rid := vars["redditID"]
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
2021-08-08 18:19:47 +00:00
return
}
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
if err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
2021-08-08 18:19:47 +00:00
return
}
if err := a.accountRepo.Disassociate(ctx, &acct, &dev); err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
2021-08-08 18:19:47 +00:00
return
}
w.WriteHeader(http.StatusOK)
}
func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
2022-11-01 02:33:11 +00:00
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
2022-05-07 19:04:35 +00:00
2021-08-08 18:19:47 +00:00
vars := mux.Vars(r)
apns := vars["apns"]
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 422, err)
2021-08-08 18:19:47 +00:00
return
}
laccs, err := a.accountRepo.GetByAPNSToken(ctx, apns)
if err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 422, err)
2021-08-08 18:19:47 +00:00
return
}
accsMap := map[string]domain.Account{}
for _, acc := range laccs {
accsMap[acc.NormalizedUsername()] = acc
}
var raccs []domain.Account
if err := json.NewDecoder(r.Body).Decode(&raccs); err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 422, err)
2021-08-08 18:19:47 +00:00
return
}
for _, acc := range raccs {
delete(accsMap, acc.NormalizedUsername())
2022-05-23 18:17:25 +00:00
rac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
tokens, err := rac.RefreshTokens(ctx)
2021-08-08 18:19:47 +00:00
if err != nil {
err := fmt.Errorf("failed to refresh tokens: %w", err)
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 422, err)
2021-08-08 18:19:47 +00:00
return
}
// Reset expiration timer
2022-03-28 21:33:01 +00:00
acc.TokenExpiresAt = time.Now().Add(tokens.Expiry)
2021-08-08 18:19:47 +00:00
acc.RefreshToken = tokens.RefreshToken
acc.AccessToken = tokens.AccessToken
2022-05-23 18:17:25 +00:00
rac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, tokens.RefreshToken, tokens.AccessToken)
me, err := rac.Me(ctx)
2021-08-08 18:19:47 +00:00
2021-09-25 13:19:42 +00:00
if err != nil {
err := fmt.Errorf("failed to fetch user info: %w", err)
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 422, err)
2021-09-25 13:19:42 +00:00
return
}
2021-08-08 18:19:47 +00:00
if me.NormalizedUsername() != acc.NormalizedUsername() {
2022-05-21 14:00:21 +00:00
err := fmt.Errorf("wrong user: expected %s, got %s", me.NormalizedUsername(), acc.NormalizedUsername())
a.errorResponse(w, r, 401, err)
2021-08-08 18:19:47 +00:00
return
}
2021-08-14 15:21:17 +00:00
// Set account ID from Reddit
acc.AccountID = me.ID
mi, err := rac.MessageInbox(ctx)
if err != nil {
a.errorResponse(w, r, 500, err)
return
}
if mi.Count > 0 {
acc.LastMessageID = mi.Children[0].FullName()
}
2021-08-08 18:19:47 +00:00
if err := a.accountRepo.CreateOrUpdate(ctx, &acc); err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 422, err)
2021-08-08 18:19:47 +00:00
return
}
2022-03-28 21:05:01 +00:00
if err := a.accountRepo.Associate(ctx, &acc, &dev); err != nil {
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 422, err)
2022-03-28 21:05:01 +00:00
return
}
2021-08-08 18:19:47 +00:00
}
for _, acc := range accsMap {
2021-09-25 13:19:42 +00:00
_ = a.accountRepo.Disassociate(ctx, &acc, &dev)
2021-08-08 18:19:47 +00:00
}
2022-05-23 18:17:25 +00:00
w.WriteHeader(http.StatusOK)
2021-08-08 18:19:47 +00:00
}
func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
2022-11-01 02:33:11 +00:00
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
2022-05-07 19:04:35 +00:00
2021-08-08 18:19:47 +00:00
vars := mux.Vars(r)
2021-07-26 16:34:26 +00:00
var acct domain.Account
2021-08-08 18:19:47 +00:00
if err := json.NewDecoder(r.Body).Decode(&acct); err != nil {
2022-05-23 18:17:25 +00:00
a.logger.Error("failed to parse request json", zap.Error(err))
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 422, err)
return
}
// Here we check whether the account is supplied with a valid token.
rac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
tokens, err := rac.RefreshTokens(ctx)
if err != nil {
2022-05-23 18:17:25 +00:00
a.logger.Error("failed to refresh token", zap.Error(err))
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 422, err)
return
}
// Reset expiration timer
2022-03-28 21:33:01 +00:00
acct.TokenExpiresAt = time.Now().Add(tokens.Expiry)
acct.RefreshToken = tokens.RefreshToken
acct.AccessToken = tokens.AccessToken
rac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
me, err := rac.Me(ctx)
if err != nil {
2022-05-23 18:17:25 +00:00
a.logger.Error("failed to grab user details from reddit", zap.Error(err))
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
return
}
if me.NormalizedUsername() != acct.NormalizedUsername() {
2022-05-21 14:00:21 +00:00
err := fmt.Errorf("wrong user: expected %s, got %s", me.NormalizedUsername(), acct.NormalizedUsername())
2022-05-23 18:17:25 +00:00
a.logger.Warn("user is not who they say they are", zap.Error(err))
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 401, err)
return
}
// Set account ID from Reddit
acct.AccountID = me.ID
mi, err := rac.MessageInbox(ctx)
if err != nil {
a.errorResponse(w, r, 500, err)
return
}
if mi.Count > 0 {
acct.LastMessageID = mi.Children[0].FullName()
}
// Associate
2021-08-08 18:19:47 +00:00
dev, err := a.deviceRepo.GetByAPNSToken(ctx, vars["apns"])
if err != nil {
2022-05-23 18:17:25 +00:00
a.logger.Error("failed to fetch device from database", zap.Error(err))
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
return
}
// Upsert account
2021-07-26 16:34:26 +00:00
if err := a.accountRepo.CreateOrUpdate(ctx, &acct); err != nil {
2022-05-23 18:17:25 +00:00
a.logger.Error("failed to update account", zap.Error(err))
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
return
}
2021-08-08 18:19:47 +00:00
if err := a.accountRepo.Associate(ctx, &acct, &dev); err != nil {
2022-05-23 18:17:25 +00:00
a.logger.Error("failed to associate account with device", zap.Error(err))
2022-05-21 14:00:21 +00:00
a.errorResponse(w, r, 500, err)
return
}
w.WriteHeader(http.StatusOK)
}