apollo-backend/internal/api/accounts.go
2021-08-14 13:56:03 -04:00

194 lines
4.5 KiB
Go

package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/domain"
)
func (a *api) disassociateAccountHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
apns := vars["apns"]
rid := vars["redditID"]
ctx := context.Background()
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
acct, err := a.accountRepo.GetByRedditID(ctx, rid)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
if err := a.accountRepo.Disassociate(ctx, &acct, &dev); err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}
w.WriteHeader(http.StatusOK)
}
func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
apns := vars["apns"]
ctx := context.Background()
dev, err := a.deviceRepo.GetByAPNSToken(ctx, apns)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
laccs, err := a.accountRepo.GetByAPNSToken(ctx, apns)
if err != nil {
fmt.Println("accounts by apns")
a.errorResponse(w, r, 422, err.Error())
return
}
accsMap := map[string]domain.Account{}
for _, acc := range laccs {
accsMap[acc.NormalizedUsername()] = acc
}
var raccs []domain.Account
if err := json.NewDecoder(r.Body).Decode(&raccs); err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
for _, acc := range raccs {
delete(accsMap, acc.NormalizedUsername())
ac := a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
tokens, err := ac.RefreshTokens()
if err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
// Reset expiration timer
acc.ExpiresAt = time.Now().Unix() + 3540
acc.RefreshToken = tokens.RefreshToken
acc.AccessToken = tokens.AccessToken
ac = a.reddit.NewAuthenticatedClient(acc.RefreshToken, acc.AccessToken)
me, err := ac.Me()
if me.NormalizedUsername() != acc.NormalizedUsername() {
a.errorResponse(w, r, 422, "nice try")
return
}
// Set account ID from Reddit
acc.AccountID = me.ID
if err := a.accountRepo.CreateOrUpdate(ctx, &acc); err != nil {
a.errorResponse(w, r, 422, err.Error())
return
}
a.accountRepo.Associate(ctx, &acc, &dev)
}
for _, acc := range accsMap {
fmt.Println(acc.NormalizedUsername())
a.accountRepo.Disassociate(ctx, &acc, &dev)
}
w.WriteHeader(http.StatusOK)
}
func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
ctx := context.Background()
var acct domain.Account
if err := json.NewDecoder(r.Body).Decode(&acct); err != nil {
a.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed to parse request json")
a.errorResponse(w, r, 422, err.Error())
return
}
// Here we check whether the account is supplied with a valid token.
ac := a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken)
tokens, err := ac.RefreshTokens()
if err != nil {
a.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed to refresh token")
a.errorResponse(w, r, 422, err.Error())
return
}
// Reset expiration timer
acct.ExpiresAt = time.Now().Unix() + 3540
acct.RefreshToken = tokens.RefreshToken
acct.AccessToken = tokens.AccessToken
ac = a.reddit.NewAuthenticatedClient(acct.RefreshToken, acct.AccessToken)
me, err := ac.Me()
if err != nil {
a.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed to grab user details from Reddit")
a.errorResponse(w, r, 500, err.Error())
return
}
if me.NormalizedUsername() != acct.NormalizedUsername() {
a.logger.WithFields(logrus.Fields{
"err": err,
}).Info("user is not who they say they are")
a.errorResponse(w, r, 422, "nice try")
return
}
// Set account ID from Reddit
acct.AccountID = me.ID
// Associate
dev, err := a.deviceRepo.GetByAPNSToken(ctx, vars["apns"])
if err != nil {
a.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed fetching device from database")
a.errorResponse(w, r, 500, err.Error())
return
}
// Upsert account
if err := a.accountRepo.CreateOrUpdate(ctx, &acct); err != nil {
a.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed updating account in database")
a.errorResponse(w, r, 500, err.Error())
return
}
if err := a.accountRepo.Associate(ctx, &acct, &dev); err != nil {
a.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed associating account with device")
a.errorResponse(w, r, 500, err.Error())
return
}
w.WriteHeader(http.StatusOK)
}