watchers check subreddit info with authentication

This commit is contained in:
Andre Medeiros 2022-06-30 15:24:58 -04:00
parent d5affe36c1
commit 655300682f
3 changed files with 17 additions and 5 deletions

View file

@ -111,8 +111,12 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
} }
if cwr.Type == "subreddit" || cwr.Type == "trending" { if cwr.Type == "subreddit" || cwr.Type == "trending" {
srr, err := a.reddit.SubredditAbout(ctx, cwr.Subreddit) ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
if err != nil { srr, err := ac.SubredditAbout(ctx, cwr.Subreddit)
if !srr.Public {
a.errorResponse(w, r, 403, reddit.ErrSubredditIsPrivate)
return
} else if err != nil {
switch err { switch err {
case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined: case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined:
err = fmt.Errorf("error watching %s: %w", cwr.Subreddit, err) err = fmt.Errorf("error watching %s: %w", cwr.Subreddit, err)
@ -229,6 +233,7 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
if watcher.Type == domain.SubredditWatcher { if watcher.Type == domain.SubredditWatcher {
lsr := strings.ToLower(watcher.Subreddit) lsr := strings.ToLower(watcher.Subreddit)
if watcher.WatcheeLabel != lsr { if watcher.WatcheeLabel != lsr {
var account domain.Account
accs, err := a.accountRepo.GetByAPNSToken(ctx, apns) accs, err := a.accountRepo.GetByAPNSToken(ctx, apns)
if err != nil { if err != nil {
a.errorResponse(w, r, 422, err) a.errorResponse(w, r, 422, err)
@ -244,6 +249,7 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
found := false found := false
for _, acc := range accs { for _, acc := range accs {
if acc.AccountID == rid { if acc.AccountID == rid {
account = acc
found = true found = true
} }
} }
@ -254,8 +260,12 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
srr, err := a.reddit.SubredditAbout(ctx, lsr) ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken)
if err != nil { srr, err := ac.SubredditAbout(ctx, lsr)
if !srr.Public {
a.errorResponse(w, r, 403, reddit.ErrSubredditIsPrivate)
return
} else if err != nil {
switch err { switch err {
case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined: case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined:
err = fmt.Errorf("error watching %s: %w", lsr, err) err = fmt.Errorf("error watching %s: %w", lsr, err)

View file

@ -160,6 +160,7 @@ type SubredditResponse struct {
Name string Name string
Quarantined bool Quarantined bool
Public bool
} }
func NewSubredditResponse(val *fastjson.Value) interface{} { func NewSubredditResponse(val *fastjson.Value) interface{} {
@ -171,7 +172,7 @@ func NewSubredditResponse(val *fastjson.Value) interface{} {
sr.ID = string(data.GetStringBytes("id")) sr.ID = string(data.GetStringBytes("id"))
sr.Name = string(data.GetStringBytes("display_name")) sr.Name = string(data.GetStringBytes("display_name"))
sr.Quarantined = data.GetBool("quarantine") sr.Quarantined = data.GetBool("quarantine")
sr.Public = string(data.GetStringBytes("subreddit_type")) == "public"
return sr return sr
} }

View file

@ -137,6 +137,7 @@ func TestSubredditResponseParsing(t *testing.T) {
assert.Equal(t, "2vq0w", s.ID) assert.Equal(t, "2vq0w", s.ID)
assert.Equal(t, "DestinyTheGame", s.Name) assert.Equal(t, "DestinyTheGame", s.Name)
assert.Equal(t, false, s.Quarantined) assert.Equal(t, false, s.Quarantined)
assert.Equal(t, true, s.Public)
} }
func TestUserResponseParsing(t *testing.T) { func TestUserResponseParsing(t *testing.T) {