diff --git a/internal/reddit/client.go b/internal/reddit/client.go index 8accdff..cbb792b 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -22,6 +22,12 @@ type Client struct { statsd statsd.ClientInterface } +var backoffSchedule = []time.Duration{ + 4 * time.Second, + 8 * time.Second, + 16 * time.Second, +} + func SplitID(id string) (string, string) { if parts := strings.Split(id, "_"); len(parts) == 2 { return parts[0], parts[1] @@ -93,21 +99,23 @@ func (rc *Client) NewAuthenticatedClient(refreshToken, accessToken string) *Auth return &AuthenticatedClient{rc, refreshToken, accessToken, nil} } -func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) { +func (rc *Client) doRequest(r *Request) ([]byte, error) { req, err := r.HTTPRequest() if err != nil { return nil, err } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), rac.tracer)) + req = req.WithContext(httptrace.WithClientTrace(req.Context(), rc.tracer)) start := time.Now() - resp, err := rac.client.Do(req) - _ = rac.statsd.Incr("reddit.api.calls", r.tags, 0.1) - _ = rac.statsd.Histogram("reddit.api.latency", float64(time.Since(start).Milliseconds()), r.tags, 0.1) + + resp, err := rc.client.Do(req) + + _ = rc.statsd.Incr("reddit.api.calls", r.tags, 0.1) + _ = rc.statsd.Histogram("reddit.api.latency", float64(time.Since(start).Milliseconds()), r.tags, 0.1) if err != nil { - _ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) + _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) if strings.Contains(err.Error(), "http2: timeout awaiting response headers") { return nil, ErrTimeout } @@ -115,15 +123,45 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in } defer resp.Body.Close() - bb, err := ioutil.ReadAll(resp.Body) - if err != nil { - _ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) - return nil, err + if resp.StatusCode != 200 { + _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) + return nil, ServerError{resp.StatusCode} } - if resp.StatusCode != 200 { + bb, err := ioutil.ReadAll(resp.Body) + if err != nil { + _ = rc.statsd.Incr("reddit.api.errors", r.tags, 0.1) + return nil, err + } + return bb, nil +} + +func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty interface{}) (interface{}, error) { + bb, err := rac.doRequest(r) + if err != nil && r.retry { + for _, backoff := range backoffSchedule { + done := make(chan struct{}) + + time.AfterFunc(backoff, func() { + _ = rac.statsd.Incr("reddit.api.retries", r.tags, 0.1) + bb, err = rac.doRequest(r) + done <- struct{}{} + }) + + <-done + + if err == nil { + break + } + } + } + + if err != nil { _ = rac.statsd.Incr("reddit.api.errors", r.tags, 0.1) - return nil, ServerError{resp.StatusCode} + if strings.Contains(err.Error(), "http2: timeout awaiting response headers") { + return nil, ErrTimeout + } + return nil, err } if r.emptyResponseBytes > 0 && len(bb) == r.emptyResponseBytes { diff --git a/internal/reddit/request.go b/internal/reddit/request.go index 0e233b5..584bee9 100644 --- a/internal/reddit/request.go +++ b/internal/reddit/request.go @@ -19,12 +19,27 @@ type Request struct { auth string tags []string emptyResponseBytes int + retry bool } type RequestOption func(*Request) func NewRequest(opts ...RequestOption) *Request { - req := &Request{url.Values{}, url.Values{}, "GET", "", "", "", nil, 0} + req := &Request{ + body: url.Values{}, + query: url.Values{}, + method: "GET", + url: "", + + token: "", + auth: "", + + tags: nil, + + emptyResponseBytes: 0, + retry: true, + } + for _, opt := range opts { opt(req) } @@ -101,3 +116,9 @@ func WithEmptyResponseBytes(bytes int) RequestOption { req.emptyResponseBytes = bytes } } + +func WithRetry(retry bool) RequestOption { + return func(req *Request) { + req.retry = retry + } +}