diff --git a/github/github.go b/github/github.go index 2321ee7..9eb99b7 100644 --- a/github/github.go +++ b/github/github.go @@ -68,6 +68,8 @@ const ( type Client struct { // HTTP client used to communicate with the API. client *http.Client + // clientMu protects the client during calls that modify the CheckRedirect func. + clientMu sync.Mutex // Base URL for API requests. Defaults to the public GitHub API, but can be // set to a domain endpoint to use with GitHub Enterprise. BaseURL should diff --git a/github/repos_releases.go b/github/repos_releases.go index 9f6133d..0eac47b 100644 --- a/github/repos_releases.go +++ b/github/repos_releases.go @@ -10,8 +10,10 @@ import ( "fmt" "io" "mime" + "net/http" "os" "path/filepath" + "strings" ) // RepositoryRelease represents a GitHub release in a repository. @@ -213,27 +215,43 @@ func (s *RepositoriesService) GetReleaseAsset(owner, repo string, id int) (*Rele return asset, resp, err } -// DownloadReleaseAsset downloads a release asset. +// DownloadReleaseAsset downloads a release asset or returns a redirect URL. // // DownloadReleaseAsset returns an io.ReadCloser that reads the contents of the // specified release asset. It is the caller's responsibility to close the ReadCloser. +// If a redirect is returned, the redirect URL will be returned as a string instead +// of the io.ReadCloser. Exactly one of rc and redirectURL will be zero. // // GitHub API docs : http://developer.github.com/v3/repos/releases/#get-a-single-release-asset -func (s *RepositoriesService) DownloadReleaseAsset(owner, repo string, id int) (io.ReadCloser, error) { +func (s *RepositoriesService) DownloadReleaseAsset(owner, repo string, id int) (rc io.ReadCloser, redirectURL string, err error) { u := fmt.Sprintf("repos/%s/%s/releases/assets/%d", owner, repo, id) req, err := s.client.NewRequest("GET", u, nil) if err != nil { - return nil, err + return nil, "", err } req.Header.Set("Accept", defaultMediaType) + s.client.clientMu.Lock() + defer s.client.clientMu.Unlock() + + var loc string + saveRedirect := s.client.client.CheckRedirect + s.client.client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + loc = req.URL.String() + return errors.New("disable redirect") + } + defer func() { s.client.client.CheckRedirect = saveRedirect }() + resp, err := s.client.client.Do(req) if err != nil { - return nil, err + if !strings.Contains(err.Error(), "disable redirect") { + return nil, "", err + } + return nil, loc, nil } - return resp.Body, nil + return resp.Body, "", nil } // EditReleaseAsset edits a repository release asset. diff --git a/github/repos_releases_test.go b/github/repos_releases_test.go index 5b0f094..532039b 100644 --- a/github/repos_releases_test.go +++ b/github/repos_releases_test.go @@ -13,6 +13,7 @@ import ( "net/http" "os" "reflect" + "strings" "testing" ) @@ -218,7 +219,7 @@ func TestRepositoriesService_DownloadReleaseAsset_Stream(t *testing.T) { fmt.Fprint(w, "Hello World") }) - reader, err := client.Repositories.DownloadReleaseAsset("o", "r", 1) + reader, _, err := client.Repositories.DownloadReleaseAsset("o", "r", 1) if err != nil { t.Errorf("Repositories.DownloadReleaseAsset returned error: %v", err) } @@ -239,28 +240,16 @@ func TestRepositoriesService_DownloadReleaseAsset_Redirect(t *testing.T) { mux.HandleFunc("/repos/o/r/releases/assets/1", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") testHeader(t, r, "Accept", defaultMediaType) - w.Header().Set("Location", server.URL+"/github-cloud/releases/1/hello-world.txt") - w.WriteHeader(http.StatusFound) + http.Redirect(w, r, "/yo", http.StatusFound) }) - mux.HandleFunc("/github-cloud/releases/1/hello-world.txt", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - w.Header().Set("Content-Type", "application/octet-stream") - w.Header().Set("Content-Disposition", "attachment; filename=hello-world.txt") - fmt.Fprint(w, "Hello World") - }) - - reader, err := client.Repositories.DownloadReleaseAsset("o", "r", 1) + _, got, err := client.Repositories.DownloadReleaseAsset("o", "r", 1) if err != nil { t.Errorf("Repositories.DownloadReleaseAsset returned error: %v", err) } - want := []byte("Hello World") - content, err := ioutil.ReadAll(reader) - if err != nil { - t.Errorf("Repositories.DownloadReleaseAsset returned bad reader: %v", err) - } - if !bytes.Equal(want, content) { - t.Errorf("Repositories.DownloadReleaseAsset returned %+v, want %+v", content, want) + want := "/yo" + if !strings.HasSuffix(got, want) { + t.Errorf("Repositories.DownloadReleaseAsset returned %+v, want %+v", got, want) } }