don't let weird watchers through

This commit is contained in:
Andre Medeiros 2021-10-14 00:25:29 -04:00
parent 95b58b2c01
commit 06b715fa6a
2 changed files with 21 additions and 1 deletions

View file

@ -1,6 +1,10 @@
package domain
import "context"
import (
"context"
validation "github.com/go-ozzo/ozzo-validation/v4"
)
type WatcherType int64
@ -47,6 +51,14 @@ type Watcher struct {
Account Account
}
func (w *Watcher) Validate() error {
return validation.ValidateStruct(w,
validation.Field(&w.Label, validation.Required, validation.Length(1, 64)),
validation.Field(&w.Type, validation.Required, validation.In(SubredditWatcher, UserWatcher, TrendingWatcher)),
validation.Field(&w.WatcheeID, validation.Required, validation.Min(1)),
)
}
type WatcherRepository interface {
GetByID(ctx context.Context, id int64) (Watcher, error)
GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error)

View file

@ -177,6 +177,10 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c
}
func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.Watcher) error {
if err := watcher.Validate(); err != nil {
return err
}
now := float64(time.Now().UTC().Unix())
query := `
@ -204,6 +208,10 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
}
func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.Watcher) error {
if err := watcher.Validate(); err != nil {
return err
}
query := `
UPDATE watchers
SET author = $2,