diff --git a/get_http.go b/get_http.go index 5e8a8cfd7..618a411f9 100644 --- a/get_http.go +++ b/get_http.go @@ -9,7 +9,6 @@ import ( "net/url" "os" "path/filepath" - "strconv" "strings" safetemp "github.com/hashicorp/go-safetemp" @@ -89,7 +88,7 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { } if g.Header != nil { - req.Header = g.Header + req.Header = g.Header.Clone() } resp, err := g.Client.Do(req) @@ -131,6 +130,12 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { return g.getSubdir(ctx, dst, source, subDir) } +// GetFile fetches the file from src and stores it at dst. +// If the server supports Accept-Range, HttpGetter will attempt a range +// request. This means it is the caller's responsibility to ensure that an +// older version of the destination file does not exist, else it will be either +// falsely identified as being replaced, or corrupted with extra bytes +// appended. func (g *HttpGetter) GetFile(dst string, src *url.URL) error { ctx := g.Context() if g.Netrc { @@ -139,7 +144,6 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { return err } } - // Create all the parent directories if needed if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { return err @@ -165,21 +169,20 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { return err } if g.Header != nil { - req.Header = g.Header + req.Header = g.Header.Clone() } headResp, err := g.Client.Do(req) - if err == nil && headResp != nil { + if err == nil { headResp.Body.Close() if headResp.StatusCode == 200 { // If the HEAD request succeeded, then attempt to set the range // query if we can. - if headResp.Header.Get("Accept-Ranges") == "bytes" { + if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { if fi, err := f.Stat(); err == nil { - if _, err = f.Seek(0, os.SEEK_END); err == nil { - req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) + if _, err = f.Seek(0, io.SeekEnd); err == nil { currentFileSize = fi.Size() - totalFileSize, _ := strconv.ParseInt(headResp.Header.Get("Content-Length"), 10, 64) - if currentFileSize >= totalFileSize { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) + if currentFileSize >= headResp.ContentLength { // file already present return nil } diff --git a/get_http_test.go b/get_http_test.go index 9424e614f..6d7fb90ac 100644 --- a/get_http_test.go +++ b/get_http_test.go @@ -222,6 +222,57 @@ func TestHttpGetter_resume(t *testing.T) { } } +// The server may support Byte-Range, but has no size for the requested object +func TestHttpGetter_resumeNoRange(t *testing.T) { + load := []byte(testHttpMetaStr) + sha := sha256.New() + if n, err := sha.Write(load); n != len(load) || err != nil { + t.Fatalf("sha write failed: %d, %s", n, err) + } + checksum := hex.EncodeToString(sha.Sum(nil)) + downloadFrom := len(load) / 2 + + ln := testHttpServer(t) + defer ln.Close() + + dst := tempDir(t) + defer os.RemoveAll(dst) + + dst = filepath.Join(dst, "..", "range") + f, err := os.Create(dst) + if err != nil { + t.Fatalf("create: %v", err) + } + if n, err := f.Write(load[:downloadFrom]); n != downloadFrom || err != nil { + t.Fatalf("partial file write failed: %d, %s", n, err) + } + if err := f.Close(); err != nil { + t.Fatalf("close failed: %s", err) + } + + u := url.URL{ + Scheme: "http", + Host: ln.Addr().String(), + Path: "/no-range", + RawQuery: "checksum=" + checksum, + } + t.Logf("url: %s", u.String()) + + // Finish getting it! + if err := GetFile(dst, u.String()); err != nil { + t.Fatalf("finishing download should not error: %v", err) + } + + b, err := ioutil.ReadFile(dst) + if err != nil { + t.Fatalf("readfile failed: %v", err) + } + + if string(b) != string(load) { + t.Fatalf("file differs: got:\n%s\n expected:\n%s\n", string(b), string(load)) + } +} + func TestHttpGetter_file(t *testing.T) { ln := testHttpServer(t) defer ln.Close() @@ -351,6 +402,7 @@ func testHttpServer(t *testing.T) net.Listener { mux.HandleFunc("/meta-subdir", testHttpHandlerMetaSubdir) mux.HandleFunc("/meta-subdir-glob", testHttpHandlerMetaSubdirGlob) mux.HandleFunc("/range", testHttpHandlerRange) + mux.HandleFunc("/no-range", testHttpHandlerNoRange) var server http.Server server.Handler = mux @@ -428,6 +480,20 @@ func testHttpHandlerRange(w http.ResponseWriter, r *http.Request) { } } +func testHttpHandlerNoRange(w http.ResponseWriter, r *http.Request) { + load := []byte(testHttpMetaStr) + switch r.Method { + case "HEAD": + // we support range, but the object size isn't known + w.Header().Add("accept-ranges", "bytes") + default: + if r.Header.Get("Range") != "" { + http.Error(w, "range not supported", http.StatusBadRequest) + } + w.Write(load) + } +} + const testHttpMetaStr = `