Merge pull request #2 from christianselig/fix/notification-payload

Fix notification payloads
This commit is contained in:
André Medeiros 2021-07-12 15:02:47 -04:00 committed by GitHub
commit d0cf4e74d0
6 changed files with 194 additions and 32 deletions

View file

@ -2,11 +2,11 @@ package main
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"time" "time"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/data" "github.com/christianselig/apollo-backend/internal/data"
) )
@ -14,32 +14,54 @@ import (
func (app *application) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { func (app *application) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
a := &data.Account{} a := &data.Account{}
if err := json.NewDecoder(r.Body).Decode(a); err != nil { if err := json.NewDecoder(r.Body).Decode(a); err != nil {
fmt.Println("failing on decoding json") app.logger.WithFields(logrus.Fields{
app.errorResponse(w, r, 500, err.Error()) "err": err,
}).Info("failed to parse request json")
app.errorResponse(w, r, 422, err.Error())
return return
} }
a.ExpiresAt = time.Now().Unix() + 3300
// Here we check whether the account is supplied with a valid token. // Here we check whether the account is supplied with a valid token.
ac := app.client.NewAuthenticatedClient(a.RefreshToken, a.AccessToken) ac := app.client.NewAuthenticatedClient(a.RefreshToken, a.AccessToken)
tokens, err := ac.RefreshTokens()
if err != nil {
app.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed to refresh token")
app.errorResponse(w, r, 422, err.Error())
return
}
// Reset expiration timer
a.ExpiresAt = time.Now().Unix() + 3540
ac = app.client.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken)
me, err := ac.Me() me, err := ac.Me()
if err != nil { if err != nil {
fmt.Println("failing on fetching remote user") app.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed to grab user details")
app.errorResponse(w, r, 500, err.Error()) app.errorResponse(w, r, 500, err.Error())
return return
} }
if me.NormalizedUsername() != a.NormalizedUsername() { if me.NormalizedUsername() != a.NormalizedUsername() {
fmt.Println("failing on account username comparison") app.logger.WithFields(logrus.Fields{
app.errorResponse(w, r, 500, "nice try") "err": err,
}).Info("user is not who they say they are")
app.errorResponse(w, r, 422, "nice try")
return return
} }
// Set account ID from Reddit
a.AccountID = me.ID
// Upsert account // Upsert account
if err := app.models.Accounts.Upsert(a); err != nil { if err := app.models.Accounts.Upsert(a); err != nil {
fmt.Println("failing on account upsert") app.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed updating account in database")
app.errorResponse(w, r, 500, err.Error()) app.errorResponse(w, r, 500, err.Error())
return return
} }
@ -47,13 +69,17 @@ func (app *application) upsertAccountHandler(w http.ResponseWriter, r *http.Requ
// Associate // Associate
d, err := app.models.Devices.GetByAPNSToken(ps.ByName("apns")) d, err := app.models.Devices.GetByAPNSToken(ps.ByName("apns"))
if err != nil { if err != nil {
fmt.Println("failing on apns") app.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed fetching account devices")
app.errorResponse(w, r, 500, err.Error()) app.errorResponse(w, r, 500, err.Error())
return return
} }
if err := app.models.DevicesAccounts.Associate(a.ID, d.ID); err != nil { if err := app.models.DevicesAccounts.Associate(a.ID, d.ID); err != nil {
fmt.Println("failing on associate") app.logger.WithFields(logrus.Fields{
"err": err,
}).Info("failed associating account with device")
app.errorResponse(w, r, 500, err.Error()) app.errorResponse(w, r, 500, err.Error())
return return
} }

View file

@ -10,6 +10,7 @@ import (
"github.com/DataDog/datadog-go/statsd" "github.com/DataDog/datadog-go/statsd"
"github.com/joho/godotenv" "github.com/joho/godotenv"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/sirupsen/logrus"
"github.com/christianselig/apollo-backend/internal/data" "github.com/christianselig/apollo-backend/internal/data"
"github.com/christianselig/apollo-backend/internal/reddit" "github.com/christianselig/apollo-backend/internal/reddit"
@ -21,14 +22,25 @@ type config struct {
type application struct { type application struct {
cfg config cfg config
logger *log.Logger logger *logrus.Logger
db *sql.DB db *sql.DB
models *data.Models models *data.Models
client *reddit.Client client *reddit.Client
} }
func main() { func main() {
logger := log.New(os.Stdout, "", log.Ldate|log.Ltime) var logger *logrus.Logger
{
logger = logrus.New()
if os.Getenv("ENV") == "" {
logger.SetLevel(logrus.DebugLevel)
} else {
logger.SetFormatter(&logrus.TextFormatter{
DisableColors: true,
FullTimestamp: true,
})
}
}
if err := godotenv.Load(); err != nil { if err := godotenv.Load(); err != nil {
logger.Printf("Couldn't find .env so I will read from existing ENV.") logger.Printf("Couldn't find .env so I will read from existing ENV.")

View file

@ -222,6 +222,7 @@ func (c *Consumer) Consume(delivery rmq.Delivery) {
stmt := `SELECT stmt := `SELECT
id, id,
username, username,
account_id,
access_token, access_token,
refresh_token, refresh_token,
expires_at, expires_at,
@ -233,6 +234,7 @@ func (c *Consumer) Consume(delivery rmq.Delivery) {
if err := c.pool.QueryRow(ctx, stmt, id).Scan( if err := c.pool.QueryRow(ctx, stmt, id).Scan(
&account.ID, &account.ID,
&account.Username, &account.Username,
&account.AccountID,
&account.AccessToken, &account.AccessToken,
&account.RefreshToken, &account.RefreshToken,
&account.ExpiresAt, &account.ExpiresAt,
@ -370,7 +372,7 @@ func (c *Consumer) Consume(delivery rmq.Delivery) {
for _, msg := range msgs.MessageListing.Messages { for _, msg := range msgs.MessageListing.Messages {
notification := &apns2.Notification{} notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo" notification.Topic = "com.christianselig.Apollo"
notification.Payload = payload.NewPayload().AlertTitle(msg.Subject).AlertBody(msg.Body) notification.Payload = payloadFromMessage(account, &msg, len(msgs.MessageListing.Messages))
for _, device := range devices { for _, device := range devices {
notification.DeviceToken = device.APNSToken notification.DeviceToken = device.APNSToken
@ -404,6 +406,83 @@ func (c *Consumer) Consume(delivery rmq.Delivery) {
}).Debug("finishing job") }).Debug("finishing job")
} }
func payloadFromMessage(acct *data.Account, msg *reddit.MessageData, badgeCount int) *payload.Payload {
postBody := msg.Body
if len(postBody) > 2000 {
postBody = msg.Body[:2000]
}
postTitle := msg.LinkTitle
if postTitle == "" {
postTitle = msg.Subject
}
if len(postTitle) > 75 {
postTitle = fmt.Sprintf("%s…", postTitle[0:75])
}
payload := payload.
NewPayload().
AlertBody(postBody).
AlertSummaryArg(msg.Author).
Badge(badgeCount).
Custom("account_id", acct.AccountID).
Custom("author", msg.Author).
Custom("destination_author", msg.Destination).
Custom("parent_id", msg.ParentID).
Custom("post_title", msg.LinkTitle).
Custom("subreddit", msg.Subreddit).
MutableContent().
Sound("traloop.wav")
switch {
case (msg.Kind == "t1" && msg.Type == "username_mention"):
title := fmt.Sprintf(`Mention in “%s”`, postTitle)
payload = payload.AlertTitle(title).Custom("type", "username")
pType, _ := reddit.SplitID(msg.ParentID)
if pType == "t1" {
payload = payload.Category("inbox-username-mention-context")
} else {
payload = payload.Category("inbox-username-mention-no-context")
}
payload = payload.Custom("subject", "comment").ThreadID("comment")
break
case (msg.Kind == "t1" && msg.Type == "post_reply"):
title := fmt.Sprintf(`%s to “%s”`, msg.Author, postTitle)
payload = payload.
AlertTitle(title).
Category("inbox-post-reply").
Custom("post_id", msg.ID).
Custom("subject", "comment").
Custom("type", "post").
ThreadID("comment")
break
case (msg.Kind == "t1" && msg.Type == "comment_reply"):
title := fmt.Sprintf(`%s in “%s”`, msg.Author, postTitle)
_, postID := reddit.SplitID(msg.ParentID)
payload = payload.
AlertTitle(title).
Category("inbox-comment-reply").
Custom("comment_id", msg.ID).
Custom("post_id", postID).
Custom("subject", "comment").
Custom("type", "comment").
ThreadID("comment")
break
case (msg.Kind == "t4"):
title := fmt.Sprintf(`Message from %s`, msg.Author)
payload = payload.
AlertTitle(title).
AlertSubtitle(postTitle).
Category("inbox-private-message").
Custom("type", "private-message")
break
}
return payload
}
func logErrors(errChan <-chan error) { func logErrors(errChan <-chan error) {
for err := range errChan { for err := range errChan {
log.Print("error: ", err) log.Print("error: ", err)

View file

@ -8,6 +8,7 @@ import (
type Account struct { type Account struct {
ID int64 ID int64
Username string Username string
AccountID string
AccessToken string AccessToken string
RefreshToken string RefreshToken string
ExpiresAt int64 ExpiresAt int64
@ -25,14 +26,14 @@ type AccountModel struct {
func (am *AccountModel) Upsert(a *Account) error { func (am *AccountModel) Upsert(a *Account) error {
query := ` query := `
INSERT INTO accounts (username, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at) INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
VALUES ($1, $2, $3, $4, '', 0, 0) VALUES ($1, $2, $3, $4, '', 0, 0)
ON CONFLICT(username) ON CONFLICT(username)
DO DO
UPDATE SET access_token = $2, refresh_token = $3, expires_at = $4, last_message_id = $5, last_checked_at = $6 UPDATE SET access_token = $2, refresh_token = $3, expires_at = $4, last_message_id = $5, last_checked_at = $6
RETURNING id` RETURNING id`
args := []interface{}{a.NormalizedUsername(), a.AccessToken, a.RefreshToken, a.ExpiresAt, a.LastMessageID, a.LastCheckedAt} args := []interface{}{a.NormalizedUsername(), a.AccountID, a.AccessToken, a.RefreshToken, a.ExpiresAt, a.LastMessageID, a.LastCheckedAt}
return am.DB.QueryRow(query, args...).Scan(&a.ID) return am.DB.QueryRow(query, args...).Scan(&a.ID)
} }

View file

@ -2,6 +2,7 @@ package reddit
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
@ -25,6 +26,14 @@ type Client struct {
statsd *statsd.Client statsd *statsd.Client
} }
func SplitID(id string) (string, string) {
if parts := strings.Split(id, "_"); len(parts) == 2 {
return parts[0], parts[1]
}
return "", ""
}
func NewClient(id, secret string, statsd *statsd.Client) *Client { func NewClient(id, secret string, statsd *statsd.Client) *Client {
tracer := &httptrace.ClientTrace{ tracer := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) { GotConn: func(info httptrace.GotConnInfo) {
@ -85,7 +94,23 @@ func (rac *AuthenticatedClient) request(r *Request) ([]byte, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
return ioutil.ReadAll(resp.Body) bb, err := ioutil.ReadAll(resp.Body)
if err != nil {
rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
return nil, err
}
if resp.StatusCode != 200 {
rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
// Try to parse a json error. Otherwise we generate a generic one
rerr := &Error{}
if jerr := json.Unmarshal(bb, rerr); jerr != nil {
return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode)
}
return nil, rerr
}
return bb, nil
} }
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
@ -129,14 +154,6 @@ func (rac *AuthenticatedClient) MessageInbox(from string) (*MessageListingRespon
return mlr, nil return mlr, nil
} }
type MeResponse struct {
Name string
}
func (mr *MeResponse) NormalizedUsername() string {
return strings.ToLower(mr.Name)
}
func (rac *AuthenticatedClient) Me() (*MeResponse, error) { func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
req := NewRequest( req := NewRequest(
WithTags([]string{"url:/api/v1/me"}), WithTags([]string{"url:/api/v1/me"}),

View file

@ -1,14 +1,32 @@
package reddit package reddit
import "fmt" import (
"fmt"
"strings"
)
type Error struct {
Message string `json:"message"`
Code int `json:"error"`
}
func (err *Error) Error() string {
return fmt.Sprintf("%s (%d)", err.Message, err.Code)
}
type Message struct { type Message struct {
ID string `json:"id"` ID string `json:"id"`
Kind string `json:"kind"` Kind string `json:"kind"`
Author string `json:"author"` Type string `json:"type"`
Subject string `json:"subject"` Author string `json:"author"`
Body string `json:"body"` Subject string `json:"subject"`
CreatedAt float64 `json:"created_utc"` Body string `json:"body"`
CreatedAt float64 `json:"created_utc"`
Context string `json:"context"`
ParentID string `json:"parent_id"`
LinkTitle string `json:"link_title"`
Destination string `json:"dest"`
Subreddit string `json:"subreddit"`
} }
type MessageData struct { type MessageData struct {
@ -32,3 +50,12 @@ type RefreshTokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
} }
type MeResponse struct {
ID string `json:"id"`
Name string
}
func (mr *MeResponse) NormalizedUsername() string {
return strings.ToLower(mr.Name)
}