diff --git a/internal/reddit/client.go b/internal/reddit/client.go index 40b6080..5b3c047 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -19,6 +19,9 @@ import ( const ( SkipRateLimiting = "" RequestRemainingBuffer = 50 + + RateLimitRemainingHeader = "X-Ratelimit-Remaining" + RateLimitResetHeader = "X-Ratelimit-Reset" ) type Client struct { @@ -113,7 +116,7 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str return &AuthenticatedClient{rc, redditId, refreshToken, accessToken} } -func (rc *Client) doRequest(r *Request) ([]byte, int, int, error) { +func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) { req, err := r.HTTPRequest() if err != nil { return nil, 0, 0, err @@ -137,12 +140,12 @@ func (rc *Client) doRequest(r *Request) ([]byte, int, int, error) { } defer resp.Body.Close() - remaining, err := strconv.Atoi(resp.Header.Get("x-ratelimit-remaining")) + remaining, err := strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64) if err != nil { remaining = 0 } - reset, err := strconv.Atoi(resp.Header.Get("x-ratelimit-reset")) + reset, err := strconv.Atoi(resp.Header.Get(RateLimitResetHeader)) if err != nil { reset = 0 } @@ -167,7 +170,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in bb, remaining, reset, err := rac.doRequest(r) - if remaining <= RequestRemainingBuffer { + if err == nil && remaining <= RequestRemainingBuffer { rac.markRateLimited(remaining, time.Duration(reset)*time.Second) } @@ -229,7 +232,7 @@ func (rac *AuthenticatedClient) isRateLimited() (bool, error) { } } -func (rac *AuthenticatedClient) markRateLimited(remaining int, duration time.Duration) error { +func (rac *AuthenticatedClient) markRateLimited(remaining float64, duration time.Duration) error { if rac.redditId == SkipRateLimiting { return ErrRequiresRedditId }