Skip to content

Commit

Permalink
Adding gh release download for .zip and .tar.gz
Browse files Browse the repository at this point in the history
Co-authored-by: Mislav Marohnić <[email protected]>
  • Loading branch information
lpessoa and mislav committed Nov 30, 2021
1 parent c987c57 commit 8058c4e
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 16 deletions.
90 changes: 76 additions & 14 deletions pkg/cmd/release/download/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package download

import (
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
Expand All @@ -27,6 +29,8 @@ type DownloadOptions struct {

// maximum number of simultaneous downloads
Concurrency int

ArchiveType string
}

func NewCmdDownload(f *cmdutil.Factory, runF func(*DownloadOptions) error) *cobra.Command {
Expand All @@ -47,12 +51,15 @@ func NewCmdDownload(f *cmdutil.Factory, runF func(*DownloadOptions) error) *cobr
Example: heredoc.Doc(`
# download all assets from a specific release
$ gh release download v1.2.3
# download only Debian packages for the latest release
$ gh release download --pattern '*.deb'
# specify multiple file patterns
$ gh release download -p '*.deb' -p '*.rpm'
# download the archive of the source code for a release
$ gh release download v1.2.3 --archive=zip
`),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
Expand All @@ -67,6 +74,11 @@ func NewCmdDownload(f *cmdutil.Factory, runF func(*DownloadOptions) error) *cobr
opts.TagName = args[0]
}

// check archive type option validity
if err := checkArchiveTypeOption(opts); err != nil {
return err
}

opts.Concurrency = 5

if runF != nil {
Expand All @@ -78,10 +90,30 @@ func NewCmdDownload(f *cmdutil.Factory, runF func(*DownloadOptions) error) *cobr

cmd.Flags().StringVarP(&opts.Destination, "dir", "D", ".", "The directory to download files into")
cmd.Flags().StringArrayVarP(&opts.FilePatterns, "pattern", "p", nil, "Download only assets that match a glob pattern")
cmd.Flags().StringVarP(&opts.ArchiveType, "archive", "A", "", "Download the source code archive in the specified `format` (zip or tar.gz)")

return cmd
}

func checkArchiveTypeOption(opts *DownloadOptions) error {
if len(opts.ArchiveType) == 0 {
return nil
}

if err := cmdutil.MutuallyExclusive(
"specify only one of '--pattern' or '--archive'",
true, // ArchiveType len > 0
len(opts.FilePatterns) > 0,
); err != nil {
return err
}

if opts.ArchiveType != "zip" && opts.ArchiveType != "tar.gz" {
return cmdutil.FlagErrorf("the value for `--archive` must be one of \"zip\" or \"tar.gz\"")
}
return nil
}

func downloadRun(opts *DownloadOptions) error {
httpClient, err := opts.HttpClient()
if err != nil {
Expand All @@ -93,8 +125,10 @@ func downloadRun(opts *DownloadOptions) error {
return err
}

var release *shared.Release
opts.IO.StartProgressIndicator()
defer opts.IO.StopProgressIndicator()

var release *shared.Release
if opts.TagName == "" {
release, err = shared.FetchLatestRelease(httpClient, baseRepo)
if err != nil {
Expand All @@ -108,11 +142,22 @@ func downloadRun(opts *DownloadOptions) error {
}

var toDownload []shared.ReleaseAsset
for _, a := range release.Assets {
if len(opts.FilePatterns) > 0 && !matchAny(opts.FilePatterns, a.Name) {
continue
isArchive := false
if opts.ArchiveType != "" {
var archiveURL = release.ZipballURL
if opts.ArchiveType == "tar.gz" {
archiveURL = release.TarballURL
}
// create pseudo-Asset with no name and pointing to ZipBallURL or TarBallURL
toDownload = append(toDownload, shared.ReleaseAsset{APIURL: archiveURL})
isArchive = true
} else {
for _, a := range release.Assets {
if len(opts.FilePatterns) > 0 && !matchAny(opts.FilePatterns, a.Name) {
continue
}
toDownload = append(toDownload, a)
}
toDownload = append(toDownload, a)
}

if len(toDownload) == 0 {
Expand All @@ -129,10 +174,7 @@ func downloadRun(opts *DownloadOptions) error {
}
}

opts.IO.StartProgressIndicator()
err = downloadAssets(httpClient, toDownload, opts.Destination, opts.Concurrency)
opts.IO.StopProgressIndicator()
return err
return downloadAssets(httpClient, toDownload, opts.Destination, opts.Concurrency, isArchive)
}

func matchAny(patterns []string, name string) bool {
Expand All @@ -144,7 +186,7 @@ func matchAny(patterns []string, name string) bool {
return false
}

func downloadAssets(httpClient *http.Client, toDownload []shared.ReleaseAsset, destDir string, numWorkers int) error {
func downloadAssets(httpClient *http.Client, toDownload []shared.ReleaseAsset, destDir string, numWorkers int, isArchive bool) error {
if numWorkers == 0 {
return errors.New("the number of concurrent workers needs to be greater than 0")
}
Expand All @@ -159,7 +201,7 @@ func downloadAssets(httpClient *http.Client, toDownload []shared.ReleaseAsset, d
for w := 1; w <= numWorkers; w++ {
go func() {
for a := range jobs {
results <- downloadAsset(httpClient, a.APIURL, filepath.Join(destDir, a.Name))
results <- downloadAsset(httpClient, a.APIURL, destDir, a.Name, isArchive)
}
}()
}
Expand All @@ -179,13 +221,17 @@ func downloadAssets(httpClient *http.Client, toDownload []shared.ReleaseAsset, d
return downloadError
}

func downloadAsset(httpClient *http.Client, assetURL, destinationPath string) error {
func downloadAsset(httpClient *http.Client, assetURL, destinationDir string, fileName string, isArchive bool) error {
req, err := http.NewRequest("GET", assetURL, nil)
if err != nil {
return err
}

req.Header.Set("Accept", "application/octet-stream")
// adding application/json to Accept header due to a bug in the zipball/tarball API endpoint that makes it mandatory
if isArchive {
req.Header.Set("Accept", "application/octet-stream, application/json")
}

resp, err := httpClient.Do(req)
if err != nil {
Expand All @@ -197,6 +243,22 @@ func downloadAsset(httpClient *http.Client, assetURL, destinationPath string) er
return api.HandleHTTPError(resp)
}

var destinationPath = filepath.Join(destinationDir, fileName)

if len(fileName) == 0 {
contentDisposition := resp.Header.Get("Content-Disposition")

_, params, err := mime.ParseMediaType(contentDisposition)
if err != nil {
return fmt.Errorf("unable to parse file name of archive: %w", err)
}
if serverFileName, ok := params["filename"]; ok {
destinationPath = filepath.Join(destinationDir, serverFileName)
} else {
return errors.New("unable to determine file name of archive")
}
}

f, err := os.OpenFile(destinationPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
return err
Expand Down
85 changes: 83 additions & 2 deletions pkg/cmd/release/download/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,36 @@ func Test_NewCmdDownload(t *testing.T) {
Concurrency: 5,
},
},
{
name: "download archive with valid option",
args: "v1.2.3 -A zip",
isTTY: true,
want: DownloadOptions{
TagName: "v1.2.3",
FilePatterns: []string(nil),
Destination: ".",
ArchiveType: "zip",
Concurrency: 5,
},
},
{
name: "no arguments",
args: "",
isTTY: true,
wantErr: "the '--pattern' flag is required when downloading the latest release",
},
{
name: "simultaneous pattern and archive arguments",
args: "-p * -A zip",
isTTY: true,
wantErr: "specify only one of '--pattern' or '--archive'",
},
{
name: "invalid archive argument",
args: "v1.2.3 -A abc",
isTTY: true,
wantErr: "the value for `--archive` must be one of \"zip\" or \"tar.gz\"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -184,6 +208,36 @@ func Test_downloadRun(t *testing.T) {
wantStderr: ``,
wantErr: "no assets match the file pattern",
},
{
name: "download archive in zip format into destination directory",
isTTY: true,
opts: DownloadOptions{
TagName: "v1.2.3",
ArchiveType: "zip",
Destination: "tmp/packages",
Concurrency: 2,
},
wantStdout: ``,
wantStderr: ``,
wantFiles: []string{
"tmp/packages/zipball.zip",
},
},
{
name: "download archive in `tar.gz` format into destination directory",
isTTY: true,
opts: DownloadOptions{
TagName: "v1.2.3",
ArchiveType: "tar.gz",
Destination: "tmp/packages",
Concurrency: 2,
},
wantStdout: ``,
wantStderr: ``,
wantFiles: []string{
"tmp/packages/tarball.tgz",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -204,12 +258,34 @@ func Test_downloadRun(t *testing.T) {
"url": "https://api.github.com/assets/3456" },
{ "name": "linux.tgz", "size": 56,
"url": "https://api.github.com/assets/5678" }
]
],
"tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3",
"zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3"
}`))
fakeHTTP.Register(httpmock.REST("GET", "assets/1234"), httpmock.StringResponse(`1234`))
fakeHTTP.Register(httpmock.REST("GET", "assets/3456"), httpmock.StringResponse(`3456`))
fakeHTTP.Register(httpmock.REST("GET", "assets/5678"), httpmock.StringResponse(`5678`))

fakeHTTP.Register(
httpmock.REST(
"GET",
"repos/OWNER/REPO/tarball/v1.2.3",
),
httpmock.WithHeader(
httpmock.StringResponse("somedata"), "content-disposition", "attachment; filename=tarball.tgz",
),
)

fakeHTTP.Register(
httpmock.REST(
"GET",
"repos/OWNER/REPO/zipball/v1.2.3",
),
httpmock.WithHeader(
httpmock.StringResponse("somedata"), "content-disposition", "attachment; filename=zipball.zip",
),
)

tt.opts.IO = io
tt.opts.HttpClient = func() (*http.Client, error) {
return &http.Client{Transport: fakeHTTP}, nil
Expand All @@ -226,7 +302,12 @@ func Test_downloadRun(t *testing.T) {
require.NoError(t, err)
}

assert.Equal(t, "application/octet-stream", fakeHTTP.Requests[1].Header.Get("Accept"))
var expectedAcceptHeader = "application/octet-stream"
if len(tt.opts.ArchiveType) > 0 {
expectedAcceptHeader = "application/octet-stream, application/json"
}

assert.Equal(t, expectedAcceptHeader, fakeHTTP.Requests[1].Header.Get("Accept"))

assert.Equal(t, tt.wantStdout, stdout.String())
assert.Equal(t, tt.wantStderr, stderr.String())
Expand Down
11 changes: 11 additions & 0 deletions pkg/httpmock/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ func StringResponse(body string) Responder {
}
}

func WithHeader(responder Responder, header string, value string) Responder {
return func(req *http.Request) (*http.Response, error) {
resp, _ := responder(req)
if resp.Header == nil {
resp.Header = make(http.Header)
}
resp.Header.Set(header, value)
return resp, nil
}
}

func StatusStringResponse(status int, body string) Responder {
return func(req *http.Request) (*http.Response, error) {
return httpResponse(status, req, bytes.NewBufferString(body)), nil
Expand Down

0 comments on commit 8058c4e

Please sign in to comment.