This commit is contained in:
Andre Medeiros 2021-07-26 12:34:26 -04:00
parent 5603d79b29
commit 22279185e1
6 changed files with 181 additions and 17 deletions

View file

@ -1,18 +1,20 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"time" "time"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/data"
) )
func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
acct := &data.Account{} ctx := context.Background()
var acct domain.Account
if err := json.NewDecoder(r.Body).Decode(acct); err != nil { if err := json.NewDecoder(r.Body).Decode(acct); err != nil {
a.logger.WithFields(logrus.Fields{ a.logger.WithFields(logrus.Fields{
"err": err, "err": err,
@ -60,7 +62,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps ht
acct.AccountID = me.ID acct.AccountID = me.ID
// Associate // Associate
d, err := a.models.Devices.GetByAPNSToken(ps.ByName("apns")) dev, err := a.deviceRepo.GetByAPNSToken(ctx, ps.ByName("apns"))
if err != nil { if err != nil {
a.logger.WithFields(logrus.Fields{ a.logger.WithFields(logrus.Fields{
"err": err, "err": err,
@ -70,7 +72,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps ht
} }
// Upsert account // Upsert account
if err := a.models.Accounts.Upsert(acct); err != nil { if err := a.accountRepo.CreateOrUpdate(ctx, &acct); err != nil {
a.logger.WithFields(logrus.Fields{ a.logger.WithFields(logrus.Fields{
"err": err, "err": err,
}).Info("failed updating account in database") }).Info("failed updating account in database")

View file

@ -11,19 +11,21 @@ import (
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/data" "github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/reddit" "github.com/christianselig/apollo-backend/internal/reddit"
"github.com/christianselig/apollo-backend/internal/repository"
) )
type api struct { type api struct {
logger *logrus.Logger logger *logrus.Logger
statsd *statsd.Client statsd *statsd.Client
db *pgxpool.Pool
reddit *reddit.Client reddit *reddit.Client
models *data.Models
accountRepo domain.AccountRepository
deviceRepo domain.DeviceRepository
} }
func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool) *api { func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) *api {
reddit := reddit.NewClient( reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"), os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"), os.Getenv("REDDIT_CLIENT_SECRET"),
@ -31,9 +33,16 @@ func NewAPI(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, d
16, 16,
) )
models := data.NewModels(ctx, db) accountRepo := repository.NewPostgresAccount(pool)
deviceRepo := repository.NewPostgresDevice(pool)
return &api{logger, statsd, db, reddit, models} return &api{
logger: logger,
statsd: statsd,
reddit: reddit,
accountRepo: accountRepo,
deviceRepo: deviceRepo,
}
} }
func (a *api) Server(port int) *http.Server { func (a *api) Server(port int) *http.Server {

View file

@ -23,6 +23,7 @@ type AccountRepository interface {
GetByID(ctx context.Context, id int64) (Account, error) GetByID(ctx context.Context, id int64) (Account, error)
GetByRedditID(ctx context.Context, id string) (Account, error) GetByRedditID(ctx context.Context, id string) (Account, error)
CreateOrUpdate(ctx context.Context, acc *Account) error
Update(ctx context.Context, acc *Account) error Update(ctx context.Context, acc *Account) error
Create(ctx context.Context, acc *Account) error Create(ctx context.Context, acc *Account) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error

View file

@ -11,7 +11,7 @@ type Device struct {
type DeviceRepository interface { type DeviceRepository interface {
GetByAPNSToken(ctx context.Context, token string) (Device, error) GetByAPNSToken(ctx context.Context, token string) (Device, error)
CreateOrUpdate(ctx context.Context, dev *Device) error
Update(ctx context.Context, dev *Device) error Update(ctx context.Context, dev *Device) error
Create(ctx context.Context, dev *Device) error Create(ctx context.Context, dev *Device) error
Delete(ctx context.Context, token string) error Delete(ctx context.Context, token string) error

View file

@ -1,18 +1,19 @@
package postgres package repository
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
"github.com/christianselig/apollo-backend/internal/domain"
) )
type postgresAccountRepository struct { type postgresAccountRepository struct {
pool *pgxpool.Pool pool *pgxpool.Pool
} }
func NewPostgresAccountRepository(pool *pgxpool.Pool) domain.AccountRepository { func NewPostgresAccount(pool *pgxpool.Pool) domain.AccountRepository {
return &postgresAccountRepository{pool: pool} return &postgresAccountRepository{pool: pool}
} }
@ -40,7 +41,6 @@ func (p *postgresAccountRepository) fetch(ctx context.Context, query string, arg
} }
accs = append(accs, acc) accs = append(accs, acc)
} }
return accs, nil return accs, nil
} }
@ -58,7 +58,6 @@ func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (doma
if len(accs) == 0 { if len(accs) == 0 {
return domain.Account{}, domain.ErrNotFound return domain.Account{}, domain.ErrNotFound
} }
return accs[0], nil return accs[0], nil
} }
@ -79,6 +78,31 @@ func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string
return accs[0], nil return accs[0], nil
} }
func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *domain.Account) error {
query := `
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
VALUES ($1, $2, $3, $4, $5, '', 0, 0)
ON CONFLICT(username) DO
UPDATE SET access_token = $3,
refresh_token = $4,
expires_at = $5
RETURNING id`
res, err := p.pool.Query(
ctx,
query,
acc.Username,
acc.AccountID,
acc.AccessToken,
acc.RefreshToken,
acc.ExpiresAt,
)
if err != nil {
return err
}
return res.Scan(&acc.ID)
}
func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error { func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error {
query := ` query := `

View file

@ -0,0 +1,128 @@
package repository
import (
"context"
"fmt"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/christianselig/apollo-backend/internal/domain"
)
type postgresDeviceRepository struct {
pool *pgxpool.Pool
}
func NewPostgresDevice(pool *pgxpool.Pool) domain.DeviceRepository {
return &postgresDeviceRepository{pool: pool}
}
func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Device, error) {
rows, err := p.pool.Query(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var devs []domain.Device
for rows.Next() {
var dev domain.Device
if err := rows.Scan(
&dev.ID,
&dev.APNSToken,
&dev.Sandbox,
&dev.LastPingedAt,
); err != nil {
return nil, err
}
devs = append(devs, dev)
}
return devs, nil
}
func (p *postgresDeviceRepository) GetByAPNSToken(ctx context.Context, token string) (domain.Device, error) {
query := `
SELECT id, apns_token, sandbox, last_pinged_at
FROM devices
WHERE id = $1`
devs, err := p.fetch(ctx, query, token)
if err != nil {
return domain.Device{}, err
}
if len(devs) == 0 {
return domain.Device{}, domain.ErrNotFound
}
return devs[0], nil
}
func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *domain.Device) error {
query := `
INSERT INTO devices (apns_token, sandbox, last_pinged_at)
VALUES ($1, $2, $3)
ON CONFLICT(apns_token)
DO
UPDATE SET last_pinged_at = $3
RETURNING id`
res, err := p.pool.Query(
ctx,
query,
dev.APNSToken,
dev.Sandbox,
dev.LastPingedAt,
)
if err != nil {
return err
}
return res.Scan(&dev.ID)
}
func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Device) error {
query := `
INSERT INTO devices
(apns_token, sandbox, last_pinged_at)
VALUES ($1, $2, $3)
RETURNING id`
res, err := p.pool.Query(
ctx,
query,
dev.APNSToken,
dev.Sandbox,
dev.LastPingedAt,
)
if err != nil {
return err
}
return res.Scan(&dev.ID)
}
func (p *postgresDeviceRepository) Update(ctx context.Context, dev *domain.Device) error {
query := `
UPDATE devices
SET last_pinged_at = $2
WHERE id = $1`
res, err := p.pool.Exec(ctx, query, dev.ID, dev.LastPingedAt)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
}
return err
}
func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) error {
query := `DELETE FROM devices WHERE apns_token = $1`
res, err := p.pool.Exec(ctx, query, token)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
}
return err
}