Skip to content

Commit

Permalink
allow to track download progress by passing options
Browse files Browse the repository at this point in the history
Also set default Getters, Decompressors & Detectors in Configure func with the idea to add a validate func that will validate if src url will work in the future
  • Loading branch information
azr committed Dec 14, 2018
1 parent 17c7d12 commit 572fb75
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 54 deletions.
30 changes: 10 additions & 20 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ type Client struct {
//
// WARNING: deprecated. If Mode is set, that will take precedence.
Dir bool

// ProgressListener allows to track file downloads.
// By default a no op progress listener is used.
ProgressListener ProgressTracker

Options []ClientOption
}

// Get downloads the configured source to the destination.
Expand All @@ -76,18 +82,7 @@ func (c *Client) Get() error {
}
}

// Default decompressor value
decompressors := c.Decompressors
if decompressors == nil {
decompressors = Decompressors
}

// Detect the URL. This is safe if it is already detected.
detectors := c.Detectors
if detectors == nil {
detectors = Detectors
}
src, err := Detect(c.Src, c.Pwd, detectors)
src, err := Detect(c.Src, c.Pwd, c.Detectors)
if err != nil {
return err
}
Expand Down Expand Up @@ -119,12 +114,7 @@ func (c *Client) Get() error {
force = u.Scheme
}

getters := c.Getters
if getters == nil {
getters = Getters
}

g, ok := getters[force]
g, ok := c.Getters[force]
if !ok {
return fmt.Errorf(
"download not supported for scheme '%s'", force)
Expand All @@ -150,7 +140,7 @@ func (c *Client) Get() error {
if archiveV == "" {
// We don't appear to... but is it part of the filename?
matchingLen := 0
for k, _ := range decompressors {
for k := range c.Decompressors {
if strings.HasSuffix(u.Path, "."+k) && len(k) > matchingLen {
archiveV = k
matchingLen = len(k)
Expand All @@ -163,7 +153,7 @@ func (c *Client) Get() error {
// real path.
var decompressDst string
var decompressDir bool
decompressor := decompressors[archiveV]
decompressor := c.Decompressors[archiveV]
if decompressor != nil {
// Create a temporary directory to store our archive. We delete
// this at the end of everything.
Expand Down
18 changes: 18 additions & 0 deletions client_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,29 @@ type ClientOption func(*Client) error

// Configure configures a client with options.
func (c *Client) Configure(opts ...ClientOption) error {
c.Options = opts
c.ProgressListener = noopProgressListener
for _, opt := range opts {
err := opt(c)
if err != nil {
return err
}
}
// Default decompressor values
if c.Decompressors == nil {
c.Decompressors = Decompressors
}
// Default detector values
if c.Detectors == nil {
c.Detectors = Detectors
}
// Default getter values
if c.Getters == nil {
c.Getters = Getters
}

for _, getter := range c.Getters {
getter.SetClient(c)
}
return nil
}
46 changes: 46 additions & 0 deletions client_option_progress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package getter

import (
"io"
)

// WithProgress allows for a user to track
// the progress of a download.
// For example by displaying a progress bar with
// current download.
// Not all getters have progress support yet.
func WithProgress(pl ProgressTracker) func(*Client) error {
return func(c *Client) error {
c.ProgressListener = pl
return nil
}
}

// ProgressTracker allows to track the progress of downloads.
type ProgressTracker interface {
// TrackProgress should be called when
// a new object is being downloaded.
// src is the location the file is
// downloaded from.
// size is the total size in bytes,
// size can be zero if the file size
// is not known.
// stream is the file being downloaded, every
// written byte will add up to processed size.
//
// TrackProgress returns a ReadCloser that wraps the
// download in progress ( stream ).
// When the download is finished, body shall be closed.
TrackProgress(src string, size int64, stream io.ReadCloser) (body io.ReadCloser)
}

// NoopProgressListener is a progress listener
// that has no effect.
type NoopProgressListener struct{}

var noopProgressListener ProgressTracker = &NoopProgressListener{}

// TrackProgress is a no op
func (*NoopProgressListener) TrackProgress(_ string, _ int64, stream io.ReadCloser) io.ReadCloser {
return stream
}
5 changes: 4 additions & 1 deletion cmd/go-getter/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ func main() {
Mode: mode,
}

if err := client.Configure(); err != nil {
log.Fatalf("Configure: %s", err)
}

if err := client.Get(); err != nil {
log.Fatalf("Error downloading: %s", err)
os.Exit(1)
}

log.Println("Success!")
Expand Down
26 changes: 14 additions & 12 deletions get.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ type Getter interface {
// ClientMode returns the mode based on the given URL. This is used to
// allow clients to let the getters decide which mode to use.
ClientMode(*url.URL) (ClientMode, error)

// SetClient allows a getter to know it's client
// in order to access client's Get functions or
// progress tracking.
SetClient(*Client)
}

// Getters is the mapping of scheme to the Getter implementation that will
Expand Down Expand Up @@ -76,10 +81,9 @@ func init() {
// folder doesn't need to exist. It will be created if it doesn't exist.
func Get(dst, src string, opts ...ClientOption) error {
c := &Client{
Src: src,
Dst: dst,
Dir: true,
Getters: Getters,
Src: src,
Dst: dst,
Dir: true,
}
if err := c.Configure(opts...); err != nil {
return err
Expand All @@ -95,10 +99,9 @@ func Get(dst, src string, opts ...ClientOption) error {
// archive, it will be unpacked directly into dst.
func GetAny(dst, src string, opts ...ClientOption) error {
c := &Client{
Src: src,
Dst: dst,
Mode: ClientModeAny,
Getters: Getters,
Src: src,
Dst: dst,
Mode: ClientModeAny,
}
if err := c.Configure(opts...); err != nil {
return err
Expand All @@ -110,10 +113,9 @@ func GetAny(dst, src string, opts ...ClientOption) error {
// dst.
func GetFile(dst, src string, opts ...ClientOption) error {
c := &Client{
Src: src,
Dst: dst,
Dir: false,
Getters: Getters,
Src: src,
Dst: dst,
Dir: false,
}
if err := c.Configure(opts...); err != nil {
return err
Expand Down
9 changes: 9 additions & 0 deletions get_base.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package getter

// getter is our base getter; it regroups
// fields all getters have in common.
type getter struct {
client *Client
}

func (g *getter) SetClient(c *Client) { g.client = c }
2 changes: 2 additions & 0 deletions get_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
// FileGetter is a Getter implementation that will download a module from
// a file scheme.
type FileGetter struct {
getter

// Copy, if set to true, will copy data instead of using a symlink
Copy bool
}
Expand Down
4 changes: 3 additions & 1 deletion get_git.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import (

// GitGetter is a Getter implementation that will download a module from
// a git repository.
type GitGetter struct{}
type GitGetter struct {
getter
}

func (g *GitGetter) ClientMode(_ *url.URL) (ClientMode, error) {
return ClientModeDir, nil
Expand Down
4 changes: 3 additions & 1 deletion get_hg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import (

// HgGetter is a Getter implementation that will download a module from
// a Mercurial repository.
type HgGetter struct{}
type HgGetter struct {
getter
}

func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) {
return ClientModeDir, nil
Expand Down
68 changes: 50 additions & 18 deletions get_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
//
// For file downloads, HTTP is used directly.
//
// The protocol for downloading a directory from an HTTP endpoing is as follows:
// The protocol for downloading a directory from an HTTP endpoint is as follows:
//
// An HTTP GET request is made to the URL with the additional GET parameter
// "terraform-get=1". This lets you handle that scenario specially if you
Expand All @@ -34,6 +34,8 @@ import (
// formed URL. The shorthand syntax of "github.com/foo/bar" or relative
// paths are not allowed.
type HttpGetter struct {
getter

// Netrc, if true, will lookup and use auth information found
// in the user's netrc file if available.
Netrc bool
Expand Down Expand Up @@ -112,52 +114,82 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error {
// into a temporary directory, then copy over the proper subdir.
source, subDir := SourceDirSubdir(source)
if subDir == "" {
return Get(dst, source)
return Get(dst, source, g.client.Options...)
}

// We have a subdir, time to jump some hoops
return g.getSubdir(dst, source, subDir)
}

func (g *HttpGetter) GetFile(dst string, u *url.URL) error {
func (g *HttpGetter) GetFile(dst string, src *url.URL) error {
if g.Netrc {
// Add auth from netrc if we can
if err := addAuthFromNetrc(u); err != nil {
if err := addAuthFromNetrc(src); err != nil {
return err
}
}

// Create all the parent directories if needed
if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
return err
}

f, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
if err != nil {
return err
}

if g.Client == nil {
g.Client = httpClient
}

req, err := http.NewRequest("GET", u.String(), nil)
var current int64

// We first make a HEAD request so we can check
// if the server supports range queries. If the server/URL doesn't
// support HEAD requests, we just fall back to GET.
req, err := http.NewRequest("HEAD", src.String(), nil)
if err != nil {
return err
}
if g.Header != nil {
req.Header = g.Header
}
headResp, err := g.Client.Do(req)
if err == nil && headResp != nil {
if headResp.StatusCode == 200 {
// If the HEAD request succeeded, then attempt to set the range
// query if we can.
if headResp.Header.Get("Accept-Ranges") == "bytes" {
if fi, err := f.Stat(); err == nil {
if _, err = f.Seek(0, os.SEEK_END); err == nil {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
current = fi.Size()
}
}
}
}
headResp.Body.Close()
}
req.Method = "GET"

req.Header = g.Header
resp, err := g.Client.Do(req)
if err != nil {
return err
}

defer resp.Body.Close()
if resp.StatusCode != 200 {
resp.Body.Close()
return fmt.Errorf("bad response code: %d", resp.StatusCode)
}

// Create all the parent directories
if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
return err
}

f, err := os.Create(dst)
if err != nil {
return err
}
println(g.client)
println(src.String())
// track download
body := g.client.ProgressListener.TrackProgress(src.String(), current+resp.ContentLength, resp.Body)
defer body.Close()

n, err := io.Copy(f, resp.Body)
n, err := io.Copy(f, body)
if err == nil && n < resp.ContentLength {
err = io.ErrShortWrite
}
Expand All @@ -179,7 +211,7 @@ func (g *HttpGetter) getSubdir(dst, source, subDir string) error {
defer tdcloser.Close()

// Download that into the given directory
if err := Get(td, source); err != nil {
if err := Get(td, source, g.client.Options...); err != nil {
return err
}

Expand Down
2 changes: 2 additions & 0 deletions get_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (

// MockGetter is an implementation of Getter that can be used for tests.
type MockGetter struct {
getter

// Proxy, if set, will be called after recording the calls below.
// If it isn't set, then the *Err values will be returned.
Proxy Getter
Expand Down
4 changes: 3 additions & 1 deletion get_s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import (

// S3Getter is a Getter implementation that will download a module from
// a S3 bucket.
type S3Getter struct{}
type S3Getter struct {
getter
}

func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
// Parse URL
Expand Down

0 comments on commit 572fb75

Please sign in to comment.