mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-22 03:37:43 +00:00
update notifier logic
This commit is contained in:
parent
ad929b98d9
commit
ec78363252
7 changed files with 106 additions and 75 deletions
|
@ -61,6 +61,7 @@ func WorkerCmd(ctx context.Context) *cobra.Command {
|
|||
}
|
||||
|
||||
consumers := runtime.NumCPU() * multiplier
|
||||
//consumers = 1
|
||||
|
||||
worker := workerFn(logger, statsd, db, redis, queue, consumers)
|
||||
worker.Start()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package reddit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -23,7 +22,7 @@ type Client struct {
|
|||
secret string
|
||||
client *http.Client
|
||||
tracer *httptrace.ClientTrace
|
||||
parser *fastjson.Parser
|
||||
pool *fastjson.ParserPool
|
||||
statsd *statsd.Client
|
||||
}
|
||||
|
||||
|
@ -74,14 +73,14 @@ func NewClient(id, secret string, statsd *statsd.Client, connLimit int) *Client
|
|||
|
||||
client := &http.Client{Transport: t}
|
||||
|
||||
parser := &fastjson.Parser{}
|
||||
pool := &fastjson.ParserPool{}
|
||||
|
||||
return &Client{
|
||||
id,
|
||||
secret,
|
||||
client,
|
||||
tracer,
|
||||
parser,
|
||||
pool,
|
||||
statsd,
|
||||
}
|
||||
}
|
||||
|
@ -98,7 +97,7 @@ func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *Auth
|
|||
return &AuthenticatedClient{rc, refreshToken, accessToken, nil}
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) request(r *Request) ([]byte, error) {
|
||||
func (rac *AuthenticatedClient) request(r *Request) (*fastjson.Value, error) {
|
||||
req, err := r.HTTPRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -122,16 +121,21 @@ func (rac *AuthenticatedClient) request(r *Request) ([]byte, error) {
|
|||
rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parser := rac.pool.Get()
|
||||
defer rac.pool.Put(parser)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
|
||||
// Try to parse a json error. Otherwise we generate a generic one
|
||||
rerr := &Error{}
|
||||
if jerr := json.Unmarshal(bb, rerr); jerr != nil {
|
||||
val, jerr := parser.ParseBytes(bb)
|
||||
if jerr != nil {
|
||||
return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode)
|
||||
}
|
||||
return nil, rerr
|
||||
return nil, NewError(val)
|
||||
}
|
||||
return bb, nil
|
||||
return parser.ParseBytes(bb)
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
||||
|
@ -144,33 +148,24 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
|
|||
WithBasicAuth(rac.id, rac.secret),
|
||||
)
|
||||
|
||||
body, err := rac.request(req)
|
||||
|
||||
val, err := rac.request(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rtr := &RefreshTokenResponse{}
|
||||
json.Unmarshal([]byte(body), rtr)
|
||||
return rtr, nil
|
||||
return NewRefreshTokenResponse(val), nil
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) MessageInbox(from string) (*ListingResponse, error) {
|
||||
req := NewRequest(
|
||||
func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) {
|
||||
opts = append([]RequestOption{
|
||||
WithTags([]string{"url:/api/v1/message/inbox"}),
|
||||
WithMethod("GET"),
|
||||
WithToken(rac.accessToken),
|
||||
WithURL("https://oauth.reddit.com/message/inbox.json"),
|
||||
WithQuery("before", from),
|
||||
)
|
||||
}, opts...)
|
||||
req := NewRequest(opts...)
|
||||
|
||||
body, err := rac.request(req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
val, err := rac.parser.ParseBytes(body)
|
||||
val, err := rac.request(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -178,24 +173,22 @@ func (rac *AuthenticatedClient) MessageInbox(from string) (*ListingResponse, err
|
|||
return NewListingResponse(val), nil
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) MessageUnread(from string) (*MessageListingResponse, error) {
|
||||
req := NewRequest(
|
||||
func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingResponse, error) {
|
||||
opts = append([]RequestOption{
|
||||
WithTags([]string{"url:/api/v1/message/unread"}),
|
||||
WithMethod("GET"),
|
||||
WithToken(rac.accessToken),
|
||||
WithURL("https://oauth.reddit.com/message/unread.json"),
|
||||
WithQuery("before", from),
|
||||
)
|
||||
}, opts...)
|
||||
|
||||
body, err := rac.request(req)
|
||||
req := NewRequest(opts...)
|
||||
|
||||
val, err := rac.request(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mlr := &MessageListingResponse{}
|
||||
json.Unmarshal([]byte(body), mlr)
|
||||
return mlr, nil
|
||||
return NewListingResponse(val), nil
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
|
||||
|
@ -206,14 +199,10 @@ func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
|
|||
WithURL("https://oauth.reddit.com/api/v1/me"),
|
||||
)
|
||||
|
||||
body, err := rac.request(req)
|
||||
|
||||
val, err := rac.request(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mr := &MeResponse{}
|
||||
err = json.Unmarshal(body, mr)
|
||||
|
||||
return mr, err
|
||||
return NewMeResponse(val), nil
|
||||
}
|
||||
|
|
4
internal/reddit/testdata/error.json
vendored
Normal file
4
internal/reddit/testdata/error.json
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"message": "Unauthorized",
|
||||
"error": 401
|
||||
}
|
7
internal/reddit/testdata/refresh_token.json
vendored
Normal file
7
internal/reddit/testdata/refresh_token.json
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"access_token": "***REMOVED***",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "***REMOVED***",
|
||||
"scope": "account creddits edit flair history identity livemanage modconfig modcontributors modflair modlog modmail modothers modposts modself modtraffic modwiki mysubreddits privatemessages read report save structuredstyles submit subscribe vote wikiedit wikiread"
|
||||
}
|
|
@ -16,36 +16,13 @@ func (err *Error) Error() string {
|
|||
return fmt.Sprintf("%s (%d)", err.Message, err.Code)
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Kind string `json:"kind"`
|
||||
Type string `json:"type"`
|
||||
Author string `json:"author"`
|
||||
Subject string `json:"subject"`
|
||||
Body string `json:"body"`
|
||||
CreatedAt float64 `json:"created_utc"`
|
||||
Context string `json:"context"`
|
||||
ParentID string `json:"parent_id"`
|
||||
LinkTitle string `json:"link_title"`
|
||||
Destination string `json:"dest"`
|
||||
Subreddit string `json:"subreddit"`
|
||||
}
|
||||
func NewError(val *fastjson.Value) *Error {
|
||||
err := &Error{}
|
||||
|
||||
type MessageData struct {
|
||||
Message `json:"data"`
|
||||
Kind string `json:"kind"`
|
||||
}
|
||||
err.Message = string(val.GetStringBytes("message"))
|
||||
err.Code = val.GetInt("error")
|
||||
|
||||
func (md MessageData) FullName() string {
|
||||
return fmt.Sprintf("%s_%s", md.Kind, md.ID)
|
||||
}
|
||||
|
||||
type MessageListing struct {
|
||||
Messages []MessageData `json:"children"`
|
||||
}
|
||||
|
||||
type MessageListingResponse struct {
|
||||
MessageListing MessageListing `json:"data"`
|
||||
return err
|
||||
}
|
||||
|
||||
type RefreshTokenResponse struct {
|
||||
|
@ -53,6 +30,15 @@ type RefreshTokenResponse struct {
|
|||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
func NewRefreshTokenResponse(val *fastjson.Value) *RefreshTokenResponse {
|
||||
rtr := &RefreshTokenResponse{}
|
||||
|
||||
rtr.AccessToken = string(val.GetStringBytes("access_token"))
|
||||
rtr.RefreshToken = string(val.GetStringBytes("refresh_token"))
|
||||
|
||||
return rtr
|
||||
}
|
||||
|
||||
type MeResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string
|
||||
|
@ -86,6 +72,10 @@ type Thing struct {
|
|||
Subreddit string `json:"subreddit"`
|
||||
}
|
||||
|
||||
func (t *Thing) FullName() string {
|
||||
return fmt.Sprintf("%s_%s", t.Kind, t.ID)
|
||||
}
|
||||
|
||||
func NewThing(val *fastjson.Value) *Thing {
|
||||
t := &Thing{}
|
||||
|
||||
|
@ -122,8 +112,12 @@ func NewListingResponse(val *fastjson.Value) *ListingResponse {
|
|||
lr.After = string(data.GetStringBytes("after"))
|
||||
lr.Before = string(data.GetStringBytes("before"))
|
||||
lr.Count = data.GetInt("dist")
|
||||
lr.Children = make([]*Thing, lr.Count)
|
||||
|
||||
if lr.Count == 0 {
|
||||
return lr
|
||||
}
|
||||
|
||||
lr.Children = make([]*Thing, lr.Count)
|
||||
children := data.GetArray("children")
|
||||
for i := 0; i < lr.Count; i++ {
|
||||
t := NewThing(children[i])
|
||||
|
|
|
@ -26,6 +26,20 @@ func TestMeResponseParsing(t *testing.T) {
|
|||
assert.Equal(t, "hugocat", me.Name)
|
||||
}
|
||||
|
||||
func TestRefreshTokenResponseParsing(t *testing.T) {
|
||||
bb, err := ioutil.ReadFile("testdata/refresh_token.json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
val, err := parser.ParseBytes(bb)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rtr := NewRefreshTokenResponse(val)
|
||||
assert.NotNil(t, rtr)
|
||||
|
||||
assert.Equal(t, "***REMOVED***", rtr.AccessToken)
|
||||
assert.Equal(t, "***REMOVED***", rtr.RefreshToken)
|
||||
}
|
||||
|
||||
func TestListingResponseParsing(t *testing.T) {
|
||||
bb, err := ioutil.ReadFile("testdata/message_inbox.json")
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -248,7 +248,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
nc.logger.WithFields(logrus.Fields{
|
||||
"accountID": id,
|
||||
}).Debug("fetching message inbox")
|
||||
msgs, err := rac.MessageInbox(account.LastMessageID)
|
||||
msgs, err := rac.MessageInbox(reddit.WithQuery("limit", "10"))
|
||||
|
||||
if err != nil {
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
|
@ -258,12 +258,32 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
return
|
||||
}
|
||||
|
||||
// Figure out where we stand
|
||||
if msgs.Count == 0 || msgs.Children[0].FullName() == account.LastMessageID {
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
"accountID": id,
|
||||
"count": len(msgs.MessageListing.Messages),
|
||||
}).Debug("no new messages, bailing early")
|
||||
return
|
||||
}
|
||||
|
||||
// Find which one is the oldest we haven't notified on
|
||||
oldest := 0
|
||||
for i, t := range msgs.Children {
|
||||
if t.FullName() == account.LastMessageID {
|
||||
break
|
||||
}
|
||||
|
||||
oldest = i
|
||||
}
|
||||
|
||||
tt := msgs.Children[:oldest]
|
||||
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
"accountID": id,
|
||||
"count": len(tt),
|
||||
}).Debug("fetched messages")
|
||||
|
||||
if len(msgs.MessageListing.Messages) == 0 {
|
||||
if len(tt) == 0 {
|
||||
nc.logger.WithFields(logrus.Fields{
|
||||
"accountID": id,
|
||||
}).Debug("no new messages, bailing early")
|
||||
|
@ -271,7 +291,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
}
|
||||
|
||||
// Set latest message we alerted on
|
||||
latestMsg := msgs.MessageListing.Messages[0]
|
||||
latestMsg := tt[0]
|
||||
if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
|
||||
stmt := `
|
||||
UPDATE accounts
|
||||
|
@ -316,10 +336,12 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
devices = append(devices, device)
|
||||
}
|
||||
|
||||
for _, msg := range msgs.MessageListing.Messages {
|
||||
// Iterate backwards so we notify from older to newer
|
||||
for i := len(tt) - 1; i >= 0; i-- {
|
||||
msg := tt[i]
|
||||
notification := &apns2.Notification{}
|
||||
notification.Topic = "com.christianselig.Apollo"
|
||||
notification.Payload = payloadFromMessage(account, &msg, len(msgs.MessageListing.Messages))
|
||||
notification.Payload = payloadFromMessage(account, msg, len(tt))
|
||||
|
||||
for _, device := range devices {
|
||||
notification.DeviceToken = device.APNSToken
|
||||
|
@ -353,7 +375,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
|
|||
}).Debug("finishing job")
|
||||
}
|
||||
|
||||
func payloadFromMessage(acct *data.Account, msg *reddit.MessageData, badgeCount int) *payload.Payload {
|
||||
func payloadFromMessage(acct *data.Account, msg *reddit.Thing, badgeCount int) *payload.Payload {
|
||||
postBody := msg.Body
|
||||
if len(postBody) > 2000 {
|
||||
postBody = msg.Body[:2000]
|
||||
|
|
Loading…
Reference in a new issue