From cf40a2256bd3186122d45694416b2c509caef36d Mon Sep 17 00:00:00 2001 From: Chris Marchesi Date: Wed, 7 Nov 2018 17:40:26 -0800 Subject: [PATCH] HttpGetter: Add ability to set headers This adds the ability to set headers that will be sent out on every request of a particular HttpGetter. This is useful in situations where query parameters are not suitable, such as when headers are explicitly expected, or when one wants to move information off the query string that would be at risk of possibly being exposed in logs or error messages. Fixes #71. --- get_http.go | 24 ++++++++++++++++++++++-- get_http_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/get_http.go b/get_http.go index d2e28796d..f66f4e59f 100644 --- a/get_http.go +++ b/get_http.go @@ -41,6 +41,12 @@ type HttpGetter struct { // Client is the http.Client to use for Get requests. // This defaults to a cleanhttp.DefaultClient if left unset. Client *http.Client + + // Header contains optional request header fields that should be included + // with every HTTP request. Note that the zero value of this field is nil, + // and as such it needs to be initialized before use, via something like + // make(http.Header). + Header http.Header } func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) { @@ -72,10 +78,17 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { u.RawQuery = q.Encode() // Get the URL - resp, err := g.Client.Get(u.String()) + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return err + } + + req.Header = g.Header + resp, err := g.Client.Do(req) if err != nil { return err } + defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("bad response code: %d", resp.StatusCode) @@ -118,10 +131,17 @@ func (g *HttpGetter) GetFile(dst string, u *url.URL) error { g.Client = httpClient } - resp, err := g.Client.Get(u.String()) + req, err := http.NewRequest("GET", u.String(), nil) if err != nil { return err } + + req.Header = g.Header + resp, err := g.Client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() if resp.StatusCode != 200 { return fmt.Errorf("bad response code: %d", resp.StatusCode) diff --git a/get_http_test.go b/get_http_test.go index c9ccb77ed..0643932ae 100644 --- a/get_http_test.go +++ b/get_http_test.go @@ -40,6 +40,34 @@ func TestHttpGetter_header(t *testing.T) { } } +func TestHttpGetter_requestHeader(t *testing.T) { + ln := testHttpServer(t) + defer ln.Close() + + g := new(HttpGetter) + g.Header = make(http.Header) + g.Header.Add("X-Foobar", "foobar") + dst := tempDir(t) + defer os.RemoveAll(dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/expect-header" + u.RawQuery = "expected=X-Foobar" + + // Get it! + if err := g.GetFile(dst, &u); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + if _, err := os.Stat(dst); err != nil { + t.Fatalf("err: %s", err) + } + assertContents(t, dst, "Hello\n") +} + func TestHttpGetter_meta(t *testing.T) { ln := testHttpServer(t) defer ln.Close() @@ -255,6 +283,7 @@ func testHttpServer(t *testing.T) net.Listener { } mux := http.NewServeMux() + mux.HandleFunc("/expect-header", testHttpHandlerExpectHeader) mux.HandleFunc("/file", testHttpHandlerFile) mux.HandleFunc("/header", testHttpHandlerHeader) mux.HandleFunc("/meta", testHttpHandlerMeta) @@ -269,6 +298,17 @@ func testHttpServer(t *testing.T) net.Listener { return ln } +func testHttpHandlerExpectHeader(w http.ResponseWriter, r *http.Request) { + if expected, ok := r.URL.Query()["expected"]; ok { + if r.Header.Get(expected[0]) != "" { + w.Write([]byte("Hello\n")) + return + } + } + + w.WriteHeader(400) +} + func testHttpHandlerFile(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello\n")) }