diff --git a/internal/api/accounts.go b/internal/api/accounts.go index 6e3a580..f3adb47 100644 --- a/internal/api/accounts.go +++ b/internal/api/accounts.go @@ -25,7 +25,7 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request anr := &accountNotificationsRequest{} if err := json.NewDecoder(r.Body).Decode(anr); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -35,18 +35,18 @@ func (a *api) notificationsAccountHandler(w http.ResponseWriter, r *http.Request dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } acct, err := a.accountRepo.GetByRedditID(ctx, rid) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } if err := a.deviceRepo.SetNotifiable(ctx, &dev, &acct, anr.InboxNotifications, anr.WatcherNotifications, anr.GlobalMute); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -62,19 +62,19 @@ func (a *api) getNotificationsAccountHandler(w http.ResponseWriter, r *http.Requ dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } acct, err := a.accountRepo.GetByRedditID(ctx, rid) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } inbox, watchers, global, err := a.deviceRepo.GetNotifiable(ctx, &dev, &acct) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -93,18 +93,18 @@ func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } acct, err := a.accountRepo.GetByRedditID(ctx, rid) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } if err := a.accountRepo.Disassociate(ctx, &acct, &dev); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -119,13 +119,13 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) if err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } laccs, err := a.accountRepo.GetByAPNSToken(ctx, apns) if err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } @@ -136,7 +136,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { var raccs []domain.Account if err := json.NewDecoder(r.Body).Decode(&raccs); err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } for _, acc := range raccs { @@ -145,7 +145,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken) tokens, err := ac.RefreshTokens(ctx) if err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } @@ -158,12 +158,13 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { me, err := ac.Me(ctx) if err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } if me.NormalizedUsername() != acc.NormalizedUsername() { - a.errorResponse(w, r, 422, "nice try") + err := fmt.Errorf("wrong user: expected %s, got %s", me.NormalizedUsername(), acc.NormalizedUsername()) + a.errorResponse(w, r, 401, err) return } @@ -171,12 +172,12 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) { acc.AccountID = me.ID if err := a.accountRepo.CreateOrUpdate(ctx, &acc); err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } if err := a.accountRepo.Associate(ctx, &acc, &dev); err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } } @@ -218,7 +219,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed to parse request json") - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } @@ -229,7 +230,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed to refresh token") - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } @@ -245,15 +246,14 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed to grab user details from Reddit") - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } if me.NormalizedUsername() != acct.NormalizedUsername() { - a.logger.WithFields(logrus.Fields{ - "err": err, - }).Info("user is not who they say they are") - a.errorResponse(w, r, 422, "nice try") + err := fmt.Errorf("wrong user: expected %s, got %s", me.NormalizedUsername(), acct.NormalizedUsername()) + a.logger.WithFields(logrus.Fields{"err": err}).Warn("user is not who they say they are") + a.errorResponse(w, r, 401, err) return } @@ -266,7 +266,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed fetching device from database") - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -275,7 +275,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed updating account in database") - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -283,7 +283,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed associating account with device") - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } diff --git a/internal/api/api.go b/internal/api/api.go index d073211..7a3628c 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -128,7 +128,7 @@ func (a *api) Routes() *mux.Router { func (a *api) testBugsnagHandler(w http.ResponseWriter, r *http.Request) { if err := bugsnag.Notify(fmt.Errorf("Test error")); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } w.WriteHeader(http.StatusOK) diff --git a/internal/api/contact.go b/internal/api/contact.go index 82f5e73..eb0fb50 100644 --- a/internal/api/contact.go +++ b/internal/api/contact.go @@ -15,7 +15,7 @@ type sendMessageRequest struct { func (a *api) contactHandler(w http.ResponseWriter, r *http.Request) { smr := &sendMessageRequest{} if err := json.NewDecoder(r.Body).Decode(smr); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -27,7 +27,7 @@ func (a *api) contactHandler(w http.ResponseWriter, r *http.Request) { } if _, err := smtp2go.Send(msg); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } w.WriteHeader(http.StatusOK) diff --git a/internal/api/devices.go b/internal/api/devices.go index 791f02e..dc7c498 100644 --- a/internal/api/devices.go +++ b/internal/api/devices.go @@ -23,7 +23,7 @@ func (a *api) upsertDeviceHandler(w http.ResponseWriter, r *http.Request) { d := &domain.Device{} if err := json.NewDecoder(r.Body).Decode(d); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -31,7 +31,7 @@ func (a *api) upsertDeviceHandler(w http.ResponseWriter, r *http.Request) { d.GracePeriodExpiresAt = d.ExpiresAt.Add(domain.DeviceGracePeriodAfterReceiptExpiry) if err := a.deviceRepo.CreateOrUpdate(ctx, d); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -49,13 +49,13 @@ func (a *api) testDeviceHandler(w http.ResponseWriter, r *http.Request) { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed fetching device from database") - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } accs, err := a.accountRepo.GetByAPNSToken(ctx, tok) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -86,7 +86,7 @@ func (a *api) testDeviceHandler(w http.ResponseWriter, r *http.Request) { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed to send test notification") - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } w.WriteHeader(http.StatusOK) @@ -99,13 +99,13 @@ func (a *api) deleteDeviceHandler(w http.ResponseWriter, r *http.Request) { dev, err := a.deviceRepo.GetByAPNSToken(ctx, vars["apns"]) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } accs, err := a.accountRepo.GetByAPNSToken(ctx, vars["apns"]) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } diff --git a/internal/api/errors.go b/internal/api/errors.go index 83815b5..045feae 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -2,7 +2,7 @@ package api import "net/http" -func (a *api) errorResponse(w http.ResponseWriter, _ *http.Request, status int, message string) { - w.Header().Set("X-Apollo-Error", message) - http.Error(w, message, status) +func (a *api) errorResponse(w http.ResponseWriter, _ *http.Request, status int, err error) { + w.Header().Set("X-Apollo-Error", err.Error()) + http.Error(w, err.Error(), status) } diff --git a/internal/api/notifications.go b/internal/api/notifications.go index e92ef1f..e6df177 100644 --- a/internal/api/notifications.go +++ b/internal/api/notifications.go @@ -34,7 +34,7 @@ func generateNotificationTester(a *api, fun notificationGenerator) func(w http.R a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed fetching device from database") - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -59,7 +59,7 @@ func generateNotificationTester(a *api, fun notificationGenerator) func(w http.R a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed to send test notification") - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } w.WriteHeader(http.StatusOK) diff --git a/internal/api/receipt.go b/internal/api/receipt.go index 903298a..0426aa8 100644 --- a/internal/api/receipt.go +++ b/internal/api/receipt.go @@ -26,14 +26,14 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed verifying receipt") - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } if apns != "" { dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -45,7 +45,7 @@ func (a *api) checkReceiptHandler(w http.ResponseWriter, r *http.Request) { accs, err := a.accountRepo.GetByAPNSToken(ctx, apns) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } diff --git a/internal/api/watcher.go b/internal/api/watcher.go index e843871..23c21b6 100644 --- a/internal/api/watcher.go +++ b/internal/api/watcher.go @@ -2,6 +2,8 @@ package api import ( "encoding/json" + "errors" + "fmt" "net/http" "strconv" "strings" @@ -53,29 +55,30 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { Criteria: watcherCriteria{}, } if err := json.NewDecoder(r.Body).Decode(cwr); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } if err := cwr.Validate(); err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns) if err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } accs, err := a.accountRepo.GetByAPNSToken(ctx, apns) if err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } if len(accs) == 0 { - a.errorResponse(w, r, 422, "can't create watchers without accounts") + err := errors.New("cannot create watchers without account") + a.errorResponse(w, r, 422, err) return } @@ -89,7 +92,8 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { } if !found { - a.errorResponse(w, r, 422, "yeah nice try") + err := errors.New("account not associated with device") + a.errorResponse(w, r, 401, err) return } @@ -110,7 +114,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { if cwr.Type == "subreddit" || cwr.Type == "trending" { srr, err := ac.SubredditAbout(ctx, cwr.Subreddit) if err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } @@ -122,7 +126,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { sr = domain.Subreddit{SubredditID: srr.ID, Name: srr.Name} _ = a.subredditRepo.CreateOrUpdate(ctx, &sr) default: - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } } @@ -139,12 +143,13 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { } else if cwr.Type == "user" { urr, err := ac.UserAbout(ctx, cwr.User) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } if !urr.AcceptFollowers { - a.errorResponse(w, r, 422, "no followers accepted") + err := errors.New("user has followers disabled") + a.errorResponse(w, r, 403, err) return } @@ -152,19 +157,20 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { err = a.userRepo.CreateOrUpdate(ctx, &u) if err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } watcher.Type = domain.UserWatcher watcher.WatcheeID = u.ID } else { - a.errorResponse(w, r, 422, "unknown watcher type") + err := fmt.Errorf("unknown watcher type: %s", cwr.Type) + a.errorResponse(w, r, 422, err) return } if err := a.watcherRepo.Create(ctx, &watcher); err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } @@ -179,13 +185,17 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, err := strconv.ParseInt(vars["watcherID"], 10, 64) if err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } watcher, err := a.watcherRepo.GetByID(ctx, id) - if err != nil || watcher.Device.APNSToken != vars["apns"] { - a.errorResponse(w, r, 422, "nice try") + if err != nil { + a.errorResponse(w, r, 422, err) + return + } else if watcher.Device.APNSToken != vars["apns"] { + err := fmt.Errorf("wrong device for watcher %d", watcher.ID) + a.errorResponse(w, r, 422, err) return } @@ -194,7 +204,7 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) { } if err := json.NewDecoder(r.Body).Decode(ewr); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -207,7 +217,7 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) { watcher.Domain = strings.ToLower(ewr.Criteria.Domain) if err := a.watcherRepo.Update(ctx, &watcher); err != nil { - a.errorResponse(w, r, 500, err.Error()) + a.errorResponse(w, r, 500, err) return } @@ -220,13 +230,17 @@ func (a *api) deleteWatcherHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, err := strconv.ParseInt(vars["watcherID"], 10, 64) if err != nil { - a.errorResponse(w, r, 422, err.Error()) + a.errorResponse(w, r, 422, err) return } watcher, err := a.watcherRepo.GetByID(ctx, id) - if err != nil || watcher.Device.APNSToken != vars["apns"] { - a.errorResponse(w, r, 422, "nice try") + if err != nil { + a.errorResponse(w, r, 422, err) + return + } else if watcher.Device.APNSToken != vars["apns"] { + err := fmt.Errorf("wrong device for watcher %d", watcher.ID) + a.errorResponse(w, r, 422, err) return } @@ -257,7 +271,7 @@ func (a *api) listWatchersHandler(w http.ResponseWriter, r *http.Request) { watchers, err := a.watcherRepo.GetByDeviceAPNSTokenAndAccountRedditID(ctx, apns, redditID) if err != nil { - a.errorResponse(w, r, 400, err.Error()) + a.errorResponse(w, r, 400, err) return }