From ec783632528735b71e6ae3898bbb962dc34215a6 Mon Sep 17 00:00:00 2001 From: Andre Medeiros Date: Thu, 15 Jul 2021 10:51:34 -0400 Subject: [PATCH] update notifier logic --- internal/cmd/worker.go | 1 + internal/reddit/client.go | 67 +++++++++------------ internal/reddit/testdata/error.json | 4 ++ internal/reddit/testdata/refresh_token.json | 7 +++ internal/reddit/types.go | 52 +++++++--------- internal/reddit/types_test.go | 14 +++++ internal/worker/notifications.go | 36 ++++++++--- 7 files changed, 106 insertions(+), 75 deletions(-) create mode 100644 internal/reddit/testdata/error.json create mode 100644 internal/reddit/testdata/refresh_token.json diff --git a/internal/cmd/worker.go b/internal/cmd/worker.go index fc13d5a..9323009 100644 --- a/internal/cmd/worker.go +++ b/internal/cmd/worker.go @@ -61,6 +61,7 @@ func WorkerCmd(ctx context.Context) *cobra.Command { } consumers := runtime.NumCPU() * multiplier + //consumers = 1 worker := workerFn(logger, statsd, db, redis, queue, consumers) worker.Start() diff --git a/internal/reddit/client.go b/internal/reddit/client.go index 8449973..959e533 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -1,7 +1,6 @@ package reddit import ( - "encoding/json" "fmt" "io/ioutil" "net/http" @@ -23,7 +22,7 @@ type Client struct { secret string client *http.Client tracer *httptrace.ClientTrace - parser *fastjson.Parser + pool *fastjson.ParserPool statsd *statsd.Client } @@ -74,14 +73,14 @@ func NewClient(id, secret string, statsd *statsd.Client, connLimit int) *Client client := &http.Client{Transport: t} - parser := &fastjson.Parser{} + pool := &fastjson.ParserPool{} return &Client{ id, secret, client, tracer, - parser, + pool, statsd, } } @@ -98,7 +97,7 @@ func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *Auth return &AuthenticatedClient{rc, refreshToken, accessToken, nil} } -func (rac *AuthenticatedClient) request(r *Request) ([]byte, error) { +func (rac *AuthenticatedClient) request(r *Request) (*fastjson.Value, error) { req, err := r.HTTPRequest() if err != nil { return nil, err @@ -122,16 +121,21 @@ func (rac *AuthenticatedClient) request(r *Request) ([]byte, error) { rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) return nil, err } + + parser := rac.pool.Get() + defer rac.pool.Put(parser) + if resp.StatusCode != 200 { rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) + // Try to parse a json error. Otherwise we generate a generic one - rerr := &Error{} - if jerr := json.Unmarshal(bb, rerr); jerr != nil { + val, jerr := parser.ParseBytes(bb) + if jerr != nil { return nil, fmt.Errorf("error from reddit: %d", resp.StatusCode) } - return nil, rerr + return nil, NewError(val) } - return bb, nil + return parser.ParseBytes(bb) } func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { @@ -144,33 +148,24 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { WithBasicAuth(rac.id, rac.secret), ) - body, err := rac.request(req) - + val, err := rac.request(req) if err != nil { return nil, err } - rtr := &RefreshTokenResponse{} - json.Unmarshal([]byte(body), rtr) - return rtr, nil + return NewRefreshTokenResponse(val), nil } -func (rac *AuthenticatedClient) MessageInbox(from string) (*ListingResponse, error) { - req := NewRequest( +func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) { + opts = append([]RequestOption{ WithTags([]string{"url:/api/v1/message/inbox"}), WithMethod("GET"), WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/message/inbox.json"), - WithQuery("before", from), - ) + }, opts...) + req := NewRequest(opts...) - body, err := rac.request(req) - - if err != nil { - return nil, err - } - - val, err := rac.parser.ParseBytes(body) + val, err := rac.request(req) if err != nil { return nil, err } @@ -178,24 +173,22 @@ func (rac *AuthenticatedClient) MessageInbox(from string) (*ListingResponse, err return NewListingResponse(val), nil } -func (rac *AuthenticatedClient) MessageUnread(from string) (*MessageListingResponse, error) { - req := NewRequest( +func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingResponse, error) { + opts = append([]RequestOption{ WithTags([]string{"url:/api/v1/message/unread"}), WithMethod("GET"), WithToken(rac.accessToken), WithURL("https://oauth.reddit.com/message/unread.json"), - WithQuery("before", from), - ) + }, opts...) - body, err := rac.request(req) + req := NewRequest(opts...) + val, err := rac.request(req) if err != nil { return nil, err } - mlr := &MessageListingResponse{} - json.Unmarshal([]byte(body), mlr) - return mlr, nil + return NewListingResponse(val), nil } func (rac *AuthenticatedClient) Me() (*MeResponse, error) { @@ -206,14 +199,10 @@ func (rac *AuthenticatedClient) Me() (*MeResponse, error) { WithURL("https://oauth.reddit.com/api/v1/me"), ) - body, err := rac.request(req) - + val, err := rac.request(req) if err != nil { return nil, err } - mr := &MeResponse{} - err = json.Unmarshal(body, mr) - - return mr, err + return NewMeResponse(val), nil } diff --git a/internal/reddit/testdata/error.json b/internal/reddit/testdata/error.json new file mode 100644 index 0000000..f8f1305 --- /dev/null +++ b/internal/reddit/testdata/error.json @@ -0,0 +1,4 @@ +{ + "message": "Unauthorized", + "error": 401 +} diff --git a/internal/reddit/testdata/refresh_token.json b/internal/reddit/testdata/refresh_token.json new file mode 100644 index 0000000..d1c7ad3 --- /dev/null +++ b/internal/reddit/testdata/refresh_token.json @@ -0,0 +1,7 @@ +{ + "access_token": "***REMOVED***", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "***REMOVED***", + "scope": "account creddits edit flair history identity livemanage modconfig modcontributors modflair modlog modmail modothers modposts modself modtraffic modwiki mysubreddits privatemessages read report save structuredstyles submit subscribe vote wikiedit wikiread" +} diff --git a/internal/reddit/types.go b/internal/reddit/types.go index 05e0c46..0c10480 100644 --- a/internal/reddit/types.go +++ b/internal/reddit/types.go @@ -16,36 +16,13 @@ func (err *Error) Error() string { return fmt.Sprintf("%s (%d)", err.Message, err.Code) } -type Message struct { - ID string `json:"id"` - Kind string `json:"kind"` - Type string `json:"type"` - Author string `json:"author"` - Subject string `json:"subject"` - Body string `json:"body"` - CreatedAt float64 `json:"created_utc"` - Context string `json:"context"` - ParentID string `json:"parent_id"` - LinkTitle string `json:"link_title"` - Destination string `json:"dest"` - Subreddit string `json:"subreddit"` -} +func NewError(val *fastjson.Value) *Error { + err := &Error{} -type MessageData struct { - Message `json:"data"` - Kind string `json:"kind"` -} + err.Message = string(val.GetStringBytes("message")) + err.Code = val.GetInt("error") -func (md MessageData) FullName() string { - return fmt.Sprintf("%s_%s", md.Kind, md.ID) -} - -type MessageListing struct { - Messages []MessageData `json:"children"` -} - -type MessageListingResponse struct { - MessageListing MessageListing `json:"data"` + return err } type RefreshTokenResponse struct { @@ -53,6 +30,15 @@ type RefreshTokenResponse struct { RefreshToken string `json:"refresh_token"` } +func NewRefreshTokenResponse(val *fastjson.Value) *RefreshTokenResponse { + rtr := &RefreshTokenResponse{} + + rtr.AccessToken = string(val.GetStringBytes("access_token")) + rtr.RefreshToken = string(val.GetStringBytes("refresh_token")) + + return rtr +} + type MeResponse struct { ID string `json:"id"` Name string @@ -86,6 +72,10 @@ type Thing struct { Subreddit string `json:"subreddit"` } +func (t *Thing) FullName() string { + return fmt.Sprintf("%s_%s", t.Kind, t.ID) +} + func NewThing(val *fastjson.Value) *Thing { t := &Thing{} @@ -122,8 +112,12 @@ func NewListingResponse(val *fastjson.Value) *ListingResponse { lr.After = string(data.GetStringBytes("after")) lr.Before = string(data.GetStringBytes("before")) lr.Count = data.GetInt("dist") - lr.Children = make([]*Thing, lr.Count) + if lr.Count == 0 { + return lr + } + + lr.Children = make([]*Thing, lr.Count) children := data.GetArray("children") for i := 0; i < lr.Count; i++ { t := NewThing(children[i]) diff --git a/internal/reddit/types_test.go b/internal/reddit/types_test.go index 61642e9..10d5aa9 100644 --- a/internal/reddit/types_test.go +++ b/internal/reddit/types_test.go @@ -26,6 +26,20 @@ func TestMeResponseParsing(t *testing.T) { assert.Equal(t, "hugocat", me.Name) } +func TestRefreshTokenResponseParsing(t *testing.T) { + bb, err := ioutil.ReadFile("testdata/refresh_token.json") + assert.NoError(t, err) + + val, err := parser.ParseBytes(bb) + assert.NoError(t, err) + + rtr := NewRefreshTokenResponse(val) + assert.NotNil(t, rtr) + + assert.Equal(t, "***REMOVED***", rtr.AccessToken) + assert.Equal(t, "***REMOVED***", rtr.RefreshToken) +} + func TestListingResponseParsing(t *testing.T) { bb, err := ioutil.ReadFile("testdata/message_inbox.json") assert.NoError(t, err) diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index b3617ba..df0ab98 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -248,7 +248,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { nc.logger.WithFields(logrus.Fields{ "accountID": id, }).Debug("fetching message inbox") - msgs, err := rac.MessageInbox(account.LastMessageID) + msgs, err := rac.MessageInbox(reddit.WithQuery("limit", "10")) if err != nil { nc.logger.WithFields(logrus.Fields{ @@ -258,12 +258,32 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { return } + // Figure out where we stand + if msgs.Count == 0 || msgs.Children[0].FullName() == account.LastMessageID { + nc.logger.WithFields(logrus.Fields{ + "accountID": id, + }).Debug("no new messages, bailing early") + return + } + + // Find which one is the oldest we haven't notified on + oldest := 0 + for i, t := range msgs.Children { + if t.FullName() == account.LastMessageID { + break + } + + oldest = i + } + + tt := msgs.Children[:oldest] + nc.logger.WithFields(logrus.Fields{ "accountID": id, - "count": len(msgs.MessageListing.Messages), + "count": len(tt), }).Debug("fetched messages") - if len(msgs.MessageListing.Messages) == 0 { + if len(tt) == 0 { nc.logger.WithFields(logrus.Fields{ "accountID": id, }).Debug("no new messages, bailing early") @@ -271,7 +291,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { } // Set latest message we alerted on - latestMsg := msgs.MessageListing.Messages[0] + latestMsg := tt[0] if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error { stmt := ` UPDATE accounts @@ -316,10 +336,12 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { devices = append(devices, device) } - for _, msg := range msgs.MessageListing.Messages { + // Iterate backwards so we notify from older to newer + for i := len(tt) - 1; i >= 0; i-- { + msg := tt[i] notification := &apns2.Notification{} notification.Topic = "com.christianselig.Apollo" - notification.Payload = payloadFromMessage(account, &msg, len(msgs.MessageListing.Messages)) + notification.Payload = payloadFromMessage(account, msg, len(tt)) for _, device := range devices { notification.DeviceToken = device.APNSToken @@ -353,7 +375,7 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { }).Debug("finishing job") } -func payloadFromMessage(acct *data.Account, msg *reddit.MessageData, badgeCount int) *payload.Payload { +func payloadFromMessage(acct *data.Account, msg *reddit.Thing, badgeCount int) *payload.Payload { postBody := msg.Body if len(postBody) > 2000 { postBody = msg.Body[:2000]