diff --git a/internal/reddit/client.go b/internal/reddit/client.go index eec5d2b..a6d7527 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -147,7 +147,7 @@ func (rc *Client) doRequest(r *Request) ([]byte, *RateLimitingInfo, error) { defer resp.Body.Close() rli := &RateLimitingInfo{Present: false} - if _, ok := resp.Header[RateLimitRemainingHeader]; ok { + if resp.Header.Get(RateLimitRemainingHeader) != "" { rli.Present = true rli.Remaining, _ = strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64) rli.Reset, _ = strconv.Atoi(resp.Header.Get(RateLimitResetHeader)) @@ -167,7 +167,7 @@ func (rc *Client) doRequest(r *Request) ([]byte, *RateLimitingInfo, error) { } func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) { - if rl, err := rac.isRateLimited(); rl || err != nil { + if rac.isRateLimited() { return nil, ErrRateLimited } @@ -191,17 +191,14 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in } } - if err == nil && rli.Present && rli.Remaining <= RequestRemainingBuffer { - _ = rac.statsd.Incr("reddit.api.ratelimit", r.tags, 0.1) - rac.markRateLimited(rli.Remaining, time.Duration(rli.Reset)*time.Second) - } - if err != nil { _ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) if strings.Contains(err.Error(), "http2: timeout awaiting response headers") { return nil, ErrTimeout } return nil, err + } else { + rac.markRateLimited(rli) } if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes { @@ -219,30 +216,34 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in return rh(val), nil } -func (rac *AuthenticatedClient) isRateLimited() (bool, error) { +func (rac *AuthenticatedClient) isRateLimited() bool { if rac.redditId == SkipRateLimiting { - return false, nil + return false } key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) _, err := rac.redis.Get(context.Background(), key).Result() - - if err == redis.Nil { - return false, nil - } else if err == nil { - return true, nil - } else { - return false, err - } + return err != redis.Nil } -func (rac *AuthenticatedClient) markRateLimited(remaining float64, duration time.Duration) error { +func (rac *AuthenticatedClient) markRateLimited(rli *RateLimitingInfo) error { if rac.redditId == SkipRateLimiting { return ErrRequiresRedditId } + if !rli.Present { + return nil + } + + if rli.Remaining > RequestRemainingBuffer { + return nil + } + + _ = rac.statsd.Incr("reddit.api.ratelimit", nil, 0.1) + key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) - _, err := rac.redis.SetEX(context.Background(), key, remaining, duration).Result() + duration := time.Duration(rli.Reset) * time.Second + _, err := rac.redis.SetEX(context.Background(), key, rli.Remaining, duration).Result() return err }