do batches in redis

This commit is contained in:
Andre Medeiros 2021-07-08 23:12:50 -04:00
parent b7aea89cfc
commit 3ce858927b
2 changed files with 58 additions and 26 deletions

View file

@ -2,6 +2,7 @@ package main
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
@ -19,6 +20,10 @@ import (
"github.com/sirupsen/logrus"
)
const (
batchSize = 100
)
func main() {
_ = godotenv.Load()
@ -172,36 +177,56 @@ func enqueueAccounts(ctx context.Context, logger *logrus.Logger, statsd *statsd.
enqueued := 0
skipped := 0
failed := 0
for _, id := range ids {
// Split ids in batches
for i := 0; i < len(ids); i += batchSize {
j := i + batchSize
if j > len(ids) {
j = len(ids)
}
batch := Int64Slice(ids[i:j])
logger.WithFields(logrus.Fields{
"len": len(batch),
}).Debug("enqueueing batch")
lua := `
local retv={}
local ids=cjson.decode(ARGV[1])
for i=1, #ids do
local key = "locks:accounts:" .. ids[i]
if redis.call("exists", key) == 0 then
retv[#retv + 1] = ids[i]
end
redis.call("setex", key, 60, 1)
end
return retv
`
res, err := redisConn.Eval(ctx, lua, []string{}, batch).Result()
if err != nil {
logger.WithFields(logrus.Fields{
"err": err,
}).Error("failed to check for locked accounts")
}
vals := res.([]interface{})
enqueued += len(vals)
skipped += len(batch) - len(vals)
for _, val := range vals {
id := val.(int64)
payload := fmt.Sprintf("%d", id)
lockKey := fmt.Sprintf("locks:accounts:%s", payload)
_, err := redisConn.Get(ctx, lockKey).Result()
if err == nil {
skipped++
continue
} else if err != redis.Nil {
logger.WithFields(logrus.Fields{
"lockKey": lockKey,
"err": err,
}).Error("failed to check for account lock")
}
if err := redisConn.SetEX(ctx, lockKey, true, 60*time.Second).Err(); err != nil {
logger.WithFields(logrus.Fields{
"lockKey": lockKey,
"err": err,
}).Error("failed to lock account")
}
if err = queue.Publish(payload); err != nil {
logger.WithFields(logrus.Fields{
"accountID": payload,
"err": err,
}).Error("failed to enqueue account")
failed++
} else {
enqueued++
}
}
}
@ -224,3 +249,10 @@ func logErrors(errChan <-chan error) {
log.Print("error: ", err)
}
}
type Int64Slice []int64
func (ii Int64Slice) MarshalBinary() (data []byte, err error) {
bytes, err := json.Marshal(ii)
return bytes, err
}

View file

@ -404,7 +404,7 @@ func (c *Consumer) Consume(delivery rmq.Delivery) {
"accountID": delivery.Payload(),
"token": device.APNSToken,
"redditUser": account.Username,
}).Debug("sent notification")
}).Info("sent notification")
}
}
}