diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b92535..b2e218e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,6 +5,7 @@ jobs: runs-on: ubuntu-latest env: DATABASE_URL: postgres://postgres:postgres@localhost/apollo_test + REDIS_URL: redis://redis:6379 services: postgres: image: postgres @@ -18,7 +19,15 @@ jobs: --health-retries 5 ports: - 5432:5432 - + redis: + image: redis + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 6379:6379 steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 diff --git a/.golangci.yml b/.golangci.yml index f74fa4e..5d9c923 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -17,11 +17,5 @@ linters: - thelper # detects golang test helpers without t.Helper() call and checks consistency of test helpers - unconvert # removes unnecessary type conversions - unparam # removes unused function parameters + - paralleltest fast: true - -issues: - exclude-rules: - # False positive: https://github.com/kunwardeep/paralleltest/issues/8. - - linters: - - paralleltest - text: "does not use range value in test Run" diff --git a/Makefile b/Makefile index 90780ea..e3eac97 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,9 @@ BREW_PREFIX ?= $(shell brew --prefix) DATABASE_URL ?= "postgres://$(USER)@localhost/apollo_test?sslmode=disable" +REDIS_URL ?= "redis://localhost:6379" test: - @DATABASE_URL=$(DATABASE_URL) go test -race -timeout 1s ./... + @DATABASE_URL=$(DATABASE_URL) REDIS_URL=$(REDIS_URL) go test -race -timeout 1s ./... test-setup: $(BREW_PREFIX)/bin/migrate migrate -path migrations/ -database $(DATABASE_URL) up diff --git a/internal/distributedlock/distributed_lock.go b/internal/distributedlock/distributed_lock.go new file mode 100644 index 0000000..2525b29 --- /dev/null +++ b/internal/distributedlock/distributed_lock.go @@ -0,0 +1,72 @@ +package distributedlock + +import ( + "context" + "fmt" + "github.com/go-redis/redis/v8" + "math/rand" + "time" +) + +const lockTopicFormat = "pubsub:locks:%s" + +type DistributedLock struct { + client *redis.Client + timeout time.Duration +} + +func New(client *redis.Client, timeout time.Duration) *DistributedLock { + return &DistributedLock{ + client: client, + timeout: timeout, + } +} + +func (d *DistributedLock) setLock(ctx context.Context, key string, uid string) error { + result, err := d.client.SetNX(ctx, key, uid, d.timeout).Result() + if err != nil { + return err + } + + if !result { + return ErrLockAlreadyAcquired + } + + return nil +} + +func (d *DistributedLock) AcquireLock(ctx context.Context, key string) (*Lock, error) { + uid := generateUniqueID() + if err := d.setLock(ctx, key, uid); err != nil { + return nil, err + } + + return NewLock(d, key, uid), nil +} + +func (d *DistributedLock) WaitAcquireLock(ctx context.Context, key string, timeout time.Duration) (*Lock, error) { + uid := generateUniqueID() + if err := d.setLock(ctx, key, uid); err == nil { + return NewLock(d, key, uid), nil + } + + ch := fmt.Sprintf(lockTopicFormat, key) + pubsub := d.client.Subscribe(ctx, ch) + + select { + case <-time.After(timeout): + return nil, ErrLockAcquisitionTimeout + case <-ctx.Done(): + return nil, ctx.Err() + case <-pubsub.Channel(): + err := d.setLock(ctx, key, uid) + if err != nil { + return nil, err + } + return NewLock(d, key, uid), nil + } +} + +func generateUniqueID() string { + return fmt.Sprintf("%d-%d", time.Now().UnixNano(), rand.Int63()) +} diff --git a/internal/distributedlock/distributed_lock_test.go b/internal/distributedlock/distributed_lock_test.go new file mode 100644 index 0000000..b0bfdfb --- /dev/null +++ b/internal/distributedlock/distributed_lock_test.go @@ -0,0 +1,73 @@ +package distributedlock_test + +import ( + "context" + "fmt" + "github.com/christianselig/apollo-backend/internal/distributedlock" + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/assert" + "os" + "testing" + "time" +) + +func NewRedisClient(t *testing.T, ctx context.Context) (*redis.Client, func()) { + t.Helper() + + opt, err := redis.ParseURL(os.Getenv("REDIS_URL")) + if err != nil { + panic(err) + } + client := redis.NewClient(opt) + if err := client.Ping(ctx).Err(); err != nil { + panic(err) + } + + return client, func() { + _ = client.Close() + } +} + +func TestDistributedLock_AcquireLock(t *testing.T) { + ctx := context.Background() + key := fmt.Sprintf("%d", time.Now().UnixNano()) + + client, closer := NewRedisClient(t, ctx) + defer closer() + + d := distributedlock.New(client, 10*time.Second) + lock, err := d.AcquireLock(ctx, key) + assert.NoError(t, err) + + _, err = d.AcquireLock(ctx, key) + assert.Equal(t, distributedlock.ErrLockAlreadyAcquired, err) + + err = lock.Release(ctx) + assert.NoError(t, err) + + _, err = d.AcquireLock(ctx, key) + assert.NoError(t, err) +} + +func TestDistributedLock_WaitAcquireLock(t *testing.T) { + ctx := context.Background() + key := fmt.Sprintf("%d", time.Now().UnixNano()) + + client, closer := NewRedisClient(t, ctx) + defer closer() + + d := distributedlock.New(client, 10*time.Second) + lock, err := d.AcquireLock(ctx, key) + assert.NoError(t, err) + + go func(l *distributedlock.Lock) { + select { + case <-time.After(100 * time.Millisecond): + _ = l.Release(ctx) + } + }(lock) + + lock, err = d.WaitAcquireLock(ctx, key, 5*time.Second) + assert.NoError(t, err) + assert.NotNil(t, lock) +} diff --git a/internal/distributedlock/errors.go b/internal/distributedlock/errors.go new file mode 100644 index 0000000..7f0bb46 --- /dev/null +++ b/internal/distributedlock/errors.go @@ -0,0 +1,9 @@ +package distributedlock + +import "errors" + +var ( + ErrLockAcquisitionTimeout = errors.New("timed out acquiring lock") + ErrLockAlreadyAcquired = errors.New("lock already acquired") + ErrLockExpired = errors.New("releasing an expired lock") +) diff --git a/internal/distributedlock/lock.go b/internal/distributedlock/lock.go new file mode 100644 index 0000000..3a49f39 --- /dev/null +++ b/internal/distributedlock/lock.go @@ -0,0 +1,45 @@ +package distributedlock + +import ( + "context" + "fmt" +) + +type Lock struct { + distributedLock *DistributedLock + key string + uid string +} + +func NewLock(distributedLock *DistributedLock, key string, uid string) *Lock { + return &Lock{ + distributedLock: distributedLock, + key: key, + uid: uid, + } +} + +func (l *Lock) Release(ctx context.Context) error { + script := ` + if redis.call("GET", KEYS[1]) == ARGV[1] then + redis.call("DEL", KEYS[1]) + redis.call("PUBLISH", KEYS[2], KEYS[1]) + return 1 + else + return 0 + end + ` + + ch := fmt.Sprintf(lockTopicFormat, l.key) + + result, err := l.distributedLock.client.Eval(ctx, script, []string{l.key, ch}, l.uid).Result() + if err != nil { + return err + } + + if result == int64(0) { + return ErrLockExpired + } + + return nil +}