Better testing (#62)

* some tests

* more tests

* tidy up go.mod

* more tests

* add postgres

* beep

* again

* Set up schema

* fix device test
This commit is contained in:
André Medeiros 2022-05-07 12:37:21 -04:00 committed by GitHub
parent 17019cecfb
commit f9b9c595cf
57 changed files with 674 additions and 550 deletions

4
.env.test Normal file
View file

@ -0,0 +1,4 @@
DATABASE_URL=postgres://andremedeiros@localhost/apollo_test?sslmode=disable
DATABASE_CONNECTION_POOL_URL=postgres://andremedeiros@localhost/apollo_test?sslmode=disable
REDIS_URL=redis://127.0.0.1:6379
STATSD_URL=127.0.0.1:8125

View file

@ -8,13 +8,35 @@ jobs:
go-version: [1.18]
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
env:
DATABASE_URL: postgres://postgres:postgres@localhost/apollo_test
services:
postgres:
image: postgres
env:
POSTGRES_DB: apollo_test
POSTGRES_PASSWORD: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
steps:
- uses: actions/checkout@v2
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- name: Lint
uses: golangci/golangci-lint-action@v2
- name: Setup database schema
run: psql -f docs/schema.sql $DATABASE_URL
- name: Test
run: go test ./... -v -race -timeout 5s

28
.golangci.yml Normal file
View file

@ -0,0 +1,28 @@
linters:
enable:
- bodyclose # checks whether HTTP response body is closed successfully
- errcheck # checks for unchecked errors in go programs
- errname # checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`
- exportloopref # checks for pointers to enclosing loop variables
- gochecknoinits # checks that no init functions are present in Go code
- ifshort # checks that your code uses short syntax for if-statements whenever possible
- importas # enforces consistent import aliases
- ineffassign # detects when assignments to existing variables are not used
- noctx # finds sending http request without context.Context
- paralleltest # detects missing usage of t.Parallel() method in go tests
- prealloc # finds slice declarations that could potentially be preallocated
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- tenv # detects using os.Setenv instead of t.Setenv
- testpackage # makes you use a separate _test package
- thelper # detects golang test helpers without t.Helper() call and checks consistency of test helpers
- unconvert # removes unnecessary type conversions
- unparam # removes unused function parameters
fast: true
issues:
exclude-rules:
# False positive: https://github.com/kunwardeep/paralleltest/issues/8.
- linters:
- paralleltest
text: "does not use range value in test Run"

View file

@ -1,8 +0,0 @@
tap 'heroku/brew'
brew 'heroku'
brew 'foreman'
brew 'golang'
brew 'golang-migrate'
brew 'postgres'
brew 'pgbouncer'

View file

@ -1,218 +0,0 @@
{
"entries": {
"brew": {
"postgres": {
"version": "13.3",
"bottle": {
"rebuild": 0,
"root_url": "https://ghcr.io/v2/homebrew/core",
"files": {
"arm64_big_sur": {
"cellar": "/opt/homebrew/Cellar",
"url": "https://ghcr.io/v2/homebrew/core/postgresql/blobs/sha256:7c0e1b76d60b428facd521c729323221712d7f6d9954e21da389aeeb2c62348e",
"sha256": "7c0e1b76d60b428facd521c729323221712d7f6d9954e21da389aeeb2c62348e"
},
"big_sur": {
"cellar": "/usr/local/Cellar",
"url": "https://ghcr.io/v2/homebrew/core/postgresql/blobs/sha256:eaf28965ead970ecfb327b121ec6a07f0a4e39865797a1a0383605a17e5911e3",
"sha256": "eaf28965ead970ecfb327b121ec6a07f0a4e39865797a1a0383605a17e5911e3"
},
"catalina": {
"cellar": "/usr/local/Cellar",
"url": "https://ghcr.io/v2/homebrew/core/postgresql/blobs/sha256:74e946503c73cd0efc55ad4b373efbd8f4fb8a9e26a670b878c6db25794aea4a",
"sha256": "74e946503c73cd0efc55ad4b373efbd8f4fb8a9e26a670b878c6db25794aea4a"
},
"mojave": {
"cellar": "/usr/local/Cellar",
"url": "https://ghcr.io/v2/homebrew/core/postgresql/blobs/sha256:36c7bde4788571e5b66ffe05b6174b62c69781d61c53c3ebcd9d278e8f148197",
"sha256": "36c7bde4788571e5b66ffe05b6174b62c69781d61c53c3ebcd9d278e8f148197"
}
}
}
},
"golang": {
"version": "1.16.5",
"bottle": {
"rebuild": 0,
"root_url": "https://ghcr.io/v2/homebrew/core",
"files": {
"arm64_big_sur": {
"cellar": "/opt/homebrew/Cellar",
"url": "https://ghcr.io/v2/homebrew/core/go/blobs/sha256:dde21eedfa67da23db70cf977ae82c0cadd5acf2a326cb91853ff54d0cf5886f",
"sha256": "dde21eedfa67da23db70cf977ae82c0cadd5acf2a326cb91853ff54d0cf5886f"
},
"big_sur": {
"cellar": "/usr/local/Cellar",
"url": "https://ghcr.io/v2/homebrew/core/go/blobs/sha256:416c5e2b7247c78482a5465f79d83c0240ee0c9098c8c7429f9c7af073402cc9",
"sha256": "416c5e2b7247c78482a5465f79d83c0240ee0c9098c8c7429f9c7af073402cc9"
},
"catalina": {
"cellar": "/usr/local/Cellar",
"url": "https://ghcr.io/v2/homebrew/core/go/blobs/sha256:8a7564fab7f715feed7506e3cc30f20295fd62914418fb636a5a4c4ca1fc7398",
"sha256": "8a7564fab7f715feed7506e3cc30f20295fd62914418fb636a5a4c4ca1fc7398"
},
"mojave": {
"cellar": "/usr/local/Cellar",
"url": "https://ghcr.io/v2/homebrew/core/go/blobs/sha256:a232e1f840525ab1e9411ba4edaa74c2bb73705e8e6feb7506649a7d608f9292",
"sha256": "a232e1f840525ab1e9411ba4edaa74c2bb73705e8e6feb7506649a7d608f9292"
},
"x86_64_linux": {
"cellar": "/home/linuxbrew/.linuxbrew/Cellar",
"url": "https://ghcr.io/v2/homebrew/core/go/blobs/sha256:1434dfa5cbe0fd0edb34eab477e156640f3f07599d33105958fe18b329bcfb7d",
"sha256": "1434dfa5cbe0fd0edb34eab477e156640f3f07599d33105958fe18b329bcfb7d"
}
}
}
},
"foreman": {
"version": "0.87.2",
"bottle": {
"rebuild": 0,
"root_url": "https://ghcr.io/v2/homebrew/core",
"files": {
"arm64_big_sur": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/foreman/blobs/sha256:575f9fbc16eca16cf479196ce44d87bb817ddb1e2eed59869ffe158d98d08a9f",
"sha256": "575f9fbc16eca16cf479196ce44d87bb817ddb1e2eed59869ffe158d98d08a9f"
},
"big_sur": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/foreman/blobs/sha256:70c762dd642d8f5aa3ca5a28e420b6c9f7befaf7699de073b7d62e174fdee88f",
"sha256": "70c762dd642d8f5aa3ca5a28e420b6c9f7befaf7699de073b7d62e174fdee88f"
},
"catalina": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/foreman/blobs/sha256:5c2b39c1f7e9667b9ebc6b7228b6cf31f06c2261c85019028272cfdda7073ea5",
"sha256": "5c2b39c1f7e9667b9ebc6b7228b6cf31f06c2261c85019028272cfdda7073ea5"
},
"mojave": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/foreman/blobs/sha256:674b5fc005986f47294acedccba6b2a2bcdc1d423e392a356f8d58cc88a2c81a",
"sha256": "674b5fc005986f47294acedccba6b2a2bcdc1d423e392a356f8d58cc88a2c81a"
},
"high_sierra": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/foreman/blobs/sha256:b0d289ff31caf33f3d549af6dd615e37588aadb243355395380c4df5b0e52d63",
"sha256": "b0d289ff31caf33f3d549af6dd615e37588aadb243355395380c4df5b0e52d63"
}
}
}
},
"faktory": {
"version": "1.5.1-1",
"bottle": false
},
"redis": {
"version": "6.2.3",
"bottle": {
"rebuild": 0,
"root_url": "https://ghcr.io/v2/homebrew/core",
"files": {
"arm64_big_sur": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/redis/blobs/sha256:b2b3cfeca2d5f110507e9e7a7af8918786f2853e39e49b0b39de68762e5b5030",
"sha256": "b2b3cfeca2d5f110507e9e7a7af8918786f2853e39e49b0b39de68762e5b5030"
},
"big_sur": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/redis/blobs/sha256:d891c5b376746c3895098fd384fb4edba972b532848f63cbad5be20e611458ac",
"sha256": "d891c5b376746c3895098fd384fb4edba972b532848f63cbad5be20e611458ac"
},
"catalina": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/redis/blobs/sha256:a269e87b26515775a7034d9d6cb996ed63d783b5a6d681b64bab92ce93bed55b",
"sha256": "a269e87b26515775a7034d9d6cb996ed63d783b5a6d681b64bab92ce93bed55b"
},
"mojave": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/redis/blobs/sha256:3373d834552eef5f6c71889299124693de6b5d5b887e520d6db96ab51da81020",
"sha256": "3373d834552eef5f6c71889299124693de6b5d5b887e520d6db96ab51da81020"
}
}
}
},
"golang-migrate": {
"version": "4.14.1",
"bottle": {
"rebuild": 0,
"root_url": "https://ghcr.io/v2/homebrew/core",
"files": {
"arm64_big_sur": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/golang-migrate/blobs/sha256:3565f7a03cfd1eeec3110aa8d56f03baa79b0de2718103c0095e51187ecd37ee",
"sha256": "3565f7a03cfd1eeec3110aa8d56f03baa79b0de2718103c0095e51187ecd37ee"
},
"big_sur": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/golang-migrate/blobs/sha256:5c61a106d9970b0f9b14e78e1523894d57b50cd0473f7d5a1fb1a9161dbff159",
"sha256": "5c61a106d9970b0f9b14e78e1523894d57b50cd0473f7d5a1fb1a9161dbff159"
},
"catalina": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/golang-migrate/blobs/sha256:a77af5282af35e0d073e82140b091eedf0b478c19aea36f1b06738690989cebb",
"sha256": "a77af5282af35e0d073e82140b091eedf0b478c19aea36f1b06738690989cebb"
},
"mojave": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/golang-migrate/blobs/sha256:8fa3758e979f09c171388887c831a6518e3f8df67b07668b6c8cebf76b19a653",
"sha256": "8fa3758e979f09c171388887c831a6518e3f8df67b07668b6c8cebf76b19a653"
}
}
}
},
"pgbouncer": {
"version": "1.15.0",
"bottle": {
"rebuild": 1,
"root_url": "https://ghcr.io/v2/homebrew/core",
"files": {
"arm64_big_sur": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/pgbouncer/blobs/sha256:8107249d240e1a53f6ae84587c08129acf5c294c4022f92d5f1c731ea6956ea3",
"sha256": "8107249d240e1a53f6ae84587c08129acf5c294c4022f92d5f1c731ea6956ea3"
},
"big_sur": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/pgbouncer/blobs/sha256:09f21ff3e7b2c125d793da2ba64110392227650ae8157ef987f041959af8fe7c",
"sha256": "09f21ff3e7b2c125d793da2ba64110392227650ae8157ef987f041959af8fe7c"
},
"catalina": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/pgbouncer/blobs/sha256:fad76f523bac43aaf7859fa0085ab7c6582f9d4aeb682e677db8f5acd9c4159a",
"sha256": "fad76f523bac43aaf7859fa0085ab7c6582f9d4aeb682e677db8f5acd9c4159a"
},
"mojave": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/pgbouncer/blobs/sha256:4187ceded551fad5801a26f790e61dd7d654acc675de73a1b4bf2858920d0734",
"sha256": "4187ceded551fad5801a26f790e61dd7d654acc675de73a1b4bf2858920d0734"
}
}
}
},
"heroku": {
"version": "7.56.0",
"bottle": false
}
},
"tap": {
"contribsys/faktory": {
"revision": "ada2f8b18fe79a40a906a7371d36c1d750422ca3"
},
"heroku/brew": {
"revision": "2a5c611fc8204f37fbe1a8be955c1fc6103edd5a"
}
}
},
"system": {
"macos": {
"big_sur": {
"HOMEBREW_VERSION": "3.2.1",
"HOMEBREW_PREFIX": "/usr/local",
"Homebrew/homebrew-core": "bd8b72f3f6453d9c55c73c36830b61a34b677bb0",
"CLT": "12.5.0.22.11",
"Xcode": "12.5.1",
"macOS": "11.4"
}
}
}
}

19
Makefile Normal file
View file

@ -0,0 +1,19 @@
BREW_PREFIX ?= $(shell brew --prefix)
DATABASE_URL ?= "postgres://$(USER)@localhost/apollo_test?sslmode=disable"
test:
@DATABASE_URL=$(DATABASE_URL) go test -race -v -timeout 1s ./...
test-setup: $(BREW_PREFIX)/bin/migrate
migrate -path migrations/ -database $(DATABASE_URL) up
build:
@go build ./cmd/apollo
lint:
@golangci-lint run
$(BREW_PREFIX)/bin/migrate:
@brew install golang-migrate
.PHONY: all build deps lint test

View file

@ -1,3 +0,0 @@
web: apollo api
scheduler: apollo scheduler
worker-notifications: apollo worker --queue notifications --multiplier 128

View file

@ -1,4 +0,0 @@
DATABASE_URL=postgres://apollo:@localhost/apollo?sslmode=disable
FAKTORY_URL=tcp://localhost:7419
REDDIT_CLIENT_ID=C7MjYkx1czyRDA
REDDIT_CLIENT_SECRET=I2AsVWbrf8h4vdQxVa5Pvf84vScF1w

36
go.mod
View file

@ -1,11 +1,9 @@
module github.com/christianselig/apollo-backend
// +heroku goVersion go1.16
go 1.16
go 1.18
require (
github.com/DataDog/datadog-go v4.8.3+incompatible
github.com/Microsoft/go-winio v0.5.0 // indirect
github.com/adjust/rmq/v4 v4.0.5
github.com/bugsnag/bugsnag-go/v2 v2.1.2
github.com/dustin/go-humanize v1.0.0
@ -15,12 +13,42 @@ require (
github.com/go-redis/redismock/v8 v8.0.6
github.com/gorilla/mux v1.8.0
github.com/heroku/x v0.0.50
github.com/jackc/pgconn v1.12.0
github.com/jackc/pgx/v4 v4.16.0
github.com/joho/godotenv v1.4.0
github.com/sideshow/apns2 v0.23.0
github.com/sirupsen/logrus v1.8.1
github.com/smtp2go-oss/smtp2go-go v1.0.1 // indirect
github.com/smtp2go-oss/smtp2go-go v1.0.1
github.com/spf13/cobra v1.4.0
github.com/stretchr/testify v1.7.1
github.com/valyala/fastjson v1.6.3
)
require (
github.com/Microsoft/go-winio v0.5.0 // indirect
github.com/bugsnag/panicwrap v1.3.4 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gofrs/uuid v4.0.0+incompatible // indirect
github.com/golang-jwt/jwt/v4 v4.4.1 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.3.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/pgtype v1.11.0 // indirect
github.com/jackc/puddle v1.2.1 // indirect
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect
golang.org/x/text v0.3.7 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)

18
go.sum
View file

@ -32,14 +32,10 @@ github.com/bugsnag/panicwrap v1.3.4 h1:A6sXFtDGsgU/4BLf5JT0o5uYg3EeKgGx3Sfs+/uk3
github.com/bugsnag/panicwrap v1.3.4/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE=
github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
@ -92,7 +88,6 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
github.com/go-redis/redismock/v8 v8.0.6 h1:rtuijPgGynsRB2Y7KDACm09WvjHWS4RaG44Nm7rcj4Y=
github.com/go-redis/redismock/v8 v8.0.6/go.mod h1:sDIF73OVsmaKzYe/1FJXGiCQ4+oHYbzjpaL9Vor0sS4=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
@ -133,7 +128,6 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/gops v0.3.22/go.mod h1:7diIdLsqpCihPSX3fQagksT/Ku/y4RL9LHTlKyEUDl8=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
@ -150,10 +144,8 @@ github.com/heroku/x v0.0.50 h1:CA0AXkSumucVJD+T+x+6c7X1iDEb+40F8GNgH5UjJwo=
github.com/heroku/x v0.0.50/go.mod h1:vr+jORZ6sG3wgEq2FAS6UbOUrz9/DxpQGN/xPHVgbSM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/hydrogen18/memlistener v0.0.0-20141126152155-54553eb933fb/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
@ -174,7 +166,6 @@ github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5W
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
@ -241,23 +232,17 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY=
github.com/onsi/ginkgo v1.15.0/go.mod h1:hF8qUzuuC8DJGygJH3726JnCZX4MYbRB8yFfISqnKUg=
github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.0.0/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc=
github.com/onsi/gomega v1.10.5/go.mod h1:gza4q3jKQJijlu05nKWRCW/GavJumGt8aNRxWg7mt48=
github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.2-0.20190227000051-27936f6d90f9/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -411,7 +396,6 @@ golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b h1:vI32FkLJNAWtGD4BwkThwEy6XS7ZLLMHkSkYfF8M0W0=
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@ -445,7 +429,6 @@ golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -454,7 +437,6 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210816074244-15123e1e1f71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View file

@ -143,7 +143,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
delete(accsMap, acc.NormalizedUsername())
ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
tokens, err := ac.RefreshTokens()
tokens, err := ac.RefreshTokens(ctx)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
return
@ -155,7 +155,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
acc.AccessToken = tokens.AccessToken
ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acc.RefreshToken, acc.AccessToken)
me, err := ac.Me()
me, err := ac.Me(ctx)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
@ -186,9 +186,9 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
_ = a.accountRepo.Disassociate(ctx, &acc, &dev)
}
go func(apns string) {
go func(ctx context.Context, apns string) {
url := fmt.Sprintf("https://apollopushserver.xyz/api/new-server-addition?apns_token=%s", apns)
req, err := http.NewRequest("POST", url, nil)
req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
req.Header.Set("Authentication", "Bearer 98g5j89aurqwfcsp9khlnvgd38fa15")
if err != nil {
@ -200,7 +200,7 @@ func (a *api) upsertAccountsHandler(w http.ResponseWriter, r *http.Request) {
resp, _ := a.httpClient.Do(req)
resp.Body.Close()
}(apns)
}(ctx, apns)
w.WriteHeader(http.StatusOK)
}
@ -221,7 +221,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
// Here we check whether the account is supplied with a valid token.
ac := a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
tokens, err := ac.RefreshTokens()
tokens, err := ac.RefreshTokens(ctx)
if err != nil {
a.logger.WithFields(logrus.Fields{
"err": err,
@ -236,7 +236,7 @@ func (a *api) upsertAccountHandler(w http.ResponseWriter, r *http.Request) {
acct.AccessToken = tokens.AccessToken
ac = a.reddit.NewAuthenticatedClient(reddit.SkipRateLimiting, acct.RefreshToken, acct.AccessToken)
me, err := ac.Me()
me, err := ac.Me(ctx)
if err != nil {
a.logger.WithFields(logrus.Fields{

View file

@ -26,8 +26,7 @@ func (a *api) contactHandler(w http.ResponseWriter, r *http.Request) {
TextBody: smr.Body,
}
_, err := smtp2go.Send(msg)
if err != nil {
if _, err := smtp2go.Send(msg); err != nil {
a.errorResponse(w, r, 500, err.Error())
return
}

View file

@ -2,7 +2,7 @@ package api
import "net/http"
func (a *api) errorResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
func (a *api) errorResponse(w http.ResponseWriter, _ *http.Request, status int, message string) {
w.Header().Set("X-Apollo-Error", message)
http.Error(w, message, status)
}

View file

@ -104,7 +104,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
}
if cwr.Type == "subreddit" || cwr.Type == "trending" {
srr, err := ac.SubredditAbout(cwr.Subreddit)
srr, err := ac.SubredditAbout(ctx, cwr.Subreddit)
if err != nil {
a.errorResponse(w, r, 422, err.Error())
return
@ -133,7 +133,7 @@ func (a *api) createWatcherHandler(w http.ResponseWriter, r *http.Request) {
watcher.WatcheeID = sr.ID
} else if cwr.Type == "user" {
urr, err := ac.UserAbout(cwr.User)
urr, err := ac.UserAbout(ctx, cwr.User)
if err != nil {
a.errorResponse(w, r, 500, err.Error())
return

View file

@ -89,9 +89,9 @@ func SchedulerCmd(ctx context.Context) *cobra.Command {
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueAccounts(ctx, logger, statsd, db, redis, luaSha, notifQueue) })
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueSubreddits(ctx, logger, statsd, db, []rmq.Queue{subredditQueue, trendingQueue}) })
_, _ = s.Every(200).Milliseconds().SingletonMode().Do(func() { enqueueUsers(ctx, logger, statsd, db, userQueue) })
_, _ = s.Every(1).Second().Do(func() { cleanQueues(ctx, logger, queue) })
_, _ = s.Every(1).Second().Do(func() { cleanQueues(logger, queue) })
_, _ = s.Every(1).Second().Do(func() { enqueueStuckAccounts(ctx, logger, statsd, db, stuckNotificationsQueue) })
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db, redis) })
_, _ = s.Every(1).Minute().Do(func() { reportStats(ctx, logger, statsd, db) })
_, _ = s.Every(1).Minute().Do(func() { pruneAccounts(ctx, logger, db) })
_, _ = s.Every(1).Minute().Do(func() { pruneDevices(ctx, logger, db) })
s.StartAsync()
@ -146,9 +146,7 @@ func pruneAccounts(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Poo
return
}
count := stale + orphaned
if count > 0 {
if count := stale + orphaned; count > 0 {
logger.WithFields(logrus.Fields{
"stale": stale,
"orphaned": orphaned,
@ -175,7 +173,7 @@ func pruneDevices(ctx context.Context, logger *logrus.Logger, pool *pgxpool.Pool
}
}
func cleanQueues(ctx context.Context, logger *logrus.Logger, jobsConn rmq.Connection) {
func cleanQueues(logger *logrus.Logger, jobsConn rmq.Connection) {
cleaner := rmq.NewCleaner(jobsConn)
count, err := cleaner.Clean()
if err != nil {
@ -192,7 +190,7 @@ func cleanQueues(ctx context.Context, logger *logrus.Logger, jobsConn rmq.Connec
}
}
func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool, redisConn *redis.Client) {
func reportStats(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, pool *pgxpool.Pool) {
var (
count int64

View file

@ -68,7 +68,7 @@ func WorkerCmd(ctx context.Context) *cobra.Command {
return fmt.Errorf("invalid queue: %s", queueID)
}
worker := workerFn(logger, statsd, db, redis, queue, consumers)
worker := workerFn(ctx, logger, statsd, db, redis, queue, consumers)
if err := worker.Start(); err != nil {
return err
}

View file

@ -1,4 +1,4 @@
package domain
package domain_test
import (
"errors"
@ -6,21 +6,27 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/christianselig/apollo-backend/internal/domain"
)
func TestValidate(t *testing.T) {
t.Parallel()
tests := map[string]struct {
subreddit Subreddit
subreddit domain.Subreddit
err error
}{
"invalid subreddit prefix": {Subreddit{Name: "u_iamthatis"}, errors.New("invalid subreddit format")},
"valid subreddit": {Subreddit{Name: "pics", SubredditID: "abcd"}, nil},
"valid subreddit starting with u": {Subreddit{Name: "urcool", SubredditID: "abcd"}, nil},
"valid subreddit with _": {Subreddit{Name: "p_i_x_a_r", SubredditID: "abcd"}, nil},
"invalid subreddit prefix": {domain.Subreddit{Name: "u_iamthatis"}, errors.New("invalid subreddit format")},
"valid subreddit": {domain.Subreddit{Name: "pics", SubredditID: "abcd"}, nil},
"valid subreddit starting with u": {domain.Subreddit{Name: "urcool", SubredditID: "abcd"}, nil},
"valid subreddit with _": {domain.Subreddit{Name: "p_i_x_a_r", SubredditID: "abcd"}, nil},
}
for scenario, tc := range tests {
t.Run(scenario, func(t *testing.T) {
t.Parallel()
err := tc.subreddit.Validate()
if tc.err == nil {

View file

@ -2,6 +2,7 @@ package itunes
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net"
@ -168,6 +169,8 @@ type (
)
func NewIAPResponse(receipt string, production bool) (*IAPResponse, error) {
ctx := context.Background()
// Send the receipt data string off to Apple's servers to verify
appleVerificationURL := "https://buy.itunes.apple.com/verifyReceipt"
@ -186,7 +189,7 @@ func NewIAPResponse(receipt string, production bool) (*IAPResponse, error) {
return nil, err
}
request, requestErr := http.NewRequest("POST", appleVerificationURL, bytes.NewBuffer(bb))
request, requestErr := http.NewRequestWithContext(ctx, "POST", appleVerificationURL, bytes.NewBuffer(bb))
if requestErr != nil {
fmt.Println(requestErr)
@ -364,7 +367,7 @@ func (iapr *IAPResponse) handleAppleResponse() {
mostRecentTransactionUnixTimestamp := mostRecentTransaction.ExpiresDateMS / 1000
// Check if it's not active
currentTimeUnixTimestamp := int64(time.Now().Unix())
currentTimeUnixTimestamp := time.Now().Unix()
if mostRecentTransactionUnixTimestamp < currentTimeUnixTimestamp {
if len(iapr.PendingRenewalInfo) > 0 && iapr.PendingRenewalInfo[0].SubscriptionAutoRenewStatus == "0" {

View file

@ -127,17 +127,22 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
}
func (rc *Client) doRequest(r *Request) ([]byte, *RateLimitingInfo, error) {
req, err := r.HTTPRequest()
func (rc *Client) doRequest(ctx context.Context, r *Request) ([]byte, *RateLimitingInfo, error) {
req, err := r.HTTPRequest(ctx)
if err != nil {
return nil, nil, err
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
req = req.WithContext(httptrace.WithClientTrace(ctx, rc.tracer))
start := time.Now()
resp, err := rc.client.Do(req)
client := r.client
if client == nil {
client = rc.client
}
resp, err := client.Do(req)
_ = rc.statsd.Incr("reddit.api.calls", r.tags, 0.1)
_ = rc.statsd.Histogram("reddit.api.latency", float64(time.Since(start).Milliseconds()), r.tags, 0.1)
@ -173,7 +178,7 @@ func (rc *Client) doRequest(r *Request) ([]byte, *RateLimitingInfo, error) {
return bb, rli, nil
}
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
func (rac *AuthenticatedClient) request(ctx context.Context, r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
if rac.isRateLimited() {
return nil, ErrRateLimited
}
@ -182,7 +187,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
return nil, err
}
bb, rli, err := rac.doRequest(r)
bb, rli, err := rac.doRequest(ctx, r)
if err != nil && r.retry {
for _, backoff := range backoffSchedule {
@ -196,7 +201,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
return
}
bb, rli, err = rac.doRequest(r)
bb, rli, err = rac.doRequest(ctx, r)
done <- struct{}{}
})
@ -281,7 +286,7 @@ func (rac *AuthenticatedClient) markRateLimited(rli *RateLimitingInfo) error {
return err
}
func (rac *AuthenticatedClient) RefreshTokens(opts ...RequestOption) (*RefreshTokenResponse, error) {
func (rac *AuthenticatedClient) RefreshTokens(ctx context.Context, opts ...RequestOption) (*RefreshTokenResponse, error) {
opts = append(rac.defaultOpts, opts...)
opts = append(opts, []RequestOption{
WithTags([]string{"url:/api/v1/access_token"}),
@ -293,7 +298,7 @@ func (rac *AuthenticatedClient) RefreshTokens(opts ...RequestOption) (*RefreshTo
}...)
req := NewRequest(opts...)
rtr, err := rac.request(req, NewRefreshTokenResponse, nil)
rtr, err := rac.request(ctx, req, NewRefreshTokenResponse, nil)
if err != nil {
switch rerr := err.(type) {
case ServerError:
@ -313,7 +318,7 @@ func (rac *AuthenticatedClient) RefreshTokens(opts ...RequestOption) (*RefreshTo
return ret, nil
}
func (rac *AuthenticatedClient) AboutInfo(fullname string, opts ...RequestOption) (*ListingResponse, error) {
func (rac *AuthenticatedClient) AboutInfo(ctx context.Context, fullname string, opts ...RequestOption) (*ListingResponse, error) {
opts = append(rac.defaultOpts, opts...)
opts = append(opts, []RequestOption{
WithMethod("GET"),
@ -323,7 +328,7 @@ func (rac *AuthenticatedClient) AboutInfo(fullname string, opts ...RequestOption
}...)
req := NewRequest(opts...)
lr, err := rac.request(req, NewListingResponse, nil)
lr, err := rac.request(ctx, req, NewListingResponse, nil)
if err != nil {
return nil, err
}
@ -331,7 +336,7 @@ func (rac *AuthenticatedClient) AboutInfo(fullname string, opts ...RequestOption
return lr.(*ListingResponse), nil
}
func (rac *AuthenticatedClient) UserPosts(user string, opts ...RequestOption) (*ListingResponse, error) {
func (rac *AuthenticatedClient) UserPosts(ctx context.Context, user string, opts ...RequestOption) (*ListingResponse, error) {
url := fmt.Sprintf("https://oauth.reddit.com/u/%s/submitted", user)
opts = append(rac.defaultOpts, opts...)
opts = append(opts, []RequestOption{
@ -341,7 +346,7 @@ func (rac *AuthenticatedClient) UserPosts(user string, opts ...RequestOption) (*
}...)
req := NewRequest(opts...)
lr, err := rac.request(req, NewListingResponse, nil)
lr, err := rac.request(ctx, req, NewListingResponse, nil)
if err != nil {
return nil, err
}
@ -349,7 +354,7 @@ func (rac *AuthenticatedClient) UserPosts(user string, opts ...RequestOption) (*
return lr.(*ListingResponse), nil
}
func (rac *AuthenticatedClient) UserAbout(user string, opts ...RequestOption) (*UserResponse, error) {
func (rac *AuthenticatedClient) UserAbout(ctx context.Context, user string, opts ...RequestOption) (*UserResponse, error) {
url := fmt.Sprintf("https://oauth.reddit.com/u/%s/about", user)
opts = append(rac.defaultOpts, opts...)
opts = append(opts, []RequestOption{
@ -358,7 +363,7 @@ func (rac *AuthenticatedClient) UserAbout(user string, opts ...RequestOption) (*
WithURL(url),
}...)
req := NewRequest(opts...)
ur, err := rac.request(req, NewUserResponse, nil)
ur, err := rac.request(ctx, req, NewUserResponse, nil)
if err != nil {
return nil, err
@ -368,7 +373,7 @@ func (rac *AuthenticatedClient) UserAbout(user string, opts ...RequestOption) (*
}
func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...RequestOption) (*SubredditResponse, error) {
func (rac *AuthenticatedClient) SubredditAbout(ctx context.Context, subreddit string, opts ...RequestOption) (*SubredditResponse, error) {
url := fmt.Sprintf("https://oauth.reddit.com/r/%s/about", subreddit)
opts = append(rac.defaultOpts, opts...)
opts = append(opts, []RequestOption{
@ -377,7 +382,7 @@ func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...Request
WithURL(url),
}...)
req := NewRequest(opts...)
sr, err := rac.request(req, NewSubredditResponse, nil)
sr, err := rac.request(ctx, req, NewSubredditResponse, nil)
if err != nil {
return nil, err
@ -386,7 +391,7 @@ func (rac *AuthenticatedClient) SubredditAbout(subreddit string, opts ...Request
return sr.(*SubredditResponse), nil
}
func (rac *AuthenticatedClient) subredditPosts(subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {
func (rac *AuthenticatedClient) subredditPosts(ctx context.Context, subreddit string, sort string, opts ...RequestOption) (*ListingResponse, error) {
url := fmt.Sprintf("https://oauth.reddit.com/r/%s/%s", subreddit, sort)
opts = append(rac.defaultOpts, opts...)
opts = append(opts, []RequestOption{
@ -396,7 +401,7 @@ func (rac *AuthenticatedClient) subredditPosts(subreddit string, sort string, op
}...)
req := NewRequest(opts...)
lr, err := rac.request(req, NewListingResponse, nil)
lr, err := rac.request(ctx, req, NewListingResponse, nil)
if err != nil {
return nil, err
}
@ -404,19 +409,19 @@ func (rac *AuthenticatedClient) subredditPosts(subreddit string, sort string, op
return lr.(*ListingResponse), nil
}
func (rac *AuthenticatedClient) SubredditHot(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(subreddit, "hot", opts...)
func (rac *AuthenticatedClient) SubredditHot(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(ctx, subreddit, "hot", opts...)
}
func (rac *AuthenticatedClient) SubredditTop(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(subreddit, "top", opts...)
func (rac *AuthenticatedClient) SubredditTop(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(ctx, subreddit, "top", opts...)
}
func (rac *AuthenticatedClient) SubredditNew(subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(subreddit, "new", opts...)
func (rac *AuthenticatedClient) SubredditNew(ctx context.Context, subreddit string, opts ...RequestOption) (*ListingResponse, error) {
return rac.subredditPosts(ctx, subreddit, "new", opts...)
}
func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) {
func (rac *AuthenticatedClient) MessageInbox(ctx context.Context, opts ...RequestOption) (*ListingResponse, error) {
opts = append(rac.defaultOpts, opts...)
opts = append(opts, []RequestOption{
WithTags([]string{"url:/api/v1/message/inbox"}),
@ -427,7 +432,7 @@ func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingRes
}...)
req := NewRequest(opts...)
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
lr, err := rac.request(ctx, req, NewListingResponse, EmptyListingResponse)
if err != nil {
switch rerr := err.(type) {
case ServerError:
@ -441,7 +446,7 @@ func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingRes
return lr.(*ListingResponse), nil
}
func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingResponse, error) {
func (rac *AuthenticatedClient) MessageUnread(ctx context.Context, opts ...RequestOption) (*ListingResponse, error) {
opts = append(rac.defaultOpts, opts...)
opts = append(opts, []RequestOption{
WithTags([]string{"url:/api/v1/message/unread"}),
@ -453,7 +458,7 @@ func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingRe
req := NewRequest(opts...)
lr, err := rac.request(req, NewListingResponse, EmptyListingResponse)
lr, err := rac.request(ctx, req, NewListingResponse, EmptyListingResponse)
if err != nil {
switch rerr := err.(type) {
case ServerError:
@ -467,7 +472,7 @@ func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingRe
return lr.(*ListingResponse), nil
}
func (rac *AuthenticatedClient) Me(opts ...RequestOption) (*MeResponse, error) {
func (rac *AuthenticatedClient) Me(ctx context.Context, opts ...RequestOption) (*MeResponse, error) {
opts = append(rac.defaultOpts, opts...)
opts = append(opts, []RequestOption{
WithTags([]string{"url:/api/v1/me"}),
@ -477,7 +482,7 @@ func (rac *AuthenticatedClient) Me(opts ...RequestOption) (*MeResponse, error) {
}...)
req := NewRequest(opts...)
mr, err := rac.request(req, NewMeResponse, nil)
mr, err := rac.request(ctx, req, NewMeResponse, nil)
if err != nil {
switch rerr := err.(type) {
case ServerError:

View file

@ -1,12 +1,14 @@
package reddit
package reddit_test
import (
"bytes"
"context"
"io/ioutil"
"net/http"
"testing"
"github.com/DataDog/datadog-go/statsd"
"github.com/christianselig/apollo-backend/internal/reddit"
"github.com/go-redis/redismock/v8"
"github.com/stretchr/testify/assert"
)
@ -21,35 +23,37 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
//NewTestClient returns *http.Client with Transport replaced to avoid making real calls
func NewTestClient(fn RoundTripFunc) *http.Client {
return &http.Client{
Transport: RoundTripFunc(fn),
}
return &http.Client{Transport: fn}
}
func TestErrorResponse(t *testing.T) {
t.Parallel()
ctx := context.Background()
db, _ := redismock.NewClientMock()
rc := NewClient("", "", &statsd.NoOpClient{}, db, 1, WithRetry(false))
rac := rc.NewAuthenticatedClient(SkipRateLimiting, "", "")
errortests := []struct {
name string
call func() error
errortests := map[string]struct {
call func(*reddit.AuthenticatedClient) error
status int
body string
err error
}{
{"/api/v1/me 500 returns ServerError", func() error { _, err := rac.Me(); return err }, 500, "", ServerError{500}},
{"/api/v1/access_token 400 returns ErrOauthRevoked", func() error { _, err := rac.RefreshTokens(); return err }, 400, "", ErrOauthRevoked},
{"/api/v1/message/inbox 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageInbox(); return err }, 403, "", ErrOauthRevoked},
{"/api/v1/message/unread 403 returns ErrOauthRevoked", func() error { _, err := rac.MessageUnread(); return err }, 403, "", ErrOauthRevoked},
{"/api/v1/me 403 returns ErrOauthRevoked", func() error { _, err := rac.Me(); return err }, 403, "", ErrOauthRevoked},
"/api/v1/me 500 returns ServerError": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 500, "", reddit.ServerError{500}},
"/api/v1/access_token 400 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.RefreshTokens(ctx); return err }, 400, "", reddit.ErrOauthRevoked},
"/api/v1/message/inbox 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageInbox(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
"/api/v1/message/unread 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.MessageUnread(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
"/api/v1/me 403 returns ErrOauthRevoked": {func(rac *reddit.AuthenticatedClient) error { _, err := rac.Me(ctx); return err }, 403, "", reddit.ErrOauthRevoked},
}
for _, tt := range errortests {
t.Run(tt.name, func(t *testing.T) {
rac.client = NewTestClient(func(req *http.Request) *http.Response {
for scenario, tt := range errortests {
tt := tt
t.Run(scenario, func(t *testing.T) {
t.Parallel()
tc := NewTestClient(func(req *http.Request) *http.Response {
return &http.Response{
StatusCode: tt.status,
Body: ioutil.NopCloser(bytes.NewBufferString(tt.body)),
@ -57,7 +61,10 @@ func TestErrorResponse(t *testing.T) {
}
})
err := tt.call()
rc := reddit.NewClient("", "", &statsd.NoOpClient{}, db, 1, reddit.WithRetry(false), reddit.WithClient(tc))
rac := rc.NewAuthenticatedClient(reddit.SkipRateLimiting, "", "")
err := tt.call(rac)
assert.ErrorIs(t, err, tt.err)
})

View file

@ -1,6 +1,7 @@
package reddit
import (
"context"
"encoding/base64"
"fmt"
"net/http"
@ -20,6 +21,7 @@ type Request struct {
tags []string
emptyResponseBytes int
retry bool
client *http.Client
}
type RequestOption func(*Request)
@ -38,6 +40,7 @@ func NewRequest(opts ...RequestOption) *Request {
emptyResponseBytes: 0,
retry: true,
client: nil,
}
for _, opt := range opts {
@ -47,8 +50,8 @@ func NewRequest(opts ...RequestOption) *Request {
return req
}
func (r *Request) HTTPRequest() (*http.Request, error) {
req, err := http.NewRequest(r.method, r.url, strings.NewReader(r.body.Encode()))
func (r *Request) HTTPRequest(ctx context.Context) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, r.method, r.url, strings.NewReader(r.body.Encode()))
req.URL.RawQuery = r.query.Encode()
req.Header.Add("Accept", "application/json")
@ -123,3 +126,9 @@ func WithRetry(retry bool) RequestOption {
req.retry = retry
}
}
func WithClient(client *http.Client) RequestOption {
return func(req *Request) {
req.client = client
}
}

View file

@ -1,27 +1,41 @@
package reddit
package reddit_test
import (
"io/ioutil"
"testing"
"time"
"github.com/christianselig/apollo-backend/internal/reddit"
"github.com/stretchr/testify/assert"
"github.com/valyala/fastjson"
)
var (
parser = &fastjson.Parser{}
)
var pool = &fastjson.ParserPool{}
func NewTestParser(t *testing.T) *fastjson.Parser {
t.Helper()
parser := pool.Get()
t.Cleanup(func() {
pool.Put(parser)
})
return parser
}
func TestMeResponseParsing(t *testing.T) {
t.Parallel()
bb, err := ioutil.ReadFile("testdata/me.json")
assert.NoError(t, err)
parser := NewTestParser(t)
val, err := parser.ParseBytes(bb)
assert.NoError(t, err)
ret := NewMeResponse(val)
me := ret.(*MeResponse)
ret := reddit.NewMeResponse(val)
me := ret.(*reddit.MeResponse)
assert.NotNil(t, me)
assert.Equal(t, "xgeee", me.ID)
@ -29,14 +43,17 @@ func TestMeResponseParsing(t *testing.T) {
}
func TestRefreshTokenResponseParsing(t *testing.T) {
t.Parallel()
bb, err := ioutil.ReadFile("testdata/refresh_token.json")
assert.NoError(t, err)
parser := NewTestParser(t)
val, err := parser.ParseBytes(bb)
assert.NoError(t, err)
ret := NewRefreshTokenResponse(val)
rtr := ret.(*RefreshTokenResponse)
ret := reddit.NewRefreshTokenResponse(val)
rtr := ret.(*reddit.RefreshTokenResponse)
assert.NotNil(t, rtr)
assert.Equal(t, "***REMOVED***", rtr.AccessToken)
@ -45,15 +62,18 @@ func TestRefreshTokenResponseParsing(t *testing.T) {
}
func TestListingResponseParsing(t *testing.T) {
t.Parallel()
// Message list
bb, err := ioutil.ReadFile("testdata/message_inbox.json")
assert.NoError(t, err)
parser := NewTestParser(t)
val, err := parser.ParseBytes(bb)
assert.NoError(t, err)
ret := NewListingResponse(val)
l := ret.(*ListingResponse)
ret := reddit.NewListingResponse(val)
l := ret.(*reddit.ListingResponse)
assert.NotNil(t, l)
assert.Equal(t, 25, l.Count)
@ -62,7 +82,7 @@ func TestListingResponseParsing(t *testing.T) {
assert.Equal(t, "", l.Before)
thing := l.Children[0]
created := time.Time(time.Date(2021, time.July, 14, 17, 56, 35, 0, time.UTC))
created := time.Date(2021, time.July, 14, 17, 56, 35, 0, time.UTC)
assert.Equal(t, "t4", thing.Kind)
assert.Equal(t, "138z6ke", thing.ID)
assert.Equal(t, "unknown", thing.Type)
@ -86,8 +106,8 @@ func TestListingResponseParsing(t *testing.T) {
val, err = parser.ParseBytes(bb)
assert.NoError(t, err)
ret = NewListingResponse(val)
l = ret.(*ListingResponse)
ret = reddit.NewListingResponse(val)
l = ret.(*reddit.ListingResponse)
assert.NotNil(t, l)
assert.Equal(t, 100, l.Count)
@ -100,14 +120,17 @@ func TestListingResponseParsing(t *testing.T) {
}
func TestSubredditResponseParsing(t *testing.T) {
t.Parallel()
bb, err := ioutil.ReadFile("testdata/subreddit_about.json")
assert.NoError(t, err)
parser := NewTestParser(t)
val, err := parser.ParseBytes(bb)
assert.NoError(t, err)
ret := NewSubredditResponse(val)
s := ret.(*SubredditResponse)
ret := reddit.NewSubredditResponse(val)
s := ret.(*reddit.SubredditResponse)
assert.NotNil(t, s)
assert.Equal(t, "t5", s.Kind)
@ -116,14 +139,17 @@ func TestSubredditResponseParsing(t *testing.T) {
}
func TestUserResponseParsing(t *testing.T) {
t.Parallel()
bb, err := ioutil.ReadFile("testdata/user_about.json")
assert.NoError(t, err)
parser := NewTestParser(t)
val, err := parser.ParseBytes(bb)
assert.NoError(t, err)
ret := NewUserResponse(val)
u := ret.(*UserResponse)
ret := reddit.NewUserResponse(val)
u := ret.(*reddit.UserResponse)
assert.NotNil(t, u)
assert.Equal(t, "t2", u.Kind)
@ -133,14 +159,17 @@ func TestUserResponseParsing(t *testing.T) {
}
func TestUserPostsParsing(t *testing.T) {
t.Parallel()
bb, err := ioutil.ReadFile("testdata/user_posts.json")
assert.NoError(t, err)
parser := NewTestParser(t)
val, err := parser.ParseBytes(bb)
assert.NoError(t, err)
ret := NewListingResponse(val)
ps := ret.(*ListingResponse)
ret := reddit.NewListingResponse(val)
ps := ret.(*reddit.ListingResponse)
assert.NotNil(t, ps)
post := ps.Children[0]

View file

@ -0,0 +1,14 @@
package repository
import (
"context"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
)
type Connection interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
}

View file

@ -5,21 +5,19 @@ import (
"fmt"
"time"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/christianselig/apollo-backend/internal/domain"
)
type postgresAccountRepository struct {
pool *pgxpool.Pool
conn Connection
}
func NewPostgresAccount(pool *pgxpool.Pool) domain.AccountRepository {
return &postgresAccountRepository{pool: pool}
func NewPostgresAccount(conn Connection) domain.AccountRepository {
return &postgresAccountRepository{conn: conn}
}
func (p *postgresAccountRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Account, error) {
rows, err := p.pool.Query(ctx, query, args...)
rows, err := p.conn.Query(ctx, query, args...)
if err != nil {
return nil, err
}
@ -96,7 +94,7 @@ func (p *postgresAccountRepository) CreateOrUpdate(ctx context.Context, acc *dom
token_expires_at = $5
RETURNING id`
return p.pool.QueryRow(
return p.conn.QueryRow(
ctx,
query,
acc.Username,
@ -115,7 +113,7 @@ func (p *postgresAccountRepository) Create(ctx context.Context, acc *domain.Acco
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id`
return p.pool.QueryRow(
return p.conn.QueryRow(
ctx,
query,
acc.Username,
@ -143,7 +141,7 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
check_count = $10
WHERE id = $1`
res, err := p.pool.Exec(
res, err := p.conn.Exec(
ctx,
query,
acc.ID,
@ -166,7 +164,7 @@ func (p *postgresAccountRepository) Update(ctx context.Context, acc *domain.Acco
func (p *postgresAccountRepository) Delete(ctx context.Context, id int64) error {
query := `DELETE FROM accounts WHERE id = $1`
res, err := p.pool.Exec(ctx, query, id)
res, err := p.conn.Exec(ctx, query, id)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -180,13 +178,13 @@ func (p *postgresAccountRepository) Associate(ctx context.Context, acc *domain.A
(account_id, device_id)
VALUES ($1, $2)
ON CONFLICT(account_id, device_id) DO NOTHING`
_, err := p.pool.Exec(ctx, query, acc.ID, dev.ID)
_, err := p.conn.Exec(ctx, query, acc.ID, dev.ID)
return err
}
func (p *postgresAccountRepository) Disassociate(ctx context.Context, acc *domain.Account, dev *domain.Device) error {
query := `DELETE FROM devices_accounts WHERE account_id = $1 AND device_id = $2`
res, err := p.pool.Exec(ctx, query, acc.ID, dev.ID)
res, err := p.conn.Exec(ctx, query, acc.ID, dev.ID)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -212,7 +210,7 @@ func (p *postgresAccountRepository) PruneStale(ctx context.Context, expiry time.
DELETE FROM accounts
WHERE token_expires_at < $1`
res, err := p.pool.Exec(ctx, query, expiry)
res, err := p.conn.Exec(ctx, query, expiry)
return res.RowsAffected(), err
}
@ -231,7 +229,7 @@ func (p *postgresAccountRepository) PruneOrphaned(ctx context.Context) (int64, e
WHERE device_count = 0
)`
res, err := p.pool.Exec(ctx, query)
res, err := p.conn.Exec(ctx, query)
return res.RowsAffected(), err
}

View file

@ -5,21 +5,19 @@ import (
"fmt"
"time"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/christianselig/apollo-backend/internal/domain"
)
type postgresDeviceRepository struct {
pool *pgxpool.Pool
conn Connection
}
func NewPostgresDevice(pool *pgxpool.Pool) domain.DeviceRepository {
return &postgresDeviceRepository{pool: pool}
func NewPostgresDevice(conn Connection) domain.DeviceRepository {
return &postgresDeviceRepository{conn: conn}
}
func (p *postgresDeviceRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Device, error) {
rows, err := p.pool.Query(ctx, query, args...)
rows, err := p.conn.Query(ctx, query, args...)
if err != nil {
return nil, err
}
@ -118,7 +116,7 @@ func (p *postgresDeviceRepository) CreateOrUpdate(ctx context.Context, dev *doma
UPDATE SET expires_at = $3, grace_period_expires_at = $4
RETURNING id`
return p.pool.QueryRow(
return p.conn.QueryRow(
ctx,
query,
dev.APNSToken,
@ -139,7 +137,7 @@ func (p *postgresDeviceRepository) Create(ctx context.Context, dev *domain.Devic
VALUES ($1, $2, $3, $4)
RETURNING id`
return p.pool.QueryRow(
return p.conn.QueryRow(
ctx,
query,
dev.APNSToken,
@ -159,7 +157,7 @@ func (p *postgresDeviceRepository) Update(ctx context.Context, dev *domain.Devic
SET expires_at = $2, grace_period_expires_at = $3
WHERE id = $1`
res, err := p.pool.Exec(ctx, query, dev.ID, dev.ExpiresAt, dev.GracePeriodExpiresAt)
res, err := p.conn.Exec(ctx, query, dev.ID, dev.ExpiresAt, dev.GracePeriodExpiresAt)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -170,7 +168,7 @@ func (p *postgresDeviceRepository) Update(ctx context.Context, dev *domain.Devic
func (p *postgresDeviceRepository) Delete(ctx context.Context, token string) error {
query := `DELETE FROM devices WHERE apns_token = $1`
res, err := p.pool.Exec(ctx, query, token)
res, err := p.conn.Exec(ctx, query, token)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -187,7 +185,7 @@ func (p *postgresDeviceRepository) SetNotifiable(ctx context.Context, dev *domai
global_mute = $3
WHERE device_id = $4 AND account_id = $5`
res, err := p.pool.Exec(ctx, query, inbox, watcher, global, dev.ID, acct.ID)
res, err := p.conn.Exec(ctx, query, inbox, watcher, global, dev.ID, acct.ID)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -203,7 +201,7 @@ func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domai
WHERE device_id = $1 AND account_id = $2`
var inbox, watcher, global bool
if err := p.pool.QueryRow(ctx, query, dev.ID, acct.ID).Scan(&inbox, &watcher, &global); err != nil {
if err := p.conn.QueryRow(ctx, query, dev.ID, acct.ID).Scan(&inbox, &watcher, &global); err != nil {
return false, false, false, domain.ErrNotFound
}
@ -213,7 +211,7 @@ func (p *postgresDeviceRepository) GetNotifiable(ctx context.Context, dev *domai
func (p *postgresDeviceRepository) PruneStale(ctx context.Context, expiry time.Time) (int64, error) {
query := `DELETE FROM devices WHERE grace_period_expires_at < $1`
res, err := p.pool.Exec(ctx, query, expiry)
res, err := p.conn.Exec(ctx, query, expiry)
return res.RowsAffected(), err
}

View file

@ -0,0 +1,135 @@
package repository_test
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/repository"
"github.com/christianselig/apollo-backend/internal/testhelper"
)
const testToken = "313a182b63224821f5595f42aa019de850a0e7b776253659a9aac8140bb8a3f2"
func NewTestPostgresDevice(t *testing.T) domain.DeviceRepository {
t.Helper()
ctx := context.Background()
conn := testhelper.NewTestPgxConn(t)
tx, err := conn.Begin(ctx)
require.NoError(t, err)
repo := repository.NewPostgresDevice(tx)
t.Cleanup(func() {
_ = tx.Rollback(ctx)
})
return repo
}
func TestPostgresDevice_GetByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
repo := NewTestPostgresDevice(t)
dev := &domain.Device{APNSToken: testToken}
require.NoError(t, repo.CreateOrUpdate(ctx, dev))
testCases := map[string]struct {
id int64
want *domain.Device
err error
}{
"valid ID": {dev.ID, dev, nil},
"invalid ID": {0, nil, domain.ErrNotFound},
}
for scenario, tc := range testCases { //nolint:paralleltest
t.Run(scenario, func(t *testing.T) {
dev, err := repo.GetByID(ctx, tc.id)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tc.want, &dev)
})
}
}
func TestPostgresDevice_Create(t *testing.T) {
t.Parallel()
ctx := context.Background()
repo := NewTestPostgresDevice(t)
testCases := map[string]struct {
have *domain.Device
err bool
}{
"valid": {&domain.Device{APNSToken: testToken}, false},
"invalid APNS token": {&domain.Device{APNSToken: "not valid"}, true},
}
for scenario, tc := range testCases { //nolint:paralleltest
t.Run(scenario, func(t *testing.T) {
err := repo.Create(ctx, tc.have)
if tc.err {
assert.Error(t, err)
return
}
assert.NotEqual(t, 0, tc.have.ID)
})
}
}
func TestPostgresDevice_Update(t *testing.T) {
t.Parallel()
ctx := context.Background()
repo := NewTestPostgresDevice(t)
testCases := map[string]struct {
fn func(*domain.Device)
err error
}{
"valid update": {func(dev *domain.Device) { dev.Sandbox = true }, nil},
"empty update": {func(dev *domain.Device) {}, nil},
"update on non existant id": {func(dev *domain.Device) { dev.ID = 0 }, errors.New("weird behaviour, total rows affected: 0")},
}
for scenario, tc := range testCases { //nolint:paralleltest
t.Run(scenario, func(t *testing.T) {
b := make([]byte, 32)
_, err := rand.Read(b)
require.NoError(t, err)
dev := &domain.Device{APNSToken: hex.EncodeToString(b)}
require.NoError(t, repo.Create(ctx, dev))
tc.fn(dev)
err = repo.Update(ctx, dev)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err, err)
return
}
require.NoError(t, err)
})
}
}

View file

@ -5,19 +5,18 @@ import (
"strings"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/jackc/pgx/v4/pgxpool"
)
type postgresSubredditRepository struct {
pool *pgxpool.Pool
conn Connection
}
func NewPostgresSubreddit(pool *pgxpool.Pool) domain.SubredditRepository {
return &postgresSubredditRepository{pool: pool}
func NewPostgresSubreddit(conn Connection) domain.SubredditRepository {
return &postgresSubredditRepository{conn: conn}
}
func (p *postgresSubredditRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Subreddit, error) {
rows, err := p.pool.Query(ctx, query, args...)
rows, err := p.conn.Query(ctx, query, args...)
if err != nil {
return nil, err
}
@ -86,7 +85,7 @@ func (p *postgresSubredditRepository) CreateOrUpdate(ctx context.Context, sr *do
ON CONFLICT(subreddit_id) DO NOTHING
RETURNING id`
return p.pool.QueryRow(
return p.conn.QueryRow(
ctx,
query,
sr.SubredditID,

View file

@ -5,21 +5,19 @@ import (
"fmt"
"strings"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/christianselig/apollo-backend/internal/domain"
)
type postgresUserRepository struct {
pool *pgxpool.Pool
conn Connection
}
func NewPostgresUser(pool *pgxpool.Pool) domain.UserRepository {
return &postgresUserRepository{pool: pool}
func NewPostgresUser(conn Connection) domain.UserRepository {
return &postgresUserRepository{conn: conn}
}
func (p *postgresUserRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.User, error) {
rows, err := p.pool.Query(ctx, query, args...)
rows, err := p.conn.Query(ctx, query, args...)
if err != nil {
return nil, err
}
@ -88,7 +86,7 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U
ON CONFLICT(user_id) DO NOTHING
RETURNING id`
return p.pool.QueryRow(
return p.conn.QueryRow(
ctx,
query,
u.UserID,
@ -99,7 +97,7 @@ func (p *postgresUserRepository) CreateOrUpdate(ctx context.Context, u *domain.U
func (p *postgresUserRepository) Delete(ctx context.Context, id int64) error {
query := `DELETE FROM users WHERE id = $1`
res, err := p.pool.Exec(ctx, query, id)
res, err := p.conn.Exec(ctx, query, id)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())

View file

@ -5,21 +5,19 @@ import (
"fmt"
"time"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/christianselig/apollo-backend/internal/domain"
)
type postgresWatcherRepository struct {
pool *pgxpool.Pool
conn Connection
}
func NewPostgresWatcher(pool *pgxpool.Pool) domain.WatcherRepository {
return &postgresWatcherRepository{pool: pool}
func NewPostgresWatcher(conn Connection) domain.WatcherRepository {
return &postgresWatcherRepository{conn: conn}
}
func (p *postgresWatcherRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]domain.Watcher, error) {
rows, err := p.pool.Query(ctx, query, args...)
rows, err := p.conn.Query(ctx, query, args...)
if err != nil {
return nil, err
}
@ -221,7 +219,7 @@ func (p *postgresWatcherRepository) Create(ctx context.Context, watcher *domain.
VALUES ($1, 0, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id`
return p.pool.QueryRow(
return p.conn.QueryRow(
ctx,
query,
now,
@ -255,7 +253,7 @@ func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.
label = $8
WHERE id = $1`
res, err := p.pool.Exec(
res, err := p.conn.Exec(
ctx,
query,
watcher.ID,
@ -277,7 +275,7 @@ func (p *postgresWatcherRepository) Update(ctx context.Context, watcher *domain.
func (p *postgresWatcherRepository) IncrementHits(ctx context.Context, id int64) error {
now := time.Now().Unix()
query := `UPDATE watchers SET hits = hits + 1, last_notified_at = $2 WHERE id = $1`
res, err := p.pool.Exec(ctx, query, id, now)
res, err := p.conn.Exec(ctx, query, id, now)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -287,7 +285,7 @@ func (p *postgresWatcherRepository) IncrementHits(ctx context.Context, id int64)
func (p *postgresWatcherRepository) Delete(ctx context.Context, id int64) error {
query := `DELETE FROM watchers WHERE id = $1`
res, err := p.pool.Exec(ctx, query, id)
res, err := p.conn.Exec(ctx, query, id)
if res.RowsAffected() != 1 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())
@ -297,7 +295,7 @@ func (p *postgresWatcherRepository) Delete(ctx context.Context, id int64) error
func (p *postgresWatcherRepository) DeleteByTypeAndWatcheeID(ctx context.Context, typ domain.WatcherType, id int64) error {
query := `DELETE FROM watchers WHERE type = $1 AND watchee_id = $2`
res, err := p.pool.Exec(ctx, query, typ, id)
res, err := p.conn.Exec(ctx, query, typ, id)
if res.RowsAffected() == 0 {
return fmt.Errorf("weird behaviour, total rows affected: %d", res.RowsAffected())

View file

@ -0,0 +1,33 @@
package repository_test
import (
"context"
"testing"
"github.com/christianselig/apollo-backend/internal/domain"
"github.com/christianselig/apollo-backend/internal/repository"
"github.com/christianselig/apollo-backend/internal/testhelper"
"github.com/stretchr/testify/require"
)
func NewTestPostgresWatcher(t *testing.T) domain.WatcherRepository {
t.Helper()
ctx := context.Background()
conn := testhelper.NewTestPgxConn(t)
tx, err := conn.Begin(ctx)
require.NoError(t, err)
repo := repository.NewPostgresWatcher(tx)
t.Cleanup(func() {
_ = tx.Rollback(ctx)
})
return repo
}
func TestPostgresWatcher_GetByID(t *testing.T) {
t.Parallel()
}

View file

@ -0,0 +1,34 @@
package testhelper
import (
"context"
"os"
"testing"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/require"
)
func NewTestPgxConn(t *testing.T) *pgx.Conn {
t.Helper()
ctx := context.Background()
connString := os.Getenv("DATABASE_URL")
if connString == "" {
t.Skipf("skipping due to missing environment variable %v", "DATABASE_URL")
}
config, err := pgx.ParseConfig(connString)
require.NoError(t, err)
conn, err := pgx.ConnectConfig(ctx, config)
require.NoError(t, err)
t.Cleanup(func() {
conn.Close(ctx)
})
return conn
}

View file

@ -33,6 +33,8 @@ const (
)
type notificationsWorker struct {
context.Context
logger *logrus.Logger
statsd *statsd.Client
db *pgxpool.Pool
@ -47,7 +49,7 @@ type notificationsWorker struct {
deviceRepo domain.DeviceRepository
}
func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
func NewNotificationsWorker(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
@ -71,6 +73,7 @@ func NewNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pg
}
return &notificationsWorker{
ctx,
logger,
statsd,
db,
@ -137,11 +140,9 @@ func NewNotificationsConsumer(nw *notificationsWorker, tag int) *notificationsCo
}
func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
ctx := context.Background()
defer func() {
lockKey := fmt.Sprintf("locks:accounts:%s", delivery.Payload())
if err := nc.redis.Del(ctx, lockKey).Err(); err != nil {
if err := nc.redis.Del(nc, lockKey).Err(); err != nil {
nc.logger.WithFields(logrus.Fields{
"lockKey": lockKey,
"err": err,
@ -168,7 +169,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
now := time.Now()
account, err := nc.accountRepo.GetByID(ctx, id)
account, err := nc.accountRepo.GetByID(nc, id)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"account#id": id,
@ -183,7 +184,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
account.CheckCount++
account.NextNotificationCheckAt = time.Now().Add(domain.NotificationCheckInterval)
if err = nc.accountRepo.Update(ctx, &account); err != nil {
if err = nc.accountRepo.Update(nc, &account); err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
"err": err,
@ -197,7 +198,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
"account#username": account.NormalizedUsername(),
}).Debug("refreshing reddit token")
tokens, err := rac.RefreshTokens()
tokens, err := rac.RefreshTokens(nc)
if err != nil {
if err != reddit.ErrOauthRevoked {
nc.logger.WithFields(logrus.Fields{
@ -207,7 +208,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return
}
err = nc.deleteAccount(ctx, account)
err = nc.deleteAccount(account)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
@ -226,7 +227,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
// Refresh client
rac = nc.reddit.NewAuthenticatedClient(account.AccountID, tokens.RefreshToken, tokens.AccessToken)
if err = nc.accountRepo.Update(ctx, &account); err != nil {
if err = nc.accountRepo.Update(nc, &account); err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
"err": err,
@ -250,14 +251,14 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
if account.LastMessageID != "" {
opts = append(opts, reddit.WithQuery("before", account.LastMessageID))
}
msgs, err := rac.MessageInbox(opts...)
msgs, err := rac.MessageInbox(nc, opts...)
if err != nil {
switch err {
case reddit.ErrTimeout: // Don't log timeouts
break
case reddit.ErrOauthRevoked:
err = nc.deleteAccount(ctx, account)
err = nc.deleteAccount(account)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
@ -297,7 +298,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
}
}
if err = nc.accountRepo.Update(ctx, &account); err != nil {
if err = nc.accountRepo.Update(nc, &account); err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
"err": err,
@ -313,7 +314,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return
}
devices, err := nc.deviceRepo.GetInboxNotifiableByAccountID(ctx, account.ID)
devices, err := nc.deviceRepo.GetInboxNotifiableByAccountID(nc, account.ID)
if err != nil {
nc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
@ -359,7 +360,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
}).Error("failed to send notification")
// Delete device as notifications might have been disabled here
_ = nc.deviceRepo.Delete(ctx, device.APNSToken)
_ = nc.deviceRepo.Delete(nc, device.APNSToken)
} else {
_ = nc.statsd.Incr("apns.notification.sent", []string{}, 1)
nc.logger.WithFields(logrus.Fields{
@ -378,20 +379,20 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
}).Debug("finishing job")
}
func (nc *notificationsConsumer) deleteAccount(ctx context.Context, account domain.Account) error {
func (nc *notificationsConsumer) deleteAccount(account domain.Account) error {
// Disassociate account from devices
devs, err := nc.deviceRepo.GetByAccountID(ctx, account.ID)
devs, err := nc.deviceRepo.GetByAccountID(nc, account.ID)
if err != nil {
return err
}
for _, dev := range devs {
if err := nc.accountRepo.Disassociate(ctx, &account, &dev); err != nil {
if err := nc.accountRepo.Disassociate(nc, &account, &dev); err != nil {
return err
}
}
return nc.accountRepo.Delete(ctx, account.ID)
return nc.accountRepo.Delete(nc, account.ID)
}
func payloadFromMessage(acct domain.Account, msg *reddit.Thing, badgeCount int) *payload.Payload {

View file

@ -18,6 +18,8 @@ import (
)
type stuckNotificationsWorker struct {
context.Context
logger *logrus.Logger
statsd *statsd.Client
db *pgxpool.Pool
@ -30,7 +32,7 @@ type stuckNotificationsWorker struct {
accountRepo domain.AccountRepository
}
func NewStuckNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
func NewStuckNotificationsWorker(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
@ -40,6 +42,7 @@ func NewStuckNotificationsWorker(logger *logrus.Logger, statsd *statsd.Client, d
)
return &stuckNotificationsWorker{
ctx,
logger,
statsd,
db,
@ -99,8 +102,6 @@ func NewStuckNotificationsConsumer(snw *stuckNotificationsWorker, tag int) *stuc
}
func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
ctx := context.Background()
snc.logger.WithFields(logrus.Fields{
"account#id": delivery.Payload(),
}).Debug("starting job")
@ -118,7 +119,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }()
account, err := snc.accountRepo.GetByID(ctx, id)
account, err := snc.accountRepo.GetByID(snc, id)
if err != nil {
snc.logger.WithFields(logrus.Fields{
"err": err,
@ -149,7 +150,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
"thing#id": account.LastMessageID,
}).Debug("checking last thing via inbox")
things, err = rac.MessageInbox()
things, err = rac.MessageInbox(snc)
if err != nil {
snc.logger.WithFields(logrus.Fields{
"err": err,
@ -157,7 +158,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
return
}
} else {
things, err = rac.AboutInfo(account.LastMessageID)
things, err = rac.AboutInfo(snc, account.LastMessageID)
if err != nil {
snc.logger.WithFields(logrus.Fields{
"err": err,
@ -192,7 +193,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
"account#username": account.NormalizedUsername(),
}).Debug("getting message inbox to determine last good thing")
things, err = rac.MessageInbox()
things, err = rac.MessageInbox(snc)
if err != nil {
snc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
@ -225,7 +226,7 @@ func (snc *stuckNotificationsConsumer) Consume(delivery rmq.Delivery) {
"thing#id": account.LastMessageID,
}).Debug("updating last good thing")
if err := snc.accountRepo.Update(ctx, &account); err != nil {
if err := snc.accountRepo.Update(snc, &account); err != nil {
snc.logger.WithFields(logrus.Fields{
"account#username": account.NormalizedUsername(),
"err": err,

View file

@ -24,6 +24,8 @@ import (
)
type subredditsWorker struct {
context.Context
logger *logrus.Logger
statsd *statsd.Client
db *pgxpool.Pool
@ -45,7 +47,7 @@ const (
subredditNotificationBodyFormat = "r/%s: \u201c%s\u201d"
)
func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
func NewSubredditsWorker(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
@ -69,6 +71,7 @@ func NewSubredditsWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpo
}
return &subredditsWorker{
ctx,
logger,
statsd,
db,
@ -137,8 +140,6 @@ func NewSubredditsConsumer(sw *subredditsWorker, tag int) *subredditsConsumer {
}
func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
ctx := context.Background()
sc.logger.WithFields(logrus.Fields{
"subreddit#id": delivery.Payload(),
}).Debug("starting job")
@ -156,7 +157,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }()
subreddit, err := sc.subredditRepo.GetByID(ctx, id)
subreddit, err := sc.subredditRepo.GetByID(sc, id)
if err != nil {
sc.logger.WithFields(logrus.Fields{
"err": err,
@ -164,7 +165,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
return
}
watchers, err := sc.watcherRepo.GetBySubredditID(ctx, subreddit.ID)
watchers, err := sc.watcherRepo.GetBySubredditID(sc, subreddit.ID)
if err != nil {
sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
@ -202,10 +203,10 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
i := rand.Intn(len(watchers))
watcher := watchers[i]
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID)
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
sps, err := rac.SubredditNew(
sps, err := rac.SubredditNew(sc,
subreddit.Name,
reddit.WithQuery("before", before),
reddit.WithQuery("limit", "100"),
@ -267,9 +268,9 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
i := rand.Intn(len(watchers))
watcher := watchers[i]
acc, _ := sc.accountRepo.GetByID(ctx, watcher.AccountID)
acc, _ := sc.accountRepo.GetByID(sc, watcher.AccountID)
rac := sc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
sps, err := rac.SubredditHot(
sps, err := rac.SubredditHot(sc,
subreddit.Name,
reddit.WithQuery("limit", "100"),
)
@ -344,7 +345,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
}
lockKey := fmt.Sprintf("watcher:%d:%s", watcher.DeviceID, post.ID)
notified, _ := sc.redis.Get(ctx, lockKey).Bool()
notified, _ := sc.redis.Get(sc, lockKey).Bool()
if notified {
sc.logger.WithFields(logrus.Fields{
@ -357,7 +358,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
continue
}
if err := sc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
if err := sc.watcherRepo.IncrementHits(sc, watcher.ID); err != nil {
sc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"watcher#id": watcher.ID,
@ -373,7 +374,7 @@ func (sc *subredditsConsumer) Consume(delivery rmq.Delivery) {
"post#id": post.ID,
}).Debug("got a hit")
sc.redis.SetEX(ctx, lockKey, true, 24*time.Hour)
sc.redis.SetEX(sc, lockKey, true, 24*time.Hour)
notifs = append(notifs, watcher)
}

View file

@ -23,6 +23,8 @@ import (
)
type trendingWorker struct {
context.Context
logger *logrus.Logger
statsd *statsd.Client
redis *redis.Client
@ -40,7 +42,7 @@ type trendingWorker struct {
const trendingNotificationTitleFormat = "🔥 r/%s Trending"
func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
func NewTrendingWorker(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
@ -64,6 +66,7 @@ func NewTrendingWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool
}
return &trendingWorker{
ctx,
logger,
statsd,
redis,
@ -131,8 +134,6 @@ func NewTrendingConsumer(tw *trendingWorker, tag int) *trendingConsumer {
}
func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
ctx := context.Background()
tc.logger.WithFields(logrus.Fields{
"subreddit#id": delivery.Payload(),
}).Debug("starting job")
@ -150,7 +151,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }()
subreddit, err := tc.subredditRepo.GetByID(ctx, id)
subreddit, err := tc.subredditRepo.GetByID(tc, id)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"err": err,
@ -158,7 +159,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
return
}
watchers, err := tc.watcherRepo.GetByTrendingSubredditID(ctx, subreddit.ID)
watchers, err := tc.watcherRepo.GetByTrendingSubredditID(tc, subreddit.ID)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
@ -179,7 +180,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
watcher := watchers[i]
rac := tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
tps, err := rac.SubredditTop(subreddit.Name, reddit.WithQuery("t", "week"))
tps, err := rac.SubredditTop(tc, subreddit.Name, reddit.WithQuery("t", "week"))
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
@ -219,7 +220,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
i = rand.Intn(len(watchers))
watcher = watchers[i]
rac = tc.reddit.NewAuthenticatedClient(watcher.Account.AccountID, watcher.Account.RefreshToken, watcher.Account.AccessToken)
hps, err := rac.SubredditHot(subreddit.Name)
hps, err := rac.SubredditHot(tc, subreddit.Name)
if err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
@ -256,7 +257,7 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
}
lockKey := fmt.Sprintf("watcher:trending:%d:%s", watcher.DeviceID, post.ID)
notified, _ := tc.redis.Get(ctx, lockKey).Bool()
notified, _ := tc.redis.Get(tc, lockKey).Bool()
if notified {
tc.logger.WithFields(logrus.Fields{
@ -268,9 +269,9 @@ func (tc *trendingConsumer) Consume(delivery rmq.Delivery) {
continue
}
tc.redis.SetEX(ctx, lockKey, true, 48*time.Hour)
tc.redis.SetEX(tc, lockKey, true, 48*time.Hour)
if err := tc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
if err := tc.watcherRepo.IncrementHits(tc, watcher.ID); err != nil {
tc.logger.WithFields(logrus.Fields{
"subreddit#id": subreddit.ID,
"watcher#id": watcher.ID,

View file

@ -23,6 +23,8 @@ import (
)
type usersWorker struct {
context.Context
logger *logrus.Logger
statsd *statsd.Client
db *pgxpool.Pool
@ -41,7 +43,7 @@ type usersWorker struct {
const userNotificationTitleFormat = "👨\u200d🚀 %s"
func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
func NewUsersWorker(ctx context.Context, logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker {
reddit := reddit.NewClient(
os.Getenv("REDDIT_CLIENT_ID"),
os.Getenv("REDDIT_CLIENT_SECRET"),
@ -65,6 +67,7 @@ func NewUsersWorker(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Po
}
return &usersWorker{
ctx,
logger,
statsd,
db,
@ -133,8 +136,6 @@ func NewUsersConsumer(uw *usersWorker, tag int) *usersConsumer {
}
func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
ctx := context.Background()
uc.logger.WithFields(logrus.Fields{
"user#id": delivery.Payload(),
}).Debug("starting job")
@ -152,7 +153,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
defer func() { _ = delivery.Ack() }()
user, err := uc.userRepo.GetByID(ctx, id)
user, err := uc.userRepo.GetByID(uc, id)
if err != nil {
uc.logger.WithFields(logrus.Fields{
"err": err,
@ -160,7 +161,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
return
}
watchers, err := uc.watcherRepo.GetByUserID(ctx, user.ID)
watchers, err := uc.watcherRepo.GetByUserID(uc, user.ID)
if err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
@ -180,10 +181,10 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
i := rand.Intn(len(watchers))
watcher := watchers[i]
acc, _ := uc.accountRepo.GetByID(ctx, watcher.AccountID)
acc, _ := uc.accountRepo.GetByID(uc, watcher.AccountID)
rac := uc.reddit.NewAuthenticatedClient(acc.AccountID, acc.RefreshToken, acc.AccessToken)
ru, err := rac.UserAbout(user.Name)
ru, err := rac.UserAbout(uc, user.Name)
if err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
@ -197,7 +198,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
"user#id": user.ID,
}).Info("user disabled followers, removing")
if err := uc.watcherRepo.DeleteByTypeAndWatcheeID(ctx, domain.UserWatcher, user.ID); err != nil {
if err := uc.watcherRepo.DeleteByTypeAndWatcheeID(uc, domain.UserWatcher, user.ID); err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"err": err,
@ -205,7 +206,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
return
}
if err := uc.userRepo.Delete(ctx, user.ID); err != nil {
if err := uc.userRepo.Delete(uc, user.ID); err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"err": err,
@ -214,7 +215,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
}
}
posts, err := rac.UserPosts(user.Name)
posts, err := rac.UserPosts(uc, user.Name)
if err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
@ -259,7 +260,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
notification.Topic = "com.christianselig.Apollo"
for _, watcher := range notifs {
if err := uc.watcherRepo.IncrementHits(ctx, watcher.ID); err != nil {
if err := uc.watcherRepo.IncrementHits(uc, watcher.ID); err != nil {
uc.logger.WithFields(logrus.Fields{
"user#id": user.ID,
"watcher#id": watcher.ID,
@ -268,7 +269,7 @@ func (uc *usersConsumer) Consume(delivery rmq.Delivery) {
return
}
device, _ := uc.deviceRepo.GetByID(ctx, watcher.DeviceID)
device, _ := uc.deviceRepo.GetByID(uc, watcher.DeviceID)
title := fmt.Sprintf(userNotificationTitleFormat, watcher.Label)
payload.AlertTitle(title)

View file

@ -1,6 +1,8 @@
package worker
import (
"context"
"github.com/DataDog/datadog-go/statsd"
"github.com/adjust/rmq/v4"
"github.com/go-redis/redis/v8"
@ -8,7 +10,7 @@ import (
"github.com/sirupsen/logrus"
)
type NewWorkerFn func(logger *logrus.Logger, statsd *statsd.Client, db *pgxpool.Pool, redis *redis.Client, queue rmq.Connection, consumers int) Worker
type NewWorkerFn func(context.Context, *logrus.Logger, *statsd.Client, *pgxpool.Pool, *redis.Client, rmq.Connection, int) Worker
type Worker interface {
Start() error
Stop()

View file

@ -0,0 +1,20 @@
-- Table Definition ----------------------------------------------
CREATE TABLE accounts (
id SERIAL PRIMARY KEY,
username character varying(20) DEFAULT ''::character varying UNIQUE,
access_token character varying(64) DEFAULT ''::character varying,
refresh_token character varying(64) DEFAULT ''::character varying,
expires_at integer DEFAULT 0,
last_message_id character varying(32) DEFAULT ''::character varying,
device_count integer DEFAULT 0,
last_checked_at double precision DEFAULT '0'::double precision,
last_enqueued_at double precision DEFAULT '0'::double precision,
account_id character varying(32) DEFAULT ''::character varying,
last_unstuck_at double precision DEFAULT '0'::double precision
);
-- Indices -------------------------------------------------------
CREATE INDEX accounts_last_checked_at_idx ON accounts(last_checked_at float8_ops);

View file

@ -0,0 +1,10 @@
-- Table Definition ----------------------------------------------
CREATE TABLE devices (
id SERIAL PRIMARY KEY,
apns_token character varying(100) UNIQUE,
sandbox boolean,
active_until integer,
grace_period_until integer
);

View file

@ -0,0 +1,15 @@
-- Table Definition ----------------------------------------------
CREATE TABLE devices_accounts (
id SERIAL PRIMARY KEY,
account_id integer REFERENCES accounts(id) ON DELETE CASCADE,
device_id integer REFERENCES devices(id) ON DELETE CASCADE,
watcher_notifiable boolean DEFAULT true,
inbox_notifiable boolean DEFAULT true,
global_mute boolean DEFAULT false
);
-- Indices -------------------------------------------------------
CREATE UNIQUE INDEX devices_accounts_account_id_device_id_idx ON devices_accounts(account_id int4_ops,device_id int4_ops);

View file

@ -0,0 +1,9 @@
-- Table Definition ----------------------------------------------
CREATE TABLE subreddits (
id SERIAL PRIMARY KEY,
subreddit_id character varying(32) DEFAULT ''::character varying UNIQUE,
name character varying(32) DEFAULT ''::character varying,
last_checked_at double precision DEFAULT '0'::double precision
);

View file

View file

@ -0,0 +1,9 @@
-- Table Definition ----------------------------------------------
CREATE TABLE users (
id SERIAL PRIMARY KEY,
user_id character varying(32) DEFAULT ''::character varying UNIQUE,
name character varying(32) DEFAULT ''::character varying,
last_checked_at double precision DEFAULT '0'::double precision
);

View file

@ -0,0 +1,24 @@
-- Table Definition ----------------------------------------------
CREATE TABLE watchers (
id SERIAL PRIMARY KEY,
device_id integer REFERENCES devices(id) ON DELETE CASCADE,
watchee_id integer,
upvotes integer DEFAULT 0,
keyword character varying(32) DEFAULT ''::character varying,
flair character varying(32) DEFAULT ''::character varying,
domain character varying(32) DEFAULT ''::character varying,
account_id integer REFERENCES accounts(id) ON DELETE CASCADE,
created_at double precision DEFAULT '0'::double precision,
hits integer DEFAULT 0,
type integer DEFAULT 0,
last_notified_at double precision DEFAULT '0'::double precision,
label character varying(64) DEFAULT ''::character varying,
author character varying(32) DEFAULT ''::character varying,
subreddit character varying(32) DEFAULT ''::character varying
);
-- Indices -------------------------------------------------------
CREATE INDEX watchers_type_watchee_id_idx ON watchers(type int4_ops,watchee_id int4_ops);

View file

@ -1 +0,0 @@
DROP TABLE devices;

View file

@ -1,8 +0,0 @@
CREATE TABLE IF NOT EXISTS devices (
id SERIAL PRIMARY KEY,
apns_token character(100) UNIQUE,
sandbox boolean,
last_pinged_at integer
);
CREATE UNIQUE INDEX IF NOT EXISTS devices_pkey ON devices(id int4_ops);
CREATE UNIQUE INDEX IF NOT EXISTS devices_apns_token_key ON devices(apns_token bpchar_ops);

View file

@ -1 +0,0 @@
DROP TABLE accounts;

View file

@ -1,13 +0,0 @@
CREATE TABLE IF NOT EXISTS accounts (
id SERIAL PRIMARY KEY,
username character varying(20),
access_token character varying(64),
refresh_token character varying(64),
expires_at integer,
last_message_id character varying(32),
device_count integer,
last_checked_at integer
);
CREATE UNIQUE INDEX IF NOT EXISTS accounts_pkey ON accounts(id int4_ops);
CREATE UNIQUE INDEX IF NOT EXISTS accounts_username_key ON accounts(username bpchar_ops);

View file

@ -1,7 +0,0 @@
CREATE TABLE IF NOT EXISTS devices_accounts (
id SERIAL PRIMARY KEY,
account_id integer,
device_id integer
);
CREATE UNIQUE INDEX IF NOT EXISTS devices_accounts_pkey ON devices_accounts(id int4_ops);
CREATE UNIQUE INDEX IF NOT EXISTS devices_accounts_account_id_device_id_idx ON devices_accounts(account_id int4_ops,device_id int4_ops);

View file

@ -1,32 +0,0 @@
#!/bin/sh
set -e
cd "$(dirname "$0")/.."
brew bundle check >/dev/null 2>&1 || {
echo "==> Installing Homebrew dependencies..."
brew bundle
}
[ -d "tmp/postgresql" ] || {
echo "===> Setting up database..."
initdb -D tmp/postgresql -U apollo
postgres -D tmp/postgresql &
echo "===> Waiting for Postgres to finish starting up..."
while ! nc -z localhost 5432; do
sleep 0.1 # wait for 1/10 of the second before check again
done
createdb apollo -U apollo
script/migrate
kill -INT `head -n1 tmp/postgresql/postmaster.pid`
}
go mod verify >/dev/null 2>&1 || {
echo "==> Installing Go dependencies..."
go mod download
}

View file

@ -1,10 +0,0 @@
#!/bin/sh
set -e
cd "$(dirname "$0")/.."
DATABASE_URL=postgres://apollo:@localhost/apollo?sslmode=disable
echo "===> Running migrations..."
migrate -path=./migrations -database=$DATABASE_URL up

View file

@ -1,11 +0,0 @@
#!/bin/sh
set -e
cd "$(dirname "$0")/.."
# ensure everything in the app is up to date.
script/bootstrap
# boot the app and any other necessary processes.
foreman start