update notifier logic

This commit is contained in:
Andre Medeiros 2021-07-15 10:51:34 -04:00
parent ad929b98d9
commit ec78363252
7 changed files with 106 additions and 75 deletions

View file

@ -61,6 +61,7 @@ func WorkerCmd(ctx context.Context) *cobra.Command {
} }
consumers := runtime.NumCPU() * multiplier consumers := runtime.NumCPU() * multiplier
//consumers = 1
worker := workerFn(logger, statsd, db, redis, queue, consumers) worker := workerFn(logger, statsd, db, redis, queue, consumers)
worker.Start() worker.Start()

View file

@ -1,7 +1,6 @@
package reddit package reddit
import ( import (
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -23,7 +22,7 @@ type Client struct {
secret string secret string
client *http.Client client *http.Client
tracer *httptrace.ClientTrace tracer *httptrace.ClientTrace
parser *fastjson.Parser pool *fastjson.ParserPool
statsd *statsd.Client statsd *statsd.Client
} }
@ -74,14 +73,14 @@ func NewClient(id, secret string, statsd *statsd.Client, connLimit int) *Client
client := &http.Client{Transport: t} client := &http.Client{Transport: t}
parser := &fastjson.Parser{} pool := &fastjson.ParserPool{}
return &Client{ return &Client{
id, id,
secret, secret,
client, client,
tracer, tracer,
parser, pool,
statsd, statsd,
} }
} }
@ -98,7 +97,7 @@ func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *Auth
return &AuthenticatedClient{rc, refreshToken, accessToken, nil} return &AuthenticatedClient{rc, refreshToken, accessToken, nil}
} }
func (rac *AuthenticatedClient) request(r *Request) ([]byte, error) { func (rac *AuthenticatedClient) request(r *Request) (*fastjson.Value, error) {
req, err := r.HTTPRequest() req, err := r.HTTPRequest()
if err != nil { if err != nil {
return nil, err return nil, err
@ -122,16 +121,21 @@ func (rac *AuthenticatedClient) request(r *Request) ([]byte, error) {
rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
return nil, err return nil, err
} }
parser := rac.pool.Get()
defer rac.pool.Put(parser)
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
// Try to parse a json error. Otherwise we generate a generic one // Try to parse a json error. Otherwise we generate a generic one
rerr := &Error{} val, jerr := parser.ParseBytes(bb)
if jerr := json.Unmarshal(bb, rerr); jerr != nil { if jerr != nil {
return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode) return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode)
} }
return nil, rerr return nil, NewError(val)
} }
return bb, nil return parser.ParseBytes(bb)
} }
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
@ -144,33 +148,24 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
WithBasicAuth(rac.id, rac.secret), WithBasicAuth(rac.id, rac.secret),
) )
body, err := rac.request(req) val, err := rac.request(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rtr := &RefreshTokenResponse{} return NewRefreshTokenResponse(val), nil
json.Unmarshal([]byte(body), rtr)
return rtr, nil
} }
func (rac *AuthenticatedClient) MessageInbox(from string) (*ListingResponse, error) { func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) {
req := NewRequest( opts = append([]RequestOption{
WithTags([]string{"url:/api/v1/message/inbox"}), WithTags([]string{"url:/api/v1/message/inbox"}),
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.accessToken), WithToken(rac.accessToken),
WithURL("https://oauth.reddit.com/message/inbox.json"), WithURL("https://oauth.reddit.com/message/inbox.json"),
WithQuery("before", from), }, opts...)
) req := NewRequest(opts...)
body, err := rac.request(req) val, err := rac.request(req)
if err != nil {
return nil, err
}
val, err := rac.parser.ParseBytes(body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -178,24 +173,22 @@ func (rac *AuthenticatedClient) MessageInbox(from string) (*ListingResponse, err
return NewListingResponse(val), nil return NewListingResponse(val), nil
} }
func (rac *AuthenticatedClient) MessageUnread(from string) (*MessageListingResponse, error) { func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingResponse, error) {
req := NewRequest( opts = append([]RequestOption{
WithTags([]string{"url:/api/v1/message/unread"}), WithTags([]string{"url:/api/v1/message/unread"}),
WithMethod("GET"), WithMethod("GET"),
WithToken(rac.accessToken), WithToken(rac.accessToken),
WithURL("https://oauth.reddit.com/message/unread.json"), WithURL("https://oauth.reddit.com/message/unread.json"),
WithQuery("before", from), }, opts...)
)
body, err := rac.request(req) req := NewRequest(opts...)
val, err := rac.request(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mlr := &MessageListingResponse{} return NewListingResponse(val), nil
json.Unmarshal([]byte(body), mlr)
return mlr, nil
} }
func (rac *AuthenticatedClient) Me() (*MeResponse, error) { func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
@ -206,14 +199,10 @@ func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
WithURL("https://oauth.reddit.com/api/v1/me"), WithURL("https://oauth.reddit.com/api/v1/me"),
) )
body, err := rac.request(req) val, err := rac.request(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mr := &MeResponse{} return NewMeResponse(val), nil
err = json.Unmarshal(body, mr)
return mr, err
} }

4
internal/reddit/testdata/error.json vendored Normal file
View file

@ -0,0 +1,4 @@
{
"message": "Unauthorized",
"error": 401
}

View file

@ -0,0 +1,7 @@
{
"access_token": "***REMOVED***",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "***REMOVED***",
"scope": "account creddits edit flair history identity livemanage modconfig modcontributors modflair modlog modmail modothers modposts modself modtraffic modwiki mysubreddits privatemessages read report save structuredstyles submit subscribe vote wikiedit wikiread"
}

View file

@ -16,36 +16,13 @@ func (err *Error) Error() string {
return fmt.Sprintf("%s (%d)", err.Message, err.Code) return fmt.Sprintf("%s (%d)", err.Message, err.Code)
} }
type Message struct { func NewError(val *fastjson.Value) *Error {
ID string `json:"id"` err := &Error{}
Kind string `json:"kind"`
Type string `json:"type"`
Author string `json:"author"`
Subject string `json:"subject"`
Body string `json:"body"`
CreatedAt float64 `json:"created_utc"`
Context string `json:"context"`
ParentID string `json:"parent_id"`
LinkTitle string `json:"link_title"`
Destination string `json:"dest"`
Subreddit string `json:"subreddit"`
}
type MessageData struct { err.Message = string(val.GetStringBytes("message"))
Message `json:"data"` err.Code = val.GetInt("error")
Kind string `json:"kind"`
}
func (md MessageData) FullName() string { return err
return fmt.Sprintf("%s_%s", md.Kind, md.ID)
}
type MessageListing struct {
Messages []MessageData `json:"children"`
}
type MessageListingResponse struct {
MessageListing MessageListing `json:"data"`
} }
type RefreshTokenResponse struct { type RefreshTokenResponse struct {
@ -53,6 +30,15 @@ type RefreshTokenResponse struct {
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
} }
func NewRefreshTokenResponse(val *fastjson.Value) *RefreshTokenResponse {
rtr := &RefreshTokenResponse{}
rtr.AccessToken = string(val.GetStringBytes("access_token"))
rtr.RefreshToken = string(val.GetStringBytes("refresh_token"))
return rtr
}
type MeResponse struct { type MeResponse struct {
ID string `json:"id"` ID string `json:"id"`
Name string Name string
@ -86,6 +72,10 @@ type Thing struct {
Subreddit string `json:"subreddit"` Subreddit string `json:"subreddit"`
} }
func (t *Thing) FullName() string {
return fmt.Sprintf("%s_%s", t.Kind, t.ID)
}
func NewThing(val *fastjson.Value) *Thing { func NewThing(val *fastjson.Value) *Thing {
t := &Thing{} t := &Thing{}
@ -122,8 +112,12 @@ func NewListingResponse(val *fastjson.Value) *ListingResponse {
lr.After = string(data.GetStringBytes("after")) lr.After = string(data.GetStringBytes("after"))
lr.Before = string(data.GetStringBytes("before")) lr.Before = string(data.GetStringBytes("before"))
lr.Count = data.GetInt("dist") lr.Count = data.GetInt("dist")
lr.Children = make([]*Thing, lr.Count)
if lr.Count == 0 {
return lr
}
lr.Children = make([]*Thing, lr.Count)
children := data.GetArray("children") children := data.GetArray("children")
for i := 0; i < lr.Count; i++ { for i := 0; i < lr.Count; i++ {
t := NewThing(children[i]) t := NewThing(children[i])

View file

@ -26,6 +26,20 @@ func TestMeResponseParsing(t *testing.T) {
assert.Equal(t, "hugocat", me.Name) assert.Equal(t, "hugocat", me.Name)
} }
func TestRefreshTokenResponseParsing(t *testing.T) {
bb, err := ioutil.ReadFile("testdata/refresh_token.json")
assert.NoError(t, err)
val, err := parser.ParseBytes(bb)
assert.NoError(t, err)
rtr := NewRefreshTokenResponse(val)
assert.NotNil(t, rtr)
assert.Equal(t, "***REMOVED***", rtr.AccessToken)
assert.Equal(t, "***REMOVED***", rtr.RefreshToken)
}
func TestListingResponseParsing(t *testing.T) { func TestListingResponseParsing(t *testing.T) {
bb, err := ioutil.ReadFile("testdata/message_inbox.json") bb, err := ioutil.ReadFile("testdata/message_inbox.json")
assert.NoError(t, err) assert.NoError(t, err)

View file

@ -248,7 +248,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "accountID": id,
}).Debug("fetching message inbox") }).Debug("fetching message inbox")
msgs, err := rac.MessageInbox(account.LastMessageID) msgs, err := rac.MessageInbox(reddit.WithQuery("limit", "10"))
if err != nil { if err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
@ -258,12 +258,32 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
return return
} }
// Figure out where we stand
if msgs.Count == 0 || msgs.Children[0].FullName() == account.LastMessageID {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "accountID": id,
"count": len(msgs.MessageListing.Messages), }).Debug("no new messages, bailing early")
return
}
// Find which one is the oldest we haven't notified on
oldest := 0
for i, t := range msgs.Children {
if t.FullName() == account.LastMessageID {
break
}
oldest = i
}
tt := msgs.Children[:oldest]
nc.logger.WithFields(logrus.Fields{
"accountID": id,
"count": len(tt),
}).Debug("fetched messages") }).Debug("fetched messages")
if len(msgs.MessageListing.Messages) == 0 { if len(tt) == 0 {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{
"accountID": id, "accountID": id,
}).Debug("no new messages, bailing early") }).Debug("no new messages, bailing early")
@ -271,7 +291,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
} }
// Set latest message we alerted on // Set latest message we alerted on
latestMsg := msgs.MessageListing.Messages[0] latestMsg := tt[0]
if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error { if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := ` stmt := `
UPDATE accounts UPDATE accounts
@ -316,10 +336,12 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
devices = append(devices, device) devices = append(devices, device)
} }
for _, msg := range msgs.MessageListing.Messages { // Iterate backwards so we notify from older to newer
for i := len(tt) - 1; i >= 0; i-- {
msg := tt[i]
notification := &apns2.Notification{} notification := &apns2.Notification{}
notification.Topic = "com.christianselig.Apollo" notification.Topic = "com.christianselig.Apollo"
notification.Payload = payloadFromMessage(account, &msg, len(msgs.MessageListing.Messages)) notification.Payload = payloadFromMessage(account, msg, len(tt))
for _, device := range devices { for _, device := range devices {
notification.DeviceToken = device.APNSToken notification.DeviceToken = device.APNSToken
@ -353,7 +375,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
}).Debug("finishing job") }).Debug("finishing job")
} }
func payloadFromMessage(acct *data.Account, msg *reddit.MessageData, badgeCount int) *payload.Payload { func payloadFromMessage(acct *data.Account, msg *reddit.Thing, badgeCount int) *payload.Payload {
postBody := msg.Body postBody := msg.Body
if len(postBody) > 2000 { if len(postBody) > 2000 {
postBody = msg.Body[:2000] postBody = msg.Body[:2000]