change API to pgxpool

This commit is contained in:
Andre Medeiros 2021-07-13 10:17:20 -04:00
parent 49a89d3674
commit b5438495ec
5 changed files with 119 additions and 75 deletions

View file

@ -1,15 +1,17 @@
package main package main
import ( import (
"database/sql" "context"
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"os" "os"
"os/signal"
"syscall"
"github.com/DataDog/datadog-go/statsd" "github.com/DataDog/datadog-go/statsd"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/joho/godotenv" "github.com/joho/godotenv"
_ "github.com/lib/pq"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/data" "github.com/christianselig/apollo-backend/internal/data"
@ -23,12 +25,15 @@ type config struct {
type application struct { type application struct {
cfg config cfg config
logger *logrus.Logger logger *logrus.Logger
db *sql.DB pool *pgxpool.Pool
models *data.Models models *data.Models
client *reddit.Client client *reddit.Client
} }
func main() { func main() {
_ = godotenv.Load()
ctx, cancel := context.WithCancel(context.Background())
var logger *logrus.Logger var logger *logrus.Logger
{ {
logger = logrus.New() logger = logrus.New()
@ -42,22 +47,27 @@ func main() {
} }
} }
if err := godotenv.Load(); err != nil {
logger.Printf("Couldn't find .env so I will read from existing ENV.")
}
var cfg config var cfg config
dburl, ok := os.LookupEnv("DATABASE_CONNECTION_POOL_URL") // Set up Postgres connection
if !ok { var pool *pgxpool.Pool
dburl = os.Getenv("DATABASE_URL") {
url := fmt.Sprintf("%s?sslmode=require", os.Getenv("DATABASE_CONNECTION_POOL_URL"))
config, err := pgxpool.ParseConfig(url)
if err != nil {
panic(err)
} }
db, err := sql.Open("postgres", fmt.Sprintf("%s?binary_parameters=yes", dburl)) // Setting the build statement cache to nil helps this work with pgbouncer
config.ConnConfig.BuildStatementCache = nil
config.ConnConfig.PreferSimpleProtocol = true
pool, err = pgxpool.ConnectConfig(ctx, config)
if err != nil { if err != nil {
log.Fatal(err) panic(err)
}
defer pool.Close()
} }
defer db.Close()
statsd, err := statsd.New("127.0.0.1:8125") statsd, err := statsd.New("127.0.0.1:8125")
if err != nil { if err != nil {
@ -73,8 +83,8 @@ func main() {
app := &application{ app := &application{
cfg, cfg,
logger, logger,
db, pool,
data.NewModels(db), data.NewModels(ctx, pool),
rc, rc,
} }
@ -89,6 +99,19 @@ func main() {
} }
logger.Printf("starting server on %s", srv.Addr) logger.Printf("starting server on %s", srv.Addr)
err = srv.ListenAndServe() go srv.ListenAndServe()
logger.Fatal(err)
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(signals)
<-signals // wait for signal
srv.Shutdown(ctx)
cancel()
go func() {
<-signals // hard exit on second signal (in case shutdown gets stuck)
os.Exit(1)
}()
} }

View file

@ -1,8 +1,11 @@
package data package data
import ( import (
"database/sql" "context"
"strings" "strings"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
) )
type Account struct { type Account struct {
@ -21,11 +24,13 @@ func (a *Account) NormalizedUsername() string {
} }
type AccountModel struct { type AccountModel struct {
DB *sql.DB ctx context.Context
pool *pgxpool.Pool
} }
func (am *AccountModel) Upsert(a *Account) error { func (am *AccountModel) Upsert(a *Account) error {
query := ` return am.pool.BeginFunc(am.ctx, func(tx pgx.Tx) error {
stmt := `
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at) INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
VALUES ($1, $2, $3, $4, $5, '', 0, 0) VALUES ($1, $2, $3, $4, $5, '', 0, 0)
ON CONFLICT(username) ON CONFLICT(username)
@ -37,9 +42,18 @@ func (am *AccountModel) Upsert(a *Account) error {
last_message_id = $6, last_message_id = $6,
last_checked_at = $7 last_checked_at = $7
RETURNING id` RETURNING id`
return tx.QueryRow(
args := []interface{}{a.NormalizedUsername(), a.AccountID, a.AccessToken, a.RefreshToken, a.ExpiresAt, a.LastMessageID, a.LastCheckedAt} am.ctx,
return am.DB.QueryRow(query, args...).Scan(&a.ID) stmt,
a.NormalizedUsername(),
a.AccountID,
a.AccessToken,
a.RefreshToken,
a.ExpiresAt,
a.LastMessageID,
a.LastCheckedAt,
).Scan(&a.ID)
})
} }
func (am *AccountModel) Delete(id int64) error { func (am *AccountModel) Delete(id int64) error {

View file

@ -1,6 +1,10 @@
package data package data
import "database/sql" import (
"context"
"github.com/jackc/pgx/v4/pgxpool"
)
type DeviceAccount struct { type DeviceAccount struct {
ID int64 ID int64
@ -9,30 +13,29 @@ type DeviceAccount struct {
} }
type DeviceAccountModel struct { type DeviceAccountModel struct {
DB *sql.DB ctx context.Context
pool *pgxpool.Pool
} }
func (dam *DeviceAccountModel) Associate(accountID int64, deviceID int64) error { func (dam *DeviceAccountModel) Associate(accountID int64, deviceID int64) error {
query := ` stmt := `
INSERT INTO devices_accounts (account_id, device_id) INSERT INTO devices_accounts (account_id, device_id)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (account_id, device_id) DO NOTHING ON CONFLICT (account_id, device_id) DO NOTHING
RETURNING id` RETURNING id`
if _, err := dam.pool.Exec(dam.ctx, stmt, accountID, deviceID); err != nil {
args := []interface{}{accountID, deviceID}
if err := dam.DB.QueryRow(query, args...).Err(); err != nil {
return err return err
} }
// Update account counter // Update account counter
query = ` stmt = `
UPDATE accounts UPDATE accounts
SET device_count = ( SET device_count = (
SELECT COUNT(*) FROM devices_accounts WHERE account_id = $1 SELECT COUNT(*) FROM devices_accounts WHERE account_id = $1
) )
WHERE id = $1` WHERE id = $1`
args = []interface{}{accountID} _, err := dam.pool.Exec(dam.ctx, stmt, accountID)
return dam.DB.QueryRow(query, args...).Err() return err
} }
type MockDeviceAccountModel struct{} type MockDeviceAccountModel struct{}

View file

@ -1,9 +1,11 @@
package data package data
import ( import (
"database/sql" "context"
"errors"
"time" "time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
) )
type Device struct { type Device struct {
@ -14,46 +16,46 @@ type Device struct {
} }
type DeviceModel struct { type DeviceModel struct {
DB *sql.DB ctx context.Context
pool *pgxpool.Pool
} }
func (dm *DeviceModel) Upsert(d *Device) error { func (dm *DeviceModel) Upsert(d *Device) error {
d.LastPingedAt = time.Now().Unix() d.LastPingedAt = time.Now().Unix()
query := ` return dm.pool.BeginFunc(dm.ctx, func(tx pgx.Tx) error {
stmt := `
INSERT INTO devices (apns_token, sandbox, last_pinged_at) INSERT INTO devices (apns_token, sandbox, last_pinged_at)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT(apns_token) ON CONFLICT(apns_token)
DO DO
UPDATE SET last_pinged_at = $3 UPDATE SET last_pinged_at = $3
RETURNING id` RETURNING id`
return tx.QueryRow(
args := []interface{}{d.APNSToken, d.Sandbox, d.LastPingedAt} dm.ctx,
return dm.DB.QueryRow(query, args...).Scan(&d.ID) stmt,
d.APNSToken,
d.Sandbox,
d.LastPingedAt,
).Scan(&d.ID)
})
} }
func (dm *DeviceModel) GetByAPNSToken(token string) (*Device, error) { func (dm *DeviceModel) GetByAPNSToken(token string) (*Device, error) {
query := ` device := &Device{}
stmt := `
SELECT id, apns_token, sandbox, last_pinged_at SELECT id, apns_token, sandbox, last_pinged_at
FROM devices FROM devices
WHERE apns_token = $1` WHERE apns_token = $1`
device := &Device{} if err := dm.pool.QueryRow(dm.ctx, stmt, token).Scan(
err := dm.DB.QueryRow(query, token).Scan(
&device.ID, &device.ID,
&device.APNSToken, &device.APNSToken,
&device.Sandbox, &device.Sandbox,
&device.LastPingedAt, &device.LastPingedAt,
) ); err != nil {
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err return nil, err
} }
}
return device, nil return device, nil
} }

View file

@ -1,8 +1,10 @@
package data package data
import ( import (
"database/sql" "context"
"errors" "errors"
"github.com/jackc/pgx/v4/pgxpool"
) )
var ( var (
@ -25,11 +27,11 @@ type Models struct {
} }
} }
func NewModels(db *sql.DB) *Models { func NewModels(ctx context.Context, pool *pgxpool.Pool) *Models {
return &Models{ return &Models{
Accounts: &AccountModel{DB: db}, Accounts: &AccountModel{ctx, pool},
Devices: &DeviceModel{DB: db}, Devices: &DeviceModel{ctx, pool},
DevicesAccounts: &DeviceAccountModel{DB: db}, DevicesAccounts: &DeviceAccountModel{ctx, pool},
} }
} }