From f49d166e6ff98636dda92ff82b90dcb3054fc2cf Mon Sep 17 00:00:00 2001 From: Evgeniy Kulikov Date: Wed, 21 Feb 2018 21:44:17 +0300 Subject: [PATCH] [FIX] Cleanup code (#1061) Code cleanup --- bind.go | 2 +- bind_test.go | 3 +- context_test.go | 11 +++---- echo.go | 18 +++++------ echo_test.go | 7 ++--- group.go | 4 +-- middleware/basic_auth.go | 6 ++-- middleware/basic_auth_test.go | 15 +++++++++- middleware/body_dump.go | 3 +- middleware/body_dump_test.go | 56 +++++++++++++++++++++++++++++++++++ middleware/compress_test.go | 2 +- middleware/csrf.go | 9 +++--- middleware/jwt.go | 2 +- middleware/key_auth.go | 8 ++--- middleware/logger.go | 8 ++--- middleware/logger_test.go | 8 ++--- middleware/proxy.go | 8 ++--- middleware/proxy_test.go | 30 ++++++++++++++----- router.go | 4 +-- 19 files changed, 140 insertions(+), 64 deletions(-) diff --git a/bind.go b/bind.go index 186bd83d9..38e071504 100644 --- a/bind.go +++ b/bind.go @@ -80,7 +80,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag val := reflect.ValueOf(ptr).Elem() if typ.Kind() != reflect.Struct { - return errors.New("Binding element must be a struct") + return errors.New("binding element must be a struct") } for i := 0; i < typ.NumField(); i++ { diff --git a/bind_test.go b/bind_test.go index 1f1fa4868..2fe59573c 100644 --- a/bind_test.go +++ b/bind_test.go @@ -135,8 +135,7 @@ func TestBindForm(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, MIMEApplicationForm) - obj := []struct{ Field string }{} - err := c.Bind(&obj) + err := c.Bind(&[]struct{ Field string }{}) assert.Error(t, err) } diff --git a/context_test.go b/context_test.go index 1b281ef44..bcbf9751a 100644 --- a/context_test.go +++ b/context_test.go @@ -2,21 +2,18 @@ package echo import ( "bytes" + "encoding/xml" "errors" "io" "mime/multipart" "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "text/template" "time" - "strings" - - "net/url" - - "encoding/xml" - "github.com/stretchr/testify/assert" ) @@ -217,7 +214,7 @@ func TestContext(t *testing.T) { c.SetParamNames("foo") c.SetParamValues("bar") c.Set("foe", "ban") - c.query = url.Values(map[string][]string{"fon": []string{"baz"}}) + c.query = url.Values(map[string][]string{"fon": {"baz"}}) c.Reset(req, httptest.NewRecorder()) assert.Equal(t, 0, len(c.ParamValues())) assert.Equal(t, 0, len(c.ParamNames())) diff --git a/echo.go b/echo.go index 4a54b31ab..cbdd473d3 100644 --- a/echo.go +++ b/echo.go @@ -251,10 +251,10 @@ var ( ErrForbidden = NewHTTPError(http.StatusForbidden) ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) - ErrValidatorNotRegistered = errors.New("Validator not registered") - ErrRendererNotRegistered = errors.New("Renderer not registered") - ErrInvalidRedirectCode = errors.New("Invalid redirect status code") - ErrCookieNotFound = errors.New("Cookie not found") + ErrValidatorNotRegistered = errors.New("validator not registered") + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrCookieNotFound = errors.New("cookie not found") ) // Error handlers @@ -530,7 +530,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string { // Routes returns the registered routes. func (e *Echo) Routes() []*Route { - routes := []*Route{} + routes := make([]*Route, 0, len(e.router.routes)) for _, v := range e.router.routes { routes = append(routes, v) } @@ -563,11 +563,11 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Middleware h := func(c Context) error { method := r.Method - path := r.URL.RawPath - if path == "" { - path = r.URL.Path + rpath := r.URL.RawPath // raw path + if rpath == "" { + rpath = r.URL.Path } - e.router.Find(method, path, c) + e.router.Find(method, rpath, c) h := c.Handler() for i := len(e.middleware) - 1; i >= 0; i-- { h = e.middleware[i](h) diff --git a/echo_test.go b/echo_test.go index a84d310d7..6fdaca5aa 100644 --- a/echo_test.go +++ b/echo_test.go @@ -2,15 +2,12 @@ package echo import ( "bytes" + "errors" "net/http" "net/http/httptest" - "testing" - "reflect" "strings" - - "errors" - + "testing" "time" "github.com/stretchr/testify/assert" diff --git a/group.go b/group.go index f7e61a2e1..5257e83ca 100644 --- a/group.go +++ b/group.go @@ -92,7 +92,7 @@ func (g *Group) Match(methods []string, path string, handler HandlerFunc, middle // Group creates a new sub-group with prefix and optional sub-group-level middleware. func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group { - m := []MiddlewareFunc{} + m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) m = append(m, middleware...) return g.echo.Group(g.prefix+prefix, m...) @@ -113,7 +113,7 @@ func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...Midd // Combine into a new slice to avoid accidentally passing the same slice for // multiple routes, which would lead to later add() calls overwriting the // middleware from earlier calls. - m := []MiddlewareFunc{} + m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) m = append(m, middleware...) return g.echo.Add(method, g.prefix+path, handler, m...) diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 6d6a37b45..e6c963245 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -93,10 +93,8 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { } } - realm := "" - if config.Realm == defaultRealm { - realm = defaultRealm - } else { + realm := defaultRealm + if config.Realm != defaultRealm { realm = strconv.Quote(config.Realm) } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index c1efb306c..93023228b 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -31,6 +31,19 @@ func TestBasicAuth(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, auth) assert.NoError(t, h(c)) + h = BasicAuthWithConfig(BasicAuthConfig{ + Skipper: nil, + Validator: f, + Realm: "someRealm", + })(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Valid credentials + auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) + req.Header.Set(echo.HeaderAuthorization, auth) + assert.NoError(t, h(c)) + // Case-insensitive header scheme auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.HeaderAuthorization, auth) @@ -41,7 +54,7 @@ func TestBasicAuth(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, auth) he := h(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code) - assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) + assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) // Missing Authorization header req.Header.Del(echo.HeaderAuthorization) diff --git a/middleware/body_dump.go b/middleware/body_dump.go index 14cf33d12..e64e5e112 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -3,12 +3,11 @@ package middleware import ( "bufio" "bytes" + "io" "io/ioutil" "net" "net/http" - "io" - "github.com/labstack/echo" ) diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index 1da8b5ad2..188ed4f9e 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "io/ioutil" "net/http" "net/http/httptest" @@ -31,10 +32,65 @@ func TestBodyDump(t *testing.T) { requestBody = string(reqBody) responseBody = string(resBody) }) + if assert.NoError(t, mw(h)(c)) { assert.Equal(t, requestBody, hw) assert.Equal(t, responseBody, hw) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.String()) } + + // Must set default skipper + BodyDumpWithConfig(BodyDumpConfig{ + Skipper: nil, + Handler: func(c echo.Context, reqBody, resBody []byte) { + requestBody = string(reqBody) + responseBody = string(resBody) + }, + }) +} + +func TestBodyDumpFails(t *testing.T) { + e := echo.New() + hw := "Hello, World!" + req := httptest.NewRequest(echo.POST, "/", strings.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + return errors.New("some error") + } + + requestBody := "" + responseBody := "" + mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { + requestBody = string(reqBody) + responseBody = string(resBody) + }) + + if !assert.Error(t, mw(h)(c)) { + t.FailNow() + } + + assert.Panics(t, func() { + mw = BodyDumpWithConfig(BodyDumpConfig{ + Skipper: nil, + Handler: nil, + }) + }) + + assert.NotPanics(t, func() { + mw = BodyDumpWithConfig(BodyDumpConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + Handler: func(c echo.Context, reqBody, resBody []byte) { + requestBody = string(reqBody) + responseBody = string(resBody) + }, + }) + + if !assert.Error(t, mw(h)(c)) { + t.FailNow() + } + }) } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 7302de979..ece13bd6d 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -36,8 +36,8 @@ func TestGzip(t *testing.T) { assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) r, err := gzip.NewReader(rec.Body) if assert.NoError(t, err) { - defer r.Close() buf := new(bytes.Buffer) + defer r.Close() buf.ReadFrom(r) assert.Equal(t, "test", buf.String()) } diff --git a/middleware/csrf.go b/middleware/csrf.go index 0d2b7fd6f..dbf1cc627 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -124,10 +124,11 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { req := c.Request() k, err := c.Cookie(config.CookieName) - token := "" + var token string + + // Generate token if err != nil { - // Generate token token = random.String(config.TokenLength) } else { // Reuse token @@ -187,7 +188,7 @@ func csrfTokenFromForm(param string) csrfTokenExtractor { return func(c echo.Context) (string, error) { token := c.FormValue(param) if token == "" { - return "", errors.New("Missing csrf token in the form parameter") + return "", errors.New("missing csrf token in the form parameter") } return token, nil } @@ -199,7 +200,7 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor { return func(c echo.Context) (string, error) { token := c.QueryParam(param) if token == "" { - return "", errors.New("Missing csrf token in the query string") + return "", errors.New("missing csrf token in the query string") } return token, nil } diff --git a/middleware/jwt.go b/middleware/jwt.go index 47d885b0a..722cade9e 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -116,7 +116,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { config.keyFunc = func(t *jwt.Token) (interface{}, error) { // Check the signing method if t.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"]) + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) } return config.SigningKey, nil } diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 4990afd97..c12f4ca9d 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -114,14 +114,14 @@ func keyFromHeader(header string, authScheme string) keyExtractor { return func(c echo.Context) (string, error) { auth := c.Request().Header.Get(header) if auth == "" { - return "", errors.New("Missing key in request header") + return "", errors.New("missing key in request header") } if header == echo.HeaderAuthorization { l := len(authScheme) if len(auth) > l+1 && auth[:l] == authScheme { return auth[l+1:], nil } - return "", errors.New("Invalid key in the request header") + return "", errors.New("invalid key in the request header") } return auth, nil } @@ -132,7 +132,7 @@ func keyFromQuery(param string) keyExtractor { return func(c echo.Context) (string, error) { key := c.QueryParam(param) if key == "" { - return "", errors.New("Missing key in the query string") + return "", errors.New("missing key in the query string") } return key, nil } @@ -143,7 +143,7 @@ func keyFromForm(param string) keyExtractor { return func(c echo.Context) (string, error) { key := c.FormValue(param) if key == "" { - return "", errors.New("Missing key in the form") + return "", errors.New("missing key in the form") } return key, nil } diff --git a/middleware/logger.go b/middleware/logger.go index c7b80f8c3..87af575ff 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -47,7 +47,7 @@ type ( // Example "${remote_ip} ${status}" // // Optional. Default value DefaultLoggerConfig.Format. - Format string `yaml:"format"` + Format string `yaml:"format"` // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. CustomTimeFormat string `yaml:"custom_time_format"` @@ -70,9 +70,9 @@ var ( `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + `"latency_human":"${latency_human}","bytes_in":${bytes_in},` + `"bytes_out":${bytes_out}}` + "\n", - CustomTimeFormat:"2006-01-02 15:04:05.00000", - Output: os.Stdout, - colorer: color.New(), + CustomTimeFormat: "2006-01-02 15:04:05.00000", + Output: os.Stdout, + colorer: color.New(), } ) diff --git a/middleware/logger_test.go b/middleware/logger_test.go index b869bd8a0..5aa988fe9 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -2,18 +2,18 @@ package middleware import ( "bytes" + "encoding/json" "errors" "net/http" "net/http/httptest" "net/url" "strings" "testing" + "time" + "unsafe" - "encoding/json" "github.com/labstack/echo" "github.com/stretchr/testify/assert" - "time" - "unsafe" ) func TestLogger(t *testing.T) { @@ -152,7 +152,7 @@ func TestLoggerCustomTimestamp(t *testing.T) { `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}",` + `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", CustomTimeFormat: customTimeFormat, - Output: buf, + Output: buf, })) e.GET("/", func(c echo.Context) error { diff --git a/middleware/proxy.go b/middleware/proxy.go index ae3ff527d..f61477372 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -108,15 +108,15 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { return } - errc := make(chan error, 2) + errCh := make(chan error, 2) cp := func(dst io.Writer, src io.Reader) { - _, err := io.Copy(dst, src) - errc <- err + _, err = io.Copy(dst, src) + errCh <- err } go cp(out, in) go cp(in, out) - err = <-errc + err = <-errCh if err != nil && err != io.EOF { c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err) } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 06d93166a..48164baa1 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -4,9 +4,8 @@ import ( "fmt" "net/http" "net/http/httptest" - "testing" - "net/url" + "testing" "github.com/labstack/echo" "github.com/stretchr/testify/assert" @@ -48,14 +47,25 @@ func TestProxy(t *testing.T) { url2, _ := url.Parse(t2.URL) targets := []*ProxyTarget{ - &ProxyTarget{ - URL: url1, + { + Name: "target 1", + URL: url1, }, - &ProxyTarget{ - URL: url2, + { + Name: "target 2", + URL: url2, }, } - rb := NewRandomBalancer(targets) + rb := NewRandomBalancer(nil) + // must add targets: + for _, target := range targets { + assert.True(t, rb.AddTarget(target)) + } + + // must ignore duplicates: + for _, target := range targets { + assert.False(t, rb.AddTarget(target)) + } // Random e := echo.New() @@ -72,6 +82,12 @@ func TestProxy(t *testing.T) { return expected[body] }) + for _, target := range targets { + assert.True(t, rb.RemoveTarget(target.Name)) + } + + assert.False(t, rb.RemoveTarget("unknown target")) + // Round-robin rrb := NewRoundRobinBalancer(targets) e = echo.New() diff --git a/router.go b/router.go index 3af4be0b4..ce81a58f2 100644 --- a/router.go +++ b/router.go @@ -59,8 +59,8 @@ func (r *Router) Add(method, path string, h HandlerFunc) { if path[0] != '/' { path = "/" + path } - ppath := path // Pristine path - pnames := []string{} // Param names + ppath := path // Pristine path + var pnames []string // Param names for i, l := 0, len(path); i < l; i++ { if path[i] == ':' {