mirror of
https://github.com/christianselig/apollo-backend
synced 2024-11-22 19:57:43 +00:00
fix rate limiter part 2
This commit is contained in:
parent
77162dd513
commit
5ad98494ee
1 changed files with 25 additions and 21 deletions
|
@ -34,6 +34,12 @@ type Client struct {
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RateLimitingInfo struct {
|
||||||
|
Remaining float64
|
||||||
|
Reset int
|
||||||
|
Present bool
|
||||||
|
}
|
||||||
|
|
||||||
var backoffSchedule = []time.Duration{
|
var backoffSchedule = []time.Duration{
|
||||||
4 * time.Second,
|
4 * time.Second,
|
||||||
8 * time.Second,
|
8 * time.Second,
|
||||||
|
@ -116,10 +122,10 @@ func (rc *Client) NewAuthenticatedClient(redditId, refreshToken, accessToken str
|
||||||
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
|
return &AuthenticatedClient{rc, redditId, refreshToken, accessToken}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) {
|
func (rc *Client) doRequest(r *Request) ([]byte, *RateLimitingInfo, error) {
|
||||||
req, err := r.HTTPRequest()
|
req, err := r.HTTPRequest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, 0, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
|
req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer))
|
||||||
|
@ -134,33 +140,30 @@ func (rc *Client) doRequest(r *Request) ([]byte, float64, int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
||||||
return nil, 0, 0, ErrTimeout
|
return nil, nil, ErrTimeout
|
||||||
}
|
}
|
||||||
return nil, 0, 0, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
remaining, err := strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64)
|
rli := &RateLimitingInfo{Present: false}
|
||||||
if err != nil {
|
if _, ok := resp.Header[RateLimitRemainingHeader]; ok {
|
||||||
remaining = 0
|
rli.Present = true
|
||||||
}
|
rli.Remaining, _ = strconv.ParseFloat(resp.Header.Get(RateLimitRemainingHeader), 64)
|
||||||
|
rli.Reset, _ = strconv.Atoi(resp.Header.Get(RateLimitResetHeader))
|
||||||
reset, err := strconv.Atoi(resp.Header.Get(RateLimitResetHeader))
|
|
||||||
if err != nil {
|
|
||||||
reset = 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
return nil, remaining, reset, ServerError{resp.StatusCode}
|
return nil, rli, ServerError{resp.StatusCode}
|
||||||
}
|
}
|
||||||
|
|
||||||
bb, err := ioutil.ReadAll(resp.Body)
|
bb, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
_ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
return nil, remaining, reset, err
|
return nil, rli, err
|
||||||
}
|
}
|
||||||
return bb, remaining, reset, nil
|
return bb, rli, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) {
|
||||||
|
@ -168,11 +171,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
|
||||||
return nil, ErrRateLimited
|
return nil, ErrRateLimited
|
||||||
}
|
}
|
||||||
|
|
||||||
bb, remaining, reset, err := rac.doRequest(r)
|
bb, rli, err := rac.doRequest(r)
|
||||||
|
|
||||||
if err == nil && remaining <= RequestRemainingBuffer {
|
|
||||||
rac.markRateLimited(remaining, time.Duration(reset)*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil && r.retry {
|
if err != nil && r.retry {
|
||||||
for _, backoff := range backoffSchedule {
|
for _, backoff := range backoffSchedule {
|
||||||
|
@ -180,7 +179,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
|
||||||
|
|
||||||
time.AfterFunc(backoff, func() {
|
time.AfterFunc(backoff, func() {
|
||||||
_ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1)
|
_ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1)
|
||||||
bb, remaining, reset, err = rac.doRequest(r)
|
bb, rli, err = rac.doRequest(r)
|
||||||
done <- struct{}{}
|
done <- struct{}{}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -192,6 +191,11 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err == nil && rli.Present && rli.Remaining <= RequestRemainingBuffer {
|
||||||
|
_ = rac.statsd.Incr("reddit.api.ratelimit", r.tags, 0.1)
|
||||||
|
rac.markRateLimited(rli.Remaining, time.Duration(rli.Reset)*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
_ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1)
|
||||||
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
if strings.Contains(err.Error(), "http2: timeout awaiting response headers") {
|
||||||
|
|
Loading…
Reference in a new issue