change API to pgxpool

This commit is contained in:
Andre Medeiros 2021-07-13 10:17:20 -04:00
parent 49a89d3674
commit b5438495ec
5 changed files with 119 additions and 75 deletions

View file

@ -1,15 +1,17 @@
package main
import (
"database/sql"
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/DataDog/datadog-go/statsd"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/joho/godotenv"
_ "github.com/lib/pq"
"github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/data"
@ -23,12 +25,15 @@ type config struct {
type application struct {
cfg config
logger *logrus.Logger
db *sql.DB
pool *pgxpool.Pool
models *data.Models
client *reddit.Client
}
func main() {
_ = godotenv.Load()
ctx, cancel := context.WithCancel(context.Background())
var logger *logrus.Logger
{
logger = logrus.New()
@ -42,22 +47,27 @@ func main() {
}
}
if err := godotenv.Load(); err != nil {
logger.Printf("Couldn't find .env so I will read from existing ENV.")
}
var cfg config
dburl, ok := os.LookupEnv("DATABASE_CONNECTION_POOL_URL")
if !ok {
dburl = os.Getenv("DATABASE_URL")
}
// Set up Postgres connection
var pool *pgxpool.Pool
{
url := fmt.Sprintf("%s?sslmode=require", os.Getenv("DATABASE_CONNECTION_POOL_URL"))
config, err := pgxpool.ParseConfig(url)
if err != nil {
panic(err)
}
db, err := sql.Open("postgres", fmt.Sprintf("%s?binary_parameters=yes", dburl))
if err != nil {
log.Fatal(err)
// Setting the build statement cache to nil helps this work with pgbouncer
config.ConnConfig.BuildStatementCache = nil
config.ConnConfig.PreferSimpleProtocol = true
pool, err = pgxpool.ConnectConfig(ctx, config)
if err != nil {
panic(err)
}
defer pool.Close()
}
defer db.Close()
statsd, err := statsd.New("127.0.0.1:8125")
if err != nil {
@ -73,8 +83,8 @@ func main() {
app := &application{
cfg,
logger,
db,
data.NewModels(db),
pool,
data.NewModels(ctx, pool),
rc,
}
@ -89,6 +99,19 @@ func main() {
}
logger.Printf("starting server on %s", srv.Addr)
err = srv.ListenAndServe()
logger.Fatal(err)
go srv.ListenAndServe()
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(signals)
<-signals // wait for signal
srv.Shutdown(ctx)
cancel()
go func() {
<-signals // hard exit on second signal (in case shutdown gets stuck)
os.Exit(1)
}()
}

View file

@ -1,8 +1,11 @@
package data
import (
"database/sql"
"context"
"strings"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
type Account struct {
@ -21,25 +24,36 @@ func (a *Account) NormalizedUsername() string {
}
type AccountModel struct {
DB *sql.DB
ctx context.Context
pool *pgxpool.Pool
}
func (am *AccountModel) Upsert(a *Account) error {
query := `
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
VALUES ($1, $2, $3, $4, $5, '', 0, 0)
ON CONFLICT(username)
DO
UPDATE SET
access_token = $3,
refresh_token = $4,
expires_at = $5,
last_message_id = $6,
last_checked_at = $7
RETURNING id`
args := []interface{}{a.NormalizedUsername(), a.AccountID, a.AccessToken, a.RefreshToken, a.ExpiresAt, a.LastMessageID, a.LastCheckedAt}
return am.DB.QueryRow(query, args...).Scan(&a.ID)
return am.pool.BeginFunc(am.ctx, func(tx pgx.Tx) error {
stmt := `
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
VALUES ($1, $2, $3, $4, $5, '', 0, 0)
ON CONFLICT(username)
DO
UPDATE SET
access_token = $3,
refresh_token = $4,
expires_at = $5,
last_message_id = $6,
last_checked_at = $7
RETURNING id`
return tx.QueryRow(
am.ctx,
stmt,
a.NormalizedUsername(),
a.AccountID,
a.AccessToken,
a.RefreshToken,
a.ExpiresAt,
a.LastMessageID,
a.LastCheckedAt,
).Scan(&a.ID)
})
}
func (am *AccountModel) Delete(id int64) error {

View file

@ -1,6 +1,10 @@
package data
import "database/sql"
import (
"context"
"github.com/jackc/pgx/v4/pgxpool"
)
type DeviceAccount struct {
ID int64
@ -9,30 +13,29 @@ type DeviceAccount struct {
}
type DeviceAccountModel struct {
DB *sql.DB
ctx context.Context
pool *pgxpool.Pool
}
func (dam *DeviceAccountModel) Associate(accountID int64, deviceID int64) error {
query := `
stmt := `
INSERT INTO devices_accounts (account_id, device_id)
VALUES ($1, $2)
ON CONFLICT (account_id, device_id) DO NOTHING
RETURNING id`
args := []interface{}{accountID, deviceID}
if err := dam.DB.QueryRow(query, args...).Err(); err != nil {
if _, err := dam.pool.Exec(dam.ctx, stmt, accountID, deviceID); err != nil {
return err
}
// Update account counter
query = `
stmt = `
UPDATE accounts
SET device_count = (
SELECT COUNT(*) FROM devices_accounts WHERE account_id = $1
)
WHERE id = $1`
args = []interface{}{accountID}
return dam.DB.QueryRow(query, args...).Err()
_, err := dam.pool.Exec(dam.ctx, stmt, accountID)
return err
}
type MockDeviceAccountModel struct{}

View file

@ -1,9 +1,11 @@
package data
import (
"database/sql"
"errors"
"context"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
type Device struct {
@ -14,45 +16,45 @@ type Device struct {
}
type DeviceModel struct {
DB *sql.DB
ctx context.Context
pool *pgxpool.Pool
}
func (dm *DeviceModel) Upsert(d *Device) error {
d.LastPingedAt = time.Now().Unix()
query := `
INSERT INTO devices (apns_token, sandbox, last_pinged_at)
VALUES ($1, $2, $3)
ON CONFLICT(apns_token)
DO
UPDATE SET last_pinged_at = $3
RETURNING id`
args := []interface{}{d.APNSToken, d.Sandbox, d.LastPingedAt}
return dm.DB.QueryRow(query, args...).Scan(&d.ID)
return dm.pool.BeginFunc(dm.ctx, func(tx pgx.Tx) error {
stmt := `
INSERT INTO devices (apns_token, sandbox, last_pinged_at)
VALUES ($1, $2, $3)
ON CONFLICT(apns_token)
DO
UPDATE SET last_pinged_at = $3
RETURNING id`
return tx.QueryRow(
dm.ctx,
stmt,
d.APNSToken,
d.Sandbox,
d.LastPingedAt,
).Scan(&d.ID)
})
}
func (dm *DeviceModel) GetByAPNSToken(token string) (*Device, error) {
query := `
device := &Device{}
stmt := `
SELECT id, apns_token, sandbox, last_pinged_at
FROM devices
WHERE apns_token = $1`
device := &Device{}
err := dm.DB.QueryRow(query, token).Scan(
if err := dm.pool.QueryRow(dm.ctx, stmt, token).Scan(
&device.ID,
&device.APNSToken,
&device.Sandbox,
&device.LastPingedAt,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
); err != nil {
return nil, err
}
return device, nil
}

View file

@ -1,8 +1,10 @@
package data
import (
"database/sql"
"context"
"errors"
"github.com/jackc/pgx/v4/pgxpool"
)
var (
@ -25,11 +27,11 @@ type Models struct {
}
}
func NewModels(db *sql.DB) *Models {
func NewModels(ctx context.Context, pool *pgxpool.Pool) *Models {
return &Models{
Accounts: &AccountModel{DB: db},
Devices: &DeviceModel{DB: db},
DevicesAccounts: &DeviceAccountModel{DB: db},
Accounts: &AccountModel{ctx, pool},
Devices: &DeviceModel{ctx, pool},
DevicesAccounts: &DeviceAccountModel{ctx, pool},
}
}