Skip to content

Commit

Permalink
HttpGetter: Add ability to set headers
Browse files Browse the repository at this point in the history
This adds the ability to set headers that will be sent out on every
request of a particular HttpGetter. This is useful in situations where
query parameters are not suitable, such as when headers are explicitly
expected, or when one wants to move information off the query string
that would be at risk of possibly being exposed in logs or error
messages.

Fixes hashicorp#71.
  • Loading branch information
vancluever committed Nov 8, 2018
1 parent 4bda8fa commit cf40a22
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
24 changes: 22 additions & 2 deletions get_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ type HttpGetter struct {
// Client is the http.Client to use for Get requests.
// This defaults to a cleanhttp.DefaultClient if left unset.
Client *http.Client

// Header contains optional request header fields that should be included
// with every HTTP request. Note that the zero value of this field is nil,
// and as such it needs to be initialized before use, via something like
// make(http.Header).
Header http.Header
}

func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) {
Expand Down Expand Up @@ -72,10 +78,17 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error {
u.RawQuery = q.Encode()

// Get the URL
resp, err := g.Client.Get(u.String())
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return err
}

req.Header = g.Header
resp, err := g.Client.Do(req)
if err != nil {
return err
}

defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("bad response code: %d", resp.StatusCode)
Expand Down Expand Up @@ -118,10 +131,17 @@ func (g *HttpGetter) GetFile(dst string, u *url.URL) error {
g.Client = httpClient
}

resp, err := g.Client.Get(u.String())
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return err
}

req.Header = g.Header
resp, err := g.Client.Do(req)
if err != nil {
return err
}

defer resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("bad response code: %d", resp.StatusCode)
Expand Down
40 changes: 40 additions & 0 deletions get_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,34 @@ func TestHttpGetter_header(t *testing.T) {
}
}

func TestHttpGetter_requestHeader(t *testing.T) {
ln := testHttpServer(t)
defer ln.Close()

g := new(HttpGetter)
g.Header = make(http.Header)
g.Header.Add("X-Foobar", "foobar")
dst := tempDir(t)
defer os.RemoveAll(dst)

var u url.URL
u.Scheme = "http"
u.Host = ln.Addr().String()
u.Path = "/expect-header"
u.RawQuery = "expected=X-Foobar"

// Get it!
if err := g.GetFile(dst, &u); err != nil {
t.Fatalf("err: %s", err)
}

// Verify the main file exists
if _, err := os.Stat(dst); err != nil {
t.Fatalf("err: %s", err)
}
assertContents(t, dst, "Hello\n")
}

func TestHttpGetter_meta(t *testing.T) {
ln := testHttpServer(t)
defer ln.Close()
Expand Down Expand Up @@ -255,6 +283,7 @@ func testHttpServer(t *testing.T) net.Listener {
}

mux := http.NewServeMux()
mux.HandleFunc("/expect-header", testHttpHandlerExpectHeader)
mux.HandleFunc("/file", testHttpHandlerFile)
mux.HandleFunc("/header", testHttpHandlerHeader)
mux.HandleFunc("/meta", testHttpHandlerMeta)
Expand All @@ -269,6 +298,17 @@ func testHttpServer(t *testing.T) net.Listener {
return ln
}

func testHttpHandlerExpectHeader(w http.ResponseWriter, r *http.Request) {
if expected, ok := r.URL.Query()["expected"]; ok {
if r.Header.Get(expected[0]) != "" {
w.Write([]byte("Hello\n"))
return
}
}

w.WriteHeader(400)
}

func testHttpHandlerFile(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello\n"))
}
Expand Down

0 comments on commit cf40a22

Please sign in to comment.