mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-25 13:17:42 +00:00
add account repository
This commit is contained in:
parent
2a5ad833eb
commit
5603d79b29
4 changed files with 182 additions and 3 deletions
145
internal/account/repository/postgres/postgres_account.go
Normal file
145
internal/account/repository/postgres/postgres_account.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/christianselig/apollo-backend/internal/domain"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
type postgresAccountRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPostgresAccountRepository(pool *pgxpool.Pool) domain.AccountRepository {
|
||||
return &postgresAccountRepository{pool: pool}
|
||||
}
|
||||
|
||||
func (p *postgresAccountRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Account, error) {
|
||||
rows, err := p.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accs []domain.Account
|
||||
for rows.Next() {
|
||||
var acc domain.Account
|
||||
if err := rows.Scan(
|
||||
&acc.ID,
|
||||
&acc.Username,
|
||||
&acc.AccountID,
|
||||
&acc.AccessToken,
|
||||
&acc.RefreshToken,
|
||||
&acc.ExpiresAt,
|
||||
&acc.LastMessageID,
|
||||
&acc.LastCheckedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accs = append(accs, acc)
|
||||
}
|
||||
|
||||
return accs, nil
|
||||
}
|
||||
|
||||
func (p *postgresAccountRepository) GetByID(ctx context.Context, id int64) (domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at
|
||||
FROM accounts
|
||||
WHERE id = $1`
|
||||
|
||||
accs, err := p.fetch(ctx, query, id)
|
||||
if err != nil {
|
||||
return domain.Account{}, err
|
||||
}
|
||||
|
||||
if len(accs) == 0 {
|
||||
return domain.Account{}, domain.ErrNotFound
|
||||
}
|
||||
|
||||
return accs[0], nil
|
||||
}
|
||||
|
||||
func (p *postgresAccountRepository) GetByRedditID(ctx context.Context, id string) (domain.Account, error) {
|
||||
query := `
|
||||
SELECT id, username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at
|
||||
FROM accounts
|
||||
WHERE account_id = $1`
|
||||
|
||||
accs, err := p.fetch(ctx, query, id)
|
||||
if err != nil {
|
||||
return domain.Account{}, err
|
||||
}
|
||||
|
||||
if len(accs) == 0 {
|
||||
return domain.Account{}, domain.ErrNotFound
|
||||
}
|
||||
|
||||
return accs[0], nil
|
||||
}
|
||||
|
||||
func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Account) error {
|
||||
query := `
|
||||
INSERT INTO accounts
|
||||
(username, account_id, access_token, refresh_token, expires_at, last_message_id, last_checked_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id`
|
||||
|
||||
res, err := p.pool.Query(
|
||||
ctx,
|
||||
query,
|
||||
acc.Username,
|
||||
acc.AccountID,
|
||||
acc.AccessToken,
|
||||
acc.RefreshToken,
|
||||
acc.ExpiresAt,
|
||||
acc.LastMessageID,
|
||||
acc.LastCheckedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return res.Scan(&acc.ID)
|
||||
}
|
||||
|
||||
func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Account) error {
|
||||
query := `
|
||||
UPDATE accounts
|
||||
SET username = $2,
|
||||
account_id = $3,
|
||||
access_token = $4,
|
||||
refresh_token = $5,
|
||||
expires_at = $6,
|
||||
last_message_id = $7,
|
||||
last_checked_at = $8
|
||||
WHERE id = $1`
|
||||
|
||||
res, err := p.pool.Exec(
|
||||
ctx,
|
||||
query,
|
||||
acc.AccountID,
|
||||
acc.AccessToken,
|
||||
acc.RefreshToken,
|
||||
acc.ExpiresAt,
|
||||
acc.LastMessageID,
|
||||
acc.LastCheckedAt,
|
||||
)
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *postgresAccountRepository) Delete(ctx context.Context, id int64) error {
|
||||
query := `DELETE FROM accounts WHERE id = $1`
|
||||
res, err := p.pool.Exec(ctx, query, id)
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -23,8 +23,8 @@ type AccountRepository interface {
|
|||
GetByID(ctx context.Context, id int64) (Account, error)
|
||||
GetByRedditID(ctx context.Context, id string) (Account, error)
|
||||
|
||||
Update(ctx context.Context, ac *Account) error
|
||||
Create(ctx context.Context, ac *Account) error
|
||||
Update(ctx context.Context, acc *Account) error
|
||||
Create(ctx context.Context, acc *Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
|
@ -32,6 +32,6 @@ type AccountRepository interface {
|
|||
type AccountUsecase interface {
|
||||
GetByID(ctx context.Context, id int64) (Account, error)
|
||||
GetByRedditID(ctx context.Context, id string) (Account, error)
|
||||
CreateOrUpdate(ctx context.Context, ac *Account) error
|
||||
CreateOrUpdate(ctx context.Context, acc *Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
}
|
||||
|
|
24
internal/domain/device.go
Normal file
24
internal/domain/device.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package domain
|
||||
|
||||
import "context"
|
||||
|
||||
type Device struct {
|
||||
ID int64
|
||||
APNSToken string
|
||||
Sandbox bool
|
||||
LastPingedAt int64
|
||||
}
|
||||
|
||||
type DeviceRepository interface {
|
||||
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
||||
|
||||
Update(ctx context.Context, dev *Device) error
|
||||
Create(ctx context.Context, dev *Device) error
|
||||
Delete(ctx context.Context, token string) error
|
||||
}
|
||||
|
||||
type DeviceUsecase interface {
|
||||
GetByAPNSToken(ctx context.Context, token string) (Device, error)
|
||||
CreateOrUpdate(ctx context.Context, dev *Device) error
|
||||
Delete(ctx context.Context, token string) error
|
||||
}
|
10
internal/domain/errors.go
Normal file
10
internal/domain/errors.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package domain
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrNotFound will be returned if the requested item is not found
|
||||
ErrNotFound = errors.New("requested item was not found")
|
||||
// ErrConflict will be returned if the item being persisted already exists
|
||||
ErrConflict = errors.New("item already exists")
|
||||
)
|
Loading…
Reference in a new issue