mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-25 13:17:42 +00:00
fix rate limiter part 2
This commit is contained in:
parent
77162dd513
commit
5ad98494ee
1 changed files with 25 additions and 21 deletions
|
@ -34,6 +34,12 @@ type Client struct {
|
|||
redis *redis.Client
|
||||
}
|
||||
|
||||
type RateLimitingInfo struct {
|
||||
Remaining float64
|
||||
Reset int
|
||||
Present bool
|
||||
}
|
||||
|
||||
var backoffSchedule = []time.Duration{
|
||||
4 * time.Second,
|
||||
8 * time.Second,
|
||||
|
@ -116,10 +122,10 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str
|
|||
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
|
||||
}
|
||||
|
||||
func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) {
|
||||
func (rc *Client) doRequest(r *Request) ([]byte, *RateLimitingInfo, error) {
|
||||
req, err := r.HTTPRequest()
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
|
||||
|
@ -134,33 +140,30 @@ func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) {
|
|||
if err != nil {
|
||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
||||
return nil, 0, 0, ErrTimeout
|
||||
return nil, nil, ErrTimeout
|
||||
}
|
||||
return nil, 0, 0, err
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
remaining, err := strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64)
|
||||
if err != nil {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
reset, err := strconv.Atoi(resp.Header.Get(RateLimitResetHeader))
|
||||
if err != nil {
|
||||
reset = 0
|
||||
rli := &RateLimitingInfo{Present: false}
|
||||
if _, ok := resp.Header[RateLimitRemainingHeader]; ok {
|
||||
rli.Present = true
|
||||
rli.Remaining, _ = strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64)
|
||||
rli.Reset, _ = strconv.Atoi(resp.Header.Get(RateLimitResetHeader))
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
return nil, remaining, reset, ServerError{resp.StatusCode}
|
||||
return nil, rli, ServerError{resp.StatusCode}
|
||||
}
|
||||
|
||||
bb, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
return nil, remaining, reset, err
|
||||
return nil, rli, err
|
||||
}
|
||||
return bb, remaining, reset, nil
|
||||
return bb, rli, nil
|
||||
}
|
||||
|
||||
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
||||
|
@ -168,11 +171,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
|
|||
return nil, ErrRateLimited
|
||||
}
|
||||
|
||||
bb, remaining, reset, err := rac.doRequest(r)
|
||||
|
||||
if err == nil && remaining <= RequestRemainingBuffer {
|
||||
rac.markRateLimited(remaining, time.Duration(reset)*time.Second)
|
||||
}
|
||||
bb, rli, err := rac.doRequest(r)
|
||||
|
||||
if err != nil && r.retry {
|
||||
for _, backoff := range backoffSchedule {
|
||||
|
@ -180,7 +179,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
|
|||
|
||||
time.AfterFunc(backoff, func() {
|
||||
_ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1)
|
||||
bb, remaining, reset, err = rac.doRequest(r)
|
||||
bb, rli, err = rac.doRequest(r)
|
||||
done <- struct{}{}
|
||||
})
|
||||
|
||||
|
@ -192,6 +191,11 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
|
|||
}
|
||||
}
|
||||
|
||||
if err == nil && rli.Present && rli.Remaining <= RequestRemainingBuffer {
|
||||
_ = rac.statsd.Incr("reddit.api.ratelimit", r.tags, 0.1)
|
||||
rac.markRateLimited(rli.Remaining, time.Duration(rli.Reset)*time.Second)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
_ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
||||
|
|
Loading…
Reference in a new issue