Skip to content

Commit 217adae

Browse files
committed
RateLimitRoundtripper: Fix mutex leak and not respecting context cancellation
1 parent d63ad78 commit 217adae

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

github/transport.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package github
22

33
import (
44
"bytes"
5+
"context"
56
"io"
67
"log"
78
"net/http"
@@ -65,7 +66,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
6566
// for read and write requests. See isWriteMethod for the distinction between them.
6667
if rlt.nextRequestDelay > 0 {
6768
log.Printf("[DEBUG] Sleeping %s between operations", rlt.nextRequestDelay)
68-
time.Sleep(rlt.nextRequestDelay)
69+
sleep(req.Context(), rlt.nextRequestDelay)
6970
}
7071

7172
rlt.nextRequestDelay = rlt.calculateNextDelay(req.Method)
@@ -81,6 +82,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
8182
// See https://github.com/google/go-github/pull/986
8283
r1, r2, err := drainBody(resp.Body)
8384
if err != nil {
85+
rlt.smartLock(false)
8486
return nil, err
8587
}
8688
resp.Body = r1
@@ -93,7 +95,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
9395
retryAfter := arlErr.GetRetryAfter()
9496
log.Printf("[DEBUG] Abuse detection mechanism triggered, sleeping for %s before retrying",
9597
retryAfter)
96-
time.Sleep(retryAfter)
98+
sleep(req.Context(), retryAfter)
9799
rlt.smartLock(false)
98100
return rlt.RoundTrip(req)
99101
}
@@ -103,7 +105,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
103105
retryAfter := time.Until(rlErr.Rate.Reset.Time)
104106
log.Printf("[DEBUG] Rate limit %d reached, sleeping for %s (until %s) before retrying",
105107
rlErr.Rate.Limit, retryAfter, time.Now().Add(retryAfter))
106-
time.Sleep(retryAfter)
108+
sleep(req.Context(), retryAfter)
107109
rlt.smartLock(false)
108110
return rlt.RoundTrip(req)
109111
}
@@ -113,6 +115,16 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
113115
return resp, nil
114116
}
115117

118+
func sleep(ctx context.Context, dur time.Duration) {
119+
t := time.NewTimer(dur)
120+
defer t.Stop()
121+
122+
select {
123+
case <-t.C:
124+
case <-ctx.Done():
125+
}
126+
}
127+
116128
// smartLock wraps the mutex locking system and performs its operation via a boolean input for locking and unlocking.
117129
// It also skips the locking when parallelRequests is set to true since, in this case, the lock is not needed.
118130
func (rlt *RateLimitTransport) smartLock(lock bool) {

github/transport_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package github
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"log"
@@ -159,6 +160,43 @@ func TestRateLimitTransport_abuseLimit_get(t *testing.T) {
159160
}
160161
}
161162

163+
func TestRateLimitTransport_abuseLimit_get_cancelled(t *testing.T) {
164+
ts := githubApiMock([]*mockResponse{
165+
{
166+
ExpectedUri: "/repos/test/blah",
167+
ResponseBody: `{
168+
"message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.",
169+
"documentation_url": "https://developer.github.com/v3/#abuse-rate-limits"
170+
}`,
171+
StatusCode: 403,
172+
ResponseHeaders: map[string]string{
173+
"Retry-After": "10",
174+
},
175+
},
176+
})
177+
defer ts.Close()
178+
179+
httpClient := http.DefaultClient
180+
httpClient.Transport = NewRateLimitTransport(http.DefaultTransport)
181+
182+
client := github.NewClient(httpClient)
183+
u, _ := url.Parse(ts.URL + "/")
184+
client.BaseURL = u
185+
186+
ctx := context.WithValue(context.Background(), ctxId, t.Name())
187+
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
188+
defer cancel()
189+
190+
start := time.Now()
191+
_, _, err := client.Repositories.Get(ctx, "test", "blah")
192+
if !errors.Is(err, context.DeadlineExceeded) {
193+
t.Fatalf("Expected context deadline exceeded, got: %v", err)
194+
}
195+
if time.Since(start) > time.Second {
196+
t.Fatalf("Waited for longer than expected: %s", time.Since(start))
197+
}
198+
}
199+
162200
func TestRateLimitTransport_abuseLimit_post(t *testing.T) {
163201
ts := githubApiMock([]*mockResponse{
164202
{

0 commit comments

Comments
 (0)