diff --git a/get_http.go b/get_http.go index d2e28796d..f66f4e59f 100644 --- a/get_http.go +++ b/get_http.go @@ -41,6 +41,12 @@ type HttpGetter struct { // Client is the http.Client to use for Get requests. // This defaults to a cleanhttp.DefaultClient if left unset. Client *http.Client + + // Header contains optional request header fields that should be included + // with every HTTP request. Note that the zero value of this field is nil, + // and as such it needs to be initialized before use, via something like + // make(http.Header). + Header http.Header } func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) { @@ -72,10 +78,17 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { u.RawQuery = q.Encode() // Get the URL - resp, err := g.Client.Get(u.String()) + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return err + } + + req.Header = g.Header + resp, err := g.Client.Do(req) if err != nil { return err } + defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("bad response code: %d", resp.StatusCode) @@ -118,10 +131,17 @@ func (g *HttpGetter) GetFile(dst string, u *url.URL) error { g.Client = httpClient } - resp, err := g.Client.Get(u.String()) + req, err := http.NewRequest("GET", u.String(), nil) if err != nil { return err } + + req.Header = g.Header + resp, err := g.Client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() if resp.StatusCode != 200 { return fmt.Errorf("bad response code: %d", resp.StatusCode) diff --git a/get_http_test.go b/get_http_test.go index c9ccb77ed..0643932ae 100644 --- a/get_http_test.go +++ b/get_http_test.go @@ -40,6 +40,34 @@ func TestHttpGetter_header(t *testing.T) { } } +func TestHttpGetter_requestHeader(t *testing.T) { + ln := testHttpServer(t) + defer ln.Close() + + g := new(HttpGetter) + g.Header = make(http.Header) + g.Header.Add("X-Foobar", "foobar") + dst := tempDir(t) + defer os.RemoveAll(dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/expect-header" + u.RawQuery = "expected=X-Foobar" + + // Get it! + if err := g.GetFile(dst, &u); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + if _, err := os.Stat(dst); err != nil { + t.Fatalf("err: %s", err) + } + assertContents(t, dst, "Hello\n") +} + func TestHttpGetter_meta(t *testing.T) { ln := testHttpServer(t) defer ln.Close() @@ -255,6 +283,7 @@ func testHttpServer(t *testing.T) net.Listener { } mux := http.NewServeMux() + mux.HandleFunc("/expect-header", testHttpHandlerExpectHeader) mux.HandleFunc("/file", testHttpHandlerFile) mux.HandleFunc("/header", testHttpHandlerHeader) mux.HandleFunc("/meta", testHttpHandlerMeta) @@ -269,6 +298,17 @@ func testHttpServer(t *testing.T) net.Listener { return ln } +func testHttpHandlerExpectHeader(w http.ResponseWriter, r *http.Request) { + if expected, ok := r.URL.Query()["expected"]; ok { + if r.Header.Get(expected[0]) != "" { + w.Write([]byte("Hello\n")) + return + } + } + + w.WriteHeader(400) +} + func testHttpHandlerFile(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello\n")) }