mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-10 22:17:44 +00:00
490 lines
12 KiB
Go
490 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"os/signal"
|
|
"runtime"
|
|
"strconv"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/DataDog/datadog-go/statsd"
|
|
"github.com/adjust/rmq/v4"
|
|
"github.com/go-redis/redis/v8"
|
|
"github.com/jackc/pgx/v4"
|
|
"github.com/jackc/pgx/v4/pgxpool"
|
|
"github.com/joho/godotenv"
|
|
"github.com/sideshow/apns2"
|
|
"github.com/sideshow/apns2/payload"
|
|
"github.com/sideshow/apns2/token"
|
|
"github.com/sirupsen/logrus"
|
|
|
|
"github.com/christianselig/apollo-backend/internal/data"
|
|
"github.com/christianselig/apollo-backend/internal/reddit"
|
|
)
|
|
|
|
const (
|
|
pollDuration = 10 * time.Millisecond
|
|
backoff = 5
|
|
rate = 0.1
|
|
)
|
|
|
|
func main() {
|
|
_ = godotenv.Load()
|
|
|
|
errChan := make(chan error, 10)
|
|
go logErrors(errChan)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
var logger *logrus.Logger
|
|
{
|
|
logger = logrus.New()
|
|
if os.Getenv("ENV") == "" {
|
|
logger.SetLevel(logrus.DebugLevel)
|
|
} else {
|
|
logger.SetFormatter(&logrus.TextFormatter{
|
|
DisableColors: true,
|
|
FullTimestamp: true,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Set up Postgres connection
|
|
var pool *pgxpool.Pool
|
|
{
|
|
config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_CONNECTION_POOL_URL"))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Setting the build statement cache to nil helps this work with pgbouncer
|
|
config.ConnConfig.BuildStatementCache = nil
|
|
config.ConnConfig.PreferSimpleProtocol = true
|
|
|
|
pool, err = pgxpool.ConnectConfig(ctx, config)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
defer pool.Close()
|
|
}
|
|
|
|
statsd, err := statsd.New("127.0.0.1:8125")
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
rc := reddit.NewClient(
|
|
os.Getenv("REDDIT_CLIENT_ID"),
|
|
os.Getenv("REDDIT_CLIENT_SECRET"),
|
|
statsd,
|
|
)
|
|
|
|
var apnsToken *token.Token
|
|
{
|
|
authKey, err := token.AuthKeyFromBytes([]byte(os.Getenv("APPLE_KEY_PKEY")))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
apnsToken = &token.Token{
|
|
AuthKey: authKey,
|
|
KeyID: os.Getenv("APPLE_KEY_ID"),
|
|
TeamID: os.Getenv("APPLE_TEAM_ID"),
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
log.Fatal("token error:", err)
|
|
}
|
|
|
|
// Set up Redis connection
|
|
var redisConn *redis.Client
|
|
{
|
|
opt, err := redis.ParseURL(os.Getenv("REDISCLOUD_URL"))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
redisConn = redis.NewClient(opt)
|
|
if err := redisConn.Ping(ctx).Err(); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
connection, err := rmq.OpenConnectionWithRedisClient("consumer", redisConn, errChan)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
queue, err := connection.OpenQueue("notifications")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
numConsumers := runtime.NumCPU() * 12
|
|
prefetchLimit := int64(numConsumers * 32)
|
|
|
|
runtime.GOMAXPROCS(numConsumers)
|
|
|
|
logger.WithFields(logrus.Fields{
|
|
"numConsumers": numConsumers,
|
|
}).Info("starting up notifications worker")
|
|
|
|
if err := queue.StartConsuming(prefetchLimit, pollDuration); err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
host, _ := os.Hostname()
|
|
|
|
for i := 0; i < numConsumers; i++ {
|
|
name := fmt.Sprintf("consumer %s-%d", host, i)
|
|
|
|
consumer := NewConsumer(i, logger, statsd, redisConn, pool, rc, apnsToken)
|
|
if _, err := queue.AddConsumer(name, consumer); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
signals := make(chan os.Signal, 1)
|
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
|
defer signal.Stop(signals)
|
|
|
|
<-signals // wait for signal
|
|
cancel()
|
|
go func() {
|
|
<-signals // hard exit on second signal (in case shutdown gets stuck)
|
|
os.Exit(1)
|
|
}()
|
|
|
|
<-connection.StopAllConsuming() // wait for all Consume() calls to finish
|
|
}
|
|
|
|
type Consumer struct {
|
|
tag int
|
|
logger *logrus.Logger
|
|
statsd *statsd.Client
|
|
redis *redis.Client
|
|
pool *pgxpool.Pool
|
|
reddit *reddit.Client
|
|
apnsSandbox *apns2.Client
|
|
apnsProduction *apns2.Client
|
|
}
|
|
|
|
func NewConsumer(tag int, logger *logrus.Logger, statsd *statsd.Client, redis *redis.Client, pool *pgxpool.Pool, rc *reddit.Client, apnsToken *token.Token) *Consumer {
|
|
return &Consumer{
|
|
tag,
|
|
logger,
|
|
statsd,
|
|
redis,
|
|
pool,
|
|
rc,
|
|
apns2.NewTokenClient(apnsToken),
|
|
apns2.NewTokenClient(apnsToken).Production(),
|
|
}
|
|
}
|
|
|
|
func (c *Consumer) Consume(delivery rmq.Delivery) {
|
|
ctx := context.Background()
|
|
|
|
defer func() {
|
|
lockKey := fmt.Sprintf("locks:accounts:%s", delivery.Payload())
|
|
if err := c.redis.Del(ctx, lockKey).Err(); err != nil {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"lockKey": lockKey,
|
|
"err": err,
|
|
}).Error("failed to remove lock")
|
|
}
|
|
}()
|
|
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": delivery.Payload(),
|
|
}).Debug("starting job")
|
|
|
|
id, err := strconv.ParseInt(delivery.Payload(), 10, 64)
|
|
if err != nil {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": delivery.Payload(),
|
|
"err": err,
|
|
}).Error("failed to parse account ID")
|
|
|
|
delivery.Reject()
|
|
return
|
|
}
|
|
|
|
defer delivery.Ack()
|
|
|
|
now := float64(time.Now().UnixNano()/int64(time.Millisecond)) / 1000
|
|
|
|
stmt := `SELECT
|
|
id,
|
|
username,
|
|
account_id,
|
|
access_token,
|
|
refresh_token,
|
|
expires_at,
|
|
last_message_id,
|
|
last_checked_at
|
|
FROM accounts
|
|
WHERE id = $1`
|
|
account := &data.Account{}
|
|
if err := c.pool.QueryRow(ctx, stmt, id).Scan(
|
|
&account.ID,
|
|
&account.Username,
|
|
&account.AccountID,
|
|
&account.AccessToken,
|
|
&account.RefreshToken,
|
|
&account.ExpiresAt,
|
|
&account.LastMessageID,
|
|
&account.LastCheckedAt,
|
|
); err != nil {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
"err": err,
|
|
}).Error("failed to fetch account from database")
|
|
return
|
|
}
|
|
|
|
if account.LastCheckedAt > 0 {
|
|
latency := now - account.LastCheckedAt - float64(backoff)
|
|
c.statsd.Histogram("apollo.queue.delay", latency, []string{}, rate)
|
|
}
|
|
|
|
rac := c.reddit.NewAuthenticatedClient(account.RefreshToken, account.AccessToken)
|
|
if account.ExpiresAt < int64(now) {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
}).Debug("refreshing reddit token")
|
|
|
|
tokens, err := rac.RefreshTokens()
|
|
if err != nil {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
"err": err,
|
|
}).Error("failed to refresh reddit tokens")
|
|
return
|
|
}
|
|
err = c.pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
|
stmt := `
|
|
UPDATE accounts
|
|
SET access_token = $1, refresh_token = $2, expires_at = $3 WHERE id = $4`
|
|
_, err := tx.Exec(ctx, stmt, tokens.AccessToken, tokens.RefreshToken, int64(now+3540), account.ID)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
"err": err,
|
|
}).Error("failed to update reddit tokens for account")
|
|
return
|
|
}
|
|
}
|
|
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
}).Debug("fetching message inbox")
|
|
msgs, err := rac.MessageInbox(account.LastMessageID)
|
|
|
|
if err != nil {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
"err": err,
|
|
}).Error("failed to fetch message inbox")
|
|
return
|
|
}
|
|
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
"count": len(msgs.MessageListing.Messages),
|
|
}).Debug("fetched messages")
|
|
|
|
if err = c.pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
|
stmt := `
|
|
UPDATE accounts
|
|
SET last_checked_at = $1
|
|
WHERE id = $2`
|
|
_, err := tx.Exec(ctx, stmt, now, account.ID)
|
|
return err
|
|
}); err != nil {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
"err": err,
|
|
}).Error("failed to update last_checked_at for account")
|
|
return
|
|
}
|
|
|
|
if len(msgs.MessageListing.Messages) == 0 {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
}).Debug("no new messages, bailing early")
|
|
return
|
|
}
|
|
|
|
// Set latest message we alerted on
|
|
latestMsg := msgs.MessageListing.Messages[0]
|
|
if err = c.pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
|
stmt := `
|
|
UPDATE accounts
|
|
SET last_message_id = $1
|
|
WHERE id = $2`
|
|
_, err := tx.Exec(ctx, stmt, latestMsg.FullName(), account.ID)
|
|
return err
|
|
}); err != nil {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
"err": err,
|
|
}).Error("failed to update last_message_id for account")
|
|
return
|
|
}
|
|
|
|
// Let's populate this with the latest message so we don't flood users with stuff
|
|
if account.LastMessageID == "" {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": delivery.Payload(),
|
|
}).Debug("populating first message ID to prevent spamming")
|
|
return
|
|
}
|
|
|
|
devices := []data.Device{}
|
|
stmt = `
|
|
SELECT apns_token, sandbox
|
|
FROM devices
|
|
LEFT JOIN devices_accounts ON devices.id = devices_accounts.device_id
|
|
WHERE devices_accounts.account_id = $1`
|
|
rows, err := c.pool.Query(ctx, stmt, account.ID)
|
|
if err != nil {
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
"err": err,
|
|
}).Error("failed to fetch account devices")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var device data.Device
|
|
rows.Scan(&device.APNSToken, &device.Sandbox)
|
|
devices = append(devices, device)
|
|
}
|
|
|
|
for _, msg := range msgs.MessageListing.Messages {
|
|
notification := &apns2.Notification{}
|
|
notification.Topic = "com.christianselig.Apollo"
|
|
notification.Payload = payloadFromMessage(account, &msg, len(msgs.MessageListing.Messages))
|
|
|
|
for _, device := range devices {
|
|
notification.DeviceToken = device.APNSToken
|
|
client := c.apnsProduction
|
|
if device.Sandbox {
|
|
client = c.apnsSandbox
|
|
}
|
|
|
|
res, err := client.Push(notification)
|
|
if err != nil {
|
|
c.statsd.Incr("apns.notification.errors", []string{}, 1)
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": id,
|
|
"err": err,
|
|
"status": res.StatusCode,
|
|
"reason": res.Reason,
|
|
}).Error("failed to send notification")
|
|
} else {
|
|
c.statsd.Incr("apns.notification.sent", []string{}, 1)
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": delivery.Payload(),
|
|
"token": device.APNSToken,
|
|
"redditUser": account.Username,
|
|
}).Info("sent notification")
|
|
}
|
|
}
|
|
}
|
|
|
|
c.logger.WithFields(logrus.Fields{
|
|
"accountID": delivery.Payload(),
|
|
}).Debug("finishing job")
|
|
}
|
|
|
|
func payloadFromMessage(acct *data.Account, msg *reddit.MessageData, badgeCount int) *payload.Payload {
|
|
postBody := msg.Body
|
|
if len(postBody) > 2000 {
|
|
postBody = msg.Body[:2000]
|
|
}
|
|
|
|
postTitle := msg.LinkTitle
|
|
if postTitle == "" {
|
|
postTitle = msg.Subject
|
|
}
|
|
if len(postTitle) > 75 {
|
|
postTitle = fmt.Sprintf("%s…", postTitle[0:75])
|
|
}
|
|
|
|
payload := payload.
|
|
NewPayload().
|
|
AlertBody(postBody).
|
|
AlertSummaryArg(msg.Author).
|
|
Badge(badgeCount).
|
|
Custom("account_id", acct.AccountID).
|
|
Custom("author", msg.Author).
|
|
Custom("destination_author", msg.Destination).
|
|
Custom("parent_id", msg.ParentID).
|
|
Custom("post_title", msg.LinkTitle).
|
|
Custom("subreddit", msg.Subreddit).
|
|
MutableContent().
|
|
Sound("traloop.wav")
|
|
|
|
switch {
|
|
case (msg.Kind == "t1" && msg.Type == "username_mention"):
|
|
title := fmt.Sprintf(`Mention in “%s”`, postTitle)
|
|
payload = payload.AlertTitle(title).Custom("type", "username")
|
|
|
|
pType, _ := reddit.SplitID(msg.ParentID)
|
|
if pType == "t1" {
|
|
payload = payload.Category("inbox-username-mention-context")
|
|
} else {
|
|
payload = payload.Category("inbox-username-mention-no-context")
|
|
}
|
|
|
|
payload = payload.Custom("subject", "comment").ThreadID("comment")
|
|
break
|
|
case (msg.Kind == "t1" && msg.Type == "post_reply"):
|
|
title := fmt.Sprintf(`%s to “%s”`, msg.Author, postTitle)
|
|
payload = payload.
|
|
AlertTitle(title).
|
|
Category("inbox-post-reply").
|
|
Custom("post_id", msg.ID).
|
|
Custom("subject", "comment").
|
|
Custom("type", "post").
|
|
ThreadID("comment")
|
|
break
|
|
case (msg.Kind == "t1" && msg.Type == "comment_reply"):
|
|
title := fmt.Sprintf(`%s in “%s”`, msg.Author, postTitle)
|
|
_, postID := reddit.SplitID(msg.ParentID)
|
|
payload = payload.
|
|
AlertTitle(title).
|
|
Category("inbox-comment-reply").
|
|
Custom("comment_id", msg.ID).
|
|
Custom("post_id", postID).
|
|
Custom("subject", "comment").
|
|
Custom("type", "comment").
|
|
ThreadID("comment")
|
|
break
|
|
case (msg.Kind == "t4"):
|
|
title := fmt.Sprintf(`Message from %s`, msg.Author)
|
|
payload = payload.
|
|
AlertTitle(title).
|
|
AlertSubtitle(postTitle).
|
|
Category("inbox-private-message").
|
|
Custom("type", "private-message")
|
|
break
|
|
}
|
|
|
|
return payload
|
|
}
|
|
|
|
func logErrors(errChan <-chan error) {
|
|
for err := range errChan {
|
|
log.Print("error: ", err)
|
|
}
|
|
}
|