fix rate limiter part 2

This commit is contained in:
Andre Medeiros 2022-03-12 13:25:34 -05:00
parent 77162dd513
commit 5ad98494ee

View file

@ -34,6 +34,12 @@ type Client struct {
redis *redis.Client redis *redis.Client
} }
type RateLimitingInfo struct {
Remaining float64
Reset int
Present bool
}
var backoffSchedule = []time.Duration{ var backoffSchedule = []time.Duration{
4 * time.Second, 4 * time.Second,
8 * time.Second, 8 * time.Second,
@ -116,10 +122,10 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken} return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
} }
func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) { func (rc *Client) doRequest(r *Request) ([]byte, *RateLimitingInfo, error) {
req, err := r.HTTPRequest() req, err := r.HTTPRequest()
if err != nil { if err != nil {
return nil, 0, 0, err return nil, nil, err
} }
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer)) req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
@ -134,33 +140,30 @@ func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) {
if err != nil { if err != nil {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") { if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
return nil, 0, 0, ErrTimeout return nil, nil, ErrTimeout
} }
return nil, 0, 0, err return nil, nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
remaining, err := strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64) rli := &RateLimitingInfo{Present: false}
if err != nil { if _, ok := resp.Header[RateLimitRemainingHeader]; ok {
remaining = 0 rli.Present = true
} rli.Remaining, _ = strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64)
rli.Reset, _ = strconv.Atoi(resp.Header.Get(RateLimitResetHeader))
reset, err := strconv.Atoi(resp.Header.Get(RateLimitResetHeader))
if err != nil {
reset = 0
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
return nil, remaining, reset, ServerError{resp.StatusCode} return nil, rli, ServerError{resp.StatusCode}
} }
bb, err := ioutil.ReadAll(resp.Body) bb, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
return nil, remaining, reset, err return nil, rli, err
} }
return bb, remaining, reset, nil return bb, rli, nil
} }
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) { func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
@ -168,11 +171,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
return nil, ErrRateLimited return nil, ErrRateLimited
} }
bb, remaining, reset, err := rac.doRequest(r) bb, rli, err := rac.doRequest(r)
if err == nil && remaining <= RequestRemainingBuffer {
rac.markRateLimited(remaining, time.Duration(reset)*time.Second)
}
if err != nil && r.retry { if err != nil && r.retry {
for _, backoff := range backoffSchedule { for _, backoff := range backoffSchedule {
@ -180,7 +179,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
time.AfterFunc(backoff, func() { time.AfterFunc(backoff, func() {
_ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1) _ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1)
bb, remaining, reset, err = rac.doRequest(r) bb, rli, err = rac.doRequest(r)
done <- struct{}{} done <- struct{}{}
}) })
@ -192,6 +191,11 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
} }
} }
if err == nil && rli.Present && rli.Remaining <= RequestRemainingBuffer {
_ = rac.statsd.Incr("reddit.api.ratelimit", r.tags, 0.1)
rac.markRateLimited(rli.Remaining, time.Duration(rli.Reset)*time.Second)
}
if err != nil { if err != nil {
_ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) _ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") { if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {