diff --git a/github/github.go b/github/github.go index 424ed1d..e9404b5 100644 --- a/github/github.go +++ b/github/github.go @@ -86,8 +86,9 @@ type Client struct { // User agent used when communicating with the GitHub API. UserAgent string - rateMu sync.Mutex - rate Rate // Rate limit for the client as determined by the most recent API call. + rateMu sync.Mutex + rateLimits [categories]Rate // Rate limits for the client as determined by the most recent API calls. + mostRecent rateLimitCategory // Services used for talking to different parts of the GitHub API. Activity *ActivityService @@ -323,7 +324,7 @@ func parseRate(r *http.Response) Rate { // current rate. func (c *Client) Rate() Rate { c.rateMu.Lock() - rate := c.rate + rate := c.rateLimits[c.mostRecent] c.rateMu.Unlock() return rate } @@ -334,6 +335,8 @@ func (c *Client) Rate() Rate { // interface, the raw response body will be written to v, without attempting to // first decode it. func (c *Client) Do(req *http.Request, v interface{}) (*Response, error) { + rateLimitCategory := category(req.URL.Path) + resp, err := c.client.Do(req) if err != nil { return nil, err @@ -348,7 +351,8 @@ func (c *Client) Do(req *http.Request, v interface{}) (*Response, error) { response := newResponse(resp) c.rateMu.Lock() - c.rate = response.Rate + c.rateLimits[rateLimitCategory] = response.Rate + c.mostRecent = rateLimitCategory c.rateMu.Unlock() err = CheckResponse(resp) @@ -528,6 +532,8 @@ type RateLimits struct { // The rate limit for non-search API requests. Unauthenticated // requests are limited to 60 per hour. Authenticated requests are // limited to 5,000 per hour. + // + // GitHub API docs: https://developer.github.com/v3/#rate-limiting Core *Rate `json:"core"` // The rate limit for search API requests. Unauthenticated requests @@ -542,6 +548,25 @@ func (r RateLimits) String() string { return Stringify(r) } +type rateLimitCategory uint8 + +const ( + coreCategory rateLimitCategory = iota + searchCategory + + categories // An array of this length will be able to contain all rate limit categories. +) + +// category returns the rate limit category of the endpoint, determined by Request.URL.Path. +func category(path string) rateLimitCategory { + switch { + default: + return coreCategory + case strings.HasPrefix(path, "/search/"): + return searchCategory + } +} + // Deprecated: RateLimit is deprecated, use RateLimits instead. func (c *Client) RateLimit() (*Rate, *Response, error) { limits, resp, err := c.RateLimits() @@ -567,6 +592,17 @@ func (c *Client) RateLimits() (*RateLimits, *Response, error) { return nil, nil, err } + if response.Resources != nil { + c.rateMu.Lock() + if response.Resources.Core != nil { + c.rateLimits[coreCategory] = *response.Resources.Core + } + if response.Resources.Search != nil { + c.rateLimits[searchCategory] = *response.Resources.Search + } + c.rateMu.Unlock() + } + return response.Resources, resp, err } diff --git a/github/github_test.go b/github/github_test.go index 06abcc5..867d23a 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -163,6 +163,13 @@ func TestNewClient(t *testing.T) { } } +// Ensure that length of Client.rateLimits is the same as number of fields in RateLimits struct. +func TestClient_rateLimits(t *testing.T) { + if got, want := len(Client{}.rateLimits), reflect.TypeOf(RateLimits{}).NumField(); got != want { + t.Errorf("len(Client{}.rateLimits) is %v, want %v", got, want) + } +} + func TestNewRequest(t *testing.T) { c := NewClient(nil)