diff --git a/internal/api/watcher.go b/internal/api/watcher.go index 054769f..0cebeba 100644 --- a/internal/api/watcher.go +++ b/internal/api/watcher.go @@ -111,8 +111,12 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) { } if cwr.Type == "subreddit" || cwr.Type == "trending" { - srr, err := a.reddit.SubredditAbout(ctx, cwr.Subreddit) - if err != nil { + ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) + srr, err := ac.SubredditAbout(ctx, cwr.Subreddit) + if !srr.Public { + a.errorResponse(w, r, 403, reddit.ErrSubredditIsPrivate) + return + } else if err != nil { switch err { case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined: err = fmt.Errorf("error watching %s: %w", cwr.Subreddit, err) @@ -229,6 +233,7 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) { if watcher.Type == domain.SubredditWatcher { lsr := strings.ToLower(watcher.Subreddit) if watcher.WatcheeLabel != lsr { + var account domain.Account accs, err := a.accountRepo.GetByAPNSToken(ctx, apns) if err != nil { a.errorResponse(w, r, 422, err) @@ -244,6 +249,7 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) { found := false for _, acc := range accs { if acc.AccountID == rid { + account = acc found = true } } @@ -254,8 +260,12 @@ func (a *api) editWatcherHandler(w http.ResponseWriter, r *http.Request) { return } - srr, err := a.reddit.SubredditAbout(ctx, lsr) - if err != nil { + ac := a.reddit.NewAuthenticatedClient(account.AccountID, account.RefreshToken, account.AccessToken) + srr, err := ac.SubredditAbout(ctx, lsr) + if !srr.Public { + a.errorResponse(w, r, 403, reddit.ErrSubredditIsPrivate) + return + } else if err != nil { switch err { case reddit.ErrSubredditIsPrivate, reddit.ErrSubredditIsQuarantined: err = fmt.Errorf("error watching %s: %w", lsr, err) diff --git a/internal/reddit/types.go b/internal/reddit/types.go index 39c885d..715f3fd 100644 --- a/internal/reddit/types.go +++ b/internal/reddit/types.go @@ -160,6 +160,7 @@ type SubredditResponse struct { Name string Quarantined bool + Public bool } func NewSubredditResponse(val *fastjson.Value) interface{} { @@ -171,7 +172,7 @@ func NewSubredditResponse(val *fastjson.Value) interface{} { sr.ID = string(data.GetStringBytes("id")) sr.Name = string(data.GetStringBytes("display_name")) sr.Quarantined = data.GetBool("quarantine") - + sr.Public = string(data.GetStringBytes("subreddit_type")) == "public" return sr } diff --git a/internal/reddit/types_test.go b/internal/reddit/types_test.go index 916cb56..5cae384 100644 --- a/internal/reddit/types_test.go +++ b/internal/reddit/types_test.go @@ -137,6 +137,7 @@ func TestSubredditResponseParsing(t *testing.T) { assert.Equal(t, "2vq0w", s.ID) assert.Equal(t, "DestinyTheGame", s.Name) assert.Equal(t, false, s.Quarantined) + assert.Equal(t, true, s.Public) } func TestUserResponseParsing(t *testing.T) {