fix concurrency with fastjson

This commit is contained in:
Andre Medeiros 2021-07-15 11:51:04 -04:00
parent 8c1ed47d1a
commit 20fe4d7ddb
4 changed files with 28 additions and 23 deletions

View file

@ -97,7 +97,7 @@ func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *Auth
return &AuthenticatedClient{rc, refreshToken, accessToken, nil} return &AuthenticatedClient{rc, refreshToken, accessToken, nil}
} }
func (rac *AuthenticatedClient) request(r *Request) (*fastjson.Value, error) { func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler) (interface{}, error) {
req, err := r.HTTPRequest() req, err := r.HTTPRequest()
if err != nil { if err != nil {
return nil, err return nil, err
@ -135,7 +135,12 @@ func (rac *AuthenticatedClient) request(r *Request) (*fastjson.Value, error) {
} }
return nil, NewError(val) return nil, NewError(val)
} }
return parser.ParseBytes(bb) val, err := parser.ParseBytes(bb)
if err != nil {
return nil, err
}
return rh(val), nil
} }
func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) { func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
@ -148,12 +153,11 @@ func (rac *AuthenticatedClient) RefreshTokens() (*RefreshTokenResponse, error) {
WithBasicAuth(rac.id, rac.secret), WithBasicAuth(rac.id, rac.secret),
) )
val, err := rac.request(req) rtr, err := rac.request(req, NewRefreshTokenResponse)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return rtr.(*RefreshTokenResponse), nil
return NewRefreshTokenResponse(val), nil
} }
func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) { func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingResponse, error) {
@ -165,12 +169,11 @@ func (rac *AuthenticatedClient) MessageInbox(opts ...RequestOption) (*ListingRes
}, opts...) }, opts...)
req := NewRequest(opts...) req := NewRequest(opts...)
val, err := rac.request(req) lr, err := rac.request(req, NewListingResponse)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return lr.(*ListingResponse), nil
return NewListingResponse(val), nil
} }
func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingResponse, error) { func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingResponse, error) {
@ -183,12 +186,11 @@ func (rac *AuthenticatedClient) MessageUnread(opts ...RequestOption) (*ListingRe
req := NewRequest(opts...) req := NewRequest(opts...)
val, err := rac.request(req) lr, err := rac.request(req, NewListingResponse)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return lr.(*ListingResponse), nil
return NewListingResponse(val), nil
} }
func (rac *AuthenticatedClient) Me() (*MeResponse, error) { func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
@ -199,10 +201,9 @@ func (rac *AuthenticatedClient) Me() (*MeResponse, error) {
WithURL("https://oauth.reddit.com/api/v1/me"), WithURL("https://oauth.reddit.com/api/v1/me"),
) )
val, err := rac.request(req) mr, err := rac.request(req, NewMeResponse)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return mr.(*MeResponse), nil
return NewMeResponse(val), nil
} }

View file

@ -7,6 +7,8 @@ import (
"github.com/valyala/fastjson" "github.com/valyala/fastjson"
) )
type ResponseHandler func(*fastjson.Value) interface{}
type Error struct { type Error struct {
Message string `json:"message"` Message string `json:"message"`
Code int `json:"error"` Code int `json:"error"`
@ -30,7 +32,7 @@ type RefreshTokenResponse struct {
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
} }
func NewRefreshTokenResponse(val *fastjson.Value) *RefreshTokenResponse { func NewRefreshTokenResponse(val *fastjson.Value) interface{} {
rtr := &RefreshTokenResponse{} rtr := &RefreshTokenResponse{}
rtr.AccessToken = string(val.GetStringBytes("access_token")) rtr.AccessToken = string(val.GetStringBytes("access_token"))
@ -48,7 +50,7 @@ func (mr *MeResponse) NormalizedUsername() string {
return strings.ToLower(mr.Name) return strings.ToLower(mr.Name)
} }
func NewMeResponse(val *fastjson.Value) *MeResponse { func NewMeResponse(val *fastjson.Value) interface{} {
mr := &MeResponse{} mr := &MeResponse{}
mr.ID = string(val.GetStringBytes("id")) mr.ID = string(val.GetStringBytes("id"))
@ -105,7 +107,7 @@ type ListingResponse struct {
Before string Before string
} }
func NewListingResponse(val *fastjson.Value) *ListingResponse { func NewListingResponse(val *fastjson.Value) interface{} {
lr := &ListingResponse{} lr := &ListingResponse{}
data := val.Get("data") data := val.Get("data")

View file

@ -19,7 +19,8 @@ func TestMeResponseParsing(t *testing.T) {
val, err := parser.ParseBytes(bb) val, err := parser.ParseBytes(bb)
assert.NoError(t, err) assert.NoError(t, err)
me := NewMeResponse(val) ret := NewMeResponse(val)
me := ret.(*MeResponse)
assert.NotNil(t, me) assert.NotNil(t, me)
assert.Equal(t, "xgeee", me.ID) assert.Equal(t, "xgeee", me.ID)
@ -33,7 +34,8 @@ func TestRefreshTokenResponseParsing(t *testing.T) {
val, err := parser.ParseBytes(bb) val, err := parser.ParseBytes(bb)
assert.NoError(t, err) assert.NoError(t, err)
rtr := NewRefreshTokenResponse(val) ret := NewRefreshTokenResponse(val)
rtr := ret.(*RefreshTokenResponse)
assert.NotNil(t, rtr) assert.NotNil(t, rtr)
assert.Equal(t, "***REMOVED***", rtr.AccessToken) assert.Equal(t, "***REMOVED***", rtr.AccessToken)
@ -47,7 +49,8 @@ func TestListingResponseParsing(t *testing.T) {
val, err := parser.ParseBytes(bb) val, err := parser.ParseBytes(bb)
assert.NoError(t, err) assert.NoError(t, err)
l := NewListingResponse(val) ret := NewListingResponse(val)
l := ret.(*ListingResponse)
assert.NotNil(t, l) assert.NotNil(t, l)
assert.Equal(t, 25, l.Count) assert.Equal(t, 25, l.Count)
@ -64,6 +67,7 @@ func TestListingResponseParsing(t *testing.T) {
assert.Equal(t, "how are you today", thing.Body) assert.Equal(t, "how are you today", thing.Body)
assert.Equal(t, 1626285395.0, thing.CreatedAt) assert.Equal(t, 1626285395.0, thing.CreatedAt)
assert.Equal(t, "hugocat", thing.Destination) assert.Equal(t, "hugocat", thing.Destination)
assert.Equal(t, "t4_138z6ke", thing.FullName())
thing = l.Children[6] thing = l.Children[6]
assert.Equal(t, "/r/calicosummer/comments/ngcapc/hello_i_am_a_cat/h4q5j98/?context=3", thing.Context) assert.Equal(t, "/r/calicosummer/comments/ngcapc/hello_i_am_a_cat/h4q5j98/?context=3", thing.Context)

View file

@ -284,14 +284,12 @@ func (nc *notificationsConsumer) Consume(delivery rmq.Delivery) {
}).Debug("fetched messages") }).Debug("fetched messages")
// Set latest message we alerted on // Set latest message we alerted on
latestMsg := tt[0]
if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error { if err = nc.db.BeginFunc(ctx, func(tx pgx.Tx) error {
stmt := ` stmt := `
UPDATE accounts UPDATE accounts
SET last_message_id = $1 SET last_message_id = $1
WHERE id = $2` WHERE id = $2`
_, err := tx.Exec(ctx, stmt, latestMsg.FullName(), account.ID) _, err := tx.Exec(ctx, stmt, tt[0].FullName(), account.ID)
return err return err
}); err != nil { }); err != nil {
nc.logger.WithFields(logrus.Fields{ nc.logger.WithFields(logrus.Fields{