diff --git a/go.mod b/go.mod index bb04aa083..27bf55173 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/kyma-project/cli.v3 go 1.21.7 require ( + github.com/gboddin/go-www-authenticate-parser v0.0.0-20230926203616-ec0b649bb077 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 ) diff --git a/go.sum b/go.sum index 8236a263c..21a1ebaab 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gboddin/go-www-authenticate-parser v0.0.0-20230926203616-ec0b649bb077 h1:JvEO7eltd2aCHF+ABLquTUziO7hzC6G7H3tgENYkDBc= +github.com/gboddin/go-www-authenticate-parser v0.0.0-20230926203616-ec0b649bb077/go.mod h1:RlYuEjNYq/NkhOCSkZGPKxP3dgZOBH94UwsQraDng8s= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/internal/btp/client.go b/internal/btp/client.go new file mode 100644 index 000000000..b1d05c427 --- /dev/null +++ b/internal/btp/client.go @@ -0,0 +1,139 @@ +package btp + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + + wwwAuthParser "github.com/gboddin/go-www-authenticate-parser" +) + +type LocalClient struct { + credentials *CISCredentials + cis *httpClient +} + +func NewLocalClient(credentials *CISCredentials, token *XSUAAToken) *LocalClient { + return &LocalClient{ + credentials: credentials, + cis: newHttpClient(token), + } +} + +type oauthTransport struct { + token *XSUAAToken +} + +func (t *oauthTransport) RoundTrip(r *http.Request) (*http.Response, error) { + r.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token.AccessToken)) + + return http.DefaultTransport.RoundTrip(r) +} + +type cisError struct { + Code int `json:"code"` + Message string `json:"message"` + Target string `json:"target"` + CorrelationID string `json:"correlationID"` +} + +type cisErrorResponse struct { + Error cisError `json:"error"` +} + +type requestOptions struct { + Body io.Reader + Headers map[string]string + Query map[string]string +} + +type httpClient struct { + client *http.Client +} + +func newHttpClient(token *XSUAAToken) *httpClient { + return &httpClient{ + client: &http.Client{ + Transport: &oauthTransport{ + token: token, + }, + }, + } +} + +func (c *httpClient) get(url string, options requestOptions) (*http.Response, error) { + return c.genericRequest(http.MethodGet, url, options) +} + +func (c *httpClient) post(url string, options requestOptions) (*http.Response, error) { + return c.genericRequest(http.MethodPost, url, options) +} + +func (c *httpClient) put(url string, options requestOptions) (*http.Response, error) { + return c.genericRequest(http.MethodPut, url, options) +} + +func (c *httpClient) patch(url string, options requestOptions) (*http.Response, error) { + return c.genericRequest(http.MethodPatch, url, options) +} + +func (c *httpClient) genericRequest(method string, url string, options requestOptions) (*http.Response, error) { + request, err := http.NewRequest(method, url, options.Body) + if err != nil { + return nil, fmt.Errorf("failed to build request: %s", err.Error()) + } + + for key, header := range options.Headers { + request.Header.Add(key, header) + } + + if len(options.Query) > 0 { + q := request.URL.Query() + for key, header := range options.Query { + q.Add(key, header) + } + request.URL.RawQuery = q.Encode() + } + + response, err := c.client.Do(request) + if err != nil { + return nil, fmt.Errorf("failed to get data from server: %s", err.Error()) + } + + if response.StatusCode >= 400 { + // error from response (status code higher or equal 400) + return nil, c.buildResponseError(response) + } + + return response, nil +} + +func (c *httpClient) buildResponseError(response *http.Response) error { + errorData := cisErrorResponse{} + err := json.NewDecoder(response.Body).Decode(&errorData) + if err == io.EOF { + // error is possibly located in headers + return c.buildErrorFromHeaders(response) + } + if err != nil { + return fmt.Errorf("failed to decode error response with status '%s': %s", response.Status, err.Error()) + } + + return c.buildErrorFromBody(&errorData) +} + +func (c *httpClient) buildErrorFromBody(errorData *cisErrorResponse) error { + return errors.New(errorData.Error.Message) +} + +func (c *httpClient) buildErrorFromHeaders(response *http.Response) error { + wwwAuthHeaderString := response.Header.Get("Www-Authenticate") + if wwwAuthHeaderString == "" { + return fmt.Errorf("failed to parse http error for status: %s", response.Status) + } + + wwwAuthHeader := wwwAuthParser.Parse(wwwAuthHeaderString) + return fmt.Errorf("%s: %s", wwwAuthHeader.Params["error"], wwwAuthHeader.Params["error_description"]) +} diff --git a/internal/btp/client_test.go b/internal/btp/client_test.go new file mode 100644 index 000000000..9948317ab --- /dev/null +++ b/internal/btp/client_test.go @@ -0,0 +1,207 @@ +package btp + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_oauthTransport_RoundTrip(t *testing.T) { + + t.Parallel() + + t.Run("client bearer authorization", func(t *testing.T) { + svr := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + require.Equal(t, "Bearer token", r.Header.Get("Authorization")) + })) + defer svr.Close() + + req, err := http.NewRequest(http.MethodPost, svr.URL, nil) + require.NoError(t, err) + clientTransport := oauthTransport{ + token: &XSUAAToken{ + AccessToken: "token", + }, + } + _, err = clientTransport.RoundTrip(req) + require.NoError(t, err) + }) +} + +func Test_GenericRequest(t *testing.T) { + + t.Parallel() + + testEmptyServer := httptest.NewServer(http.HandlerFunc(fixGenericRequestHandler(t, requestOptions{}))) + defer testEmptyServer.Close() + + testServer := httptest.NewServer(http.HandlerFunc(fixGenericRequestHandler(t, fixRequestOptions()))) + defer testServer.Close() + + testErrorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(415) + })) + defer testErrorServer.Close() + + t.Run("simple GET request", func(t *testing.T) { + c := httpClient{client: http.DefaultClient} + + response, err := c.genericRequest(http.MethodGet, testEmptyServer.URL, requestOptions{}) + + require.NoError(t, err) + require.NotNil(t, response) + require.Equal(t, 200, response.StatusCode) + + _ = response.Body.Close() + }) + + t.Run("simple POST request with additional data", func(t *testing.T) { + c := httpClient{client: http.DefaultClient} + + response, err := c.genericRequest(http.MethodPost, testServer.URL, fixRequestOptions()) + + require.NoError(t, err) + require.NotNil(t, response) + require.Equal(t, 200, response.StatusCode) + + _ = response.Body.Close() + }) + + t.Run("build request error becuse of wrong method name", func(t *testing.T) { + c := httpClient{client: http.DefaultClient} + + response, err := c.genericRequest("DoEsNoTeXiSt)", testServer.URL, requestOptions{}) + + require.Equal(t, errors.New("failed to build request: net/http: invalid method \"DoEsNoTeXiSt)\""), err) + require.Nil(t, response) + }) + + t.Run("cant reach server by URL error", func(t *testing.T) { + c := httpClient{client: http.DefaultClient} + + response, err := c.genericRequest(http.MethodGet, "http://does-not-exist", requestOptions{}) + + require.Equal(t, errors.New("failed to get data from server: Get \"http://does-not-exist\": dial tcp: lookup does-not-exist: no such host"), err) + require.Nil(t, response) + }) + + t.Run("handle 415 response status", func(t *testing.T) { + c := httpClient{client: http.DefaultClient} + + response, err := c.genericRequest(http.MethodGet, testErrorServer.URL, requestOptions{}) + + require.Equal(t, errors.New("failed to parse http error for status: 415 Unsupported Media Type"), err) + require.Nil(t, response) + }) +} + +func Test_httpClient_buildResponseError(t *testing.T) { + + t.Parallel() + + tests := []struct { + name string + response *http.Response + expectedErr error + }{ + { + name: "build error from status", + response: &http.Response{ + Status: "Unauthorized", + Body: io.NopCloser(strings.NewReader("")), + }, + expectedErr: errors.New("failed to parse http error for status: Unauthorized"), + }, + { + name: "build error from header", + response: &http.Response{ + Status: "Unauthorized", + Header: http.Header{ + "Www-Authenticate": []string{"Bearer error=\"error\",error_description=\"description\""}, + }, + Body: io.NopCloser(strings.NewReader("")), + }, + expectedErr: errors.New("error: description"), + }, + { + name: "build error from body", + response: &http.Response{ + Status: "Unauthorized", + Body: io.NopCloser(strings.NewReader(`{ + "error": { + "code": 123, + "message": "message", + "target": "target", + "correlationID": "correlationID" + } + }`)), + }, + expectedErr: errors.New("message"), + }, + { + name: "decode response error", + response: &http.Response{ + Status: "Unauthorized", + Body: io.NopCloser(strings.NewReader("[test=value]")), + }, + expectedErr: errors.New("failed to decode error response with status 'Unauthorized': invalid character 'e' in literal true (expecting 'r')"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &httpClient{} + + err := c.buildResponseError(tt.response) + + require.Equal(t, tt.expectedErr, err) + }) + } +} + +func fixRequestOptions() requestOptions { + return requestOptions{ + Body: strings.NewReader("test data"), + Headers: map[string]string{ + "Test-Header": "test-header-value", + }, + Query: map[string]string{ + "test-query": "test-query-value", + }, + } +} + +func fixGenericRequestHandler(t *testing.T, expectedOptions requestOptions) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + for key, expectedValue := range expectedOptions.Query { + value, ok := r.URL.Query()[key] + require.True(t, ok) + require.Equal(t, expectedValue, value[0]) + } + + for key, expectedValue := range expectedOptions.Headers { + value, ok := r.Header[key] + require.True(t, ok) + require.Equal(t, expectedValue, value[0]) + } + + data := make([]byte, 0) + expectedData := make([]byte, 0) + var err error + if r.Body != nil { + data, err = io.ReadAll(r.Body) + require.NoError(t, err) + } + if expectedOptions.Body != nil { + expectedData, err = io.ReadAll(expectedOptions.Body) + require.NoError(t, err) + } + require.Equal(t, expectedData, data) + + w.WriteHeader(200) + } +}