diff --git a/cmd/apollo-api/main.go b/cmd/apollo-api/main.go index af343b7..8257e3f 100644 --- a/cmd/apollo-api/main.go +++ b/cmd/apollo-api/main.go @@ -1,15 +1,17 @@ package main import ( - "database/sql" + "context" "fmt" "log" "net/http" "os" + "os/signal" + "syscall" "github.com/DataDog/datadog-go/statsd" + "github.com/jackc/pgx/v4/pgxpool" "github.com/joho/godotenv" - _ "github.com/lib/pq" "github.com/sirupsen/logrus" "github.com/christianselig/apollo-backend/internal/data" @@ -23,12 +25,15 @@ type config struct { type application struct { cfg config logger *logrus.Logger - db *sql.DB + pool *pgxpool.Pool models *data.Models client *reddit.Client } func main() { + _ = godotenv.Load() + ctx, cancel := context.WithCancel(context.Background()) + var logger *logrus.Logger { logger = logrus.New() @@ -42,22 +47,27 @@ func main() { } } - if err := godotenv.Load(); err != nil { - logger.Printf("Couldn't find .env so I will read from existing ENV.") - } - var cfg config - dburl, ok := os.LookupEnv("DATABASE_CONNECTION_POOL_URL") - if !ok { - dburl = os.Getenv("DATABASE_URL") - } + // Set up Postgres connection + var pool *pgxpool.Pool + { + url := fmt.Sprintf("%s?sslmode=require", os.Getenv("DATABASE_CONNECTION_POOL_URL")) + config, err := pgxpool.ParseConfig(url) + if err != nil { + panic(err) + } - db, err := sql.Open("postgres", fmt.Sprintf("%s?binary_parameters=yes", dburl)) - if err != nil { - log.Fatal(err) + // Setting the build statement cache to nil helps this work with pgbouncer + config.ConnConfig.BuildStatementCache = nil + config.ConnConfig.PreferSimpleProtocol = true + + pool, err = pgxpool.ConnectConfig(ctx, config) + if err != nil { + panic(err) + } + defer pool.Close() } - defer db.Close() statsd, err := statsd.New("127.0.0.1:8125") if err != nil { @@ -73,8 +83,8 @@ func main() { app := &application{ cfg, logger, - db, - data.NewModels(db), + pool, + data.NewModels(ctx, pool), rc, } @@ -89,6 +99,19 @@ func main() { } logger.Printf("starting server on %s", srv.Addr) - err = srv.ListenAndServe() - logger.Fatal(err) + go srv.ListenAndServe() + + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(signals) + + <-signals // wait for signal + + srv.Shutdown(ctx) + cancel() + + go func() { + <-signals // hard exit on second signal (in case shutdown gets stuck) + os.Exit(1) + }() } diff --git a/internal/data/accounts.go b/internal/data/accounts.go index 7c514c8..08ba2fb 100644 --- a/internal/data/accounts.go +++ b/internal/data/accounts.go @@ -1,8 +1,11 @@ package data import ( - "database/sql" + "context" "strings" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" ) type Account struct { @@ -21,25 +24,36 @@ func (a *Account) NormalizedUsername() string { } type AccountModel struct { - DB *sql.DB + ctx context.Context + pool *pgxpool.Pool } func (am *AccountModel) Upsert(a *Account) error { - query := ` - INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at) - VALUES ($1, $2, $3, $4, $5, '', 0, 0) - ON CONFLICT(username) - DO - UPDATE SET - access_token = $3, - refresh_token = $4, - expires_at = $5, - last_message_id = $6, - last_checked_at = $7 - RETURNING id` - - args := []interface{}{a.NormalizedUsername(), a.AccountID, a.AccessToken, a.RefreshToken, a.ExpiresAt, a.LastMessageID, a.LastCheckedAt} - return am.DB.QueryRow(query, args...).Scan(&a.ID) + return am.pool.BeginFunc(am.ctx, func(tx pgx.Tx) error { + stmt := ` + INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at) + VALUES ($1, $2, $3, $4, $5, '', 0, 0) + ON CONFLICT(username) + DO + UPDATE SET + access_token = $3, + refresh_token = $4, + expires_at = $5, + last_message_id = $6, + last_checked_at = $7 + RETURNING id` + return tx.QueryRow( + am.ctx, + stmt, + a.NormalizedUsername(), + a.AccountID, + a.AccessToken, + a.RefreshToken, + a.ExpiresAt, + a.LastMessageID, + a.LastCheckedAt, + ).Scan(&a.ID) + }) } func (am *AccountModel) Delete(id int64) error { diff --git a/internal/data/device_accounts.go b/internal/data/device_accounts.go index 8cbc67b..8715c4f 100644 --- a/internal/data/device_accounts.go +++ b/internal/data/device_accounts.go @@ -1,6 +1,10 @@ package data -import "database/sql" +import ( + "context" + + "github.com/jackc/pgx/v4/pgxpool" +) type DeviceAccount struct { ID int64 @@ -9,30 +13,29 @@ type DeviceAccount struct { } type DeviceAccountModel struct { - DB *sql.DB + ctx context.Context + pool *pgxpool.Pool } func (dam *DeviceAccountModel) Associate(accountID int64, deviceID int64) error { - query := ` + stmt := ` INSERT INTO devices_accounts (account_id, device_id) VALUES ($1, $2) ON CONFLICT (account_id, device_id) DO NOTHING RETURNING id` - - args := []interface{}{accountID, deviceID} - if err := dam.DB.QueryRow(query, args...).Err(); err != nil { + if _, err := dam.pool.Exec(dam.ctx, stmt, accountID, deviceID); err != nil { return err } // Update account counter - query = ` + stmt = ` UPDATE accounts SET device_count = ( SELECT COUNT(*) FROM devices_accounts WHERE account_id = $1 ) WHERE id = $1` - args = []interface{}{accountID} - return dam.DB.QueryRow(query, args...).Err() + _, err := dam.pool.Exec(dam.ctx, stmt, accountID) + return err } type MockDeviceAccountModel struct{} diff --git a/internal/data/devices.go b/internal/data/devices.go index 925958d..6ab6151 100644 --- a/internal/data/devices.go +++ b/internal/data/devices.go @@ -1,9 +1,11 @@ package data import ( - "database/sql" - "errors" + "context" "time" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" ) type Device struct { @@ -14,45 +16,45 @@ type Device struct { } type DeviceModel struct { - DB *sql.DB + ctx context.Context + pool *pgxpool.Pool } func (dm *DeviceModel) Upsert(d *Device) error { d.LastPingedAt = time.Now().Unix() - query := ` - INSERT INTO devices (apns_token, sandbox, last_pinged_at) - VALUES ($1, $2, $3) - ON CONFLICT(apns_token) - DO - UPDATE SET last_pinged_at = $3 - RETURNING id` - - args := []interface{}{d.APNSToken, d.Sandbox, d.LastPingedAt} - return dm.DB.QueryRow(query, args...).Scan(&d.ID) + return dm.pool.BeginFunc(dm.ctx, func(tx pgx.Tx) error { + stmt := ` + INSERT INTO devices (apns_token, sandbox, last_pinged_at) + VALUES ($1, $2, $3) + ON CONFLICT(apns_token) + DO + UPDATE SET last_pinged_at = $3 + RETURNING id` + return tx.QueryRow( + dm.ctx, + stmt, + d.APNSToken, + d.Sandbox, + d.LastPingedAt, + ).Scan(&d.ID) + }) } func (dm *DeviceModel) GetByAPNSToken(token string) (*Device, error) { - query := ` + device := &Device{} + stmt := ` SELECT id, apns_token, sandbox, last_pinged_at FROM devices WHERE apns_token = $1` - device := &Device{} - err := dm.DB.QueryRow(query, token).Scan( + if err := dm.pool.QueryRow(dm.ctx, stmt, token).Scan( &device.ID, &device.APNSToken, &device.Sandbox, &device.LastPingedAt, - ) - - if err != nil { - switch { - case errors.Is(err, sql.ErrNoRows): - return nil, ErrRecordNotFound - default: - return nil, err - } + ); err != nil { + return nil, err } return device, nil } diff --git a/internal/data/models.go b/internal/data/models.go index 1e3dd79..c615802 100644 --- a/internal/data/models.go +++ b/internal/data/models.go @@ -1,8 +1,10 @@ package data import ( - "database/sql" + "context" "errors" + + "github.com/jackc/pgx/v4/pgxpool" ) var ( @@ -25,11 +27,11 @@ type Models struct { } } -func NewModels(db *sql.DB) *Models { +func NewModels(ctx context.Context, pool *pgxpool.Pool) *Models { return &Models{ - Accounts: &AccountModel{DB: db}, - Devices: &DeviceModel{DB: db}, - DevicesAccounts: &DeviceAccountModel{DB: db}, + Accounts: &AccountModel{ctx, pool}, + Devices: &DeviceModel{ctx, pool}, + DevicesAccounts: &DeviceAccountModel{ctx, pool}, } }