diff --git a/internal/reddit/client.go b/internal/reddit/client.go index 959e533..a7091b7 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -97,7 +97,7 @@ func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *Auth return &AuthenticatedClient{rc, refreshToken, accessToken, nil} } -func (rac *AuthenticatedClient) request(r *Request) (*fastjson.Value, error) { +func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler) (interface{}, error) { req, err := r.HTTPRequest() if err != nil { return nil, err @@ -135,7 +135,12 @@ func (rac *AuthenticatedClient) request(r *Request) (*fastjson.Value, error) { } return nil, NewError(val) } - return parser.ParseBytes(bb) + val, err := parser.ParseBytes(bb) + if err != nil { + return nil, err + } + + return rh(val), nil } func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { @@ -148,12 +153,11 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { WithBasicAuth(rac.id, rac.secret), ) - val, err := rac.request(req) + rtr, err := rac.request(req, NewRefreshTokenResponse) if err != nil { return nil, err } - - return NewRefreshTokenResponse(val), nil + return rtr.(*RefreshTokenResponse), nil } func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) { @@ -165,12 +169,11 @@ func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingRes }, opts...) req := NewRequest(opts...) - val, err := rac.request(req) + lr, err := rac.request(req, NewListingResponse) if err != nil { return nil, err } - - return NewListingResponse(val), nil + return lr.(*ListingResponse), nil } func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingResponse, error) { @@ -183,12 +186,11 @@ func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingRe req := NewRequest(opts...) - val, err := rac.request(req) + lr, err := rac.request(req, NewListingResponse) if err != nil { return nil, err } - - return NewListingResponse(val), nil + return lr.(*ListingResponse), nil } func (rac *AuthenticatedClient) Me() (*MeResponse, error) { @@ -199,10 +201,9 @@ func (rac *AuthenticatedClient) Me() (*MeResponse, error) { WithURL("https://oauth.reddit.com/api/v1/me"), ) - val, err := rac.request(req) + mr, err := rac.request(req, NewMeResponse) if err != nil { return nil, err } - - return NewMeResponse(val), nil + return mr.(*MeResponse), nil } diff --git a/internal/reddit/types.go b/internal/reddit/types.go index 859feb3..0cd818b 100644 --- a/internal/reddit/types.go +++ b/internal/reddit/types.go @@ -7,6 +7,8 @@ import ( "github.com/valyala/fastjson" ) +type ResponseHandler func(*fastjson.Value) interface{} + type Error struct { Message string `json:"message"` Code int `json:"error"` @@ -30,7 +32,7 @@ type RefreshTokenResponse struct { RefreshToken string `json:"refresh_token"` } -func NewRefreshTokenResponse(val *fastjson.Value) *RefreshTokenResponse { +func NewRefreshTokenResponse(val *fastjson.Value) interface{} { rtr := &RefreshTokenResponse{} rtr.AccessToken = string(val.GetStringBytes("access_token")) @@ -48,7 +50,7 @@ func (mr *MeResponse) NormalizedUsername() string { return strings.ToLower(mr.Name) } -func NewMeResponse(val *fastjson.Value) *MeResponse { +func NewMeResponse(val *fastjson.Value) interface{} { mr := &MeResponse{} mr.ID = string(val.GetStringBytes("id")) @@ -105,7 +107,7 @@ type ListingResponse struct { Before string } -func NewListingResponse(val *fastjson.Value) *ListingResponse { +func NewListingResponse(val *fastjson.Value) interface{} { lr := &ListingResponse{} data := val.Get("data") diff --git a/internal/reddit/types_test.go b/internal/reddit/types_test.go index 10d5aa9..689c2fa 100644 --- a/internal/reddit/types_test.go +++ b/internal/reddit/types_test.go @@ -19,7 +19,8 @@ func TestMeResponseParsing(t *testing.T) { val, err := parser.ParseBytes(bb) assert.NoError(t, err) - me := NewMeResponse(val) + ret := NewMeResponse(val) + me := ret.(*MeResponse) assert.NotNil(t, me) assert.Equal(t, "xgeee", me.ID) @@ -33,7 +34,8 @@ func TestRefreshTokenResponseParsing(t *testing.T) { val, err := parser.ParseBytes(bb) assert.NoError(t, err) - rtr := NewRefreshTokenResponse(val) + ret := NewRefreshTokenResponse(val) + rtr := ret.(*RefreshTokenResponse) assert.NotNil(t, rtr) assert.Equal(t, "***REMOVED***", rtr.AccessToken) @@ -47,7 +49,8 @@ func TestListingResponseParsing(t *testing.T) { val, err := parser.ParseBytes(bb) assert.NoError(t, err) - l := NewListingResponse(val) + ret := NewListingResponse(val) + l := ret.(*ListingResponse) assert.NotNil(t, l) assert.Equal(t, 25, l.Count) @@ -64,6 +67,7 @@ func TestListingResponseParsing(t *testing.T) { assert.Equal(t, "how are you today", thing.Body) assert.Equal(t, 1626285395.0, thing.CreatedAt) assert.Equal(t, "hugocat", thing.Destination) + assert.Equal(t, "t4_138z6ke", thing.FullName()) thing = l.Children[6] assert.Equal(t, "/r/calicosummer/comments/ngcapc/hello_i_am_a_cat/h4q5j98/?context=3", thing.Context) diff --git a/internal/worker/notifications.go b/internal/worker/notifications.go index 491308f..1ac768c 100644 --- a/internal/worker/notifications.go +++ b/internal/worker/notifications.go @@ -284,14 +284,12 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) { }).Debug("fetched messages") // Set latest message we alerted on - latestMsg := tt[0] - if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error { stmt := ` UPDATE accounts SET last_message_id = $1 WHERE id = $2` - _, err := tx.Exec(ctx, stmt, latestMsg.FullName(), account.ID) + _, err := tx.Exec(ctx, stmt, tt[0].FullName(), account.ID) return err }); err != nil { nc.logger.WithFields(logrus.Fields{