diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b92535..1936daa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,6 +5,7 @@ jobs: runs-on: ubuntu-latest env: DATABASE_URL: postgres://postgres:postgres@localhost/apollo_test + REDIS_URL: redis://localhost:6379 services: postgres: image: postgres @@ -18,12 +19,20 @@ jobs: --health-retries 5 ports: - 5432:5432 - + redis: + image: redis + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 6379:6379 steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 with: - go-version: 1.19.3 + go-version: 1.20.2 - uses: golangci/golangci-lint-action@v3 - run: psql -f docs/schema.sql $DATABASE_URL - run: go test ./... -v -race -timeout 5s diff --git a/.golangci.yml b/.golangci.yml index f74fa4e..5d9c923 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -17,11 +17,5 @@ linters: - thelper # detects golang test helpers without t.Helper() call and checks consistency of test helpers - unconvert # removes unnecessary type conversions - unparam # removes unused function parameters + - paralleltest fast: true - -issues: - exclude-rules: - # False positive: https://github.com/kunwardeep/paralleltest/issues/8. - - linters: - - paralleltest - text: "does not use range value in test Run" diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/apollo-backend.iml b/.idea/apollo-backend.iml new file mode 100644 index 0000000..5e764c4 --- /dev/null +++ b/.idea/apollo-backend.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 0000000..a55e7a1 --- /dev/null +++ b/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..fa496e0 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Makefile b/Makefile index 90780ea..e3eac97 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,9 @@ BREW_PREFIX ?= $(shell brew --prefix) DATABASE_URL ?= "postgres://$(USER)@localhost/apollo_test?sslmode=disable" +REDIS_URL ?= "redis://localhost:6379" test: - @DATABASE_URL=$(DATABASE_URL) go test -race -timeout 1s ./... + @DATABASE_URL=$(DATABASE_URL) REDIS_URL=$(REDIS_URL) go test -race -timeout 1s ./... test-setup: $(BREW_PREFIX)/bin/migrate migrate -path migrations/ -database $(DATABASE_URL) up diff --git a/internal/distributedlock/distributed_lock.go b/internal/distributedlock/distributed_lock.go new file mode 100644 index 0000000..b989aea --- /dev/null +++ b/internal/distributedlock/distributed_lock.go @@ -0,0 +1,92 @@ +package distributedlock + +import ( + "context" + "fmt" + "math/rand" + "time" + + "github.com/go-redis/redis/v8" +) + +const ( + lockTopicFormat = "pubsub:locks:%s" + lockReleaseScript = ` + if redis.call("GET", KEYS[1]) == ARGV[1] then + redis.call("DEL", KEYS[1]) + redis.call("PUBLISH", KEYS[2], KEYS[1]) + return 1 + else + return 0 + end + ` +) + +type DistributedLock struct { + client *redis.Client + sha string + timeout time.Duration +} + +func New(client *redis.Client, timeout time.Duration) (*DistributedLock, error) { + sha, err := client.ScriptLoad(context.Background(), lockReleaseScript).Result() + if err != nil { + return nil, err + } + + return &DistributedLock{ + client: client, + sha: sha, + timeout: timeout, + }, nil +} + +func (d *DistributedLock) setLock(ctx context.Context, key string, uid string) error { + result, err := d.client.SetNX(ctx, key, uid, d.timeout).Result() + if err != nil { + return err + } + + if !result { + return ErrLockAlreadyAcquired + } + + return nil +} + +func (d *DistributedLock) AcquireLock(ctx context.Context, key string) (*Lock, error) { + uid := generateUniqueID() + if err := d.setLock(ctx, key, uid); err != nil { + return nil, err + } + + return NewLock(d, key, uid), nil +} + +func (d *DistributedLock) WaitAcquireLock(ctx context.Context, key string, timeout time.Duration) (*Lock, error) { + uid := generateUniqueID() + if err := d.setLock(ctx, key, uid); err == nil { + return NewLock(d, key, uid), nil + } + + ch := fmt.Sprintf(lockTopicFormat, key) + pubsub := d.client.Subscribe(ctx, ch) + defer func() { _ = pubsub.Close() }() + + select { + case <-time.After(timeout): + return nil, ErrLockAcquisitionTimeout + case <-ctx.Done(): + return nil, ctx.Err() + case <-pubsub.Channel(): + err := d.setLock(ctx, key, uid) + if err != nil { + return nil, err + } + return NewLock(d, key, uid), nil + } +} + +func generateUniqueID() string { + return fmt.Sprintf("%d-%d", time.Now().UnixNano(), rand.Int63()) +} diff --git a/internal/distributedlock/distributed_lock_test.go b/internal/distributedlock/distributed_lock_test.go new file mode 100644 index 0000000..3d8211f --- /dev/null +++ b/internal/distributedlock/distributed_lock_test.go @@ -0,0 +1,84 @@ +package distributedlock_test + +import ( + "context" + "fmt" + "math/rand" + "os" + "testing" + "time" + + "github.com/christianselig/apollo-backend/internal/distributedlock" + + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/assert" +) + +func NewRedisClient(t *testing.T, ctx context.Context) (*redis.Client, func()) { + t.Helper() + + opt, err := redis.ParseURL(os.Getenv("REDIS_URL")) + if err != nil { + panic(err) + } + client := redis.NewClient(opt) + if err := client.Ping(ctx).Err(); err != nil { + panic(err) + } + + return client, func() { + _ = client.Close() + } +} + +func TestDistributedLock_AcquireLock(t *testing.T) { + t.Parallel() + + ctx := context.Background() + key := fmt.Sprintf("key:%d-%d", time.Now().UnixNano(), rand.Int63()) + + client, closer := NewRedisClient(t, ctx) + defer closer() + + d, err := distributedlock.New(client, 10*time.Second) + assert.NoError(t, err) + + lock, err := d.AcquireLock(ctx, key) + assert.NoError(t, err) + + _, err = d.AcquireLock(ctx, key) + assert.Equal(t, distributedlock.ErrLockAlreadyAcquired, err) + + err = lock.Release(ctx) + assert.NoError(t, err) + + _, err = d.AcquireLock(ctx, key) + assert.NoError(t, err) +} + +func TestDistributedLock_WaitAcquireLock(t *testing.T) { + t.Parallel() + + ctx := context.Background() + key := fmt.Sprintf("key:%d-%d", time.Now().UnixNano(), rand.Int63()) + + client, closer := NewRedisClient(t, ctx) + defer closer() + + d, err := distributedlock.New(client, 10*time.Second) + assert.NoError(t, err) + + lock, err := d.AcquireLock(ctx, key) + assert.NoError(t, err) + + go func(l *distributedlock.Lock) { + select { + case <-time.After(100 * time.Millisecond): + _ = l.Release(ctx) + } + }(lock) + + lock, err = d.WaitAcquireLock(ctx, key, 5*time.Second) + assert.NoError(t, err) + assert.NotNil(t, lock) +} diff --git a/internal/distributedlock/errors.go b/internal/distributedlock/errors.go new file mode 100644 index 0000000..7f0bb46 --- /dev/null +++ b/internal/distributedlock/errors.go @@ -0,0 +1,9 @@ +package distributedlock + +import "errors" + +var ( + ErrLockAcquisitionTimeout = errors.New("timed out acquiring lock") + ErrLockAlreadyAcquired = errors.New("lock already acquired") + ErrLockExpired = errors.New("releasing an expired lock") +) diff --git a/internal/distributedlock/lock.go b/internal/distributedlock/lock.go new file mode 100644 index 0000000..005730e --- /dev/null +++ b/internal/distributedlock/lock.go @@ -0,0 +1,36 @@ +package distributedlock + +import ( + "context" + "fmt" +) + +type Lock struct { + distributedLock *DistributedLock + key string + uid string +} + +func NewLock(distributedLock *DistributedLock, key string, uid string) *Lock { + return &Lock{ + distributedLock: distributedLock, + key: key, + uid: uid, + } +} + +func (l *Lock) Release(ctx context.Context) error { + + ch := fmt.Sprintf(lockTopicFormat, l.key) + + result, err := l.distributedLock.client.EvalSha(ctx, l.distributedLock.sha, []string{l.key, ch}, l.uid).Result() + if err != nil { + return err + } + + if result == int64(0) { + return ErrLockExpired + } + + return nil +}