mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-13 07:27:43 +00:00
change API to pgxpool
This commit is contained in:
parent
49a89d3674
commit
b5438495ec
5 changed files with 119 additions and 75 deletions
|
@ -1,15 +1,17 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/DataDog/datadog-go/statsd"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/joho/godotenv"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/christianselig/apollo-backend/internal/data"
|
||||
|
@ -23,12 +25,15 @@ type config struct {
|
|||
type application struct {
|
||||
cfg config
|
||||
logger *logrus.Logger
|
||||
db *sql.DB
|
||||
pool *pgxpool.Pool
|
||||
models *data.Models
|
||||
client *reddit.Client
|
||||
}
|
||||
|
||||
func main() {
|
||||
_ = godotenv.Load()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var logger *logrus.Logger
|
||||
{
|
||||
logger = logrus.New()
|
||||
|
@ -42,22 +47,27 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
if err := godotenv.Load(); err != nil {
|
||||
logger.Printf("Couldn't find .env so I will read from existing ENV.")
|
||||
}
|
||||
|
||||
var cfg config
|
||||
|
||||
dburl, ok := os.LookupEnv("DATABASE_CONNECTION_POOL_URL")
|
||||
if !ok {
|
||||
dburl = os.Getenv("DATABASE_URL")
|
||||
}
|
||||
// Set up Postgres connection
|
||||
var pool *pgxpool.Pool
|
||||
{
|
||||
url := fmt.Sprintf("%s?sslmode=require", os.Getenv("DATABASE_CONNECTION_POOL_URL"))
|
||||
config, err := pgxpool.ParseConfig(url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", fmt.Sprintf("%s?binary_parameters=yes", dburl))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
// Setting the build statement cache to nil helps this work with pgbouncer
|
||||
config.ConnConfig.BuildStatementCache = nil
|
||||
config.ConnConfig.PreferSimpleProtocol = true
|
||||
|
||||
pool, err = pgxpool.ConnectConfig(ctx, config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
statsd, err := statsd.New("127.0.0.1:8125")
|
||||
if err != nil {
|
||||
|
@ -73,8 +83,8 @@ func main() {
|
|||
app := &application{
|
||||
cfg,
|
||||
logger,
|
||||
db,
|
||||
data.NewModels(db),
|
||||
pool,
|
||||
data.NewModels(ctx, pool),
|
||||
rc,
|
||||
}
|
||||
|
||||
|
@ -89,6 +99,19 @@ func main() {
|
|||
}
|
||||
|
||||
logger.Printf("starting server on %s", srv.Addr)
|
||||
err = srv.ListenAndServe()
|
||||
logger.Fatal(err)
|
||||
go srv.ListenAndServe()
|
||||
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(signals)
|
||||
|
||||
<-signals // wait for signal
|
||||
|
||||
srv.Shutdown(ctx)
|
||||
cancel()
|
||||
|
||||
go func() {
|
||||
<-signals // hard exit on second signal (in case shutdown gets stuck)
|
||||
os.Exit(1)
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
package data
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
type Account struct {
|
||||
|
@ -21,25 +24,36 @@ func (a *Account) NormalizedUsername() string {
|
|||
}
|
||||
|
||||
type AccountModel struct {
|
||||
DB *sql.DB
|
||||
ctx context.Context
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (am *AccountModel) Upsert(a *Account) error {
|
||||
query := `
|
||||
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
|
||||
VALUES ($1, $2, $3, $4, $5, '', 0, 0)
|
||||
ON CONFLICT(username)
|
||||
DO
|
||||
UPDATE SET
|
||||
access_token = $3,
|
||||
refresh_token = $4,
|
||||
expires_at = $5,
|
||||
last_message_id = $6,
|
||||
last_checked_at = $7
|
||||
RETURNING id`
|
||||
|
||||
args := []interface{}{a.NormalizedUsername(), a.AccountID, a.AccessToken, a.RefreshToken, a.ExpiresAt, a.LastMessageID, a.LastCheckedAt}
|
||||
return am.DB.QueryRow(query, args...).Scan(&a.ID)
|
||||
return am.pool.BeginFunc(am.ctx, func(tx pgx.Tx) error {
|
||||
stmt := `
|
||||
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
|
||||
VALUES ($1, $2, $3, $4, $5, '', 0, 0)
|
||||
ON CONFLICT(username)
|
||||
DO
|
||||
UPDATE SET
|
||||
access_token = $3,
|
||||
refresh_token = $4,
|
||||
expires_at = $5,
|
||||
last_message_id = $6,
|
||||
last_checked_at = $7
|
||||
RETURNING id`
|
||||
return tx.QueryRow(
|
||||
am.ctx,
|
||||
stmt,
|
||||
a.NormalizedUsername(),
|
||||
a.AccountID,
|
||||
a.AccessToken,
|
||||
a.RefreshToken,
|
||||
a.ExpiresAt,
|
||||
a.LastMessageID,
|
||||
a.LastCheckedAt,
|
||||
).Scan(&a.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func (am *AccountModel) Delete(id int64) error {
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
package data
|
||||
|
||||
import "database/sql"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
type DeviceAccount struct {
|
||||
ID int64
|
||||
|
@ -9,30 +13,29 @@ type DeviceAccount struct {
|
|||
}
|
||||
|
||||
type DeviceAccountModel struct {
|
||||
DB *sql.DB
|
||||
ctx context.Context
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (dam *DeviceAccountModel) Associate(accountID int64, deviceID int64) error {
|
||||
query := `
|
||||
stmt := `
|
||||
INSERT INTO devices_accounts (account_id, device_id)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (account_id, device_id) DO NOTHING
|
||||
RETURNING id`
|
||||
|
||||
args := []interface{}{accountID, deviceID}
|
||||
if err := dam.DB.QueryRow(query, args...).Err(); err != nil {
|
||||
if _, err := dam.pool.Exec(dam.ctx, stmt, accountID, deviceID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update account counter
|
||||
query = `
|
||||
stmt = `
|
||||
UPDATE accounts
|
||||
SET device_count = (
|
||||
SELECT COUNT(*) FROM devices_accounts WHERE account_id = $1
|
||||
)
|
||||
WHERE id = $1`
|
||||
args = []interface{}{accountID}
|
||||
return dam.DB.QueryRow(query, args...).Err()
|
||||
_, err := dam.pool.Exec(dam.ctx, stmt, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
type MockDeviceAccountModel struct{}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package data
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
|
@ -14,45 +16,45 @@ type Device struct {
|
|||
}
|
||||
|
||||
type DeviceModel struct {
|
||||
DB *sql.DB
|
||||
ctx context.Context
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (dm *DeviceModel) Upsert(d *Device) error {
|
||||
d.LastPingedAt = time.Now().Unix()
|
||||
|
||||
query := `
|
||||
INSERT INTO devices (apns_token, sandbox, last_pinged_at)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT(apns_token)
|
||||
DO
|
||||
UPDATE SET last_pinged_at = $3
|
||||
RETURNING id`
|
||||
|
||||
args := []interface{}{d.APNSToken, d.Sandbox, d.LastPingedAt}
|
||||
return dm.DB.QueryRow(query, args...).Scan(&d.ID)
|
||||
return dm.pool.BeginFunc(dm.ctx, func(tx pgx.Tx) error {
|
||||
stmt := `
|
||||
INSERT INTO devices (apns_token, sandbox, last_pinged_at)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT(apns_token)
|
||||
DO
|
||||
UPDATE SET last_pinged_at = $3
|
||||
RETURNING id`
|
||||
return tx.QueryRow(
|
||||
dm.ctx,
|
||||
stmt,
|
||||
d.APNSToken,
|
||||
d.Sandbox,
|
||||
d.LastPingedAt,
|
||||
).Scan(&d.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func (dm *DeviceModel) GetByAPNSToken(token string) (*Device, error) {
|
||||
query := `
|
||||
device := &Device{}
|
||||
stmt := `
|
||||
SELECT id, apns_token, sandbox, last_pinged_at
|
||||
FROM devices
|
||||
WHERE apns_token = $1`
|
||||
|
||||
device := &Device{}
|
||||
err := dm.DB.QueryRow(query, token).Scan(
|
||||
if err := dm.pool.QueryRow(dm.ctx, stmt, token).Scan(
|
||||
&device.ID,
|
||||
&device.APNSToken,
|
||||
&device.Sandbox,
|
||||
&device.LastPingedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
return nil, ErrRecordNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return device, nil
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
package data
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -25,11 +27,11 @@ type Models struct {
|
|||
}
|
||||
}
|
||||
|
||||
func NewModels(db *sql.DB) *Models {
|
||||
func NewModels(ctx context.Context, pool *pgxpool.Pool) *Models {
|
||||
return &Models{
|
||||
Accounts: &AccountModel{DB: db},
|
||||
Devices: &DeviceModel{DB: db},
|
||||
DevicesAccounts: &DeviceAccountModel{DB: db},
|
||||
Accounts: &AccountModel{ctx, pool},
|
||||
Devices: &DeviceModel{ctx, pool},
|
||||
DevicesAccounts: &DeviceAccountModel{ctx, pool},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue