diff --git a/internal/domain/watcher.go b/internal/domain/watcher.go index bd1128d..3012396 100644 --- a/internal/domain/watcher.go +++ b/internal/domain/watcher.go @@ -1,6 +1,10 @@ package domain -import "context" +import ( + "context" + + validation "github.com/go-ozzo/ozzo-validation/v4" +) type WatcherType int64 @@ -47,6 +51,14 @@ type Watcher struct { Account Account } +func (w *Watcher) Validate() error { + return validation.ValidateStruct(w, + validation.Field(&w.Label, validation.Required, validation.Length(1, 64)), + validation.Field(&w.Type, validation.Required, validation.In(SubredditWatcher, UserWatcher, TrendingWatcher)), + validation.Field(&w.WatcheeID, validation.Required, validation.Min(1)), + ) +} + type WatcherRepository interface { GetByID(ctx context.Context, id int64) (Watcher, error) GetBySubredditID(ctx context.Context, id int64) ([]Watcher, error) diff --git a/internal/repository/postgres_watcher.go b/internal/repository/postgres_watcher.go index dad7b2e..3cd4ab5 100644 --- a/internal/repository/postgres_watcher.go +++ b/internal/repository/postgres_watcher.go @@ -177,6 +177,10 @@ func (p *postgresWatcherRepository) GetByDeviceAPNSTokenAndAccountRedditID(ctx c } func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.Watcher) error { + if err := watcher.Validate(); err != nil { + return err + } + now := float64(time.Now().UTC().Unix()) query := ` @@ -204,6 +208,10 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain. } func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.Watcher) error { + if err := watcher.Validate(); err != nil { + return err + } + query := ` UPDATE watchers SET author = $2,