mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-22 11:47:42 +00:00
Merge pull request #2 from christianselig/fix/notification-payload
Fix notification payloads
This commit is contained in:
commit
d0cf4e74d0
6 changed files with 194 additions and 32 deletions
|
@ -2,11 +2,11 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/christianselig/apollo-backend/internal/data"
|
"github.com/christianselig/apollo-backend/internal/data"
|
||||||
)
|
)
|
||||||
|
@ -14,32 +14,54 @@ import (
|
||||||
func (app *application) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
func (app *application) upsertAccountHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||||
a := &data.Account{}
|
a := &data.Account{}
|
||||||
if err := json.NewDecoder(r.Body).Decode(a); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(a); err != nil {
|
||||||
fmt.Println("failing on decoding json")
|
app.logger.WithFields(logrus.Fields{
|
||||||
app.errorResponse(w, r, 500, err.Error())
|
"err": err,
|
||||||
|
}).Info("failed to parse request json")
|
||||||
|
app.errorResponse(w, r, 422, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
a.ExpiresAt = time.Now().Unix() + 3300
|
|
||||||
|
|
||||||
// Here we check whether the account is supplied with a valid token.
|
// Here we check whether the account is supplied with a valid token.
|
||||||
ac := app.client.NewAuthenticatedClient(a.RefreshToken, a.AccessToken)
|
ac := app.client.NewAuthenticatedClient(a.RefreshToken, a.AccessToken)
|
||||||
|
tokens, err := ac.RefreshTokens()
|
||||||
|
if err != nil {
|
||||||
|
app.logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Info("failed to refresh token")
|
||||||
|
app.errorResponse(w, r, 422, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset expiration timer
|
||||||
|
a.ExpiresAt = time.Now().Unix() + 3540
|
||||||
|
|
||||||
|
ac = app.client.NewAuthenticatedClient(tokens.RefreshToken, tokens.AccessToken)
|
||||||
me, err := ac.Me()
|
me, err := ac.Me()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("failing on fetching remote user")
|
app.logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Info("failed to grab user details")
|
||||||
app.errorResponse(w, r, 500, err.Error())
|
app.errorResponse(w, r, 500, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if me.NormalizedUsername() != a.NormalizedUsername() {
|
if me.NormalizedUsername() != a.NormalizedUsername() {
|
||||||
fmt.Println("failing on account username comparison")
|
app.logger.WithFields(logrus.Fields{
|
||||||
app.errorResponse(w, r, 500, "nice try")
|
"err": err,
|
||||||
|
}).Info("user is not who they say they are")
|
||||||
|
app.errorResponse(w, r, 422, "nice try")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set account ID from Reddit
|
||||||
|
a.AccountID = me.ID
|
||||||
|
|
||||||
// Upsert account
|
// Upsert account
|
||||||
if err := app.models.Accounts.Upsert(a); err != nil {
|
if err := app.models.Accounts.Upsert(a); err != nil {
|
||||||
fmt.Println("failing on account upsert")
|
app.logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Info("failed updating account in database")
|
||||||
app.errorResponse(w, r, 500, err.Error())
|
app.errorResponse(w, r, 500, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -47,13 +69,17 @@ func (app *application) upsertAccountHandler(w http.ResponseWriter, r *http.Requ
|
||||||
// Associate
|
// Associate
|
||||||
d, err := app.models.Devices.GetByAPNSToken(ps.ByName("apns"))
|
d, err := app.models.Devices.GetByAPNSToken(ps.ByName("apns"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("failing on apns")
|
app.logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Info("failed fetching account devices")
|
||||||
app.errorResponse(w, r, 500, err.Error())
|
app.errorResponse(w, r, 500, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := app.models.DevicesAccounts.Associate(a.ID, d.ID); err != nil {
|
if err := app.models.DevicesAccounts.Associate(a.ID, d.ID); err != nil {
|
||||||
fmt.Println("failing on associate")
|
app.logger.WithFields(logrus.Fields{
|
||||||
|
"err": err,
|
||||||
|
}).Info("failed associating account with device")
|
||||||
app.errorResponse(w, r, 500, err.Error())
|
app.errorResponse(w, r, 500, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/DataDog/datadog-go/statsd"
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/christianselig/apollo-backend/internal/data"
|
"github.com/christianselig/apollo-backend/internal/data"
|
||||||
"github.com/christianselig/apollo-backend/internal/reddit"
|
"github.com/christianselig/apollo-backend/internal/reddit"
|
||||||
|
@ -21,14 +22,25 @@ type config struct {
|
||||||
|
|
||||||
type application struct {
|
type application struct {
|
||||||
cfg config
|
cfg config
|
||||||
logger *log.Logger
|
logger *logrus.Logger
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
models *data.Models
|
models *data.Models
|
||||||
client *reddit.Client
|
client *reddit.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
logger := log.New(os.Stdout, "", log.Ldate|log.Ltime)
|
var logger *logrus.Logger
|
||||||
|
{
|
||||||
|
logger = logrus.New()
|
||||||
|
if os.Getenv("ENV") == "" {
|
||||||
|
logger.SetLevel(logrus.DebugLevel)
|
||||||
|
} else {
|
||||||
|
logger.SetFormatter(&logrus.TextFormatter{
|
||||||
|
DisableColors: true,
|
||||||
|
FullTimestamp: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := godotenv.Load(); err != nil {
|
if err := godotenv.Load(); err != nil {
|
||||||
logger.Printf("Couldn't find .env so I will read from existing ENV.")
|
logger.Printf("Couldn't find .env so I will read from existing ENV.")
|
||||||
|
|
|
@ -222,6 +222,7 @@ func (c *Consumer) Consume(delivery rmq.Delivery) {
|
||||||
stmt := `SELECT
|
stmt := `SELECT
|
||||||
id,
|
id,
|
||||||
username,
|
username,
|
||||||
|
account_id,
|
||||||
access_token,
|
access_token,
|
||||||
refresh_token,
|
refresh_token,
|
||||||
expires_at,
|
expires_at,
|
||||||
|
@ -233,6 +234,7 @@ func (c *Consumer) Consume(delivery rmq.Delivery) {
|
||||||
if err := c.pool.QueryRow(ctx, stmt, id).Scan(
|
if err := c.pool.QueryRow(ctx, stmt, id).Scan(
|
||||||
&account.ID,
|
&account.ID,
|
||||||
&account.Username,
|
&account.Username,
|
||||||
|
&account.AccountID,
|
||||||
&account.AccessToken,
|
&account.AccessToken,
|
||||||
&account.RefreshToken,
|
&account.RefreshToken,
|
||||||
&account.ExpiresAt,
|
&account.ExpiresAt,
|
||||||
|
@ -370,7 +372,7 @@ func (c *Consumer) Consume(delivery rmq.Delivery) {
|
||||||
for _, msg := range msgs.MessageListing.Messages {
|
for _, msg := range msgs.MessageListing.Messages {
|
||||||
notification := &apns2.Notification{}
|
notification := &apns2.Notification{}
|
||||||
notification.Topic = "com.christianselig.Apollo"
|
notification.Topic = "com.christianselig.Apollo"
|
||||||
notification.Payload = payload.NewPayload().AlertTitle(msg.Subject).AlertBody(msg.Body)
|
notification.Payload = payloadFromMessage(account, &msg, len(msgs.MessageListing.Messages))
|
||||||
|
|
||||||
for _, device := range devices {
|
for _, device := range devices {
|
||||||
notification.DeviceToken = device.APNSToken
|
notification.DeviceToken = device.APNSToken
|
||||||
|
@ -404,6 +406,83 @@ func (c *Consumer) Consume(delivery rmq.Delivery) {
|
||||||
}).Debug("finishing job")
|
}).Debug("finishing job")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func payloadFromMessage(acct *data.Account, msg *reddit.MessageData, badgeCount int) *payload.Payload {
|
||||||
|
postBody := msg.Body
|
||||||
|
if len(postBody) > 2000 {
|
||||||
|
postBody = msg.Body[:2000]
|
||||||
|
}
|
||||||
|
|
||||||
|
postTitle := msg.LinkTitle
|
||||||
|
if postTitle == "" {
|
||||||
|
postTitle = msg.Subject
|
||||||
|
}
|
||||||
|
if len(postTitle) > 75 {
|
||||||
|
postTitle = fmt.Sprintf("%s…", postTitle[0:75])
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := payload.
|
||||||
|
NewPayload().
|
||||||
|
AlertBody(postBody).
|
||||||
|
AlertSummaryArg(msg.Author).
|
||||||
|
Badge(badgeCount).
|
||||||
|
Custom("account_id", acct.AccountID).
|
||||||
|
Custom("author", msg.Author).
|
||||||
|
Custom("destination_author", msg.Destination).
|
||||||
|
Custom("parent_id", msg.ParentID).
|
||||||
|
Custom("post_title", msg.LinkTitle).
|
||||||
|
Custom("subreddit", msg.Subreddit).
|
||||||
|
MutableContent().
|
||||||
|
Sound("traloop.wav")
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case (msg.Kind == "t1" && msg.Type == "username_mention"):
|
||||||
|
title := fmt.Sprintf(`Mention in “%s”`, postTitle)
|
||||||
|
payload = payload.AlertTitle(title).Custom("type", "username")
|
||||||
|
|
||||||
|
pType, _ := reddit.SplitID(msg.ParentID)
|
||||||
|
if pType == "t1" {
|
||||||
|
payload = payload.Category("inbox-username-mention-context")
|
||||||
|
} else {
|
||||||
|
payload = payload.Category("inbox-username-mention-no-context")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = payload.Custom("subject", "comment").ThreadID("comment")
|
||||||
|
break
|
||||||
|
case (msg.Kind == "t1" && msg.Type == "post_reply"):
|
||||||
|
title := fmt.Sprintf(`%s to “%s”`, msg.Author, postTitle)
|
||||||
|
payload = payload.
|
||||||
|
AlertTitle(title).
|
||||||
|
Category("inbox-post-reply").
|
||||||
|
Custom("post_id", msg.ID).
|
||||||
|
Custom("subject", "comment").
|
||||||
|
Custom("type", "post").
|
||||||
|
ThreadID("comment")
|
||||||
|
break
|
||||||
|
case (msg.Kind == "t1" && msg.Type == "comment_reply"):
|
||||||
|
title := fmt.Sprintf(`%s in “%s”`, msg.Author, postTitle)
|
||||||
|
_, postID := reddit.SplitID(msg.ParentID)
|
||||||
|
payload = payload.
|
||||||
|
AlertTitle(title).
|
||||||
|
Category("inbox-comment-reply").
|
||||||
|
Custom("comment_id", msg.ID).
|
||||||
|
Custom("post_id", postID).
|
||||||
|
Custom("subject", "comment").
|
||||||
|
Custom("type", "comment").
|
||||||
|
ThreadID("comment")
|
||||||
|
break
|
||||||
|
case (msg.Kind == "t4"):
|
||||||
|
title := fmt.Sprintf(`Message from %s`, msg.Author)
|
||||||
|
payload = payload.
|
||||||
|
AlertTitle(title).
|
||||||
|
AlertSubtitle(postTitle).
|
||||||
|
Category("inbox-private-message").
|
||||||
|
Custom("type", "private-message")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
func logErrors(errChan <-chan error) {
|
func logErrors(errChan <-chan error) {
|
||||||
for err := range errChan {
|
for err := range errChan {
|
||||||
log.Print("error: ", err)
|
log.Print("error: ", err)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
type Account struct {
|
type Account struct {
|
||||||
ID int64
|
ID int64
|
||||||
Username string
|
Username string
|
||||||
|
AccountID string
|
||||||
AccessToken string
|
AccessToken string
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
ExpiresAt int64
|
ExpiresAt int64
|
||||||
|
@ -25,14 +26,14 @@ type AccountModel struct {
|
||||||
|
|
||||||
func (am *AccountModel) Upsert(a *Account) error {
|
func (am *AccountModel) Upsert(a *Account) error {
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO accounts (username, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
|
INSERT INTO accounts (username, account_id, access_token, refresh_token, expires_at, last_message_id, device_count, last_checked_at)
|
||||||
VALUES ($1, $2, $3, $4, '', 0, 0)
|
VALUES ($1, $2, $3, $4, '', 0, 0)
|
||||||
ON CONFLICT(username)
|
ON CONFLICT(username)
|
||||||
DO
|
DO
|
||||||
UPDATE SET access_token = $2, refresh_token = $3, expires_at = $4, last_message_id = $5, last_checked_at = $6
|
UPDATE SET access_token = $2, refresh_token = $3, expires_at = $4, last_message_id = $5, last_checked_at = $6
|
||||||
RETURNING id`
|
RETURNING id`
|
||||||
|
|
||||||
args := []interface{}{a.NormalizedUsername(), a.AccessToken, a.RefreshToken, a.ExpiresAt, a.LastMessageID, a.LastCheckedAt}
|
args := []interface{}{a.NormalizedUsername(), a.AccountID, a.AccessToken, a.RefreshToken, a.ExpiresAt, a.LastMessageID, a.LastCheckedAt}
|
||||||
return am.DB.QueryRow(query, args...).Scan(&a.ID)
|
return am.DB.QueryRow(query, args...).Scan(&a.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package reddit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
|
@ -25,6 +26,14 @@ type Client struct {
|
||||||
statsd *statsd.Client
|
statsd *statsd.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SplitID(id string) (string, string) {
|
||||||
|
if parts := strings.Split(id, "_"); len(parts) == 2 {
|
||||||
|
return parts[0], parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
func NewClient(id, secret string, statsd *statsd.Client) *Client {
|
func NewClient(id, secret string, statsd *statsd.Client) *Client {
|
||||||
tracer := &httptrace.ClientTrace{
|
tracer := &httptrace.ClientTrace{
|
||||||
GotConn: func(info httptrace.GotConnInfo) {
|
GotConn: func(info httptrace.GotConnInfo) {
|
||||||
|
@ -85,7 +94,23 @@ func (rac *AuthenticatedClient) request(r *Request) ([]byte, error) {
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
return ioutil.ReadAll(resp.Body)
|
bb, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
|
|
||||||
|
// Try to parse a json error. Otherwise we generate a generic one
|
||||||
|
rerr := &Error{}
|
||||||
|
if jerr := json.Unmarshal(bb, rerr); jerr != nil {
|
||||||
|
return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, rerr
|
||||||
|
}
|
||||||
|
return bb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
||||||
|
@ -129,14 +154,6 @@ func (rac *AuthenticatedClient) MessageInbox(from string) (*MessageListingRespon
|
||||||
return mlr, nil
|
return mlr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MeResponse struct {
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mr *MeResponse) NormalizedUsername() string {
|
|
||||||
return strings.ToLower(mr.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
|
func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
|
||||||
req := NewRequest(
|
req := NewRequest(
|
||||||
WithTags([]string{"url:/api/v1/me"}),
|
WithTags([]string{"url:/api/v1/me"}),
|
||||||
|
|
|
@ -1,14 +1,32 @@
|
||||||
package reddit
|
package reddit
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Code int `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err *Error) Error() string {
|
||||||
|
return fmt.Sprintf("%s (%d)", err.Message, err.Code)
|
||||||
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Kind string `json:"kind"`
|
Kind string `json:"kind"`
|
||||||
|
Type string `json:"type"`
|
||||||
Author string `json:"author"`
|
Author string `json:"author"`
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Body string `json:"body"`
|
Body string `json:"body"`
|
||||||
CreatedAt float64 `json:"created_utc"`
|
CreatedAt float64 `json:"created_utc"`
|
||||||
|
Context string `json:"context"`
|
||||||
|
ParentID string `json:"parent_id"`
|
||||||
|
LinkTitle string `json:"link_title"`
|
||||||
|
Destination string `json:"dest"`
|
||||||
|
Subreddit string `json:"subreddit"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MessageData struct {
|
type MessageData struct {
|
||||||
|
@ -32,3 +50,12 @@ type RefreshTokenResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MeResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *MeResponse) NormalizedUsername() string {
|
||||||
|
return strings.ToLower(mr.Name)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue