Skip to content

Commit

Permalink
[FIX] Cleanup code (labstack#1061)
Browse files Browse the repository at this point in the history
Code cleanup
  • Loading branch information
im-kulikov authored and vishr committed Feb 21, 2018
1 parent 90d675f commit f49d166
Show file tree
Hide file tree
Showing 19 changed files with 140 additions and 64 deletions.
2 changes: 1 addition & 1 deletion bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
val := reflect.ValueOf(ptr).Elem()

if typ.Kind() != reflect.Struct {
return errors.New("Binding element must be a struct")
return errors.New("binding element must be a struct")
}

for i := 0; i < typ.NumField(); i++ {
Expand Down
3 changes: 1 addition & 2 deletions bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ func TestBindForm(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, MIMEApplicationForm)
obj := []struct{ Field string }{}
err := c.Bind(&obj)
err := c.Bind(&[]struct{ Field string }{})
assert.Error(t, err)
}

Expand Down
11 changes: 4 additions & 7 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@ package echo

import (
"bytes"
"encoding/xml"
"errors"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"text/template"
"time"

"strings"

"net/url"

"encoding/xml"

"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -217,7 +214,7 @@ func TestContext(t *testing.T) {
c.SetParamNames("foo")
c.SetParamValues("bar")
c.Set("foe", "ban")
c.query = url.Values(map[string][]string{"fon": []string{"baz"}})
c.query = url.Values(map[string][]string{"fon": {"baz"}})
c.Reset(req, httptest.NewRecorder())
assert.Equal(t, 0, len(c.ParamValues()))
assert.Equal(t, 0, len(c.ParamNames()))
Expand Down
18 changes: 9 additions & 9 deletions echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,10 @@ var (
ErrForbidden = NewHTTPError(http.StatusForbidden)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
ErrValidatorNotRegistered = errors.New("Validator not registered")
ErrRendererNotRegistered = errors.New("Renderer not registered")
ErrInvalidRedirectCode = errors.New("Invalid redirect status code")
ErrCookieNotFound = errors.New("Cookie not found")
ErrValidatorNotRegistered = errors.New("validator not registered")
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
)

// Error handlers
Expand Down Expand Up @@ -530,7 +530,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {

// Routes returns the registered routes.
func (e *Echo) Routes() []*Route {
routes := []*Route{}
routes := make([]*Route, 0, len(e.router.routes))
for _, v := range e.router.routes {
routes = append(routes, v)
}
Expand Down Expand Up @@ -563,11 +563,11 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Middleware
h := func(c Context) error {
method := r.Method
path := r.URL.RawPath
if path == "" {
path = r.URL.Path
rpath := r.URL.RawPath // raw path
if rpath == "" {
rpath = r.URL.Path
}
e.router.Find(method, path, c)
e.router.Find(method, rpath, c)
h := c.Handler()
for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h)
Expand Down
7 changes: 2 additions & 5 deletions echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@ package echo

import (
"bytes"
"errors"
"net/http"
"net/http/httptest"
"testing"

"reflect"
"strings"

"errors"

"testing"
"time"

"github.com/stretchr/testify/assert"
Expand Down
4 changes: 2 additions & 2 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (g *Group) Match(methods []string, path string, handler HandlerFunc, middle

// Group creates a new sub-group with prefix and optional sub-group-level middleware.
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group {
m := []MiddlewareFunc{}
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
m = append(m, middleware...)
return g.echo.Group(g.prefix+prefix, m...)
Expand All @@ -113,7 +113,7 @@ func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...Midd
// Combine into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls.
m := []MiddlewareFunc{}
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
m = append(m, middleware...)
return g.echo.Add(method, g.prefix+path, handler, m...)
Expand Down
6 changes: 2 additions & 4 deletions middleware/basic_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,8 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
}
}

realm := ""
if config.Realm == defaultRealm {
realm = defaultRealm
} else {
realm := defaultRealm
if config.Realm != defaultRealm {
realm = strconv.Quote(config.Realm)
}

Expand Down
15 changes: 14 additions & 1 deletion middleware/basic_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
Expand All @@ -41,7 +54,7 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))

// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
Expand Down
3 changes: 1 addition & 2 deletions middleware/body_dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ package middleware
import (
"bufio"
"bytes"
"io"
"io/ioutil"
"net"
"net/http"

"io"

"github.com/labstack/echo"
)

Expand Down
56 changes: 56 additions & 0 deletions middleware/body_dump_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -31,10 +32,65 @@ func TestBodyDump(t *testing.T) {
requestBody = string(reqBody)
responseBody = string(resBody)
})

if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
assert.Equal(t, responseBody, hw)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.String())
}

// Must set default skipper
BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
},
})
}

func TestBodyDumpFails(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(echo.POST, "/", strings.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
return errors.New("some error")
}

requestBody := ""
responseBody := ""
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
})

if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}

assert.Panics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: nil,
})
})

assert.NotPanics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
Skipper: func(c echo.Context) bool {
return true
},
Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
},
})

if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}
})
}
2 changes: 1 addition & 1 deletion middleware/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func TestGzip(t *testing.T) {
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body)
if assert.NoError(t, err) {
defer r.Close()
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal(t, "test", buf.String())
}
Expand Down
9 changes: 5 additions & 4 deletions middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,11 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {

req := c.Request()
k, err := c.Cookie(config.CookieName)
token := ""

var token string

// Generate token
if err != nil {
// Generate token
token = random.String(config.TokenLength)
} else {
// Reuse token
Expand Down Expand Up @@ -187,7 +188,7 @@ func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errors.New("Missing csrf token in the form parameter")
return "", errors.New("missing csrf token in the form parameter")
}
return token, nil
}
Expand All @@ -199,7 +200,7 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", errors.New("Missing csrf token in the query string")
return "", errors.New("missing csrf token in the query string")
}
return token, nil
}
Expand Down
2 changes: 1 addition & 1 deletion middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
config.keyFunc = func(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"])
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return config.SigningKey, nil
}
Expand Down
8 changes: 4 additions & 4 deletions middleware/key_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if auth == "" {
return "", errors.New("Missing key in request header")
return "", errors.New("missing key in request header")
}
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", errors.New("Invalid key in the request header")
return "", errors.New("invalid key in the request header")
}
return auth, nil
}
Expand All @@ -132,7 +132,7 @@ func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("Missing key in the query string")
return "", errors.New("missing key in the query string")
}
return key, nil
}
Expand All @@ -143,7 +143,7 @@ func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", errors.New("Missing key in the form")
return "", errors.New("missing key in the form")
}
return key, nil
}
Expand Down
8 changes: 4 additions & 4 deletions middleware/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type (
// Example "${remote_ip} ${status}"
//
// Optional. Default value DefaultLoggerConfig.Format.
Format string `yaml:"format"`
Format string `yaml:"format"`

// Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
CustomTimeFormat string `yaml:"custom_time_format"`
Expand All @@ -70,9 +70,9 @@ var (
`"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
`"latency_human":"${latency_human}","bytes_in":${bytes_in},` +
`"bytes_out":${bytes_out}}` + "\n",
CustomTimeFormat:"2006-01-02 15:04:05.00000",
Output: os.Stdout,
colorer: color.New(),
CustomTimeFormat: "2006-01-02 15:04:05.00000",
Output: os.Stdout,
colorer: color.New(),
}
)

Expand Down
8 changes: 4 additions & 4 deletions middleware/logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ package middleware

import (
"bytes"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"unsafe"

"encoding/json"
"github.com/labstack/echo"
"github.com/stretchr/testify/assert"
"time"
"unsafe"
)

func TestLogger(t *testing.T) {
Expand Down Expand Up @@ -152,7 +152,7 @@ func TestLoggerCustomTimestamp(t *testing.T) {
`"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}",` +
`"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n",
CustomTimeFormat: customTimeFormat,
Output: buf,
Output: buf,
}))

e.GET("/", func(c echo.Context) error {
Expand Down
Loading

0 comments on commit f49d166

Please sign in to comment.