mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-25 21:27:42 +00:00
change API to pgxpool
This commit is contained in:
parent
49a89d3674
commit
b5438495ec
5 changed files with 119 additions and 75 deletions
|
@ -1,15 +1,17 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/DataDog/datadog-go/statsd"
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
_ "github.com/lib/pq"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/christianselig/apollo-backend/internal/data"
|
"github.com/christianselig/apollo-backend/internal/data"
|
||||||
|
@ -23,12 +25,15 @@ type config struct {
|
||||||
type application struct {
|
type application struct {
|
||||||
cfg config
|
cfg config
|
||||||
logger *logrus.Logger
|
logger *logrus.Logger
|
||||||
db *sql.DB
|
pool *pgxpool.Pool
|
||||||
models *data.Models
|
models *data.Models
|
||||||
client *reddit.Client
|
client *reddit.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
_ = godotenv.Load()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
var logger *logrus.Logger
|
var logger *logrus.Logger
|
||||||
{
|
{
|
||||||
logger = logrus.New()
|
logger = logrus.New()
|
||||||
|
@ -42,22 +47,27 @@ func main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := godotenv.Load(); err != nil {
|
|
||||||
logger.Printf("Couldn't find .env so I will read from existing ENV.")
|
|
||||||
}
|
|
||||||
|
|
||||||
var cfg config
|
var cfg config
|
||||||
|
|
||||||
dburl, ok := os.LookupEnv("DATABASE_CONNECTION_POOL_URL")
|
// Set up Postgres connection
|
||||||
if !ok {
|
var pool *pgxpool.Pool
|
||||||
dburl = os.Getenv("DATABASE_URL")
|
{
|
||||||
|
url := fmt.Sprintf("%s?sslmode=require", os.Getenv("DATABASE_CONNECTION_POOL_URL"))
|
||||||
|
config, err := pgxpool.ParseConfig(url)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sql.Open("postgres", fmt.Sprintf("%s?binary_parameters=yes", dburl))
|
// Setting the build statement cache to nil helps this work with pgbouncer
|
||||||
|
config.ConnConfig.BuildStatementCache = nil
|
||||||
|
config.ConnConfig.PreferSimpleProtocol = true
|
||||||
|
|
||||||
|
pool, err = pgxpool.ConnectConfig(ctx, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
panic(err)
|
||||||
|
}
|
||||||
|
defer pool.Close()
|
||||||
}
|
}
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
statsd, err := statsd.New("127.0.0.1:8125")
|
statsd, err := statsd.New("127.0.0.1:8125")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -73,8 +83,8 @@ func main() {
|
||||||
app := &application{
|
app := &application{
|
||||||
cfg,
|
cfg,
|
||||||
logger,
|
logger,
|
||||||
db,
|
pool,
|
||||||
data.NewModels(db),
|
data.NewModels(ctx, pool),
|
||||||
rc,
|
rc,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,6 +99,19 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("starting server on %s", srv.Addr)
|
logger.Printf("starting server on %s", srv.Addr)
|
||||||
err = srv.ListenAndServe()
|
go srv.ListenAndServe()
|
||||||
logger.Fatal(err)
|
|
||||||
|
signals := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
defer signal.Stop(signals)
|
||||||
|
|
||||||
|
<-signals // wait for signal
|
||||||
|
|
||||||
|
srv.Shutdown(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-signals // hard exit on second signal (in case shutdown gets stuck)
|
||||||
|
os.Exit(1)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
package data
|
package data
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Account struct {
|
type Account struct {
|
||||||
|
@ -21,11 +24,13 @@ func (a *Account) NormalizedUsername() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountModel struct {
|
type AccountModel struct {
|
||||||
DB *sql.DB
|
ctx context.Context
|
||||||
|
pool *pgxpool.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *AccountModel) Upsert(a *Account) error {
|
func (am *AccountModel) Upsert(a *Account) error {
|
||||||
query := `
|
return am.pool.BeginFunc(am.ctx, func(tx pgx.Tx) error {
|
||||||
|
stmt := `
|
||||||
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
|
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
|
||||||
VALUES ($1, $2, $3, $4, $5, '', 0, 0)
|
VALUES ($1, $2, $3, $4, $5, '', 0, 0)
|
||||||
ON CONFLICT(username)
|
ON CONFLICT(username)
|
||||||
|
@ -37,9 +42,18 @@ func (am *AccountModel) Upsert(a *Account) error {
|
||||||
last_message_id = $6,
|
last_message_id = $6,
|
||||||
last_checked_at = $7
|
last_checked_at = $7
|
||||||
RETURNING id`
|
RETURNING id`
|
||||||
|
return tx.QueryRow(
|
||||||
args := []interface{}{a.NormalizedUsername(), a.AccountID, a.AccessToken, a.RefreshToken, a.ExpiresAt, a.LastMessageID, a.LastCheckedAt}
|
am.ctx,
|
||||||
return am.DB.QueryRow(query, args...).Scan(&a.ID)
|
stmt,
|
||||||
|
a.NormalizedUsername(),
|
||||||
|
a.AccountID,
|
||||||
|
a.AccessToken,
|
||||||
|
a.RefreshToken,
|
||||||
|
a.ExpiresAt,
|
||||||
|
a.LastMessageID,
|
||||||
|
a.LastCheckedAt,
|
||||||
|
).Scan(&a.ID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *AccountModel) Delete(id int64) error {
|
func (am *AccountModel) Delete(id int64) error {
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package data
|
package data
|
||||||
|
|
||||||
import "database/sql"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
type DeviceAccount struct {
|
type DeviceAccount struct {
|
||||||
ID int64
|
ID int64
|
||||||
|
@ -9,30 +13,29 @@ type DeviceAccount struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeviceAccountModel struct {
|
type DeviceAccountModel struct {
|
||||||
DB *sql.DB
|
ctx context.Context
|
||||||
|
pool *pgxpool.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dam *DeviceAccountModel) Associate(accountID int64, deviceID int64) error {
|
func (dam *DeviceAccountModel) Associate(accountID int64, deviceID int64) error {
|
||||||
query := `
|
stmt := `
|
||||||
INSERT INTO devices_accounts (account_id, device_id)
|
INSERT INTO devices_accounts (account_id, device_id)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
ON CONFLICT (account_id, device_id) DO NOTHING
|
ON CONFLICT (account_id, device_id) DO NOTHING
|
||||||
RETURNING id`
|
RETURNING id`
|
||||||
|
if _, err := dam.pool.Exec(dam.ctx, stmt, accountID, deviceID); err != nil {
|
||||||
args := []interface{}{accountID, deviceID}
|
|
||||||
if err := dam.DB.QueryRow(query, args...).Err(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update account counter
|
// Update account counter
|
||||||
query = `
|
stmt = `
|
||||||
UPDATE accounts
|
UPDATE accounts
|
||||||
SET device_count = (
|
SET device_count = (
|
||||||
SELECT COUNT(*) FROM devices_accounts WHERE account_id = $1
|
SELECT COUNT(*) FROM devices_accounts WHERE account_id = $1
|
||||||
)
|
)
|
||||||
WHERE id = $1`
|
WHERE id = $1`
|
||||||
args = []interface{}{accountID}
|
_, err := dam.pool.Exec(dam.ctx, stmt, accountID)
|
||||||
return dam.DB.QueryRow(query, args...).Err()
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockDeviceAccountModel struct{}
|
type MockDeviceAccountModel struct{}
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package data
|
package data
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"errors"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
|
@ -14,46 +16,46 @@ type Device struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeviceModel struct {
|
type DeviceModel struct {
|
||||||
DB *sql.DB
|
ctx context.Context
|
||||||
|
pool *pgxpool.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dm *DeviceModel) Upsert(d *Device) error {
|
func (dm *DeviceModel) Upsert(d *Device) error {
|
||||||
d.LastPingedAt = time.Now().Unix()
|
d.LastPingedAt = time.Now().Unix()
|
||||||
|
|
||||||
query := `
|
return dm.pool.BeginFunc(dm.ctx, func(tx pgx.Tx) error {
|
||||||
|
stmt := `
|
||||||
INSERT INTO devices (apns_token, sandbox, last_pinged_at)
|
INSERT INTO devices (apns_token, sandbox, last_pinged_at)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
ON CONFLICT(apns_token)
|
ON CONFLICT(apns_token)
|
||||||
DO
|
DO
|
||||||
UPDATE SET last_pinged_at = $3
|
UPDATE SET last_pinged_at = $3
|
||||||
RETURNING id`
|
RETURNING id`
|
||||||
|
return tx.QueryRow(
|
||||||
args := []interface{}{d.APNSToken, d.Sandbox, d.LastPingedAt}
|
dm.ctx,
|
||||||
return dm.DB.QueryRow(query, args...).Scan(&d.ID)
|
stmt,
|
||||||
|
d.APNSToken,
|
||||||
|
d.Sandbox,
|
||||||
|
d.LastPingedAt,
|
||||||
|
).Scan(&d.ID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dm *DeviceModel) GetByAPNSToken(token string) (*Device, error) {
|
func (dm *DeviceModel) GetByAPNSToken(token string) (*Device, error) {
|
||||||
query := `
|
device := &Device{}
|
||||||
|
stmt := `
|
||||||
SELECT id, apns_token, sandbox, last_pinged_at
|
SELECT id, apns_token, sandbox, last_pinged_at
|
||||||
FROM devices
|
FROM devices
|
||||||
WHERE apns_token = $1`
|
WHERE apns_token = $1`
|
||||||
|
|
||||||
device := &Device{}
|
if err := dm.pool.QueryRow(dm.ctx, stmt, token).Scan(
|
||||||
err := dm.DB.QueryRow(query, token).Scan(
|
|
||||||
&device.ID,
|
&device.ID,
|
||||||
&device.APNSToken,
|
&device.APNSToken,
|
||||||
&device.Sandbox,
|
&device.Sandbox,
|
||||||
&device.LastPingedAt,
|
&device.LastPingedAt,
|
||||||
)
|
); err != nil {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
switch {
|
|
||||||
case errors.Is(err, sql.ErrNoRows):
|
|
||||||
return nil, ErrRecordNotFound
|
|
||||||
default:
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return device, nil
|
return device, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
package data
|
package data
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -25,11 +27,11 @@ type Models struct {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewModels(db *sql.DB) *Models {
|
func NewModels(ctx context.Context, pool *pgxpool.Pool) *Models {
|
||||||
return &Models{
|
return &Models{
|
||||||
Accounts: &AccountModel{DB: db},
|
Accounts: &AccountModel{ctx, pool},
|
||||||
Devices: &DeviceModel{DB: db},
|
Devices: &DeviceModel{ctx, pool},
|
||||||
DevicesAccounts: &DeviceAccountModel{DB: db},
|
DevicesAccounts: &DeviceAccountModel{ctx, pool},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue