fix rate limiting parsing

This commit is contained in:
Andre Medeiros 2022-03-12 13:15:59 -05:00
parent 7a6955212b
commit 77162dd513

View file

@ -19,6 +19,9 @@ import (
const ( const (
SkipRateLimiting = "<SKIP_RATE_LIMITING>" SkipRateLimiting = "<SKIP_RATE_LIMITING>"
RequestRemainingBuffer = 50 RequestRemainingBuffer = 50
RateLimitRemainingHeader = "X-Ratelimit-Remaining"
RateLimitResetHeader = "X-Ratelimit-Reset"
) )
type Client struct { type Client struct {
@ -113,7 +116,7 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken} return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
} }
func (rc *Client) doRequest(r *Request) ([]byte, int, int, error) { func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) {
req, err := r.HTTPRequest() req, err := r.HTTPRequest()
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
@ -137,12 +140,12 @@ func (rc *Client) doRequest(r *Request) ([]byte, int, int, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
remaining, err := strconv.Atoi(resp.Header.Get("x-ratelimit-remaining")) remaining, err := strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64)
if err != nil { if err != nil {
remaining = 0 remaining = 0
} }
reset, err := strconv.Atoi(resp.Header.Get("x-ratelimit-reset")) reset, err := strconv.Atoi(resp.Header.Get(RateLimitResetHeader))
if err != nil { if err != nil {
reset = 0 reset = 0
} }
@ -167,7 +170,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
bb, remaining, reset, err := rac.doRequest(r) bb, remaining, reset, err := rac.doRequest(r)
if remaining <= RequestRemainingBuffer { if err == nil && remaining <= RequestRemainingBuffer {
rac.markRateLimited(remaining, time.Duration(reset)*time.Second) rac.markRateLimited(remaining, time.Duration(reset)*time.Second)
} }
@ -229,7 +232,7 @@ func (rac *AuthenticatedClient) isRateLimited() (bool, error) {
} }
} }
func (rac *AuthenticatedClient) markRateLimited(remaining int, duration time.Duration) error { func (rac *AuthenticatedClient) markRateLimited(remaining float64, duration time.Duration) error {
if rac.redditId == SkipRateLimiting { if rac.redditId == SkipRateLimiting {
return ErrRequiresRedditId return ErrRequiresRedditId
} }