diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 1b92535..1936daa 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -5,6 +5,7 @@ jobs:
runs-on: ubuntu-latest
env:
DATABASE_URL: postgres://postgres:postgres@localhost/apollo_test
+ REDIS_URL: redis://localhost:6379
services:
postgres:
image: postgres
@@ -18,12 +19,20 @@ jobs:
--health-retries 5
ports:
- 5432:5432
-
+ redis:
+ image: redis
+ options: >-
+ --health-cmd "redis-cli ping"
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+ ports:
+ - 6379:6379
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
with:
- go-version: 1.19.3
+ go-version: 1.20.2
- uses: golangci/golangci-lint-action@v3
- run: psql -f docs/schema.sql $DATABASE_URL
- run: go test ./... -v -race -timeout 5s
diff --git a/.golangci.yml b/.golangci.yml
index f74fa4e..5d9c923 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -17,11 +17,5 @@ linters:
- thelper # detects golang test helpers without t.Helper() call and checks consistency of test helpers
- unconvert # removes unnecessary type conversions
- unparam # removes unused function parameters
+ - paralleltest
fast: true
-
-issues:
- exclude-rules:
- # False positive: https://github.com/kunwardeep/paralleltest/issues/8.
- - linters:
- - paralleltest
- text: "does not use range value in test Run"
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/apollo-backend.iml b/.idea/apollo-backend.iml
new file mode 100644
index 0000000..5e764c4
--- /dev/null
+++ b/.idea/apollo-backend.iml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml
new file mode 100644
index 0000000..a55e7a1
--- /dev/null
+++ b/.idea/codeStyles/codeStyleConfig.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..fa496e0
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Makefile b/Makefile
index 90780ea..e3eac97 100644
--- a/Makefile
+++ b/Makefile
@@ -1,8 +1,9 @@
BREW_PREFIX ?= $(shell brew --prefix)
DATABASE_URL ?= "postgres://$(USER)@localhost/apollo_test?sslmode=disable"
+REDIS_URL ?= "redis://localhost:6379"
test:
- @DATABASE_URL=$(DATABASE_URL) go test -race -timeout 1s ./...
+ @DATABASE_URL=$(DATABASE_URL) REDIS_URL=$(REDIS_URL) go test -race -timeout 1s ./...
test-setup: $(BREW_PREFIX)/bin/migrate
migrate -path migrations/ -database $(DATABASE_URL) up
diff --git a/internal/distributedlock/distributed_lock.go b/internal/distributedlock/distributed_lock.go
new file mode 100644
index 0000000..b989aea
--- /dev/null
+++ b/internal/distributedlock/distributed_lock.go
@@ -0,0 +1,92 @@
+package distributedlock
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "time"
+
+ "github.com/go-redis/redis/v8"
+)
+
+const (
+ lockTopicFormat = "pubsub:locks:%s"
+ lockReleaseScript = `
+ if redis.call("GET", KEYS[1]) == ARGV[1] then
+ redis.call("DEL", KEYS[1])
+ redis.call("PUBLISH", KEYS[2], KEYS[1])
+ return 1
+ else
+ return 0
+ end
+ `
+)
+
+type DistributedLock struct {
+ client *redis.Client
+ sha string
+ timeout time.Duration
+}
+
+func New(client *redis.Client, timeout time.Duration) (*DistributedLock, error) {
+ sha, err := client.ScriptLoad(context.Background(), lockReleaseScript).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ return &DistributedLock{
+ client: client,
+ sha: sha,
+ timeout: timeout,
+ }, nil
+}
+
+func (d *DistributedLock) setLock(ctx context.Context, key string, uid string) error {
+ result, err := d.client.SetNX(ctx, key, uid, d.timeout).Result()
+ if err != nil {
+ return err
+ }
+
+ if !result {
+ return ErrLockAlreadyAcquired
+ }
+
+ return nil
+}
+
+func (d *DistributedLock) AcquireLock(ctx context.Context, key string) (*Lock, error) {
+ uid := generateUniqueID()
+ if err := d.setLock(ctx, key, uid); err != nil {
+ return nil, err
+ }
+
+ return NewLock(d, key, uid), nil
+}
+
+func (d *DistributedLock) WaitAcquireLock(ctx context.Context, key string, timeout time.Duration) (*Lock, error) {
+ uid := generateUniqueID()
+ if err := d.setLock(ctx, key, uid); err == nil {
+ return NewLock(d, key, uid), nil
+ }
+
+ ch := fmt.Sprintf(lockTopicFormat, key)
+ pubsub := d.client.Subscribe(ctx, ch)
+ defer func() { _ = pubsub.Close() }()
+
+ select {
+ case <-time.After(timeout):
+ return nil, ErrLockAcquisitionTimeout
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-pubsub.Channel():
+ err := d.setLock(ctx, key, uid)
+ if err != nil {
+ return nil, err
+ }
+ return NewLock(d, key, uid), nil
+ }
+}
+
+func generateUniqueID() string {
+ return fmt.Sprintf("%d-%d", time.Now().UnixNano(), rand.Int63())
+}
diff --git a/internal/distributedlock/distributed_lock_test.go b/internal/distributedlock/distributed_lock_test.go
new file mode 100644
index 0000000..3d8211f
--- /dev/null
+++ b/internal/distributedlock/distributed_lock_test.go
@@ -0,0 +1,84 @@
+package distributedlock_test
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/christianselig/apollo-backend/internal/distributedlock"
+
+ "github.com/go-redis/redis/v8"
+ "github.com/stretchr/testify/assert"
+)
+
+func NewRedisClient(t *testing.T, ctx context.Context) (*redis.Client, func()) {
+ t.Helper()
+
+ opt, err := redis.ParseURL(os.Getenv("REDIS_URL"))
+ if err != nil {
+ panic(err)
+ }
+ client := redis.NewClient(opt)
+ if err := client.Ping(ctx).Err(); err != nil {
+ panic(err)
+ }
+
+ return client, func() {
+ _ = client.Close()
+ }
+}
+
+func TestDistributedLock_AcquireLock(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ key := fmt.Sprintf("key:%d-%d", time.Now().UnixNano(), rand.Int63())
+
+ client, closer := NewRedisClient(t, ctx)
+ defer closer()
+
+ d, err := distributedlock.New(client, 10*time.Second)
+ assert.NoError(t, err)
+
+ lock, err := d.AcquireLock(ctx, key)
+ assert.NoError(t, err)
+
+ _, err = d.AcquireLock(ctx, key)
+ assert.Equal(t, distributedlock.ErrLockAlreadyAcquired, err)
+
+ err = lock.Release(ctx)
+ assert.NoError(t, err)
+
+ _, err = d.AcquireLock(ctx, key)
+ assert.NoError(t, err)
+}
+
+func TestDistributedLock_WaitAcquireLock(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ key := fmt.Sprintf("key:%d-%d", time.Now().UnixNano(), rand.Int63())
+
+ client, closer := NewRedisClient(t, ctx)
+ defer closer()
+
+ d, err := distributedlock.New(client, 10*time.Second)
+ assert.NoError(t, err)
+
+ lock, err := d.AcquireLock(ctx, key)
+ assert.NoError(t, err)
+
+ go func(l *distributedlock.Lock) {
+ select {
+ case <-time.After(100 * time.Millisecond):
+ _ = l.Release(ctx)
+ }
+ }(lock)
+
+ lock, err = d.WaitAcquireLock(ctx, key, 5*time.Second)
+ assert.NoError(t, err)
+ assert.NotNil(t, lock)
+}
diff --git a/internal/distributedlock/errors.go b/internal/distributedlock/errors.go
new file mode 100644
index 0000000..7f0bb46
--- /dev/null
+++ b/internal/distributedlock/errors.go
@@ -0,0 +1,9 @@
+package distributedlock
+
+import "errors"
+
+var (
+ ErrLockAcquisitionTimeout = errors.New("timed out acquiring lock")
+ ErrLockAlreadyAcquired = errors.New("lock already acquired")
+ ErrLockExpired = errors.New("releasing an expired lock")
+)
diff --git a/internal/distributedlock/lock.go b/internal/distributedlock/lock.go
new file mode 100644
index 0000000..005730e
--- /dev/null
+++ b/internal/distributedlock/lock.go
@@ -0,0 +1,36 @@
+package distributedlock
+
+import (
+ "context"
+ "fmt"
+)
+
+type Lock struct {
+ distributedLock *DistributedLock
+ key string
+ uid string
+}
+
+func NewLock(distributedLock *DistributedLock, key string, uid string) *Lock {
+ return &Lock{
+ distributedLock: distributedLock,
+ key: key,
+ uid: uid,
+ }
+}
+
+func (l *Lock) Release(ctx context.Context) error {
+
+ ch := fmt.Sprintf(lockTopicFormat, l.key)
+
+ result, err := l.distributedLock.client.EvalSha(ctx, l.distributedLock.sha, []string{l.key, ch}, l.uid).Result()
+ if err != nil {
+ return err
+ }
+
+ if result == int64(0) {
+ return ErrLockExpired
+ }
+
+ return nil
+}