diff --git a/internal/http/client.go b/internal/http/client.go index dfe43fd..45e293d 100644 --- a/internal/http/client.go +++ b/internal/http/client.go @@ -19,6 +19,15 @@ type OAuthTokenSourceCreator interface { CreateOAuth2TokenSource(ctx context.Context) (oauth2.TokenSource, error) } +// BytesBufferReadCloser allows direct access to the underlying buffer +// so that re-reading ops do not incur a performance penalty. +type BytesBufferReadCloser struct { + *bytes.Buffer +} + +// Close is no-op - avoids http.NewRequest wrapping the body in an io.NopCloser +func (BytesBufferReadCloser) Close() error { return nil } + // retryableAuthTransport wraps a http.RoundTripper and combines it with an OAuthTokenSourceCreator so // that any 401s cause a re-authentication and request retry type retryableAuthTransport struct { @@ -97,16 +106,26 @@ func (t *retryableAuthTransport) RoundTrip(req *http.Request) (*http.Response, e func backupRequestBody(req *http.Request) error { if req.Body != nil && req.GetBody == nil { - bodyBytes, err := io.ReadAll(req.Body) - ios.Close(req.Body) // Ensure the body is always closed - if err != nil { - return err - } + // TODO handle bytes.Reader and strings.Reader + // optimization - we can re-read bytes.Buffer over and over again + if r, ok := req.Body.(BytesBufferReadCloser); ok { + buf := r.Buffer + req.GetBody = func() (io.ReadCloser, error) { + return BytesBufferReadCloser{buf}, nil + } + req.Body, _ = req.GetBody() + } else { + bodyBytes, err := io.ReadAll(req.Body) + ios.Close(req.Body) // Ensure the body is always closed + if err != nil { + return err + } - req.GetBody = func() (io.ReadCloser, error) { - return io.NopCloser(bytes.NewBuffer(bodyBytes)), nil + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewBuffer(bodyBytes)), nil + } + req.Body, _ = req.GetBody() } - req.Body, _ = req.GetBody() } return nil } diff --git a/internal/http/client_test.go b/internal/http/client_test.go index 48e18f9..41d1098 100644 --- a/internal/http/client_test.go +++ b/internal/http/client_test.go @@ -1,6 +1,7 @@ package http_test import ( + "bytes" "context" "github.com/cloudfoundry-community/go-cfclient/v3/internal/http" "github.com/cloudfoundry-community/go-cfclient/v3/testutil" @@ -32,12 +33,13 @@ func (ts *MockedOAuthTokenSource) Token() (*oauth2.Token, error) { } func TestOAuthSessionManager(t *testing.T) { + g := testutil.NewObjectJSONGenerator(1) serverURL := testutil.SetupMultiple([]testutil.MockRoute{ { - Method: "GET", + Method: "POST", Endpoint: "/v3/organizations", - Output: []string{"organizations[]"}, - Statuses: []int{200}, + Output: []string{"auth error", g.Organization().JSON}, + Statuses: []int{401, 201}, UserAgent: "Go-http-client/1.1", }, { @@ -65,14 +67,15 @@ func TestOAuthSessionManager(t *testing.T) { client, err := http.NewAuthenticatedClient(context.Background(), gohttp.DefaultClient, tokenSrcCreator) require.NoError(t, err) - resp, err := client.Get(serverURL + "/v3/organizations") + body := http.BytesBufferReadCloser{Buffer: bytes.NewBufferString(g.Organization().JSON)} + resp, err := client.Post(serverURL+"/v3/organizations", "application/json", body) require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) + require.Equal(t, 201, resp.StatusCode) // to the caller the retry is transparent on 401 resp, err = client.Get(serverURL + "/v3/spaces") require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) - tokenSrcCreator.AssertNumberOfCalls(t, "CreateOAuth2TokenSource", 2) + tokenSrcCreator.AssertNumberOfCalls(t, "CreateOAuth2TokenSource", 3) } diff --git a/internal/http/request.go b/internal/http/request.go index 33060fd..2b320e3 100644 --- a/internal/http/request.go +++ b/internal/http/request.go @@ -57,5 +57,7 @@ func EncodeBody(obj any) (io.Reader, error) { if err := json.NewEncoder(buf).Encode(obj); err != nil { return nil, fmt.Errorf("error encoding object to JSON: %w", err) } - return buf, nil + return &BytesBufferReadCloser{ + buf, + }, nil } diff --git a/testutil/api_mock.go b/testutil/api_mock.go index c5bc69d..7e983fa 100644 --- a/testutil/api_mock.go +++ b/testutil/api_mock.go @@ -115,6 +115,7 @@ func SetupMultiple(mockEndpoints []MockRoute, t *testing.T) string { return status, singleOutput }) case "POST": + count := 0 r.Post(endpoint, func(res http.ResponseWriter, req *http.Request) (int, string) { testUserAgent(req.Header.Get("User-Agent"), userAgent, t) testQueryString(req.URL.RawQuery, queryString, t) @@ -122,7 +123,10 @@ func SetupMultiple(mockEndpoints []MockRoute, t *testing.T) string { if redirectLocation != "" { res.Header().Add("Location", redirectLocation) } - return status, output[0] + singleOutput := output[count] + status = statuses[count] + count++ + return status, singleOutput }) case "DELETE": r.Delete(endpoint, func(res http.ResponseWriter, req *http.Request) (int, string) {