diff --git a/internal/distributedlock/distributed_lock.go b/internal/distributedlock/distributed_lock.go index 2525b29..bae898a 100644 --- a/internal/distributedlock/distributed_lock.go +++ b/internal/distributedlock/distributed_lock.go @@ -3,23 +3,42 @@ package distributedlock import ( "context" "fmt" - "github.com/go-redis/redis/v8" "math/rand" "time" + + "github.com/go-redis/redis/v8" ) -const lockTopicFormat = "pubsub:locks:%s" +const ( + lockTopicFormat = "pubsub:locks:%s" + lockReleaseScript = ` + if redis.call("GET", KEYS[1]) == ARGV[1] then + redis.call("DEL", KEYS[1]) + redis.call("PUBLISH", KEYS[2], KEYS[1]) + return 1 + else + return 0 + end + ` +) type DistributedLock struct { client *redis.Client + sha string timeout time.Duration } -func New(client *redis.Client, timeout time.Duration) *DistributedLock { +func New(client *redis.Client, timeout time.Duration) (*DistributedLock, error) { + sha, err := client.ScriptLoad(context.Background(), lockReleaseScript).Result() + if err != nil { + return nil, err + } + return &DistributedLock{ client: client, + sha: sha, timeout: timeout, - } + }, nil } func (d *DistributedLock) setLock(ctx context.Context, key string, uid string) error { diff --git a/internal/distributedlock/distributed_lock_test.go b/internal/distributedlock/distributed_lock_test.go index b0bfdfb..3d8211f 100644 --- a/internal/distributedlock/distributed_lock_test.go +++ b/internal/distributedlock/distributed_lock_test.go @@ -3,12 +3,15 @@ package distributedlock_test import ( "context" "fmt" - "github.com/christianselig/apollo-backend/internal/distributedlock" - "github.com/go-redis/redis/v8" - "github.com/stretchr/testify/assert" + "math/rand" "os" "testing" "time" + + "github.com/christianselig/apollo-backend/internal/distributedlock" + + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/assert" ) func NewRedisClient(t *testing.T, ctx context.Context) (*redis.Client, func()) { @@ -29,13 +32,17 @@ func NewRedisClient(t *testing.T, ctx context.Context) (*redis.Client, func()) { } func TestDistributedLock_AcquireLock(t *testing.T) { + t.Parallel() + ctx := context.Background() - key := fmt.Sprintf("%d", time.Now().UnixNano()) + key := fmt.Sprintf("key:%d-%d", time.Now().UnixNano(), rand.Int63()) client, closer := NewRedisClient(t, ctx) defer closer() - d := distributedlock.New(client, 10*time.Second) + d, err := distributedlock.New(client, 10*time.Second) + assert.NoError(t, err) + lock, err := d.AcquireLock(ctx, key) assert.NoError(t, err) @@ -50,13 +57,17 @@ func TestDistributedLock_AcquireLock(t *testing.T) { } func TestDistributedLock_WaitAcquireLock(t *testing.T) { + t.Parallel() + ctx := context.Background() - key := fmt.Sprintf("%d", time.Now().UnixNano()) + key := fmt.Sprintf("key:%d-%d", time.Now().UnixNano(), rand.Int63()) client, closer := NewRedisClient(t, ctx) defer closer() - d := distributedlock.New(client, 10*time.Second) + d, err := distributedlock.New(client, 10*time.Second) + assert.NoError(t, err) + lock, err := d.AcquireLock(ctx, key) assert.NoError(t, err) diff --git a/internal/distributedlock/lock.go b/internal/distributedlock/lock.go index 3a49f39..005730e 100644 --- a/internal/distributedlock/lock.go +++ b/internal/distributedlock/lock.go @@ -20,19 +20,10 @@ func NewLock(distributedLock *DistributedLock, key string, uid string) *Lock { } func (l *Lock) Release(ctx context.Context) error { - script := ` - if redis.call("GET", KEYS[1]) == ARGV[1] then - redis.call("DEL", KEYS[1]) - redis.call("PUBLISH", KEYS[2], KEYS[1]) - return 1 - else - return 0 - end - ` ch := fmt.Sprintf(lockTopicFormat, l.key) - result, err := l.distributedLock.client.Eval(ctx, script, []string{l.key, ch}, l.uid).Result() + result, err := l.distributedLock.client.EvalSha(ctx, l.distributedLock.sha, []string{l.key, ch}, l.uid).Result() if err != nil { return err }