diff --git a/internal/reddit/client.go b/internal/reddit/client.go index 5b3c047..eec5d2b 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -34,6 +34,12 @@ type Client struct { redis *redis.Client } +type RateLimitingInfo struct { + Remaining float64 + Reset int + Present bool +} + var backoffSchedule = []time.Duration{ 4 * time.Second, 8 * time.Second, @@ -116,10 +122,10 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str return &AuthenticatedClient{rc, redditId, refreshToken, accessToken} } -func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) { +func (rc *Client) doRequest(r *Request) ([]byte, *RateLimitingInfo, error) { req, err := r.HTTPRequest() if err != nil { - return nil, 0, 0, err + return nil, nil, err } req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer)) @@ -134,33 +140,30 @@ func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) { if err != nil { _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) if strings.Contains(err.Error(), "http2: timeout awaiting response headers") { - return nil, 0, 0, ErrTimeout + return nil, nil, ErrTimeout } - return nil, 0, 0, err + return nil, nil, err } defer resp.Body.Close() - remaining, err := strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64) - if err != nil { - remaining = 0 - } - - reset, err := strconv.Atoi(resp.Header.Get(RateLimitResetHeader)) - if err != nil { - reset = 0 + rli := &RateLimitingInfo{Present: false} + if _, ok := resp.Header[RateLimitRemainingHeader]; ok { + rli.Present = true + rli.Remaining, _ = strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64) + rli.Reset, _ = strconv.Atoi(resp.Header.Get(RateLimitResetHeader)) } if resp.StatusCode != 200 { _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) - return nil, remaining, reset, ServerError{resp.StatusCode} + return nil, rli, ServerError{resp.StatusCode} } bb, err := ioutil.ReadAll(resp.Body) if err != nil { _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) - return nil, remaining, reset, err + return nil, rli, err } - return bb, remaining, reset, nil + return bb, rli, nil } func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) { @@ -168,11 +171,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in return nil, ErrRateLimited } - bb, remaining, reset, err := rac.doRequest(r) - - if err == nil && remaining <= RequestRemainingBuffer { - rac.markRateLimited(remaining, time.Duration(reset)*time.Second) - } + bb, rli, err := rac.doRequest(r) if err != nil && r.retry { for _, backoff := range backoffSchedule { @@ -180,7 +179,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in time.AfterFunc(backoff, func() { _ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1) - bb, remaining, reset, err = rac.doRequest(r) + bb, rli, err = rac.doRequest(r) done <- struct{}{} }) @@ -192,6 +191,11 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in } } + if err == nil && rli.Present && rli.Remaining <= RequestRemainingBuffer { + _ = rac.statsd.Incr("reddit.api.ratelimit", r.tags, 0.1) + rac.markRateLimited(rli.Remaining, time.Duration(rli.Reset)*time.Second) + } + if err != nil { _ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {