apollo-backend/internal/distributedlock/distributed_lock.go

92 lines
1.8 KiB
Go
Raw Normal View History

2023-04-01 15:57:28 +00:00
package distributedlock
import (
"context"
"fmt"
"math/rand"
"time"
2023-04-01 16:07:48 +00:00
"github.com/go-redis/redis/v8"
2023-04-01 15:57:28 +00:00
)
2023-04-01 16:07:48 +00:00
const (
lockTopicFormat = "pubsub:locks:%s"
lockReleaseScript = `
if redis.call("GET", KEYS[1]) == ARGV[1] then
redis.call("DEL", KEYS[1])
redis.call("PUBLISH", KEYS[2], KEYS[1])
return 1
else
return 0
end
`
)
2023-04-01 15:57:28 +00:00
type DistributedLock struct {
client *redis.Client
2023-04-01 16:07:48 +00:00
sha string
2023-04-01 15:57:28 +00:00
timeout time.Duration
}
2023-04-01 16:07:48 +00:00
func New(client *redis.Client, timeout time.Duration) (*DistributedLock, error) {
sha, err := client.ScriptLoad(context.Background(), lockReleaseScript).Result()
if err != nil {
return nil, err
}
2023-04-01 15:57:28 +00:00
return &DistributedLock{
client: client,
2023-04-01 16:07:48 +00:00
sha: sha,
2023-04-01 15:57:28 +00:00
timeout: timeout,
2023-04-01 16:07:48 +00:00
}, nil
2023-04-01 15:57:28 +00:00
}
func (d *DistributedLock) setLock(ctx context.Context, key string, uid string) error {
result, err := d.client.SetNX(ctx, key, uid, d.timeout).Result()
if err != nil {
return err
}
if !result {
return ErrLockAlreadyAcquired
}
return nil
}
func (d *DistributedLock) AcquireLock(ctx context.Context, key string) (*Lock, error) {
uid := generateUniqueID()
if err := d.setLock(ctx, key, uid); err != nil {
return nil, err
}
return NewLock(d, key, uid), nil
}
func (d *DistributedLock) WaitAcquireLock(ctx context.Context, key string, timeout time.Duration) (*Lock, error) {
uid := generateUniqueID()
if err := d.setLock(ctx, key, uid); err == nil {
return NewLock(d, key, uid), nil
}
ch := fmt.Sprintf(lockTopicFormat, key)
pubsub := d.client.Subscribe(ctx, ch)
select {
case <-time.After(timeout):
return nil, ErrLockAcquisitionTimeout
case <-ctx.Done():
return nil, ctx.Err()
case <-pubsub.Channel():
err := d.setLock(ctx, key, uid)
if err != nil {
return nil, err
}
return NewLock(d, key, uid), nil
}
}
func generateUniqueID() string {
return fmt.Sprintf("%d-%d", time.Now().UnixNano(), rand.Int63())
}