diff --git a/internal/reddit/client.go b/internal/reddit/client.go index f10e949..40b6080 100644 --- a/internal/reddit/client.go +++ b/internal/reddit/client.go @@ -168,7 +168,7 @@ func (rac *AuthenticatedClient) request(r *Request, rh ResponseHandler, empty in bb, remaining, reset, err := rac.doRequest(r) if remaining <= RequestRemainingBuffer { - rac.markRateLimited(time.Duration(reset) * time.Second) + rac.markRateLimited(remaining, time.Duration(reset)*time.Second) } if err != nil && r.retry { @@ -218,22 +218,24 @@ func (rac *AuthenticatedClient) isRateLimited() (bool, error) { } key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) - res, err := rac.redis.Exists(context.Background(), key).Result() + _, err := rac.redis.Get(context.Background(), key).Result() - if err != nil { + if err == redis.Nil { + return false, nil + } else if err == nil { + return true, nil + } else { return false, err } - - return res > 0, nil } -func (rac *AuthenticatedClient) markRateLimited(duration time.Duration) error { +func (rac *AuthenticatedClient) markRateLimited(remaining int, duration time.Duration) error { if rac.redditId == SkipRateLimiting { return ErrRequiresRedditId } key := fmt.Sprintf("reddit:%s:ratelimited", rac.redditId) - _, err := rac.redis.SetEX(context.Background(), key, true, duration).Result() + _, err := rac.redis.SetEX(context.Background(), key, remaining, duration).Result() return err }