diff --git a/internal/api/accounts.go b/internal/api/accounts.go index 168c707..938003d 100644 --- a/internal/api/accounts.go +++ b/internal/api/accounts.go @@ -1,18 +1,20 @@ package api import ( + "context" "encoding/json" "net/http" "time" + "github.com/christianselig/apollo-backend/internal/domain" "github.com/julienschmidt/httprouter" "github.com/sirupsen/logrus" - - "github.com/christianselig/apollo-backend/internal/data" ) func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - acct := &data.Account{} + ctx := context.Background() + var acct domain.Account + if err := json.NewDecoder(r.Body).Decode(acct); err != nil { a.logger.WithFields(logrus.Fields{ "err": err, @@ -60,7 +62,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps ht acct.AccountID = me.ID // Associate - d, err := a.models.Devices.GetByAPNSToken(ps.ByName("apns")) + dev, err := a.deviceRepo.GetByAPNSToken(ctx, ps.ByName("apns")) if err != nil { a.logger.WithFields(logrus.Fields{ "err": err, @@ -70,7 +72,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps ht } // Upsert account - if err := a.models.Accounts.Upsert(acct); err != nil { + if err := a.accountRepo.CreateOrUpdate(ctx, &acct); err != nil { a.logger.WithFields(logrus.Fields{ "err": err, }).Info("failed updating account in database") diff --git a/internal/api/api.go b/internal/api/api.go index d4f83df..bb2f61d 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -11,19 +11,21 @@ import ( "github.com/julienschmidt/httprouter" "github.com/sirupsen/logrus" - "github.com/christianselig/apollo-backend/internal/data" + "github.com/christianselig/apollo-backend/internal/domain" "github.com/christianselig/apollo-backend/internal/reddit" + "github.com/christianselig/apollo-backend/internal/repository" ) type api struct { logger *logrus.Logger statsd *statsd.Client - db *pgxpool.Pool reddit *reddit.Client - models *data.Models + + accountRepo domain.AccountRepository + deviceRepo domain.DeviceRepository } -func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool) *api { +func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api { reddit := reddit.NewClient( os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_SECRET"), @@ -31,9 +33,16 @@ func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, d 16, ) - models := data.NewModels(ctx, db) + accountRepo := repository.NewPostgresAccount(pool) + deviceRepo := repository.NewPostgresDevice(pool) - return &api{logger, statsd, db, reddit, models} + return &api{ + logger: logger, + statsd: statsd, + reddit: reddit, + accountRepo: accountRepo, + deviceRepo: deviceRepo, + } } func (a *api) Server(port int) *http.Server { diff --git a/internal/domain/account.go b/internal/domain/account.go index 2c98c5c..4a767da 100644 --- a/internal/domain/account.go +++ b/internal/domain/account.go @@ -23,6 +23,7 @@ type AccountRepository interface { GetByID(ctx context.Context, id int64) (Account, error) GetByRedditID(ctx context.Context, id string) (Account, error) + CreateOrUpdate(ctx context.Context, acc *Account) error Update(ctx context.Context, acc *Account) error Create(ctx context.Context, acc *Account) error Delete(ctx context.Context, id int64) error diff --git a/internal/domain/device.go b/internal/domain/device.go index 70f4cad..5fad5d3 100644 --- a/internal/domain/device.go +++ b/internal/domain/device.go @@ -11,7 +11,7 @@ type Device struct { type DeviceRepository interface { GetByAPNSToken(ctx context.Context, token string) (Device, error) - + CreateOrUpdate(ctx context.Context, dev *Device) error Update(ctx context.Context, dev *Device) error Create(ctx context.Context, dev *Device) error Delete(ctx context.Context, token string) error diff --git a/internal/account/repository/postgres/postgres_account.go b/internal/repository/postgres_account.go similarity index 81% rename from internal/account/repository/postgres/postgres_account.go rename to internal/repository/postgres_account.go index c6709c4..e0c648f 100644 --- a/internal/account/repository/postgres/postgres_account.go +++ b/internal/repository/postgres_account.go @@ -1,18 +1,19 @@ -package postgres +package repository import ( "context" "fmt" - "github.com/christianselig/apollo-backend/internal/domain" "github.com/jackc/pgx/v4/pgxpool" + + "github.com/christianselig/apollo-backend/internal/domain" ) type postgresAccountRepository struct { pool *pgxpool.Pool } -func NewPostgresAccountRepository(pool *pgxpool.Pool) domain.AccountRepository { +func NewPostgresAccount(pool *pgxpool.Pool) domain.AccountRepository { return &postgresAccountRepository{pool: pool} } @@ -40,7 +41,6 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg } accs = append(accs, acc) } - return accs, nil } @@ -58,7 +58,6 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma if len(accs) == 0 { return domain.Account{}, domain.ErrNotFound } - return accs[0], nil } @@ -79,6 +78,31 @@ func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string return accs[0], nil } +func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error { + query := ` + INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at) + VALUES ($1, $2, $3, $4, $5, '', 0, 0) + ON CONFLICT(username) DO + UPDATE SET access_token = $3, + refresh_token = $4, + expires_at = $5 + RETURNING id` + + res, err := p.pool.Query( + ctx, + query, + acc.Username, + acc.AccountID, + acc.AccessToken, + acc.RefreshToken, + acc.ExpiresAt, + ) + if err != nil { + return err + } + + return res.Scan(&acc.ID) +} func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error { query := ` diff --git a/internal/repository/postgres_device.go b/internal/repository/postgres_device.go new file mode 100644 index 0000000..648de3d --- /dev/null +++ b/internal/repository/postgres_device.go @@ -0,0 +1,128 @@ +package repository + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v4/pgxpool" + + "github.com/christianselig/apollo-backend/internal/domain" +) + +type postgresDeviceRepository struct { + pool *pgxpool.Pool +} + +func NewPostgresDevice(pool *pgxpool.Pool) domain.DeviceRepository { + return &postgresDeviceRepository{pool: pool} +} + +func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Device, error) { + rows, err := p.pool.Query(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var devs []domain.Device + for rows.Next() { + var dev domain.Device + if err := rows.Scan( + &dev.ID, + &dev.APNSToken, + &dev.Sandbox, + &dev.LastPingedAt, + ); err != nil { + return nil, err + } + devs = append(devs, dev) + } + return devs, nil +} + +func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token string) (domain.Device, error) { + query := ` + SELECT id, apns_token, sandbox, last_pinged_at + FROM devices + WHERE id = $1` + + devs, err := p.fetch(ctx, query, token) + + if err != nil { + return domain.Device{}, err + } + if len(devs) == 0 { + return domain.Device{}, domain.ErrNotFound + } + return devs[0], nil +} + +func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error { + query := ` + INSERT INTO devices (apns_token, sandbox, last_pinged_at) + VALUES ($1, $2, $3) + ON CONFLICT(apns_token) + DO + UPDATE SET last_pinged_at = $3 + RETURNING id` + + res, err := p.pool.Query( + ctx, + query, + dev.APNSToken, + dev.Sandbox, + dev.LastPingedAt, + ) + + if err != nil { + return err + } + return res.Scan(&dev.ID) + +} + +func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Device) error { + query := ` + INSERT INTO devices + (apns_token, sandbox, last_pinged_at) + VALUES ($1, $2, $3) + RETURNING id` + + res, err := p.pool.Query( + ctx, + query, + dev.APNSToken, + dev.Sandbox, + dev.LastPingedAt, + ) + + if err != nil { + return err + } + return res.Scan(&dev.ID) +} + +func (p *postgresDeviceRepository) Update(ctx context.Context, dev *domain.Device) error { + query := ` + UPDATE devices + SET last_pinged_at = $2 + WHERE id = $1` + + res, err := p.pool.Exec(ctx, query, dev.ID, dev.LastPingedAt) + + if res.RowsAffected() != 1 { + return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) + } + return err +} + +func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) error { + query := `DELETE FROM devices WHERE apns_token = $1` + + res, err := p.pool.Exec(ctx, query, token) + + if res.RowsAffected() != 1 { + return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected()) + } + return err +}