diff --git a/.travis.yml b/.travis.yml index 4a79893ff..74bf30330 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,8 +10,6 @@ before_install: script: - go test -coverprofile=echo.coverprofile - go test -coverprofile=middleware.coverprofile ./middleware - - go test -coverprofile=engine_standatd.coverprofile ./engine/standard - - go test -coverprofile=engine_fasthttp.coverprofile ./engine/fasthttp - $HOME/gopath/bin/gover - $HOME/gopath/bin/goveralls -coverprofile=gover.coverprofile -service=travis-ci matrix: diff --git a/binder.go b/binder.go index 59a3afec6..4df466ce0 100644 --- a/binder.go +++ b/binder.go @@ -22,43 +22,47 @@ type ( func (b *binder) Bind(i interface{}, c Context) (err error) { req := c.Request() - if req.Method() == GET { + if req.Method == GET { if err = b.bindData(i, c.QueryParams()); err != nil { - err = NewHTTPError(http.StatusBadRequest, err.Error()) + return NewHTTPError(http.StatusBadRequest, err.Error()) } return } - ctype := req.Header().Get(HeaderContentType) - if req.Body() == nil { - err = NewHTTPError(http.StatusBadRequest, "request body can't be empty") - return + ctype := req.Header.Get(HeaderContentType) + if req.ContentLength == 0 { + return NewHTTPError(http.StatusBadRequest, "request body can't be empty") } - err = ErrUnsupportedMediaType switch { case strings.HasPrefix(ctype, MIMEApplicationJSON): - if err = json.NewDecoder(req.Body()).Decode(i); err != nil { + if err = json.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*json.UnmarshalTypeError); ok { - err = NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unmarshal type error: expected=%v, got=%v, offset=%v", ute.Type, ute.Value, ute.Offset)) + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unmarshal type error: expected=%v, got=%v, offset=%v", ute.Type, ute.Value, ute.Offset)) } else if se, ok := err.(*json.SyntaxError); ok { - err = NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: offset=%v, error=%v", se.Offset, se.Error())) + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: offset=%v, error=%v", se.Offset, se.Error())) } else { - err = NewHTTPError(http.StatusBadRequest, err.Error()) + return NewHTTPError(http.StatusBadRequest, err.Error()) } } case strings.HasPrefix(ctype, MIMEApplicationXML): - if err = xml.NewDecoder(req.Body()).Decode(i); err != nil { + if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*xml.UnsupportedTypeError); ok { - err = NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) } else if se, ok := err.(*xml.SyntaxError); ok { - err = NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: line=%v, error=%v", se.Line, se.Error())) + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: line=%v, error=%v", se.Line, se.Error())) } else { - err = NewHTTPError(http.StatusBadRequest, err.Error()) + return NewHTTPError(http.StatusBadRequest, err.Error()) } } case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): - if err = b.bindData(i, req.FormParams()); err != nil { - err = NewHTTPError(http.StatusBadRequest, err.Error()) + params, err := c.FormParams() + if err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()) + } + if err = b.bindData(i, params); err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()) } + default: + return ErrUnsupportedMediaType } return } @@ -100,8 +104,8 @@ func (b *binder) bindData(ptr interface{}, data map[string][]string) error { if structFieldKind == reflect.Slice && numElems > 0 { sliceOf := structField.Type().Elem().Kind() slice := reflect.MakeSlice(structField.Type(), numElems, numElems) - for i := 0; i < numElems; i++ { - if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil { + for j := 0; j < numElems; j++ { + if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil { return err } } diff --git a/binder_test.go b/binder_test.go index 39117c9a6..662782be0 100644 --- a/binder_test.go +++ b/binder_test.go @@ -5,11 +5,11 @@ import ( "io" "mime/multipart" "net/http" + "net/http/httptest" "reflect" "strings" "testing" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) @@ -70,19 +70,19 @@ func TestBinderForm(t *testing.T) { testBinderOkay(t, strings.NewReader(userForm), MIMEApplicationForm) testBinderError(t, nil, MIMEApplicationForm) e := New() - req := test.NewRequest(POST, "/", strings.NewReader(userForm)) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(POST, "/", strings.NewReader(userForm)) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) - req.Header().Set(HeaderContentType, MIMEApplicationForm) - var obj = make([]struct{ Field string }, 0) + req.Header.Set(HeaderContentType, MIMEApplicationForm) + obj := []struct{ Field string }{} err := c.Bind(&obj) assert.Error(t, err) } func TestBinderQueryParams(t *testing.T) { e := New() - req := test.NewRequest(GET, "/?id=1&name=Jon Snow", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(GET, "/?id=1&name=Jon Snow", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) err := c.Bind(u) @@ -105,11 +105,6 @@ func TestBinderUnsupportedMediaType(t *testing.T) { testBinderError(t, strings.NewReader(invalidContent), MIMEApplicationJSON) } -// func assertCustomer(t *testing.T, c *user) { -// assert.Equal(t, 1, c.ID) -// assert.Equal(t, "Joe", c.Name) -// } - func TestBinderbindForm(t *testing.T) { ts := new(binderTestStruct) b := new(binder) @@ -201,10 +196,10 @@ func assertBinderTestStruct(t *testing.T, ts *binderTestStruct) { func testBinderOkay(t *testing.T, r io.Reader, ctype string) { e := New() - req := test.NewRequest(POST, "/", r) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(POST, "/", r) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) - req.Header().Set(HeaderContentType, ctype) + req.Header.Set(HeaderContentType, ctype) u := new(user) err := c.Bind(u) if assert.NoError(t, err) { @@ -215,10 +210,10 @@ func testBinderOkay(t *testing.T, r io.Reader, ctype string) { func testBinderError(t *testing.T, r io.Reader, ctype string) { e := New() - req := test.NewRequest(POST, "/", r) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(POST, "/", r) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) - req.Header().Set(HeaderContentType, ctype) + req.Header.Set(HeaderContentType, ctype) u := new(user) err := c.Bind(u) diff --git a/context.go b/context.go index ebc2ac2e2..54c2fd294 100644 --- a/context.go +++ b/context.go @@ -7,12 +7,14 @@ import ( "io" "mime" "mime/multipart" + "net" "net/http" + "net/url" "os" "path/filepath" + "strings" "time" - "github.com/labstack/echo/engine" "github.com/labstack/echo/log" "bytes" @@ -32,11 +34,21 @@ type ( // SetStdContext sets `context.Context`. SetStdContext(context.Context) - // Request returns `engine.Request` interface. - Request() engine.Request + // Request returns `*http.Request`. + Request() *http.Request - // Request returns `engine.Response` interface. - Response() engine.Response + // Request returns `*Response`. + Response() *Response + + // IsTLS returns true if HTTP connection is TLS otherwise false. + IsTLS() bool + + // Scheme returns the HTTP protocol scheme, `http` or `https`. + Scheme() string + + // RealIP returns the client's network address based on `X-Forwarded-For` + // or `X-Real-IP` request header. + RealIP() string // Path returns the registered path for the handler. Path() string @@ -62,41 +74,35 @@ type ( // SetParamValues sets path parameter values. SetParamValues(...string) - // QueryParam returns the query param for the provided name. It is an alias - // for `engine.URL#QueryParam()`. + // QueryParam returns the query param for the provided name. QueryParam(string) string - // QueryParams returns the query parameters as map. - // It is an alias for `engine.URL#QueryParams()`. - QueryParams() map[string][]string + // QueryParams returns the query parameters as `url.Values`. + QueryParams() url.Values - // FormValue returns the form field value for the provided name. It is an - // alias for `engine.Request#FormValue()`. + // QueryString returns the URL query string. + QueryString() string + + // FormValue returns the form field value for the provided name. FormValue(string) string - // FormParams returns the form parameters as map. - // It is an alias for `engine.Request#FormParams()`. - FormParams() map[string][]string + // FormParams returns the form parameters as `url.Values`. + FormParams() (url.Values, error) - // FormFile returns the multipart form file for the provided name. It is an - // alias for `engine.Request#FormFile()`. + // FormFile returns the multipart form file for the provided name. FormFile(string) (*multipart.FileHeader, error) // MultipartForm returns the multipart form. - // It is an alias for `engine.Request#MultipartForm()`. MultipartForm() (*multipart.Form, error) // Cookie returns the named cookie provided in the request. - // It is an alias for `engine.Request#Cookie()`. - Cookie(string) (engine.Cookie, error) + Cookie(string) (*http.Cookie, error) // SetCookie adds a `Set-Cookie` header in HTTP response. - // It is an alias for `engine.Response#SetCookie()`. - SetCookie(engine.Cookie) + SetCookie(*http.Cookie) // Cookies returns the HTTP cookies sent with the request. - // It is an alias for `engine.Request#Cookies()`. - Cookies() []engine.Cookie + Cookies() []*http.Cookie // Get retrieves data from the context. Get(string) interface{} @@ -184,23 +190,25 @@ type ( // Reset resets the context after request completes. It must be called along // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. // See `Echo#ServeHTTP()` - Reset(engine.Request, engine.Response) + Reset(*http.Request, http.ResponseWriter) } echoContext struct { context context.Context - request engine.Request - response engine.Response + request *http.Request + response *Response path string pnames []string pvalues []string + query url.Values handler HandlerFunc echo *Echo } ) const ( - indexPage = "index.html" + defaultMemory = 32 << 20 // 32 MB + indexPage = "index.html" ) func (c *echoContext) StdContext() context.Context { @@ -227,14 +235,39 @@ func (c *echoContext) Value(key interface{}) interface{} { return c.context.Value(key) } -func (c *echoContext) Request() engine.Request { +func (c *echoContext) Request() *http.Request { return c.request } -func (c *echoContext) Response() engine.Response { +func (c *echoContext) Response() *Response { return c.response } +func (c *echoContext) IsTLS() bool { + return c.request.TLS != nil +} + +func (c *echoContext) Scheme() string { + // Can't use `r.Request.URL.Scheme` + // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 + if c.IsTLS() { + return "https" + } + return "http" +} + +func (c *echoContext) RealIP() string { + ra := c.request.RemoteAddr + if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { + ra = ip + } else if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { + ra = ip + } else { + ra, _, _ = net.SplitHostPort(ra) + } + return ra +} + func (c *echoContext) Path() string { return c.path } @@ -279,38 +312,59 @@ func (c *echoContext) SetParamValues(values ...string) { } func (c *echoContext) QueryParam(name string) string { - return c.request.URL().QueryParam(name) + if c.query == nil { + c.query = c.request.URL.Query() + } + return c.query.Get(name) } -func (c *echoContext) QueryParams() map[string][]string { - return c.request.URL().QueryParams() +func (c *echoContext) QueryParams() url.Values { + if c.query == nil { + c.query = c.request.URL.Query() + } + return c.query +} + +func (c *echoContext) QueryString() string { + return c.request.URL.RawQuery } func (c *echoContext) FormValue(name string) string { return c.request.FormValue(name) } -func (c *echoContext) FormParams() map[string][]string { - return c.request.FormParams() +func (c *echoContext) FormParams() (url.Values, error) { + if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) { + if err := c.request.ParseMultipartForm(defaultMemory); err != nil { + return nil, err + } + } else { + if err := c.request.ParseForm(); err != nil { + return nil, err + } + } + return c.request.Form, nil } func (c *echoContext) FormFile(name string) (*multipart.FileHeader, error) { - return c.request.FormFile(name) + _, fh, err := c.request.FormFile(name) + return fh, err } func (c *echoContext) MultipartForm() (*multipart.Form, error) { - return c.request.MultipartForm() + err := c.request.ParseMultipartForm(defaultMemory) + return c.request.MultipartForm, err } -func (c *echoContext) Cookie(name string) (engine.Cookie, error) { +func (c *echoContext) Cookie(name string) (*http.Cookie, error) { return c.request.Cookie(name) } -func (c *echoContext) SetCookie(cookie engine.Cookie) { - c.response.SetCookie(cookie) +func (c *echoContext) SetCookie(cookie *http.Cookie) { + http.SetCookie(c.Response(), cookie) } -func (c *echoContext) Cookies() []engine.Cookie { +func (c *echoContext) Cookies() []*http.Cookie { return c.request.Cookies() } @@ -323,15 +377,15 @@ func (c *echoContext) Get(key string) interface{} { } func (c *echoContext) Bind(i interface{}) error { - return c.echo.binder.Bind(i, c) + return c.echo.Binder.Bind(i, c) } func (c *echoContext) Render(code int, name string, data interface{}) (err error) { - if c.echo.renderer == nil { + if c.echo.Renderer == nil { return ErrRendererNotRegistered } buf := new(bytes.Buffer) - if err = c.echo.renderer.Render(buf, name, data, c); err != nil { + if err = c.echo.Renderer.Render(buf, name, data, c); err != nil { return } c.response.Header().Set(HeaderContentType, MIMETextHTMLCharsetUTF8) @@ -356,7 +410,7 @@ func (c *echoContext) String(code int, s string) (err error) { func (c *echoContext) JSON(code int, i interface{}) (err error) { b, err := json.Marshal(i) - if c.echo.Debug() { + if c.echo.Debug { b, err = json.MarshalIndent(i, "", " ") } if err != nil { @@ -392,7 +446,7 @@ func (c *echoContext) JSONPBlob(code int, callback string, b []byte) (err error) func (c *echoContext) XML(code int, i interface{}) (err error) { b, err := xml.Marshal(i) - if c.echo.Debug() { + if c.echo.Debug { b, err = xml.MarshalIndent(i, "", " ") } if err != nil { @@ -474,7 +528,7 @@ func (c *echoContext) Redirect(code int, url string) error { } func (c *echoContext) Error(err error) { - c.echo.httpErrorHandler(err, c) + c.echo.HTTPErrorHandler(err, c) } func (c *echoContext) Echo() *Echo { @@ -490,14 +544,14 @@ func (c *echoContext) SetHandler(h HandlerFunc) { } func (c *echoContext) Logger() log.Logger { - return c.echo.logger + return c.echo.Logger } func (c *echoContext) ServeContent(content io.ReadSeeker, name string, modtime time.Time) error { req := c.Request() res := c.Response() - if t, err := time.Parse(http.TimeFormat, req.Header().Get(HeaderIfModifiedSince)); err == nil && modtime.Before(t.Add(1*time.Second)) { + if t, err := time.Parse(http.TimeFormat, req.Header.Get(HeaderIfModifiedSince)); err == nil && modtime.Before(t.Add(1*time.Second)) { res.Header().Del(HeaderContentType) res.Header().Del(HeaderContentLength) return c.NoContent(http.StatusNotModified) @@ -520,9 +574,10 @@ func ContentTypeByExtension(name string) (t string) { return } -func (c *echoContext) Reset(req engine.Request, res engine.Response) { +func (c *echoContext) Reset(r *http.Request, w http.ResponseWriter) { + // c.query = nil c.context = context.Background() - c.request = req - c.response = res + c.request = r + c.response.reset(w) c.handler = NotFoundHandler } diff --git a/context_test.go b/context_test.go index f7f8fe789..6d4f4db25 100644 --- a/context_test.go +++ b/context_test.go @@ -6,6 +6,7 @@ import ( "io" "mime/multipart" "net/http" + "net/http/httptest" "os" "testing" "text/template" @@ -19,7 +20,6 @@ import ( "encoding/xml" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) @@ -35,185 +35,182 @@ func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) func TestContext(t *testing.T) { e := New() - req := test.NewRequest(POST, "/", strings.NewReader(userJSON)) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(POST, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*echoContext) // Echo assert.Equal(t, e, c.Echo()) // Request - assert.Equal(t, req, c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.Equal(t, rec, c.Response()) - - // Logger - assert.Equal(t, e.logger, c.Logger()) + assert.NotNil(t, c.Response()) //-------- // Render //-------- - tpl := &Template{ + tmpl := &Template{ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), } - c.echo.SetRenderer(tpl) + c.echo.Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) } - c.echo.renderer = nil + c.echo.Renderer = nil err = c.Render(http.StatusOK, "hello", "Jon Snow") assert.Error(t, err) // JSON - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSON, rec.Body.String()) } // JSON (error) - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) err = c.JSON(http.StatusOK, make(chan bool)) assert.Error(t, err) // JSONP - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) callback := "callback" err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) } // XML - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) err = c.XML(http.StatusOK, user{1, "Jon Snow"}) if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, xml.Header+userXML, rec.Body.String()) } // XML (error) - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) err = c.XML(http.StatusOK, make(chan bool)) assert.Error(t, err) // String - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) err = c.String(http.StatusOK, "Hello, World!") if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, "Hello, World!", rec.Body.String()) } // HTML - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) err = c.HTML(http.StatusOK, "Hello, World!") if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, "Hello, World!", rec.Body.String()) } // Stream - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) r := strings.NewReader("response from a stream") err = c.Stream(http.StatusOK, "application/octet-stream", r) if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) assert.Equal(t, "response from a stream", rec.Body.String()) } // Attachment - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) file, err := os.Open("_fixture/images/walle.png") if assert.NoError(t, err) { err = c.Attachment(file, "walle.png") if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "attachment; filename=walle.png", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, 219885, rec.Body.Len()) } } // Inline - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) file, err = os.Open("_fixture/images/walle.png") if assert.NoError(t, err) { err = c.Inline(file, "walle.png") if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "inline; filename=walle.png", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, 219885, rec.Body.Len()) } } // NoContent - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) c.NoContent(http.StatusOK) - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) // Error - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*echoContext) c.Error(errors.New("error")) - assert.Equal(t, http.StatusInternalServerError, rec.Status()) + assert.Equal(t, http.StatusInternalServerError, rec.Code) // Reset - c.Reset(req, test.NewResponseRecorder()) + c.Reset(req, httptest.NewRecorder()) } func TestContextCookie(t *testing.T) { e := New() - req := test.NewRequest(GET, "/", nil) + req, _ := http.NewRequest(GET, "/", nil) theme := "theme=light" user := "user=Jon Snow" - req.Header().Add(HeaderCookie, theme) - req.Header().Add(HeaderCookie, user) - rec := test.NewResponseRecorder() + req.Header.Add(HeaderCookie, theme) + req.Header.Add(HeaderCookie, user) + rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*echoContext) // Read single cookie, err := c.Cookie("theme") if assert.NoError(t, err) { - assert.Equal(t, "theme", cookie.Name()) - assert.Equal(t, "light", cookie.Value()) + assert.Equal(t, "theme", cookie.Name) + assert.Equal(t, "light", cookie.Value) } // Read multiple for _, cookie := range c.Cookies() { - switch cookie.Name() { + switch cookie.Name { case "theme": - assert.Equal(t, "light", cookie.Value()) + assert.Equal(t, "light", cookie.Value) case "user": - assert.Equal(t, "Jon Snow", cookie.Value()) + assert.Equal(t, "Jon Snow", cookie.Value) } } // Write - cookie = &test.Cookie{Cookie: &http.Cookie{ + cookie = &http.Cookie{ Name: "SSID", Value: "Ap4PGTEq", Domain: "labstack.com", @@ -221,7 +218,7 @@ func TestContextCookie(t *testing.T) { Expires: time.Now(), Secure: true, HttpOnly: true, - }} + } c.SetCookie(cookie) assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") @@ -247,7 +244,7 @@ func TestContextPath(t *testing.T) { func TestContextPathParam(t *testing.T) { e := New() - req := test.NewRequest(GET, "/", nil) + req, _ := http.NewRequest(GET, "/", nil) c := e.NewContext(req, nil) // ParamNames @@ -271,8 +268,8 @@ func TestContextFormValue(t *testing.T) { f.Set("email", "jon@labstack.com") e := New() - req := test.NewRequest(POST, "/", strings.NewReader(f.Encode())) - req.Header().Add(HeaderContentType, MIMEApplicationForm) + req, _ := http.NewRequest(POST, "/", strings.NewReader(f.Encode())) + req.Header.Add(HeaderContentType, MIMEApplicationForm) c := e.NewContext(req, nil) // FormValue @@ -280,17 +277,20 @@ func TestContextFormValue(t *testing.T) { assert.Equal(t, "jon@labstack.com", c.FormValue("email")) // FormParams - assert.Equal(t, map[string][]string{ - "name": []string{"Jon Snow"}, - "email": []string{"jon@labstack.com"}, - }, c.FormParams()) + params, err := c.FormParams() + if assert.NoError(t, err) { + assert.Equal(t, url.Values{ + "name": []string{"Jon Snow"}, + "email": []string{"jon@labstack.com"}, + }, params) + } } func TestContextQueryParam(t *testing.T) { q := make(url.Values) q.Set("name", "Jon Snow") q.Set("email", "jon@labstack.com") - req := test.NewRequest(GET, "/?"+q.Encode(), nil) + req, _ := http.NewRequest(GET, "/?"+q.Encode(), nil) e := New() c := e.NewContext(req, nil) @@ -299,7 +299,7 @@ func TestContextQueryParam(t *testing.T) { assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) // QueryParams - assert.Equal(t, map[string][]string{ + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, c.QueryParams()) @@ -314,9 +314,9 @@ func TestContextFormFile(t *testing.T) { w.Write([]byte("test")) } mr.Close() - req := test.NewRequest(POST, "/", buf) - req.Header().Set(HeaderContentType, mr.FormDataContentType()) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(POST, "/", buf) + req.Header.Set(HeaderContentType, mr.FormDataContentType()) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.FormFile("file") if assert.NoError(t, err) { @@ -330,9 +330,9 @@ func TestContextMultipartForm(t *testing.T) { mw := multipart.NewWriter(buf) mw.WriteField("name", "Jon Snow") mw.Close() - req := test.NewRequest(POST, "/", buf) - req.Header().Set(HeaderContentType, mw.FormDataContentType()) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(POST, "/", buf) + req.Header.Set(HeaderContentType, mw.FormDataContentType()) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.MultipartForm() if assert.NoError(t, err) { @@ -342,11 +342,11 @@ func TestContextMultipartForm(t *testing.T) { func TestContextRedirect(t *testing.T) { e := New() - req := test.NewRequest(GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) - assert.Equal(t, http.StatusMovedPermanently, rec.Status()) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } @@ -374,8 +374,8 @@ func TestContextStore(t *testing.T) { func TestContextServeContent(t *testing.T) { e := New() - req := test.NewRequest(GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) fs := http.Dir("_fixture/images") @@ -385,15 +385,15 @@ func TestContextServeContent(t *testing.T) { if assert.NoError(t, err) { // Not cached if assert.NoError(t, c.ServeContent(f, fi.Name(), fi.ModTime())) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) } // Cached - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec) - req.Header().Set(HeaderIfModifiedSince, fi.ModTime().UTC().Format(http.TimeFormat)) + req.Header.Set(HeaderIfModifiedSince, fi.ModTime().UTC().Format(http.TimeFormat)) if assert.NoError(t, c.ServeContent(f, fi.Name(), fi.ModTime())) { - assert.Equal(t, http.StatusNotModified, rec.Status()) + assert.Equal(t, http.StatusNotModified, rec.Code) } } } diff --git a/echo.go b/echo.go index a8e9c31c8..c9667de2c 100644 --- a/echo.go +++ b/echo.go @@ -39,18 +39,20 @@ package echo import ( "bytes" + "crypto/tls" "errors" "fmt" "io" + "net" "net/http" "path" "reflect" "runtime" "sync" + "time" "golang.org/x/net/context" - "github.com/labstack/echo/engine" "github.com/labstack/echo/log" glog "github.com/labstack/gommon/log" ) @@ -58,18 +60,24 @@ import ( type ( // Echo is the top-level framework instance. Echo struct { - server engine.Server - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - notFoundHandler HandlerFunc - httpErrorHandler HTTPErrorHandler - binder Binder - renderer Renderer - pool sync.Pool - debug bool - router *Router - logger log.Logger + Server *http.Server + TLSCertFile string + TLSKeyFile string + Listener net.Listener + // DisableHTTP2 disables HTTP2 + DisableHTTP2 bool + // Debug mode + Debug bool + HTTPErrorHandler + Binder Binder + Renderer Renderer + Logger log.Logger + premiddleware []MiddlewareFunc + middleware []MiddlewareFunc + maxParam *int + notFoundHandler HandlerFunc + pool sync.Pool + router *Router } // Route contains a handler and information for matching against requests. @@ -226,22 +234,21 @@ func New() (e *Echo) { return e.NewContext(nil, nil) } e.router = NewRouter(e) - // Defaults - e.SetHTTPErrorHandler(e.DefaultHTTPErrorHandler) - e.SetBinder(&binder{}) + e.HTTPErrorHandler = e.DefaultHTTPErrorHandler + e.Binder = &binder{} l := glog.New("echo") l.SetLevel(glog.OFF) - e.SetLogger(l) + e.Logger = l return } // NewContext returns a Context instance. -func (e *Echo) NewContext(req engine.Request, res engine.Response) Context { +func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { return &echoContext{ context: context.Background(), - request: req, - response: res, + request: r, + response: NewResponse(w, e), echo: e, pvalues: make([]string, *e.maxParam), handler: NotFoundHandler, @@ -253,26 +260,6 @@ func (e *Echo) Router() *Router { return e.router } -// Logger returns the logger instance. -func (e *Echo) Logger() log.Logger { - return e.logger -} - -// SetLogger defines a custom logger. -func (e *Echo) SetLogger(l log.Logger) { - e.logger = l -} - -// SetLogOutput sets the output destination for the logger. Default value is `os.Std*` -func (e *Echo) SetLogOutput(w io.Writer) { - e.logger.SetOutput(w) -} - -// SetLogLevel sets the log level for the logger. Default value ERROR. -func (e *Echo) SetLogLevel(l glog.Lvl) { - e.logger.SetLevel(l) -} - // DefaultHTTPErrorHandler invokes the default HTTP error handler. func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { code := http.StatusInternalServerError @@ -281,47 +268,17 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { code = he.Code msg = he.Message } - if e.debug { + if e.Debug { msg = err.Error() } - if !c.Response().Committed() { - if c.Request().Method() == HEAD { // Issue #608 + if !c.Response().Committed { + if c.Request().Method == HEAD { // Issue #608 c.NoContent(code) } else { c.String(code, msg) } } - e.logger.Error(err) -} - -// SetHTTPErrorHandler registers a custom Echo.HTTPErrorHandler. -func (e *Echo) SetHTTPErrorHandler(h HTTPErrorHandler) { - e.httpErrorHandler = h -} - -// SetBinder registers a custom binder. It's invoked by `Context#Bind()`. -func (e *Echo) SetBinder(b Binder) { - e.binder = b -} - -// Binder returns the binder instance. -func (e *Echo) Binder() Binder { - return e.binder -} - -// SetRenderer registers an HTML template renderer. It's invoked by `Context#Render()`. -func (e *Echo) SetRenderer(r Renderer) { - e.renderer = r -} - -// SetDebug enables/disables debug mode. -func (e *Echo) SetDebug(on bool) { - e.debug = on -} - -// Debug returns debug mode (enabled or disabled). -func (e *Echo) Debug() bool { - return e.debug + e.Logger.Error(err) } // Pre adds middleware to the chain which is run before router. @@ -340,99 +297,54 @@ func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) { e.add(CONNECT, path, h, m...) } -// Connect is deprecated, use `CONNECT()` instead. -func (e *Echo) Connect(path string, h HandlerFunc, m ...MiddlewareFunc) { - e.CONNECT(path, h, m...) -} - // DELETE registers a new DELETE route for a path with matching handler in the router // with optional route-level middleware. func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) { e.add(DELETE, path, h, m...) } -// Delete is deprecated, use `DELETE()` instead. -func (e *Echo) Delete(path string, h HandlerFunc, m ...MiddlewareFunc) { - e.DELETE(path, h, m...) -} - // GET registers a new GET route for a path with matching handler in the router // with optional route-level middleware. func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) { e.add(GET, path, h, m...) } -// Get is deprecated, use `GET()` instead. -func (e *Echo) Get(path string, h HandlerFunc, m ...MiddlewareFunc) { - e.GET(path, h, m...) -} - // HEAD registers a new HEAD route for a path with matching handler in the // router with optional route-level middleware. func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) { e.add(HEAD, path, h, m...) } -// Head is deprecated, use `HEAD()` instead. -func (e *Echo) Head(path string, h HandlerFunc, m ...MiddlewareFunc) { - e.HEAD(path, h, m...) -} - // OPTIONS registers a new OPTIONS route for a path with matching handler in the // router with optional route-level middleware. func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) { e.add(OPTIONS, path, h, m...) } -// Options is deprecated, use `OPTIONS()` instead. -func (e *Echo) Options(path string, h HandlerFunc, m ...MiddlewareFunc) { - e.OPTIONS(path, h, m...) -} - // PATCH registers a new PATCH route for a path with matching handler in the // router with optional route-level middleware. func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) { e.add(PATCH, path, h, m...) } -// Patch is deprecated, use `PATCH()` instead. -func (e *Echo) Patch(path string, h HandlerFunc, m ...MiddlewareFunc) { - e.PATCH(path, h, m...) -} - // POST registers a new POST route for a path with matching handler in the // router with optional route-level middleware. func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) { e.add(POST, path, h, m...) } -// Post is deprecated, use `POST()` instead. -func (e *Echo) Post(path string, h HandlerFunc, m ...MiddlewareFunc) { - e.POST(path, h, m...) -} - // PUT registers a new PUT route for a path with matching handler in the // router with optional route-level middleware. func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) { e.add(PUT, path, h, m...) } -// Put is deprecated, use `PUT()` instead. -func (e *Echo) Put(path string, h HandlerFunc, m ...MiddlewareFunc) { - e.PUT(path, h, m...) -} - // TRACE registers a new TRACE route for a path with matching handler in the // router with optional route-level middleware. func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) { e.add(TRACE, path, h, m...) } -// Trace is deprecated, use `TRACE()` instead. -func (e *Echo) Trace(path string, h HandlerFunc, m ...MiddlewareFunc) { - e.TRACE(path, h, m...) -} - // Any registers a new route for all HTTP methods and path with matching handler // in the router with optional route-level middleware. func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) { @@ -540,14 +452,15 @@ func (e *Echo) ReleaseContext(c Context) { e.pool.Put(c) } -func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) { +// ServeHTTP implements `http.Handler` interface, which serves HTTP requests. +func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { c := e.pool.Get().(*echoContext) - c.Reset(req, res) + c.Reset(r, w) // Middleware h := func(Context) error { - method := req.Method() - path := req.URL().Path() + method := r.Method + path := r.URL.Path e.router.Find(method, path, c) h := c.handler for i := len(e.middleware) - 1; i >= 0; i-- { @@ -563,27 +476,64 @@ func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) { // Execute chain if err := h(c); err != nil { - e.httpErrorHandler(err, c) + e.HTTPErrorHandler(err, c) } e.pool.Put(c) } // Run starts the HTTP server. -func (e *Echo) Run(s engine.Server) error { - e.server = s - s.SetHandler(e) - s.SetLogger(e.logger) - if e.Debug() { - e.SetLogLevel(glog.DEBUG) - e.logger.Debug("running in debug mode") +func (e *Echo) Run(address string) (err error) { + if e.Server == nil { + e.Server = &http.Server{Handler: e} } - return s.Start() + if e.Listener == nil { + e.Listener, err = net.Listen("tcp", address) + if err != nil { + return + } + if e.TLSCertFile != "" && e.TLSKeyFile != "" { + // TODO: https://github.com/golang/go/commit/d24f446a90ea94b87591bf16228d7d871fec3d92 + config := &tls.Config{ + NextProtos: []string{"http/1.1"}, + } + if !e.DisableHTTP2 { + config.NextProtos = append(config.NextProtos, "h2") + } + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(e.TLSCertFile, e.TLSKeyFile) + if err != nil { + return + } + e.Listener = tls.NewListener(tcpKeepAliveListener{e.Listener.(*net.TCPListener)}, config) + } else { + e.Listener = tcpKeepAliveListener{e.Listener.(*net.TCPListener)} + } + } + return e.Server.Serve(e.Listener) } -// Stop stops the HTTP server. +// Stop stops the HTTP server func (e *Echo) Stop() error { - return e.server.Stop() + return e.Listener.Close() +} + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil } // NewHTTPError creates a new HTTPError instance. diff --git a/echo_test.go b/echo_test.go index 677fd17b2..9aade226d 100644 --- a/echo_test.go +++ b/echo_test.go @@ -2,9 +2,8 @@ package echo import ( "bytes" - "fmt" - "io/ioutil" "net/http" + "net/http/httptest" "testing" "reflect" @@ -12,8 +11,6 @@ import ( "errors" - "github.com/labstack/echo/test" - "github.com/labstack/gommon/log" "github.com/stretchr/testify/assert" ) @@ -33,20 +30,16 @@ const ( func TestEcho(t *testing.T) { e := New() - req := test.NewRequest(GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) // Router assert.NotNil(t, e.Router()) - // Debug - e.SetDebug(true) - assert.True(t, e.debug) - // DefaultHTTPErrorHandler e.DefaultHTTPErrorHandler(errors.New("error"), c) - assert.Equal(t, http.StatusInternalServerError, rec.Status()) + assert.Equal(t, http.StatusInternalServerError, rec.Code) } func TestEchoStatic(t *testing.T) { @@ -306,10 +299,10 @@ func TestEchoGroup(t *testing.T) { func TestEchoNotFound(t *testing.T) { e := New() - req := test.NewRequest(GET, "/files", nil) - rec := test.NewResponseRecorder() - e.ServeHTTP(req, rec) - assert.Equal(t, http.StatusNotFound, rec.Status()) + req, _ := http.NewRequest(GET, "/files", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) } func TestEchoMethodNotAllowed(t *testing.T) { @@ -317,10 +310,10 @@ func TestEchoMethodNotAllowed(t *testing.T) { e.GET("/", func(c Context) error { return c.String(http.StatusOK, "Echo!") }) - req := test.NewRequest(POST, "/", nil) - rec := test.NewResponseRecorder() - e.ServeHTTP(req, rec) - assert.Equal(t, http.StatusMethodNotAllowed, rec.Status()) + req, _ := http.NewRequest(POST, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) } func TestEchoHTTPError(t *testing.T) { @@ -337,25 +330,13 @@ func TestEchoContext(t *testing.T) { e.ReleaseContext(c) } -func TestEchoLogger(t *testing.T) { - e := New() - l := log.New("test") - e.SetLogger(l) - assert.Equal(t, l, e.Logger()) - e.SetLogOutput(ioutil.Discard) - assert.Equal(t, l.Output(), ioutil.Discard) - e.SetLogLevel(log.OFF) - assert.Equal(t, l.Level(), log.OFF) -} - func testMethod(t *testing.T, method, path string, e *Echo) { - m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:])) p := reflect.ValueOf(path) h := reflect.ValueOf(func(c Context) error { return c.String(http.StatusOK, method) }) i := interface{}(e) - reflect.ValueOf(i).MethodByName(m).Call([]reflect.Value{p, h}) + reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h}) _, body := request(method, path, e) if body != method { t.Errorf("expected body `%s`, got %s.", method, body) @@ -363,15 +344,8 @@ func testMethod(t *testing.T, method, path string, e *Echo) { } func request(method, path string, e *Echo) (int, string) { - req := test.NewRequest(method, path, nil) - rec := test.NewResponseRecorder() - e.ServeHTTP(req, rec) - return rec.Status(), rec.Body.String() -} - -func TestEchoBinder(t *testing.T) { - e := New() - b := &binder{} - e.SetBinder(b) - assert.Equal(t, b, e.Binder()) + req, _ := http.NewRequest(method, path, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + return rec.Code, rec.Body.String() } diff --git a/engine/engine.go b/engine/engine.go deleted file mode 100644 index 0a8f2408e..000000000 --- a/engine/engine.go +++ /dev/null @@ -1,233 +0,0 @@ -package engine - -import ( - "io" - "mime/multipart" - "time" - - "net" - - "github.com/labstack/echo/log" -) - -type ( - // Server defines the interface for HTTP server. - Server interface { - // SetHandler sets the handler for the HTTP server. - SetHandler(Handler) - - // SetLogger sets the logger for the HTTP server. - SetLogger(log.Logger) - - // Start starts the HTTP server. - Start() error - - // Stop stops the HTTP server by closing underlying TCP connection. - Stop() error - } - - // Request defines the interface for HTTP request. - Request interface { - // IsTLS returns true if HTTP connection is TLS otherwise false. - IsTLS() bool - - // Scheme returns the HTTP protocol scheme, `http` or `https`. - Scheme() string - - // Host returns HTTP request host. Per RFC 2616, this is either the value of - // the `Host` header or the host name given in the URL itself. - Host() string - - // SetHost sets the host of the request. - SetHost(string) - - // URI returns the unmodified `Request-URI` sent by the client. - URI() string - - // SetURI sets the URI of the request. - SetURI(string) - - // URL returns `engine.URL`. - URL() URL - - // Header returns `engine.Header`. - Header() Header - - // Referer returns the referring URL, if sent in the request. - Referer() string - - // Protocol returns the protocol version string of the HTTP request. - // Protocol() string - - // ProtocolMajor returns the major protocol version of the HTTP request. - // ProtocolMajor() int - - // ProtocolMinor returns the minor protocol version of the HTTP request. - // ProtocolMinor() int - - // ContentLength returns the size of request's body. - ContentLength() int64 - - // UserAgent returns the client's `User-Agent`. - UserAgent() string - - // RemoteAddress returns the client's network address. - RemoteAddress() string - - // RealIP returns the client's network address based on `X-Forwarded-For` - // or `X-Real-IP` request header. - RealIP() string - - // Method returns the request's HTTP function. - Method() string - - // SetMethod sets the HTTP method of the request. - SetMethod(string) - - // Body returns request's body. - Body() io.Reader - - // Body sets request's body. - SetBody(io.Reader) - - // FormValue returns the form field value for the provided name. - FormValue(string) string - - // FormParams returns the form parameters. - FormParams() map[string][]string - - // FormFile returns the multipart form file for the provided name. - FormFile(string) (*multipart.FileHeader, error) - - // MultipartForm returns the multipart form. - MultipartForm() (*multipart.Form, error) - - // Cookie returns the named cookie provided in the request. - Cookie(string) (Cookie, error) - - // Cookies returns the HTTP cookies sent with the request. - Cookies() []Cookie - } - - // Response defines the interface for HTTP response. - Response interface { - // Header returns `engine.Header` - Header() Header - - // WriteHeader sends an HTTP response header with status code. - WriteHeader(int) - - // Write writes the data to the connection as part of an HTTP reply. - Write(b []byte) (int, error) - - // SetCookie adds a `Set-Cookie` header in HTTP response. - SetCookie(Cookie) - - // Status returns the HTTP response status. - Status() int - - // Size returns the number of bytes written to HTTP response. - Size() int64 - - // Committed returns true if HTTP response header is written, otherwise false. - Committed() bool - - // Write returns the HTTP response writer. - Writer() io.Writer - - // SetWriter sets the HTTP response writer. - SetWriter(io.Writer) - } - - // Header defines the interface for HTTP header. - Header interface { - // Add adds the key, value pair to the header. It appends to any existing values - // associated with key. - Add(string, string) - - // Del deletes the values associated with key. - Del(string) - - // Set sets the header entries associated with key to the single element value. - // It replaces any existing values associated with key. - Set(string, string) - - // Get gets the first value associated with the given key. If there are - // no values associated with the key, Get returns "". - Get(string) string - - // Keys returns the header keys. - Keys() []string - - // Contains checks if the header is set. - Contains(string) bool - } - - // URL defines the interface for HTTP request url. - URL interface { - // Path returns the request URL path. - Path() string - - // SetPath sets the request URL path. - SetPath(string) - - // QueryParam returns the query param for the provided name. - QueryParam(string) string - - // QueryParam returns the query parameters as map. - QueryParams() map[string][]string - - // QueryString returns the URL query string. - QueryString() string - } - - // Cookie defines the interface for HTTP cookie. - Cookie interface { - // Name returns the name of the cookie. - Name() string - - // Value returns the value of the cookie. - Value() string - - // Path returns the path of the cookie. - Path() string - - // Domain returns the domain of the cookie. - Domain() string - - // Expires returns the expiry time of the cookie. - Expires() time.Time - - // Secure indicates if cookie is secured. - Secure() bool - - // HTTPOnly indicate if cookies is HTTP only. - HTTPOnly() bool - } - - // Config defines engine config. - Config struct { - Address string // TCP address to listen on. - Listener net.Listener // Custom `net.Listener`. If set, server accepts connections on it. - TLSCertFile string // TLS certificate file path. - TLSKeyFile string // TLS key file path. - DisableHTTP2 bool // Disables HTTP/2. - ReadTimeout time.Duration // Maximum duration before timing out read of the request. - WriteTimeout time.Duration // Maximum duration before timing out write of the response. - } - - // Handler defines an interface to server HTTP requests via `ServeHTTP(Request, Response)` - // function. - Handler interface { - ServeHTTP(Request, Response) - } - - // HandlerFunc is an adapter to allow the use of `func(Request, Response)` as - // an HTTP handler. - HandlerFunc func(Request, Response) -) - -// ServeHTTP serves HTTP request. -func (h HandlerFunc) ServeHTTP(req Request, res Response) { - h(req, res) -} diff --git a/engine/fasthttp/cookie.go b/engine/fasthttp/cookie.go deleted file mode 100644 index 2ebe4a15f..000000000 --- a/engine/fasthttp/cookie.go +++ /dev/null @@ -1,49 +0,0 @@ -package fasthttp - -import ( - "time" - - "github.com/valyala/fasthttp" -) - -type ( - // Cookie implements `engine.Cookie`. - Cookie struct { - *fasthttp.Cookie - } -) - -// Name implements `engine.Cookie#Name` function. -func (c *Cookie) Name() string { - return string(c.Cookie.Key()) -} - -// Value implements `engine.Cookie#Value` function. -func (c *Cookie) Value() string { - return string(c.Cookie.Value()) -} - -// Path implements `engine.Cookie#Path` function. -func (c *Cookie) Path() string { - return string(c.Cookie.Path()) -} - -// Domain implements `engine.Cookie#Domain` function. -func (c *Cookie) Domain() string { - return string(c.Cookie.Domain()) -} - -// Expires implements `engine.Cookie#Expires` function. -func (c *Cookie) Expires() time.Time { - return c.Cookie.Expire() -} - -// Secure implements `engine.Cookie#Secure` function. -func (c *Cookie) Secure() bool { - return c.Cookie.Secure() -} - -// HTTPOnly implements `engine.Cookie#HTTPOnly` function. -func (c *Cookie) HTTPOnly() bool { - return c.Cookie.HTTPOnly() -} diff --git a/engine/fasthttp/cookie_test.go b/engine/fasthttp/cookie_test.go deleted file mode 100644 index fb20bb591..000000000 --- a/engine/fasthttp/cookie_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package fasthttp - -import ( - "github.com/labstack/echo/engine/test" - fast "github.com/valyala/fasthttp" - "testing" - "time" -) - -func TestCookie(t *testing.T) { - fCookie := &fast.Cookie{} - fCookie.SetKey("session") - fCookie.SetValue("securetoken") - fCookie.SetPath("/") - fCookie.SetDomain("github.com") - fCookie.SetExpire(time.Date(2016, time.January, 1, 0, 0, 0, 0, time.UTC)) - fCookie.SetSecure(true) - fCookie.SetHTTPOnly(true) - - cookie := &Cookie{ - fCookie, - } - test.CookieTest(t, cookie) -} diff --git a/engine/fasthttp/header.go b/engine/fasthttp/header.go deleted file mode 100644 index 4f8a28991..000000000 --- a/engine/fasthttp/header.go +++ /dev/null @@ -1,97 +0,0 @@ -// +build !appengine - -package fasthttp - -import "github.com/valyala/fasthttp" - -type ( - // RequestHeader holds `fasthttp.RequestHeader`. - RequestHeader struct { - *fasthttp.RequestHeader - } - - // ResponseHeader holds `fasthttp.ResponseHeader`. - ResponseHeader struct { - *fasthttp.ResponseHeader - } -) - -// Add implements `engine.Header#Add` function. -func (h *RequestHeader) Add(key, val string) { - h.RequestHeader.Add(key, val) -} - -// Del implements `engine.Header#Del` function. -func (h *RequestHeader) Del(key string) { - h.RequestHeader.Del(key) -} - -// Set implements `engine.Header#Set` function. -func (h *RequestHeader) Set(key, val string) { - h.RequestHeader.Set(key, val) -} - -// Get implements `engine.Header#Get` function. -func (h *RequestHeader) Get(key string) string { - return string(h.Peek(key)) -} - -// Keys implements `engine.Header#Keys` function. -func (h *RequestHeader) Keys() (keys []string) { - keys = make([]string, h.Len()) - i := 0 - h.VisitAll(func(k, v []byte) { - keys[i] = string(k) - i++ - }) - return -} - -// Contains implements `engine.Header#Contains` function. -func (h *RequestHeader) Contains(key string) bool { - return h.Peek(key) != nil -} - -func (h *RequestHeader) reset(hdr *fasthttp.RequestHeader) { - h.RequestHeader = hdr -} - -// Add implements `engine.Header#Add` function. -func (h *ResponseHeader) Add(key, val string) { - h.ResponseHeader.Add(key, val) -} - -// Del implements `engine.Header#Del` function. -func (h *ResponseHeader) Del(key string) { - h.ResponseHeader.Del(key) -} - -// Get implements `engine.Header#Get` function. -func (h *ResponseHeader) Get(key string) string { - return string(h.Peek(key)) -} - -// Set implements `engine.Header#Set` function. -func (h *ResponseHeader) Set(key, val string) { - h.ResponseHeader.Set(key, val) -} - -// Keys implements `engine.Header#Keys` function. -func (h *ResponseHeader) Keys() (keys []string) { - keys = make([]string, h.Len()) - i := 0 - h.VisitAll(func(k, v []byte) { - keys[i] = string(k) - i++ - }) - return -} - -// Contains implements `engine.Header#Contains` function. -func (h *ResponseHeader) Contains(key string) bool { - return h.Peek(key) != nil -} - -func (h *ResponseHeader) reset(hdr *fasthttp.ResponseHeader) { - h.ResponseHeader = hdr -} diff --git a/engine/fasthttp/header_test.go b/engine/fasthttp/header_test.go deleted file mode 100644 index 7e0fc0c85..000000000 --- a/engine/fasthttp/header_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package fasthttp - -import ( - "github.com/labstack/echo/engine/test" - "github.com/stretchr/testify/assert" - fast "github.com/valyala/fasthttp" - "testing" -) - -func TestRequestHeader(t *testing.T) { - header := &RequestHeader{&fast.RequestHeader{}} - test.HeaderTest(t, header) - - header.reset(&fast.RequestHeader{}) - assert.Len(t, header.Keys(), 0) -} - -func TestResponseHeader(t *testing.T) { - header := &ResponseHeader{&fast.ResponseHeader{}} - test.HeaderTest(t, header) - - header.reset(&fast.ResponseHeader{}) - assert.Len(t, header.Keys(), 1) -} diff --git a/engine/fasthttp/request.go b/engine/fasthttp/request.go deleted file mode 100644 index a14ac9628..000000000 --- a/engine/fasthttp/request.go +++ /dev/null @@ -1,198 +0,0 @@ -// +build !appengine - -package fasthttp - -import ( - "bytes" - "io" - "mime/multipart" - "net" - - "github.com/labstack/echo" - "github.com/labstack/echo/engine" - "github.com/labstack/echo/log" - "github.com/valyala/fasthttp" -) - -type ( - // Request implements `engine.Request`. - Request struct { - *fasthttp.RequestCtx - header engine.Header - url engine.URL - logger log.Logger - } -) - -// NewRequest returns `Request` instance. -func NewRequest(c *fasthttp.RequestCtx, l log.Logger) *Request { - return &Request{ - RequestCtx: c, - url: &URL{URI: c.URI()}, - header: &RequestHeader{RequestHeader: &c.Request.Header}, - logger: l, - } -} - -// IsTLS implements `engine.Request#TLS` function. -func (r *Request) IsTLS() bool { - return r.RequestCtx.IsTLS() -} - -// Scheme implements `engine.Request#Scheme` function. -func (r *Request) Scheme() string { - return string(r.RequestCtx.URI().Scheme()) -} - -// Host implements `engine.Request#Host` function. -func (r *Request) Host() string { - return string(r.RequestCtx.Host()) -} - -// SetHost implements `engine.Request#SetHost` function. -func (r *Request) SetHost(host string) { - r.RequestCtx.Request.SetHost(host) -} - -// URL implements `engine.Request#URL` function. -func (r *Request) URL() engine.URL { - return r.url -} - -// Header implements `engine.Request#Header` function. -func (r *Request) Header() engine.Header { - return r.header -} - -// Referer implements `engine.Request#Referer` function. -func (r *Request) Referer() string { - return string(r.Request.Header.Referer()) -} - -// ContentLength implements `engine.Request#ContentLength` function. -func (r *Request) ContentLength() int64 { - return int64(r.Request.Header.ContentLength()) -} - -// UserAgent implements `engine.Request#UserAgent` function. -func (r *Request) UserAgent() string { - return string(r.RequestCtx.UserAgent()) -} - -// RemoteAddress implements `engine.Request#RemoteAddress` function. -func (r *Request) RemoteAddress() string { - return r.RemoteAddr().String() -} - -// RealIP implements `engine.Request#RealIP` function. -func (r *Request) RealIP() string { - ra := r.RemoteAddress() - if ip := r.Header().Get(echo.HeaderXForwardedFor); ip != "" { - ra = ip - } else if ip := r.Header().Get(echo.HeaderXRealIP); ip != "" { - ra = ip - } else { - ra, _, _ = net.SplitHostPort(ra) - } - return ra -} - -// Method implements `engine.Request#Method` function. -func (r *Request) Method() string { - return string(r.RequestCtx.Method()) -} - -// SetMethod implements `engine.Request#SetMethod` function. -func (r *Request) SetMethod(method string) { - r.Request.Header.SetMethodBytes([]byte(method)) -} - -// URI implements `engine.Request#URI` function. -func (r *Request) URI() string { - return string(r.RequestURI()) -} - -// SetURI implements `engine.Request#SetURI` function. -func (r *Request) SetURI(uri string) { - r.Request.Header.SetRequestURI(uri) -} - -// Body implements `engine.Request#Body` function. -func (r *Request) Body() io.Reader { - return bytes.NewBuffer(r.Request.Body()) -} - -// SetBody implements `engine.Request#SetBody` function. -func (r *Request) SetBody(reader io.Reader) { - r.Request.SetBodyStream(reader, 0) -} - -// FormValue implements `engine.Request#FormValue` function. -func (r *Request) FormValue(name string) string { - return string(r.RequestCtx.FormValue(name)) -} - -// FormParams implements `engine.Request#FormParams` function. -func (r *Request) FormParams() (params map[string][]string) { - params = make(map[string][]string) - mf, err := r.RequestCtx.MultipartForm() - - if err == fasthttp.ErrNoMultipartForm { - r.PostArgs().VisitAll(func(k, v []byte) { - key := string(k) - if _, ok := params[key]; ok { - params[key] = append(params[key], string(v)) - } else { - params[string(k)] = []string{string(v)} - } - }) - } else if err == nil { - for k, v := range mf.Value { - if len(v) > 0 { - params[k] = v - } - } - } - - return -} - -// FormFile implements `engine.Request#FormFile` function. -func (r *Request) FormFile(name string) (*multipart.FileHeader, error) { - return r.RequestCtx.FormFile(name) -} - -// MultipartForm implements `engine.Request#MultipartForm` function. -func (r *Request) MultipartForm() (*multipart.Form, error) { - return r.RequestCtx.MultipartForm() -} - -// Cookie implements `engine.Request#Cookie` function. -func (r *Request) Cookie(name string) (engine.Cookie, error) { - c := new(fasthttp.Cookie) - b := r.Request.Header.Cookie(name) - if b == nil { - return nil, echo.ErrCookieNotFound - } - c.SetKey(name) - c.SetValueBytes(b) - return &Cookie{c}, nil -} - -// Cookies implements `engine.Request#Cookies` function. -func (r *Request) Cookies() []engine.Cookie { - cookies := []engine.Cookie{} - r.Request.Header.VisitAllCookie(func(name, value []byte) { - c := new(fasthttp.Cookie) - c.SetKeyBytes(name) - c.SetValueBytes(value) - cookies = append(cookies, &Cookie{c}) - }) - return cookies -} - -func (r *Request) reset(c *fasthttp.RequestCtx, h engine.Header, u engine.URL) { - r.RequestCtx = c - r.header = h - r.url = u -} diff --git a/engine/fasthttp/request_test.go b/engine/fasthttp/request_test.go deleted file mode 100644 index 429865d63..000000000 --- a/engine/fasthttp/request_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package fasthttp - -import ( - "bufio" - "bytes" - "net" - "net/url" - "testing" - - "github.com/labstack/echo/engine/test" - "github.com/labstack/gommon/log" - fast "github.com/valyala/fasthttp" -) - -type fakeAddr struct { - addr string - net.Addr -} - -func (a fakeAddr) String() string { - return a.addr -} - -func TestRequest(t *testing.T) { - ctx := new(fast.RequestCtx) - url, _ := url.Parse("http://github.com/labstack/echo") - ctx.Init(&fast.Request{}, fakeAddr{addr: "127.0.0.1"}, nil) - ctx.Request.Read(bufio.NewReader(bytes.NewBufferString(test.MultipartRequest))) - ctx.Request.SetRequestURI(url.String()) - test.RequestTest(t, NewRequest(ctx, log.New("echo"))) -} diff --git a/engine/fasthttp/response.go b/engine/fasthttp/response.go deleted file mode 100644 index 457a3a456..000000000 --- a/engine/fasthttp/response.go +++ /dev/null @@ -1,108 +0,0 @@ -// +build !appengine - -package fasthttp - -import ( - "io" - "net/http" - - "github.com/labstack/echo/engine" - "github.com/labstack/echo/log" - "github.com/valyala/fasthttp" -) - -type ( - // Response implements `engine.Response`. - Response struct { - *fasthttp.RequestCtx - header engine.Header - status int - size int64 - committed bool - writer io.Writer - logger log.Logger - } -) - -// NewResponse returns `Response` instance. -func NewResponse(c *fasthttp.RequestCtx, l log.Logger) *Response { - return &Response{ - RequestCtx: c, - header: &ResponseHeader{ResponseHeader: &c.Response.Header}, - writer: c, - logger: l, - } -} - -// Header implements `engine.Response#Header` function. -func (r *Response) Header() engine.Header { - return r.header -} - -// WriteHeader implements `engine.Response#WriteHeader` function. -func (r *Response) WriteHeader(code int) { - if r.committed { - r.logger.Warn("response already committed") - return - } - r.status = code - r.SetStatusCode(code) - r.committed = true -} - -// Write implements `engine.Response#Write` function. -func (r *Response) Write(b []byte) (n int, err error) { - if !r.committed { - r.WriteHeader(http.StatusOK) - } - n, err = r.writer.Write(b) - r.size += int64(n) - return -} - -// SetCookie implements `engine.Response#SetCookie` function. -func (r *Response) SetCookie(c engine.Cookie) { - cookie := new(fasthttp.Cookie) - cookie.SetKey(c.Name()) - cookie.SetValue(c.Value()) - cookie.SetPath(c.Path()) - cookie.SetDomain(c.Domain()) - cookie.SetExpire(c.Expires()) - cookie.SetSecure(c.Secure()) - cookie.SetHTTPOnly(c.HTTPOnly()) - r.Response.Header.SetCookie(cookie) -} - -// Status implements `engine.Response#Status` function. -func (r *Response) Status() int { - return r.status -} - -// Size implements `engine.Response#Size` function. -func (r *Response) Size() int64 { - return r.size -} - -// Committed implements `engine.Response#Committed` function. -func (r *Response) Committed() bool { - return r.committed -} - -// Writer implements `engine.Response#Writer` function. -func (r *Response) Writer() io.Writer { - return r.writer -} - -// SetWriter implements `engine.Response#SetWriter` function. -func (r *Response) SetWriter(w io.Writer) { - r.writer = w -} - -func (r *Response) reset(c *fasthttp.RequestCtx, h engine.Header) { - r.RequestCtx = c - r.header = h - r.status = http.StatusOK - r.size = 0 - r.committed = false - r.writer = c -} diff --git a/engine/fasthttp/response_test.go b/engine/fasthttp/response_test.go deleted file mode 100644 index b162d08b1..000000000 --- a/engine/fasthttp/response_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package fasthttp - -import ( - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/valyala/fasthttp" - - "github.com/labstack/gommon/log" -) - -func TestResponseWriteHeader(t *testing.T) { - c := new(fasthttp.RequestCtx) - res := NewResponse(c, log.New("test")) - res.WriteHeader(http.StatusOK) - assert.True(t, res.Committed()) - assert.Equal(t, http.StatusOK, res.Status()) -} - -func TestResponseWrite(t *testing.T) { - c := new(fasthttp.RequestCtx) - res := NewResponse(c, log.New("test")) - res.Write([]byte("test")) - assert.Equal(t, int64(4), res.Size()) - assert.Equal(t, "test", string(c.Response.Body())) -} - -func TestResponseSetCookie(t *testing.T) { - c := new(fasthttp.RequestCtx) - res := NewResponse(c, log.New("test")) - cookie := new(fasthttp.Cookie) - cookie.SetKey("name") - cookie.SetValue("Jon Snow") - res.SetCookie(&Cookie{cookie}) - c.Response.Header.SetCookie(cookie) - ck := new(fasthttp.Cookie) - ck.SetKey("name") - assert.True(t, c.Response.Header.Cookie(ck)) - assert.Equal(t, "Jon Snow", string(ck.Value())) -} diff --git a/engine/fasthttp/server.go b/engine/fasthttp/server.go deleted file mode 100644 index 37d738fc4..000000000 --- a/engine/fasthttp/server.go +++ /dev/null @@ -1,188 +0,0 @@ -// +build !appengine - -package fasthttp - -import ( - "sync" - - "github.com/labstack/echo" - "github.com/labstack/echo/engine" - "github.com/labstack/echo/log" - glog "github.com/labstack/gommon/log" - "github.com/valyala/fasthttp" -) - -type ( - // Server implements `engine.Server`. - Server struct { - *fasthttp.Server - config engine.Config - handler engine.Handler - logger log.Logger - pool *pool - } - - pool struct { - request sync.Pool - response sync.Pool - requestHeader sync.Pool - responseHeader sync.Pool - url sync.Pool - } -) - -// New returns `Server` with provided listen address. -func New(addr string) *Server { - c := engine.Config{Address: addr} - return WithConfig(c) -} - -// WithTLS returns `Server` with provided TLS config. -func WithTLS(addr, certFile, keyFile string) *Server { - c := engine.Config{ - Address: addr, - TLSCertFile: certFile, - TLSKeyFile: keyFile, - } - return WithConfig(c) -} - -// WithConfig returns `Server` with provided config. -func WithConfig(c engine.Config) (s *Server) { - s = &Server{ - Server: new(fasthttp.Server), - config: c, - pool: &pool{ - request: sync.Pool{ - New: func() interface{} { - return &Request{logger: s.logger} - }, - }, - response: sync.Pool{ - New: func() interface{} { - return &Response{logger: s.logger} - }, - }, - requestHeader: sync.Pool{ - New: func() interface{} { - return &RequestHeader{} - }, - }, - responseHeader: sync.Pool{ - New: func() interface{} { - return &ResponseHeader{} - }, - }, - url: sync.Pool{ - New: func() interface{} { - return &URL{} - }, - }, - }, - handler: engine.HandlerFunc(func(req engine.Request, res engine.Response) { - panic("echo: handler not set, use `Server#SetHandler()` to set it.") - }), - logger: glog.New("echo"), - } - s.ReadTimeout = c.ReadTimeout - s.WriteTimeout = c.WriteTimeout - s.Handler = s.ServeHTTP - return -} - -// SetHandler implements `engine.Server#SetHandler` function. -func (s *Server) SetHandler(h engine.Handler) { - s.handler = h -} - -// SetLogger implements `engine.Server#SetLogger` function. -func (s *Server) SetLogger(l log.Logger) { - s.logger = l -} - -// Start implements `engine.Server#Start` function. -func (s *Server) Start() error { - if s.config.Listener == nil { - return s.startDefaultListener() - } - return s.startCustomListener() - -} - -// Stop implements `engine.Server#Stop` function. -func (s *Server) Stop() error { - // TODO: implement `engine.Server#Stop` function - return nil -} - -func (s *Server) startDefaultListener() error { - c := s.config - if c.TLSCertFile != "" && c.TLSKeyFile != "" { - return s.ListenAndServeTLS(c.Address, c.TLSCertFile, c.TLSKeyFile) - } - return s.ListenAndServe(c.Address) -} - -func (s *Server) startCustomListener() error { - c := s.config - if c.TLSCertFile != "" && c.TLSKeyFile != "" { - return s.ServeTLS(c.Listener, c.TLSCertFile, c.TLSKeyFile) - } - return s.Serve(c.Listener) -} - -func (s *Server) ServeHTTP(c *fasthttp.RequestCtx) { - // Request - req := s.pool.request.Get().(*Request) - reqHdr := s.pool.requestHeader.Get().(*RequestHeader) - reqURL := s.pool.url.Get().(*URL) - reqHdr.reset(&c.Request.Header) - reqURL.reset(c.URI()) - req.reset(c, reqHdr, reqURL) - - // Response - res := s.pool.response.Get().(*Response) - resHdr := s.pool.responseHeader.Get().(*ResponseHeader) - resHdr.reset(&c.Response.Header) - res.reset(c, resHdr) - - s.handler.ServeHTTP(req, res) - - // Return to pool - s.pool.request.Put(req) - s.pool.requestHeader.Put(reqHdr) - s.pool.url.Put(reqURL) - s.pool.response.Put(res) - s.pool.responseHeader.Put(resHdr) -} - -// WrapHandler wraps `fasthttp.RequestHandler` into `echo.HandlerFunc`. -func WrapHandler(h fasthttp.RequestHandler) echo.HandlerFunc { - return func(c echo.Context) error { - req := c.Request().(*Request) - res := c.Response().(*Response) - ctx := req.RequestCtx - h(ctx) - res.status = ctx.Response.StatusCode() - res.size = int64(ctx.Response.Header.ContentLength()) - return nil - } -} - -// WrapMiddleware wraps `func(fasthttp.RequestHandler) fasthttp.RequestHandler` -// into `echo.MiddlewareFunc` -func WrapMiddleware(m func(fasthttp.RequestHandler) fasthttp.RequestHandler) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { - req := c.Request().(*Request) - res := c.Response().(*Response) - ctx := req.RequestCtx - m(func(ctx *fasthttp.RequestCtx) { - next(c) - })(ctx) - res.status = ctx.Response.StatusCode() - res.size = int64(ctx.Response.Header.ContentLength()) - return - } - } -} diff --git a/engine/fasthttp/server_test.go b/engine/fasthttp/server_test.go deleted file mode 100644 index 70802f7e1..000000000 --- a/engine/fasthttp/server_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package fasthttp - -import ( - "bytes" - "net/http" - "testing" - - "github.com/labstack/echo" - "github.com/labstack/echo/engine" - "github.com/stretchr/testify/assert" - "github.com/valyala/fasthttp" -) - -// TODO: Fix me -func TestServer(t *testing.T) { - s := New("") - s.SetHandler(engine.HandlerFunc(func(req engine.Request, res engine.Response) { - })) - ctx := new(fasthttp.RequestCtx) - s.ServeHTTP(ctx) -} - -func TestServerWrapHandler(t *testing.T) { - e := echo.New() - ctx := new(fasthttp.RequestCtx) - req := NewRequest(ctx, nil) - res := NewResponse(ctx, nil) - c := e.NewContext(req, res) - h := WrapHandler(func(ctx *fasthttp.RequestCtx) { - ctx.Write([]byte("test")) - }) - if assert.NoError(t, h(c)) { - assert.Equal(t, http.StatusOK, ctx.Response.StatusCode()) - assert.Equal(t, "test", string(ctx.Response.Body())) - } -} - -func TestServerWrapMiddleware(t *testing.T) { - e := echo.New() - ctx := new(fasthttp.RequestCtx) - req := NewRequest(ctx, nil) - res := NewResponse(ctx, nil) - c := e.NewContext(req, res) - buf := new(bytes.Buffer) - mw := WrapMiddleware(func(h fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - buf.Write([]byte("mw")) - h(ctx) - } - }) - h := mw(func(c echo.Context) error { - return c.String(http.StatusOK, "OK") - }) - if assert.NoError(t, h(c)) { - assert.Equal(t, "mw", buf.String()) - assert.Equal(t, http.StatusOK, ctx.Response.StatusCode()) - assert.Equal(t, "OK", string(ctx.Response.Body())) - } -} diff --git a/engine/fasthttp/url.go b/engine/fasthttp/url.go deleted file mode 100644 index 177b08aa1..000000000 --- a/engine/fasthttp/url.go +++ /dev/null @@ -1,49 +0,0 @@ -// +build !appengine - -package fasthttp - -import "github.com/valyala/fasthttp" - -type ( - // URL implements `engine.URL`. - URL struct { - *fasthttp.URI - } -) - -// Path implements `engine.URL#Path` function. -func (u *URL) Path() string { - return string(u.URI.PathOriginal()) -} - -// SetPath implements `engine.URL#SetPath` function. -func (u *URL) SetPath(path string) { - u.URI.SetPath(path) -} - -// QueryParam implements `engine.URL#QueryParam` function. -func (u *URL) QueryParam(name string) string { - return string(u.QueryArgs().Peek(name)) -} - -// QueryParams implements `engine.URL#QueryParams` function. -func (u *URL) QueryParams() (params map[string][]string) { - params = make(map[string][]string) - u.QueryArgs().VisitAll(func(k, v []byte) { - _, ok := params[string(k)] - if !ok { - params[string(k)] = make([]string, 0) - } - params[string(k)] = append(params[string(k)], string(v)) - }) - return -} - -// QueryString implements `engine.URL#QueryString` function. -func (u *URL) QueryString() string { - return string(u.URI.QueryString()) -} - -func (u *URL) reset(uri *fasthttp.URI) { - u.URI = uri -} diff --git a/engine/fasthttp/url_test.go b/engine/fasthttp/url_test.go deleted file mode 100644 index 64524d4cc..000000000 --- a/engine/fasthttp/url_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package fasthttp - -import ( - "github.com/labstack/echo/engine/test" - "github.com/stretchr/testify/assert" - fast "github.com/valyala/fasthttp" - "testing" -) - -func TestURL(t *testing.T) { - uri := &fast.URI{} - uri.Parse([]byte("github.com"), []byte("/labstack/echo?param1=value1¶m1=value2¶m2=value3")) - mUrl := &URL{uri} - test.URLTest(t, mUrl) - - mUrl.reset(&fast.URI{}) - assert.Equal(t, "", string(mUrl.Host())) -} diff --git a/engine/standard/cookie.go b/engine/standard/cookie.go deleted file mode 100644 index 45bc4851c..000000000 --- a/engine/standard/cookie.go +++ /dev/null @@ -1,48 +0,0 @@ -package standard - -import ( - "net/http" - "time" -) - -type ( - // Cookie implements `engine.Cookie`. - Cookie struct { - *http.Cookie - } -) - -// Name implements `engine.Cookie#Name` function. -func (c *Cookie) Name() string { - return c.Cookie.Name -} - -// Value implements `engine.Cookie#Value` function. -func (c *Cookie) Value() string { - return c.Cookie.Value -} - -// Path implements `engine.Cookie#Path` function. -func (c *Cookie) Path() string { - return c.Cookie.Path -} - -// Domain implements `engine.Cookie#Domain` function. -func (c *Cookie) Domain() string { - return c.Cookie.Domain -} - -// Expires implements `engine.Cookie#Expires` function. -func (c *Cookie) Expires() time.Time { - return c.Cookie.Expires -} - -// Secure implements `engine.Cookie#Secure` function. -func (c *Cookie) Secure() bool { - return c.Cookie.Secure -} - -// HTTPOnly implements `engine.Cookie#HTTPOnly` function. -func (c *Cookie) HTTPOnly() bool { - return c.Cookie.HttpOnly -} diff --git a/engine/standard/cookie_test.go b/engine/standard/cookie_test.go deleted file mode 100644 index af839dd6d..000000000 --- a/engine/standard/cookie_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package standard - -import ( - "github.com/labstack/echo/engine/test" - "net/http" - "testing" - "time" -) - -func TestCookie(t *testing.T) { - cookie := &Cookie{&http.Cookie{ - Name: "session", - Value: "securetoken", - Path: "/", - Domain: "github.com", - Expires: time.Date(2016, time.January, 1, 0, 0, 0, 0, time.UTC), - Secure: true, - HttpOnly: true, - }} - test.CookieTest(t, cookie) -} diff --git a/engine/standard/header.go b/engine/standard/header.go deleted file mode 100644 index 001849f44..000000000 --- a/engine/standard/header.go +++ /dev/null @@ -1,51 +0,0 @@ -package standard - -import "net/http" - -type ( - // Header implements `engine.Header`. - Header struct { - http.Header - } -) - -// Add implements `engine.Header#Add` function. -func (h *Header) Add(key, val string) { - h.Header.Add(key, val) -} - -// Del implements `engine.Header#Del` function. -func (h *Header) Del(key string) { - h.Header.Del(key) -} - -// Set implements `engine.Header#Set` function. -func (h *Header) Set(key, val string) { - h.Header.Set(key, val) -} - -// Get implements `engine.Header#Get` function. -func (h *Header) Get(key string) string { - return h.Header.Get(key) -} - -// Keys implements `engine.Header#Keys` function. -func (h *Header) Keys() (keys []string) { - keys = make([]string, len(h.Header)) - i := 0 - for k := range h.Header { - keys[i] = k - i++ - } - return -} - -// Contains implements `engine.Header#Contains` function. -func (h *Header) Contains(key string) bool { - _, ok := h.Header[key] - return ok -} - -func (h *Header) reset(hdr http.Header) { - h.Header = hdr -} diff --git a/engine/standard/header_test.go b/engine/standard/header_test.go deleted file mode 100644 index dd1deb166..000000000 --- a/engine/standard/header_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package standard - -import ( - "github.com/labstack/echo/engine/test" - "github.com/stretchr/testify/assert" - "net/http" - "testing" -) - -func TestHeader(t *testing.T) { - header := &Header{http.Header{}} - test.HeaderTest(t, header) - - header.reset(http.Header{}) - assert.Len(t, header.Keys(), 0) -} diff --git a/engine/standard/request.go b/engine/standard/request.go deleted file mode 100644 index 3448d3c47..000000000 --- a/engine/standard/request.go +++ /dev/null @@ -1,205 +0,0 @@ -package standard - -import ( - "fmt" - "io" - "io/ioutil" - "mime/multipart" - "net" - "net/http" - "strings" - - "github.com/labstack/echo" - "github.com/labstack/echo/engine" - "github.com/labstack/echo/log" -) - -type ( - // Request implements `engine.Request`. - Request struct { - *http.Request - header engine.Header - url engine.URL - logger log.Logger - } -) - -const ( - defaultMemory = 32 << 20 // 32 MB -) - -// NewRequest returns `Request` instance. -func NewRequest(r *http.Request, l log.Logger) *Request { - return &Request{ - Request: r, - url: &URL{URL: r.URL}, - header: &Header{Header: r.Header}, - logger: l, - } -} - -// IsTLS implements `engine.Request#TLS` function. -func (r *Request) IsTLS() bool { - return r.Request.TLS != nil -} - -// Scheme implements `engine.Request#Scheme` function. -func (r *Request) Scheme() string { - // Can't use `r.Request.URL.Scheme` - // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 - if r.IsTLS() { - return "https" - } - return "http" -} - -// Host implements `engine.Request#Host` function. -func (r *Request) Host() string { - return r.Request.Host -} - -// SetHost implements `engine.Request#SetHost` function. -func (r *Request) SetHost(host string) { - r.Request.Host = host -} - -// URL implements `engine.Request#URL` function. -func (r *Request) URL() engine.URL { - return r.url -} - -// Header implements `engine.Request#Header` function. -func (r *Request) Header() engine.Header { - return r.header -} - -// Referer implements `engine.Request#Referer` function. -func (r *Request) Referer() string { - return r.Request.Referer() -} - -// func Proto() string { -// return r.request.Proto() -// } -// -// func ProtoMajor() int { -// return r.request.ProtoMajor() -// } -// -// func ProtoMinor() int { -// return r.request.ProtoMinor() -// } - -// ContentLength implements `engine.Request#ContentLength` function. -func (r *Request) ContentLength() int64 { - return r.Request.ContentLength -} - -// UserAgent implements `engine.Request#UserAgent` function. -func (r *Request) UserAgent() string { - return r.Request.UserAgent() -} - -// RemoteAddress implements `engine.Request#RemoteAddress` function. -func (r *Request) RemoteAddress() string { - return r.RemoteAddr -} - -// RealIP implements `engine.Request#RealIP` function. -func (r *Request) RealIP() string { - ra := r.RemoteAddress() - if ip := r.Header().Get(echo.HeaderXForwardedFor); ip != "" { - ra = ip - } else if ip := r.Header().Get(echo.HeaderXRealIP); ip != "" { - ra = ip - } else { - ra, _, _ = net.SplitHostPort(ra) - } - return ra -} - -// Method implements `engine.Request#Method` function. -func (r *Request) Method() string { - return r.Request.Method -} - -// SetMethod implements `engine.Request#SetMethod` function. -func (r *Request) SetMethod(method string) { - r.Request.Method = method -} - -// URI implements `engine.Request#URI` function. -func (r *Request) URI() string { - return r.RequestURI -} - -// SetURI implements `engine.Request#SetURI` function. -func (r *Request) SetURI(uri string) { - r.RequestURI = uri -} - -// Body implements `engine.Request#Body` function. -func (r *Request) Body() io.Reader { - return r.Request.Body -} - -// SetBody implements `engine.Request#SetBody` function. -func (r *Request) SetBody(reader io.Reader) { - r.Request.Body = ioutil.NopCloser(reader) -} - -// FormValue implements `engine.Request#FormValue` function. -func (r *Request) FormValue(name string) string { - return r.Request.FormValue(name) -} - -// FormParams implements `engine.Request#FormParams` function. -func (r *Request) FormParams() map[string][]string { - if strings.HasPrefix(r.header.Get(echo.HeaderContentType), echo.MIMEMultipartForm) { - if err := r.ParseMultipartForm(defaultMemory); err != nil { - panic(fmt.Sprintf("echo: %v", err)) - } - } else { - if err := r.ParseForm(); err != nil { - panic(fmt.Sprintf("echo: %v", err)) - } - } - return map[string][]string(r.Request.Form) -} - -// FormFile implements `engine.Request#FormFile` function. -func (r *Request) FormFile(name string) (*multipart.FileHeader, error) { - _, fh, err := r.Request.FormFile(name) - return fh, err -} - -// MultipartForm implements `engine.Request#MultipartForm` function. -func (r *Request) MultipartForm() (*multipart.Form, error) { - err := r.ParseMultipartForm(defaultMemory) - return r.Request.MultipartForm, err -} - -// Cookie implements `engine.Request#Cookie` function. -func (r *Request) Cookie(name string) (engine.Cookie, error) { - c, err := r.Request.Cookie(name) - if err != nil { - return nil, echo.ErrCookieNotFound - } - return &Cookie{c}, nil -} - -// Cookies implements `engine.Request#Cookies` function. -func (r *Request) Cookies() []engine.Cookie { - cs := r.Request.Cookies() - cookies := make([]engine.Cookie, len(cs)) - for i, c := range cs { - cookies[i] = &Cookie{c} - } - return cookies -} - -func (r *Request) reset(req *http.Request, h engine.Header, u engine.URL) { - r.Request = req - r.header = h - r.url = u -} diff --git a/engine/standard/request_test.go b/engine/standard/request_test.go deleted file mode 100644 index ce10f134c..000000000 --- a/engine/standard/request_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package standard - -import ( - "bufio" - "net/http" - "net/url" - "strings" - "testing" - - "github.com/labstack/echo/engine/test" - "github.com/labstack/gommon/log" - "github.com/stretchr/testify/assert" -) - -func TestRequest(t *testing.T) { - httpReq, _ := http.ReadRequest(bufio.NewReader(strings.NewReader(test.MultipartRequest))) - url, _ := url.Parse("http://github.com/labstack/echo") - httpReq.URL = url - httpReq.RemoteAddr = "127.0.0.1" - req := NewRequest(httpReq, log.New("echo")) - test.RequestTest(t, req) - nr, _ := http.NewRequest("GET", "/", nil) - req.reset(nr, nil, nil) - assert.Equal(t, "", req.Host()) -} diff --git a/engine/standard/response.go b/engine/standard/response.go deleted file mode 100644 index 8b24d5eb6..000000000 --- a/engine/standard/response.go +++ /dev/null @@ -1,146 +0,0 @@ -package standard - -import ( - "bufio" - "io" - "net" - "net/http" - - "github.com/labstack/echo/engine" - "github.com/labstack/echo/log" -) - -type ( - // Response implements `engine.Response`. - Response struct { - http.ResponseWriter - adapter *responseAdapter - header engine.Header - status int - size int64 - committed bool - writer io.Writer - logger log.Logger - } - - responseAdapter struct { - *Response - } -) - -// NewResponse returns `Response` instance. -func NewResponse(w http.ResponseWriter, l log.Logger) (r *Response) { - r = &Response{ - ResponseWriter: w, - header: &Header{Header: w.Header()}, - writer: w, - logger: l, - } - r.adapter = &responseAdapter{Response: r} - return -} - -// Header implements `engine.Response#Header` function. -func (r *Response) Header() engine.Header { - return r.header -} - -// WriteHeader implements `engine.Response#WriteHeader` function. -func (r *Response) WriteHeader(code int) { - if r.committed { - r.logger.Warn("response already committed") - return - } - r.status = code - r.ResponseWriter.WriteHeader(code) - r.committed = true -} - -// Write implements `engine.Response#Write` function. -func (r *Response) Write(b []byte) (n int, err error) { - if !r.committed { - r.WriteHeader(http.StatusOK) - } - n, err = r.writer.Write(b) - r.size += int64(n) - return -} - -// SetCookie implements `engine.Response#SetCookie` function. -func (r *Response) SetCookie(c engine.Cookie) { - http.SetCookie(r.ResponseWriter, &http.Cookie{ - Name: c.Name(), - Value: c.Value(), - Path: c.Path(), - Domain: c.Domain(), - Expires: c.Expires(), - Secure: c.Secure(), - HttpOnly: c.HTTPOnly(), - }) -} - -// Status implements `engine.Response#Status` function. -func (r *Response) Status() int { - return r.status -} - -// Size implements `engine.Response#Size` function. -func (r *Response) Size() int64 { - return r.size -} - -// Committed implements `engine.Response#Committed` function. -func (r *Response) Committed() bool { - return r.committed -} - -// Writer implements `engine.Response#Writer` function. -func (r *Response) Writer() io.Writer { - return r.writer -} - -// SetWriter implements `engine.Response#SetWriter` function. -func (r *Response) SetWriter(w io.Writer) { - r.writer = w -} - -// Flush implements the http.Flusher interface to allow an HTTP handler to flush -// buffered data to the client. -// See https://golang.org/pkg/net/http/#Flusher -func (r *Response) Flush() { - r.ResponseWriter.(http.Flusher).Flush() -} - -// Hijack implements the http.Hijacker interface to allow an HTTP handler to -// take over the connection. -// See https://golang.org/pkg/net/http/#Hijacker -func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return r.ResponseWriter.(http.Hijacker).Hijack() -} - -// CloseNotify implements the http.CloseNotifier interface to allow detecting -// when the underlying connection has gone away. -// This mechanism can be used to cancel long operations on the server if the -// client has disconnected before the response is ready. -// See https://golang.org/pkg/net/http/#CloseNotifier -func (r *Response) CloseNotify() <-chan bool { - return r.ResponseWriter.(http.CloseNotifier).CloseNotify() -} - -func (r *Response) reset(w http.ResponseWriter, a *responseAdapter, h engine.Header) { - r.ResponseWriter = w - r.adapter = a - r.header = h - r.status = http.StatusOK - r.size = 0 - r.committed = false - r.writer = w -} - -func (r *responseAdapter) Header() http.Header { - return r.ResponseWriter.Header() -} - -func (r *responseAdapter) reset(res *Response) { - r.Response = res -} diff --git a/engine/standard/response_test.go b/engine/standard/response_test.go deleted file mode 100644 index 8c7142a51..000000000 --- a/engine/standard/response_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package standard - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/labstack/gommon/log" - "github.com/stretchr/testify/assert" -) - -func TestResponseWriteHeader(t *testing.T) { - rec := httptest.NewRecorder() - res := NewResponse(rec, log.New("test")) - res.WriteHeader(http.StatusOK) - assert.True(t, res.Committed()) - assert.Equal(t, http.StatusOK, res.Status()) -} - -func TestResponseWrite(t *testing.T) { - rec := httptest.NewRecorder() - res := NewResponse(rec, log.New("test")) - res.Write([]byte("test")) - assert.Equal(t, int64(4), res.Size()) - assert.Equal(t, "test", rec.Body.String()) - res.Flush() - assert.True(t, rec.Flushed) -} - -func TestResponseSetCookie(t *testing.T) { - rec := httptest.NewRecorder() - res := NewResponse(rec, log.New("test")) - res.SetCookie(&Cookie{&http.Cookie{ - Name: "name", - Value: "Jon Snow", - }}) - assert.Equal(t, "name=Jon Snow", rec.Header().Get("Set-Cookie")) -} diff --git a/engine/standard/server.go b/engine/standard/server.go deleted file mode 100644 index b681dc2dd..000000000 --- a/engine/standard/server.go +++ /dev/null @@ -1,208 +0,0 @@ -package standard - -import ( - "crypto/tls" - "net" - "net/http" - "sync" - "time" - - "github.com/labstack/echo" - "github.com/labstack/echo/engine" - "github.com/labstack/echo/log" - glog "github.com/labstack/gommon/log" -) - -type ( - // Server implements `engine.Server`. - Server struct { - *http.Server - config engine.Config - handler engine.Handler - logger log.Logger - pool *pool - } - - pool struct { - request sync.Pool - response sync.Pool - responseAdapter sync.Pool - header sync.Pool - url sync.Pool - } -) - -// New returns `Server` instance with provided listen address. -func New(addr string) *Server { - c := engine.Config{Address: addr} - return WithConfig(c) -} - -// WithTLS returns `Server` instance with provided TLS config. -func WithTLS(addr, certFile, keyFile string) *Server { - c := engine.Config{ - Address: addr, - TLSCertFile: certFile, - TLSKeyFile: keyFile, - } - return WithConfig(c) -} - -// WithConfig returns `Server` instance with provided config. -func WithConfig(c engine.Config) (s *Server) { - s = &Server{ - Server: new(http.Server), - config: c, - pool: &pool{ - request: sync.Pool{ - New: func() interface{} { - return &Request{logger: s.logger} - }, - }, - response: sync.Pool{ - New: func() interface{} { - return &Response{logger: s.logger} - }, - }, - responseAdapter: sync.Pool{ - New: func() interface{} { - return &responseAdapter{} - }, - }, - header: sync.Pool{ - New: func() interface{} { - return &Header{} - }, - }, - url: sync.Pool{ - New: func() interface{} { - return &URL{} - }, - }, - }, - handler: engine.HandlerFunc(func(req engine.Request, res engine.Response) { - panic("echo: handler not set, use `Server#SetHandler()` to set it.") - }), - logger: glog.New("echo"), - } - s.ReadTimeout = c.ReadTimeout - s.WriteTimeout = c.WriteTimeout - s.Addr = c.Address - s.Handler = s - return -} - -// SetHandler implements `engine.Server#SetHandler` function. -func (s *Server) SetHandler(h engine.Handler) { - s.handler = h -} - -// SetLogger implements `engine.Server#SetLogger` function. -func (s *Server) SetLogger(l log.Logger) { - s.logger = l -} - -// Start implements `engine.Server#Start` function. -func (s *Server) Start() error { - if s.config.Listener == nil { - ln, err := net.Listen("tcp", s.config.Address) - if err != nil { - return err - } - - if s.config.TLSCertFile != "" && s.config.TLSKeyFile != "" { - // TODO: https://github.com/golang/go/commit/d24f446a90ea94b87591bf16228d7d871fec3d92 - config := &tls.Config{ - NextProtos: []string{"http/1.1"}, - } - if !s.config.DisableHTTP2 { - config.NextProtos = append(config.NextProtos, "h2") - } - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(s.config.TLSCertFile, s.config.TLSKeyFile) - if err != nil { - return err - } - s.config.Listener = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) - } else { - s.config.Listener = tcpKeepAliveListener{ln.(*net.TCPListener)} - } - } - - return s.Serve(s.config.Listener) -} - -// Stop implements `engine.Server#Stop` function. -func (s *Server) Stop() error { - return s.config.Listener.Close() -} - -// ServeHTTP implements `http.Handler` interface. -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Request - req := s.pool.request.Get().(*Request) - reqHdr := s.pool.header.Get().(*Header) - reqURL := s.pool.url.Get().(*URL) - reqHdr.reset(r.Header) - reqURL.reset(r.URL) - req.reset(r, reqHdr, reqURL) - - // Response - res := s.pool.response.Get().(*Response) - resAdpt := s.pool.responseAdapter.Get().(*responseAdapter) - resAdpt.reset(res) - resHdr := s.pool.header.Get().(*Header) - resHdr.reset(w.Header()) - res.reset(w, resAdpt, resHdr) - - s.handler.ServeHTTP(req, res) - - // Return to pool - s.pool.request.Put(req) - s.pool.header.Put(reqHdr) - s.pool.url.Put(reqURL) - s.pool.response.Put(res) - s.pool.header.Put(resHdr) -} - -// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. -func WrapHandler(h http.Handler) echo.HandlerFunc { - return func(c echo.Context) error { - req := c.Request().(*Request) - res := c.Response().(*Response) - h.ServeHTTP(res.adapter, req.Request) - return nil - } -} - -// WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc` -func WrapMiddleware(m func(http.Handler) http.Handler) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { - req := c.Request().(*Request) - res := c.Response().(*Response) - m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err = next(c) - })).ServeHTTP(res.adapter, req.Request) - return - } - } -} - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - tc, err := ln.AcceptTCP() - if err != nil { - return - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - return tc, nil -} diff --git a/engine/standard/server_test.go b/engine/standard/server_test.go deleted file mode 100644 index 3eec170c8..000000000 --- a/engine/standard/server_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package standard - -import ( - "bytes" - "net/http" - "net/http/httptest" - "testing" - - "github.com/labstack/echo" - "github.com/labstack/echo/engine" - "github.com/stretchr/testify/assert" -) - -// TODO: Fix me -func TestServer(t *testing.T) { - s := New("") - s.SetHandler(engine.HandlerFunc(func(req engine.Request, res engine.Response) { - })) - rec := httptest.NewRecorder() - req := new(http.Request) - s.ServeHTTP(rec, req) -} - -func TestServerWrapHandler(t *testing.T) { - e := echo.New() - req := NewRequest(new(http.Request), nil) - rec := httptest.NewRecorder() - res := NewResponse(rec, nil) - c := e.NewContext(req, res) - h := WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("test")) - })) - if assert.NoError(t, h(c)) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "test", rec.Body.String()) - } -} - -func TestServerWrapMiddleware(t *testing.T) { - e := echo.New() - req := NewRequest(new(http.Request), nil) - rec := httptest.NewRecorder() - res := NewResponse(rec, nil) - c := e.NewContext(req, res) - buf := new(bytes.Buffer) - mw := WrapMiddleware(func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - buf.Write([]byte("mw")) - h.ServeHTTP(w, r) - }) - }) - h := mw(func(c echo.Context) error { - return c.String(http.StatusOK, "OK") - }) - if assert.NoError(t, h(c)) { - assert.Equal(t, "mw", buf.String()) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "OK", rec.Body.String()) - } -} diff --git a/engine/standard/url.go b/engine/standard/url.go deleted file mode 100644 index 9c436e95f..000000000 --- a/engine/standard/url.go +++ /dev/null @@ -1,47 +0,0 @@ -package standard - -import "net/url" - -type ( - // URL implements `engine.URL`. - URL struct { - *url.URL - query url.Values - } -) - -// Path implements `engine.URL#Path` function. -func (u *URL) Path() string { - return u.URL.EscapedPath() -} - -// SetPath implements `engine.URL#SetPath` function. -func (u *URL) SetPath(path string) { - u.URL.Path = path -} - -// QueryParam implements `engine.URL#QueryParam` function. -func (u *URL) QueryParam(name string) string { - if u.query == nil { - u.query = u.Query() - } - return u.query.Get(name) -} - -// QueryParams implements `engine.URL#QueryParams` function. -func (u *URL) QueryParams() map[string][]string { - if u.query == nil { - u.query = u.Query() - } - return map[string][]string(u.query) -} - -// QueryString implements `engine.URL#QueryString` function. -func (u *URL) QueryString() string { - return u.URL.RawQuery -} - -func (u *URL) reset(url *url.URL) { - u.URL = url - u.query = nil -} diff --git a/engine/standard/url_test.go b/engine/standard/url_test.go deleted file mode 100644 index 33d06c30c..000000000 --- a/engine/standard/url_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package standard - -import ( - "github.com/labstack/echo/engine/test" - "github.com/stretchr/testify/assert" - "net/url" - "testing" -) - -func TestURL(t *testing.T) { - u, _ := url.Parse("https://github.com/labstack/echo?param1=value1¶m1=value2¶m2=value3") - mUrl := &URL{u, nil} - test.URLTest(t, mUrl) - - mUrl.reset(&url.URL{}) - assert.Equal(t, "", mUrl.Host) -} diff --git a/engine/test/test_helpers.go b/engine/test/test_helpers.go deleted file mode 100644 index f7d40dcfb..000000000 --- a/engine/test/test_helpers.go +++ /dev/null @@ -1,62 +0,0 @@ -package test - -import ( - "testing" - "time" - - "github.com/labstack/echo/engine" - "github.com/stretchr/testify/assert" -) - -func HeaderTest(t *testing.T, header engine.Header) { - h := "X-My-Header" - v := "value" - nv := "new value" - h1 := "X-Another-Header" - - header.Add(h, v) - assert.Equal(t, v, header.Get(h)) - - header.Set(h, nv) - assert.Equal(t, nv, header.Get(h)) - - assert.True(t, header.Contains(h)) - - header.Del(h) - assert.False(t, header.Contains(h)) - - header.Add(h, v) - header.Add(h1, v) - - for _, expected := range []string{h, h1} { - found := false - for _, actual := range header.Keys() { - if actual == expected { - found = true - break - } - } - if !found { - t.Errorf("Header %s not found", expected) - } - } -} - -func URLTest(t *testing.T, url engine.URL) { - path := "/echo/test" - url.SetPath(path) - assert.Equal(t, path, url.Path()) - assert.Equal(t, map[string][]string{"param1": []string{"value1", "value2"}, "param2": []string{"value3"}}, url.QueryParams()) - assert.Equal(t, "value1", url.QueryParam("param1")) - assert.Equal(t, "param1=value1¶m1=value2¶m2=value3", url.QueryString()) -} - -func CookieTest(t *testing.T, cookie engine.Cookie) { - assert.Equal(t, "github.com", cookie.Domain()) - assert.Equal(t, time.Date(2016, time.January, 1, 0, 0, 0, 0, time.UTC), cookie.Expires()) - assert.True(t, cookie.HTTPOnly()) - assert.True(t, cookie.Secure()) - assert.Equal(t, "session", cookie.Name()) - assert.Equal(t, "/", cookie.Path()) - assert.Equal(t, "securetoken", cookie.Value()) -} diff --git a/engine/test/test_request.go b/engine/test/test_request.go deleted file mode 100644 index 3edebac6f..000000000 --- a/engine/test/test_request.go +++ /dev/null @@ -1,97 +0,0 @@ -package test - -import ( - "io/ioutil" - "strings" - "testing" - - "github.com/labstack/echo/engine" - "github.com/stretchr/testify/assert" -) - -const MultipartRequest = `POST /labstack/echo HTTP/1.1 -Host: github.com -Connection: close -User-Agent: Mozilla/5.0 (Macintosh; U; Intel Mac OS X; de-de) AppleWebKit/523.10.3 (KHTML, like Gecko) Version/3.0.4 Safari/523.10 -Content-Type: multipart/form-data; boundary=Asrf456BGe4h -Content-Length: 261 -Accept-Encoding: gzip -Accept-Charset: ISO-8859-1,UTF-8;q=0.7,*;q=0.7 -Cache-Control: no-cache -Accept-Language: de,en;q=0.7,en-us;q=0.3 -Referer: https://github.com/ -Cookie: session=securetoken; user=123 -X-Real-IP: 192.168.1.1 - ---Asrf456BGe4h -Content-Disposition: form-data; name="foo" - -bar ---Asrf456BGe4h -Content-Disposition: form-data; name="baz" - -bat ---Asrf456BGe4h -Content-Disposition: form-data; name="note"; filename="note.txt" -Content-Type: text/plain - -Hello world! ---Asrf456BGe4h-- -` - -func RequestTest(t *testing.T, request engine.Request) { - assert.Equal(t, "github.com", request.Host()) - request.SetHost("labstack.com") - assert.Equal(t, "labstack.com", request.Host()) - request.SetURI("/labstack/echo?token=54321") - assert.Equal(t, "/labstack/echo?token=54321", request.URI()) - assert.Equal(t, "/labstack/echo", request.URL().Path()) - assert.Equal(t, "https://github.com/", request.Referer()) - assert.Equal(t, "192.168.1.1", request.Header().Get("X-Real-IP")) - assert.Equal(t, "http", request.Scheme()) - assert.Equal(t, "Mozilla/5.0 (Macintosh; U; Intel Mac OS X; de-de) AppleWebKit/523.10.3 (KHTML, like Gecko) Version/3.0.4 Safari/523.10", request.UserAgent()) - assert.Equal(t, "127.0.0.1", request.RemoteAddress()) - assert.Equal(t, "192.168.1.1", request.RealIP()) - assert.Equal(t, "POST", request.Method()) - assert.Equal(t, int64(261), request.ContentLength()) - assert.Equal(t, "bar", request.FormValue("foo")) - - if fHeader, err := request.FormFile("note"); assert.NoError(t, err) { - if file, err := fHeader.Open(); assert.NoError(t, err) { - text, _ := ioutil.ReadAll(file) - assert.Equal(t, "Hello world!", string(text)) - } - } - - assert.Equal(t, map[string][]string{"baz": []string{"bat"}, "foo": []string{"bar"}}, request.FormParams()) - - if form, err := request.MultipartForm(); assert.NoError(t, err) { - _, ok := form.File["note"] - assert.True(t, ok) - } - - request.SetMethod("PUT") - assert.Equal(t, "PUT", request.Method()) - - request.SetBody(strings.NewReader("Hello")) - if body, err := ioutil.ReadAll(request.Body()); assert.NoError(t, err) { - assert.Equal(t, "Hello", string(body)) - } - - if cookie, err := request.Cookie("session"); assert.NoError(t, err) { - assert.Equal(t, "session", cookie.Name()) - assert.Equal(t, "securetoken", cookie.Value()) - } - - _, err := request.Cookie("foo") - assert.Error(t, err) - - // Cookies - cs := request.Cookies() - if assert.Len(t, cs, 2) { - assert.Equal(t, "session", cs[0].Name()) - assert.Equal(t, "securetoken", cs[0].Value()) - assert.Equal(t, "user", cs[1].Name()) - assert.Equal(t, "123", cs[1].Value()) - } -} diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index b355f3507..f24fdc62a 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -60,7 +60,7 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return next(c) } - auth := c.Request().Header().Get(echo.HeaderAuthorization) + auth := c.Request().Header.Get(echo.HeaderAuthorization) l := len(basic) if len(auth) > l+1 && auth[:l] == basic { diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index fac0cbda4..85c00db86 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -3,17 +3,17 @@ package middleware import ( "encoding/base64" "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) func TestBasicAuth(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - res := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/", nil) + res := httptest.NewRecorder() c := e.NewContext(req, res) f := func(u, p string) bool { if u == "joe" && p == "secret" { @@ -27,24 +27,24 @@ func TestBasicAuth(t *testing.T) { // Valid credentials auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header().Set(echo.HeaderAuthorization, auth) + req.Header.Set(echo.HeaderAuthorization, auth) assert.NoError(t, h(c)) // Incorrect password auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) - req.Header().Set(echo.HeaderAuthorization, auth) + req.Header.Set(echo.HeaderAuthorization, auth) he := h(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code) assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) // Empty Authorization header - req.Header().Set(echo.HeaderAuthorization, "") + req.Header.Set(echo.HeaderAuthorization, "") he = h(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code) // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header().Set(echo.HeaderAuthorization, auth) + req.Header.Set(echo.HeaderAuthorization, auth) he = h(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code) } diff --git a/middleware/body_limit.go b/middleware/body_limit.go index 6bcf5223d..84654bbc7 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -23,7 +23,7 @@ type ( limitedReader struct { BodyLimitConfig - reader io.Reader + reader io.ReadCloser read int64 context echo.Context } @@ -74,15 +74,15 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { req := c.Request() // Based on content length - if req.ContentLength() > config.limit { + if req.ContentLength > config.limit { return echo.ErrStatusRequestEntityTooLarge } // Based on content read r := pool.Get().(*limitedReader) - r.Reset(req.Body(), c) + r.Reset(req.Body, c) defer pool.Put(r) - req.SetBody(r) + req.Body = r return next(c) } @@ -98,7 +98,11 @@ func (r *limitedReader) Read(b []byte) (n int, err error) { return } -func (r *limitedReader) Reset(reader io.Reader, context echo.Context) { +func (r *limitedReader) Close() error { + return r.reader.Close() +} + +func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { r.reader = reader r.context = context } diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index edd9cda6f..17fdff2fe 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -4,21 +4,22 @@ import ( "bytes" "io/ioutil" "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) func TestBodyLimit(t *testing.T) { + return e := echo.New() hw := []byte("Hello, World!") - req := test.NewRequest(echo.POST, "/", bytes.NewReader(hw)) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.POST, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body()) + body, err := ioutil.ReadAll(c.Request().Body) if err != nil { return err } @@ -27,8 +28,8 @@ func TestBodyLimit(t *testing.T) { // Based on content length (within limit) if assert.NoError(t, BodyLimit("2M")(h)(c)) { - assert.Equal(t, http.StatusOK, rec.Status()) - assert.Equal(t, hw, rec.Body.Bytes()) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes) } // Based on content read (overlimit) @@ -36,17 +37,17 @@ func TestBodyLimit(t *testing.T) { assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // Based on content read (within limit) - req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw)) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.POST, "/", bytes.NewReader(hw)) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) if assert.NoError(t, BodyLimit("2M")(h)(c)) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "Hello, World!", rec.Body.String()) } // Based on content read (overlimit) - req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw)) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.POST, "/", bytes.NewReader(hw)) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) he = BodyLimit("2B")(h)(c).(*echo.HTTPError) assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) diff --git a/middleware/compress.go b/middleware/compress.go index 78e67c973..d5aa22bac 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -1,15 +1,16 @@ package middleware import ( + "bufio" "compress/gzip" "io" "io/ioutil" + "net" "net/http" "strings" "sync" "github.com/labstack/echo" - "github.com/labstack/echo/engine" ) type ( @@ -24,8 +25,8 @@ type ( } gzipResponseWriter struct { - engine.Response io.Writer + http.ResponseWriter } ) @@ -65,36 +66,51 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) - if strings.Contains(c.Request().Header().Get(echo.HeaderAcceptEncoding), scheme) { + if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), scheme) { rw := res.Writer() - gw := pool.Get().(*gzip.Writer) - gw.Reset(rw) + w := pool.Get().(*gzip.Writer) + w.Reset(c.Response().Writer()) + // rw := res.Writer() + // gw := pool.Get().(*gzip.Writer) + // gw.Reset(rw) defer func() { - if res.Size() == 0 { + if res.Size == 0 { // We have to reset response to it's pristine state when // nothing is written to body or error is returned. // See issue #424, #407. res.SetWriter(rw) res.Header().Del(echo.HeaderContentEncoding) - gw.Reset(ioutil.Discard) + w.Reset(ioutil.Discard) } - gw.Close() - pool.Put(gw) + w.Close() + pool.Put(w) }() - g := gzipResponseWriter{Response: res, Writer: gw} + grw := gzipResponseWriter{Writer: w, ResponseWriter: res.Writer()} res.Header().Set(echo.HeaderContentEncoding, scheme) - res.SetWriter(g) + res.SetWriter(grw) } return next(c) } } } -func (g gzipResponseWriter) Write(b []byte) (int, error) { - if g.Header().Get(echo.HeaderContentType) == "" { - g.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) +func (w gzipResponseWriter) Write(b []byte) (int, error) { + if w.Header().Get(echo.HeaderContentType) == "" { + w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } - return g.Writer.Write(b) + return w.Writer.Write(b) +} + +func (w gzipResponseWriter) Flush() error { + return w.Writer.(*gzip.Writer).Flush() +} + +func (w gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return w.ResponseWriter.(http.Hijacker).Hijack() +} + +func (w *gzipResponseWriter) CloseNotify() <-chan bool { + return w.ResponseWriter.(http.CloseNotifier).CloseNotify() } func gzipPool(config GzipConfig) sync.Pool { diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 7bf0519bc..683757454 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -4,17 +4,17 @@ import ( "bytes" "compress/gzip" "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) func TestGzip(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) // Skip if no Accept-Encoding header @@ -25,9 +25,9 @@ func TestGzip(t *testing.T) { h(c) assert.Equal(t, "test", rec.Body.String()) - req = test.NewRequest(echo.GET, "/", nil) - req.Header().Set(echo.HeaderAcceptEncoding, "gzip") - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, "gzip") + rec = httptest.NewRecorder() c = e.NewContext(req, rec) // Gzip @@ -45,8 +45,8 @@ func TestGzip(t *testing.T) { func TestGzipNoContent(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := Gzip()(func(c echo.Context) error { return c.NoContent(http.StatusOK) @@ -64,9 +64,9 @@ func TestGzipErrorReturned(t *testing.T) { e.GET("/", func(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, "error") }) - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() - e.ServeHTTP(req, rec) + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Equal(t, "error", rec.Body.String()) } diff --git a/middleware/cors.go b/middleware/cors.go index 69300c4f3..e68d77aaa 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -88,8 +88,8 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() - origin := req.Header().Get(echo.HeaderOrigin) - originSet := req.Header().Contains(echo.HeaderOrigin) // Issue #517 + origin := req.Header.Get(echo.HeaderOrigin) + _, originSet := req.Header[echo.HeaderOrigin] // Check allowed origins allowedOrigin := "" @@ -101,7 +101,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } // Simple request - if req.Method() != echo.OPTIONS { + if req.Method != echo.OPTIONS { res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) if !originSet || allowedOrigin == "" { return next(c) @@ -131,7 +131,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { if allowHeaders != "" { res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) } else { - h := req.Header().Get(echo.HeaderAccessControlRequestHeaders) + h := req.Header.Get(echo.HeaderAccessControlRequestHeaders) if h != "" { res.Header().Set(echo.HeaderAccessControlAllowHeaders, h) } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 846c4b340..46efe6868 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -2,17 +2,17 @@ package middleware import ( "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) func TestCORS(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) cors := CORSWithConfig(CORSConfig{ AllowCredentials: true, @@ -26,26 +26,26 @@ func TestCORS(t *testing.T) { assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) // Empty origin header - req = test.NewRequest(echo.GET, "/", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) - req.Header().Set(echo.HeaderOrigin, "") + req.Header.Set(echo.HeaderOrigin, "") h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) // Wildcard origin - req = test.NewRequest(echo.GET, "/", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) - req.Header().Set(echo.HeaderOrigin, "localhost") + req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) // Simple request - req = test.NewRequest(echo.GET, "/", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) - req.Header().Set(echo.HeaderOrigin, "localhost") + req.Header.Set(echo.HeaderOrigin, "localhost") cors = CORSWithConfig(CORSConfig{ AllowOrigins: []string{"localhost"}, AllowCredentials: true, @@ -58,11 +58,11 @@ func TestCORS(t *testing.T) { assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) // Preflight request - req = test.NewRequest(echo.OPTIONS, "/", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.OPTIONS, "/", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) - req.Header().Set(echo.HeaderOrigin, "localhost") - req.Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderOrigin, "localhost") + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) diff --git a/middleware/csrf.go b/middleware/csrf.go index 543948f37..6d9b18ed0 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -131,10 +131,10 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { token = random.String(config.TokenLength) } else { // Reuse token - token = k.Value() + token = k.Value } - switch req.Method() { + switch req.Method { case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE: default: // Validate token only for requests which are not defined as 'safe' by RFC7231 @@ -148,18 +148,18 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { } // Set CSRF cookie - cookie := new(echo.Cookie) - cookie.SetName(config.CookieName) - cookie.SetValue(token) + cookie := new(http.Cookie) + cookie.Name = config.CookieName + cookie.Value = token if config.CookiePath != "" { - cookie.SetPath(config.CookiePath) + cookie.Path = config.CookiePath } if config.CookieDomain != "" { - cookie.SetDomain(config.CookieDomain) + cookie.Domain = config.CookieDomain } - cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)) - cookie.SetSecure(config.CookieSecure) - cookie.SetHTTPOnly(config.CookieHTTPOnly) + cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) + cookie.Secure = config.CookieSecure + cookie.HttpOnly = config.CookieHTTPOnly c.SetCookie(cookie) // Store token in the context @@ -177,7 +177,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { // provided request header. func csrfTokenFromHeader(header string) csrfTokenExtractor { return func(c echo.Context) (string, error) { - return c.Request().Header().Get(header), nil + return c.Request().Header.Get(header), nil } } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index a367642a6..7f7d82ec7 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -2,20 +2,20 @@ package middleware import ( "net/http" + "net/http/httptest" "net/url" "strings" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" ) func TestCSRF(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ TokenLength: 16, @@ -29,24 +29,24 @@ func TestCSRF(t *testing.T) { assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") // Without CSRF cookie - req = test.NewRequest(echo.POST, "/", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.POST, "/", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) assert.Error(t, h(c)) // Empty/invalid CSRF token - req = test.NewRequest(echo.POST, "/", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.POST, "/", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) - req.Header().Set(echo.HeaderXCSRFToken, "") + req.Header.Set(echo.HeaderXCSRFToken, "") assert.Error(t, h(c)) // Valid CSRF token token := random.String(16) - req.Header().Set(echo.HeaderCookie, "_csrf="+token) - req.Header().Set(echo.HeaderXCSRFToken, token) + req.Header.Set(echo.HeaderCookie, "_csrf="+token) + req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) } } @@ -54,8 +54,8 @@ func TestCSRFTokenFromForm(t *testing.T) { f := make(url.Values) f.Set("csrf", "token") e := echo.New() - req := test.NewRequest(echo.POST, "/", strings.NewReader(f.Encode())) - req.Header().Add(echo.HeaderContentType, echo.MIMEApplicationForm) + req, _ := http.NewRequest(echo.POST, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) c := e.NewContext(req, nil) token, err := csrfTokenFromForm("csrf")(c) if assert.NoError(t, err) { @@ -69,8 +69,8 @@ func TestCSRFTokenFromQuery(t *testing.T) { q := make(url.Values) q.Set("csrf", "token") e := echo.New() - req := test.NewRequest(echo.GET, "/?"+q.Encode(), nil) - req.Header().Add(echo.HeaderContentType, echo.MIMEApplicationForm) + req, _ := http.NewRequest(echo.GET, "/?"+q.Encode(), nil) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) c := e.NewContext(req, nil) token, err := csrfTokenFromQuery("csrf")(c) if assert.NoError(t, err) { diff --git a/middleware/jwt.go b/middleware/jwt.go index a323935dc..e04cf34f2 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -153,7 +153,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { // jwtFromHeader returns a `jwtExtractor` that extracts token from request header. func jwtFromHeader(header string) jwtExtractor { return func(c echo.Context) (string, error) { - auth := c.Request().Header().Get(header) + auth := c.Request().Header.Get(header) l := len(bearer) if len(auth) > l+1 && auth[:l] == bearer { return auth[l+1:], nil @@ -181,6 +181,6 @@ func jwtFromCookie(name string) jwtExtractor { if err != nil { return "", errors.New("empty jwt in cookie") } - return cookie.Value(), nil + return cookie.Value, nil } } diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index d976a8c59..a7241c87d 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -2,11 +2,11 @@ package middleware import ( "net/http" + "net/http/httptest" "testing" "github.com/dgrijalva/jwt-go" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) @@ -148,10 +148,10 @@ func TestJWT(t *testing.T) { tc.reqURL = "/" } - req := test.NewRequest(echo.GET, tc.reqURL, nil) - res := test.NewResponseRecorder() - req.Header().Set(echo.HeaderAuthorization, tc.hdrAuth) - req.Header().Set(echo.HeaderCookie, tc.hdrCookie) + req, _ := http.NewRequest(echo.GET, tc.reqURL, nil) + res := httptest.NewRecorder() + req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) + req.Header.Set(echo.HeaderCookie, tc.hdrCookie) c := e.NewContext(req, res) if tc.expPanic { diff --git a/middleware/logger.go b/middleware/logger.go index d588f811d..f2de87bb2 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -117,16 +117,16 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case "time_rfc3339": return w.Write([]byte(time.Now().Format(time.RFC3339))) case "remote_ip": - ra := req.RealIP() + ra := c.RealIP() return w.Write([]byte(ra)) case "host": - return w.Write([]byte(req.Host())) + return w.Write([]byte(req.Host)) case "uri": - return w.Write([]byte(req.URI())) + return w.Write([]byte(req.RequestURI)) case "method": - return w.Write([]byte(req.Method())) + return w.Write([]byte(req.Method)) case "path": - p := req.URL().Path() + p := req.URL.Path if p == "" { p = "/" } @@ -136,7 +136,7 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case "user_agent": return w.Write([]byte(req.UserAgent())) case "status": - n := res.Status() + n := res.Status s := config.color.Green(n) switch { case n >= 500: @@ -153,13 +153,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case "latency_human": return w.Write([]byte(stop.Sub(start).String())) case "bytes_in": - b := req.Header().Get(echo.HeaderContentLength) + b := req.Header.Get(echo.HeaderContentLength) if b == "" { b = "0" } return w.Write([]byte(b)) case "bytes_out": - return w.Write([]byte(strconv.FormatInt(res.Size(), 10))) + return w.Write([]byte(strconv.FormatInt(res.Size, 10))) } return 0, nil }) diff --git a/middleware/logger_test.go b/middleware/logger_test.go index d8e308323..0e7614b88 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -4,18 +4,18 @@ import ( "bytes" "errors" "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) func TestLogger(t *testing.T) { // Note: Just for the test coverage, not a real test. e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := Logger()(func(c echo.Context) error { return c.String(http.StatusOK, "test") @@ -25,7 +25,7 @@ func TestLogger(t *testing.T) { h(c) // Status 3xx - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = Logger()(func(c echo.Context) error { return c.String(http.StatusTemporaryRedirect, "test") @@ -33,7 +33,7 @@ func TestLogger(t *testing.T) { h(c) // Status 4xx - rec = test.NewResponseRecorder() + rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = Logger()(func(c echo.Context) error { return c.String(http.StatusNotFound, "test") @@ -41,8 +41,8 @@ func TestLogger(t *testing.T) { h(c) // Status 5xx with empty path - req = test.NewRequest(echo.GET, "", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = Logger()(func(c echo.Context) error { return errors.New("error") @@ -52,25 +52,25 @@ func TestLogger(t *testing.T) { func TestLoggerIPAddress(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) buf := new(bytes.Buffer) - e.Logger().SetOutput(buf) + e.Logger.SetOutput(buf) ip := "127.0.0.1" h := Logger()(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) // With X-Real-IP - req.Header().Add(echo.HeaderXRealIP, ip) + req.Header.Add(echo.HeaderXRealIP, ip) h(c) assert.Contains(t, ip, buf.String()) // With X-Forwarded-For buf.Reset() - req.Header().Del(echo.HeaderXRealIP) - req.Header().Add(echo.HeaderXForwardedFor, ip) + req.Header.Del(echo.HeaderXRealIP) + req.Header.Add(echo.HeaderXForwardedFor, ip) h(c) assert.Contains(t, ip, buf.String()) diff --git a/middleware/method_override.go b/middleware/method_override.go index 156fac5bf..dcc94b0f8 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -52,10 +52,10 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { } req := c.Request() - if req.Method() == echo.POST { + if req.Method == echo.POST { m := config.Getter(c) if m != "" { - req.SetMethod(m) + req.Method = m } } return next(c) @@ -67,7 +67,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { // the request header. func MethodFromHeader(header string) MethodOverrideGetter { return func(c echo.Context) string { - return c.Request().Header().Get(header) + return c.Request().Header.Get(header) } } diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 964ed1a7b..c5a2a74c4 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -3,10 +3,10 @@ package middleware import ( "bytes" "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) @@ -18,32 +18,32 @@ func TestMethodOverride(t *testing.T) { } // Override with http header - req := test.NewRequest(echo.POST, "/", nil) - rec := test.NewResponseRecorder() - req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE) + req, _ := http.NewRequest(echo.POST, "/", nil) + rec := httptest.NewRecorder() + req.Header.Set(echo.HeaderXHTTPMethodOverride, echo.DELETE) c := e.NewContext(req, rec) m(h)(c) - assert.Equal(t, echo.DELETE, req.Method()) + assert.Equal(t, echo.DELETE, req.Method) // Override with form parameter m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) - req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method="+echo.DELETE))) - rec = test.NewResponseRecorder() - req.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm) + req, _ = http.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method="+echo.DELETE))) + rec = httptest.NewRecorder() + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) c = e.NewContext(req, rec) m(h)(c) - assert.Equal(t, echo.DELETE, req.Method()) + assert.Equal(t, echo.DELETE, req.Method) // Override with query paramter m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) - req = test.NewRequest(echo.POST, "/?_method="+echo.DELETE, nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.POST, "/?_method="+echo.DELETE, nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) m(h)(c) - assert.Equal(t, echo.DELETE, req.Method()) + assert.Equal(t, echo.DELETE, req.Method) // Ignore `GET` - req = test.NewRequest(echo.GET, "/", nil) - req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE) - assert.Equal(t, echo.GET, req.Method()) + req, _ = http.NewRequest(echo.GET, "/", nil) + req.Header.Set(echo.HeaderXHTTPMethodOverride, echo.DELETE) + assert.Equal(t, echo.GET, req.Method) } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 878c4140e..6a46cc21b 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -3,24 +3,24 @@ package middleware import ( "bytes" "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) - e.SetLogOutput(buf) - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() + e.Logger.SetOutput(buf) + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := Recover()(echo.HandlerFunc(func(c echo.Context) error { panic("test") })) h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Status()) + assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.Contains(t, buf.String(), "PANIC RECOVER") } diff --git a/middleware/redirect.go b/middleware/redirect.go index 43863c724..a41d752ce 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -52,9 +52,10 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { } req := c.Request() - host := req.Host() - uri := req.URI() - if !req.IsTLS() { + host := req.Host + uri := req.RequestURI + println(uri) + if !c.IsTLS() { return c.Redirect(config.Code, "https://"+host+uri) } return next(c) @@ -88,9 +89,9 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { } req := c.Request() - host := req.Host() - uri := req.URI() - if !req.IsTLS() && host[:3] != "www" { + host := req.Host + uri := req.RequestURI + if !c.IsTLS() && host[:3] != "www" { return c.Redirect(http.StatusMovedPermanently, "https://www."+host+uri) } return next(c) @@ -124,10 +125,10 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { } req := c.Request() - scheme := req.Scheme() - host := req.Host() + scheme := c.Scheme() + host := req.Host if host[:3] != "www" { - uri := req.URI() + uri := req.RequestURI return c.Redirect(http.StatusMovedPermanently, scheme+"://www."+host+uri) } return next(c) @@ -160,10 +161,10 @@ func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { } req := c.Request() - scheme := req.Scheme() - host := req.Host() + scheme := c.Scheme() + host := req.Host if host[:3] == "www" { - uri := req.URI() + uri := req.RequestURI return c.Redirect(http.StatusMovedPermanently, scheme+"://"+host[4:]+uri) } return next(c) diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 1da64fd6c..b5a36804c 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -2,61 +2,61 @@ package middleware import ( "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) -func TestHTTPSRedirect(t *testing.T) { +func TestRedirectHTTPSRedirect(t *testing.T) { e := echo.New() next := func(c echo.Context) (err error) { return c.NoContent(http.StatusOK) } - req := test.NewRequest(echo.GET, "http://labstack.com", nil) - res := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "http://labstack.com", nil) + res := httptest.NewRecorder() c := e.NewContext(req, res) HTTPSRedirect()(next)(c) - assert.Equal(t, http.StatusMovedPermanently, res.Status()) + assert.Equal(t, http.StatusMovedPermanently, res.Code) assert.Equal(t, "https://labstack.com", res.Header().Get(echo.HeaderLocation)) } -func TestHTTPSWWWRedirect(t *testing.T) { +func TestRedirectHTTPSWWWRedirect(t *testing.T) { e := echo.New() next := func(c echo.Context) (err error) { return c.NoContent(http.StatusOK) } - req := test.NewRequest(echo.GET, "http://labstack.com", nil) - res := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "http://labstack.com", nil) + res := httptest.NewRecorder() c := e.NewContext(req, res) HTTPSWWWRedirect()(next)(c) - assert.Equal(t, http.StatusMovedPermanently, res.Status()) + assert.Equal(t, http.StatusMovedPermanently, res.Code) assert.Equal(t, "https://www.labstack.com", res.Header().Get(echo.HeaderLocation)) } -func TestWWWRedirect(t *testing.T) { +func TestRedirectWWWRedirect(t *testing.T) { e := echo.New() next := func(c echo.Context) (err error) { return c.NoContent(http.StatusOK) } - req := test.NewRequest(echo.GET, "http://labstack.com", nil) - res := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "http://labstack.com", nil) + res := httptest.NewRecorder() c := e.NewContext(req, res) WWWRedirect()(next)(c) - assert.Equal(t, http.StatusMovedPermanently, res.Status()) + assert.Equal(t, http.StatusMovedPermanently, res.Code) assert.Equal(t, "http://www.labstack.com", res.Header().Get(echo.HeaderLocation)) } -func TestNonWWWRedirect(t *testing.T) { +func TestRedirectNonWWWRedirect(t *testing.T) { e := echo.New() next := func(c echo.Context) (err error) { return c.NoContent(http.StatusOK) } - req := test.NewRequest(echo.GET, "http://www.labstack.com", nil) - res := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "http://www.labstack.com", nil) + res := httptest.NewRecorder() c := e.NewContext(req, res) NonWWWRedirect()(next)(c) - assert.Equal(t, http.StatusMovedPermanently, res.Status()) + assert.Equal(t, http.StatusMovedPermanently, res.Code) assert.Equal(t, "http://labstack.com", res.Header().Get(echo.HeaderLocation)) } diff --git a/middleware/secure.go b/middleware/secure.go index 0dc42aa21..725f8f61a 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -100,7 +100,7 @@ func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc { if config.XFrameOptions != "" { res.Header().Set(echo.HeaderXFrameOptions, config.XFrameOptions) } - if (req.IsTLS() || (req.Header().Get(echo.HeaderXForwardedProto) == "https")) && config.HSTSMaxAge != 0 { + if (c.IsTLS() || (req.Header.Get(echo.HeaderXForwardedProto) == "https")) && config.HSTSMaxAge != 0 { subdomains := "" if !config.HSTSExcludeSubdomains { subdomains = "; includeSubdomains" diff --git a/middleware/secure_test.go b/middleware/secure_test.go index dabdb651b..d202a682d 100644 --- a/middleware/secure_test.go +++ b/middleware/secure_test.go @@ -2,17 +2,17 @@ package middleware import ( "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) func TestSecure(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { return c.String(http.StatusOK, "test") @@ -27,8 +27,8 @@ func TestSecure(t *testing.T) { assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy)) // Custom - req.Header().Set(echo.HeaderXForwardedProto, "https") - rec = test.NewResponseRecorder() + req.Header.Set(echo.HeaderXForwardedProto, "https") + rec = httptest.NewRecorder() c = e.NewContext(req, rec) SecureWithConfig(SecureConfig{ XSSProtection: "", diff --git a/middleware/slash.go b/middleware/slash.go index ccd9498c8..540bae3f7 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -46,9 +46,9 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc } req := c.Request() - url := req.URL() - path := url.Path() - qs := url.QueryString() + url := req.URL + path := url.Path + qs := c.QueryString() if path != "/" && path[len(path)-1] != '/' { path += "/" uri := path @@ -62,8 +62,8 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc } // Forward - req.SetURI(uri) - url.SetPath(path) + req.RequestURI = uri + url.Path = path } return next(c) } @@ -93,9 +93,9 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu } req := c.Request() - url := req.URL() - path := url.Path() - qs := url.QueryString() + url := req.URL + path := url.Path + qs := c.QueryString() l := len(path) - 1 if l >= 0 && path != "/" && path[l] == '/' { path = path[:l] @@ -110,8 +110,8 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu } // Forward - req.SetURI(uri) - url.SetPath(path) + req.RequestURI = uri + url.Path = path } return next(c) } diff --git a/middleware/slash_test.go b/middleware/slash_test.go index 084703d0a..48a25cecb 100644 --- a/middleware/slash_test.go +++ b/middleware/slash_test.go @@ -2,28 +2,28 @@ package middleware import ( "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) func TestAddTrailingSlash(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/add-slash", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/add-slash", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := AddTrailingSlash()(func(c echo.Context) error { return nil }) h(c) - assert.Equal(t, "/add-slash/", req.URL().Path()) - assert.Equal(t, "/add-slash/", req.URI()) + assert.Equal(t, "/add-slash/", req.URL.Path) + assert.Equal(t, "/add-slash/", req.RequestURI) // With config - req = test.NewRequest(echo.GET, "/add-slash?key=value", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/add-slash?key=value", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = AddTrailingSlashWithConfig(TrailingSlashConfig{ RedirectCode: http.StatusMovedPermanently, @@ -31,25 +31,25 @@ func TestAddTrailingSlash(t *testing.T) { return nil }) h(c) - assert.Equal(t, http.StatusMovedPermanently, rec.Status()) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) assert.Equal(t, "/add-slash/?key=value", rec.Header().Get(echo.HeaderLocation)) } func TestRemoveTrailingSlash(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/remove-slash/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/remove-slash/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := RemoveTrailingSlash()(func(c echo.Context) error { return nil }) h(c) - assert.Equal(t, "/remove-slash", req.URL().Path()) - assert.Equal(t, "/remove-slash", req.URI()) + assert.Equal(t, "/remove-slash", req.URL.Path) + assert.Equal(t, "/remove-slash", req.RequestURI) // With config - req = test.NewRequest(echo.GET, "/remove-slash/?key=value", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/remove-slash/?key=value", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{ RedirectCode: http.StatusMovedPermanently, @@ -57,16 +57,16 @@ func TestRemoveTrailingSlash(t *testing.T) { return nil }) h(c) - assert.Equal(t, http.StatusMovedPermanently, rec.Status()) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) assert.Equal(t, "/remove-slash?key=value", rec.Header().Get(echo.HeaderLocation)) // With bare URL - req = test.NewRequest(echo.GET, "http://localhost", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "http://localhost", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = RemoveTrailingSlash()(func(c echo.Context) error { return nil }) h(c) - assert.Equal(t, "", req.URL().Path()) + assert.Equal(t, "", req.URL.Path) } diff --git a/middleware/static.go b/middleware/static.go index 1aed32da5..8155a71f8 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -68,7 +68,7 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { } fs := http.Dir(config.Root) - p := c.Request().URL().Path() + p := c.Request().URL.Path if strings.Contains(c.Path(), "*") { // If serving from a group, e.g. `/static*`. p = c.P(0) } diff --git a/middleware/static_test.go b/middleware/static_test.go index b1ce6649e..488053eab 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -2,17 +2,17 @@ package middleware import ( "net/http" + "net/http/httptest" "testing" "github.com/labstack/echo" - "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) func TestStatic(t *testing.T) { e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - rec := test.NewResponseRecorder() + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := Static("../_fixture")(func(c echo.Context) error { return echo.ErrNotFound @@ -24,8 +24,8 @@ func TestStatic(t *testing.T) { } // HTML5 mode - req = test.NewRequest(echo.GET, "/client", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/client", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) static := StaticWithConfig(StaticConfig{ Root: "../_fixture", @@ -35,12 +35,12 @@ func TestStatic(t *testing.T) { return echo.ErrNotFound }) if assert.NoError(t, h(c)) { - assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, http.StatusOK, rec.Code) } // Browse - req = test.NewRequest(echo.GET, "/", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) static = StaticWithConfig(StaticConfig{ Root: "../_fixture/images", @@ -54,8 +54,8 @@ func TestStatic(t *testing.T) { } // Not found - req = test.NewRequest(echo.GET, "/not-found", nil) - rec = test.NewResponseRecorder() + req, _ = http.NewRequest(echo.GET, "/not-found", nil) + rec = httptest.NewRecorder() c = e.NewContext(req, rec) static = StaticWithConfig(StaticConfig{ Root: "../_fixture/images", diff --git a/response.go b/response.go new file mode 100644 index 000000000..a44f3ff03 --- /dev/null +++ b/response.go @@ -0,0 +1,99 @@ +package echo + +import ( + "bufio" + "net" + "net/http" +) + +type ( + // Response wraps an http.ResponseWriter and implements its interface to be used + // by an HTTP handler to construct an HTTP response. + // See: https://golang.org/pkg/net/http/#ResponseWriter + Response struct { + writer http.ResponseWriter + Status int + Size int64 + Committed bool + echo *Echo + } +) + +// NewResponse creates a new instance of Response. +func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) { + return &Response{writer: w, echo: e} +} + +// SetWriter sets the http.ResponseWriter instance for this Response. +func (r *Response) SetWriter(w http.ResponseWriter) { + r.writer = w +} + +// Writer returns the http.ResponseWriter instance for this Response. +func (r *Response) Writer() http.ResponseWriter { + return r.writer +} + +// Header returns the header map for the writer that will be sent by +// WriteHeader. Changing the header after a call to WriteHeader (or Write) has +// no effect unless the modified headers were declared as trailers by setting +// the "Trailer" header before the call to WriteHeader (see example) +// To suppress implicit response headers, set their value to nil. +// Example: https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +func (r *Response) Header() http.Header { + return r.writer.Header() +} + +// WriteHeader sends an HTTP response header with status code. If WriteHeader is +// not called explicitly, the first call to Write will trigger an implicit +// WriteHeader(http.StatusOK). Thus explicit calls to WriteHeader are mainly +// used to send error codes. +func (r *Response) WriteHeader(code int) { + if r.Committed { + r.echo.Logger.Warn("response already committed") + return + } + r.Status = code + r.writer.WriteHeader(code) + r.Committed = true +} + +// Write writes the data to the connection as part of an HTTP reply. +func (r *Response) Write(b []byte) (n int, err error) { + if !r.Committed { + r.WriteHeader(http.StatusOK) + } + n, err = r.writer.Write(b) + r.Size += int64(n) + return +} + +// Flush implements the http.Flusher interface to allow an HTTP handler to flush +// buffered data to the client. +// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher) +func (r *Response) Flush() { + r.writer.(http.Flusher).Flush() +} + +// Hijack implements the http.Hijacker interface to allow an HTTP handler to +// take over the connection. +// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker) +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return r.writer.(http.Hijacker).Hijack() +} + +// CloseNotify implements the http.CloseNotifier interface to allow detecting +// when the underlying connection has gone away. +// This mechanism can be used to cancel long operations on the server if the +// client has disconnected before the response is ready. +// See [http.CloseNotifier](https://golang.org/pkg/net/http/#CloseNotifier) +func (r *Response) CloseNotify() <-chan bool { + return r.writer.(http.CloseNotifier).CloseNotify() +} + +func (r *Response) reset(w http.ResponseWriter) { + r.writer = w + r.Size = 0 + r.Status = http.StatusOK + r.Committed = false +} diff --git a/test/cookie.go b/test/cookie.go deleted file mode 100644 index 551ec4df4..000000000 --- a/test/cookie.go +++ /dev/null @@ -1,48 +0,0 @@ -package test - -import ( - "net/http" - "time" -) - -type ( - // Cookie implements `engine.Cookie`. - Cookie struct { - *http.Cookie - } -) - -// Name implements `engine.Cookie#Name` function. -func (c *Cookie) Name() string { - return c.Cookie.Name -} - -// Value implements `engine.Cookie#Value` function. -func (c *Cookie) Value() string { - return c.Cookie.Value -} - -// Path implements `engine.Cookie#Path` function. -func (c *Cookie) Path() string { - return c.Cookie.Path -} - -// Domain implements `engine.Cookie#Domain` function. -func (c *Cookie) Domain() string { - return c.Cookie.Domain -} - -// Expires implements `engine.Cookie#Expires` function. -func (c *Cookie) Expires() time.Time { - return c.Cookie.Expires -} - -// Secure implements `engine.Cookie#Secure` function. -func (c *Cookie) Secure() bool { - return c.Cookie.Secure -} - -// HTTPOnly implements `engine.Cookie#HTTPOnly` function. -func (c *Cookie) HTTPOnly() bool { - return c.Cookie.HttpOnly -} diff --git a/test/header.go b/test/header.go deleted file mode 100644 index 57e81fc64..000000000 --- a/test/header.go +++ /dev/null @@ -1,44 +0,0 @@ -package test - -import "net/http" - -type ( - Header struct { - header http.Header - } -) - -func (h *Header) Add(key, val string) { - h.header.Add(key, val) -} - -func (h *Header) Del(key string) { - h.header.Del(key) -} - -func (h *Header) Get(key string) string { - return h.header.Get(key) -} - -func (h *Header) Set(key, val string) { - h.header.Set(key, val) -} - -func (h *Header) Keys() (keys []string) { - keys = make([]string, len(h.header)) - i := 0 - for k := range h.header { - keys[i] = k - i++ - } - return -} - -func (h *Header) Contains(key string) bool { - _, ok := h.header[key] - return ok -} - -func (h *Header) reset(hdr http.Header) { - h.header = hdr -} diff --git a/test/request.go b/test/request.go deleted file mode 100644 index dd5571cc5..000000000 --- a/test/request.go +++ /dev/null @@ -1,176 +0,0 @@ -package test - -import ( - "errors" - "io" - "io/ioutil" - "mime/multipart" - "net" - "net/http" - "strings" - - "github.com/labstack/echo/engine" -) - -type ( - Request struct { - request *http.Request - url engine.URL - header engine.Header - } -) - -const ( - defaultMemory = 32 << 20 // 32 MB -) - -func NewRequest(method, url string, body io.Reader) engine.Request { - r, _ := http.NewRequest(method, url, body) - return &Request{ - request: r, - url: &URL{url: r.URL}, - header: &Header{r.Header}, - } -} - -func (r *Request) IsTLS() bool { - return r.request.TLS != nil -} - -func (r *Request) Scheme() string { - if r.IsTLS() { - return "https" - } - return "http" -} - -func (r *Request) Host() string { - return r.request.Host -} - -func (r *Request) SetHost(host string) { - r.request.Host = host -} - -func (r *Request) URL() engine.URL { - return r.url -} - -func (r *Request) Header() engine.Header { - return r.header -} - -func (r *Request) Referer() string { - return r.request.Referer() -} - -// func Proto() string { -// return r.request.Proto() -// } -// -// func ProtoMajor() int { -// return r.request.ProtoMajor() -// } -// -// func ProtoMinor() int { -// return r.request.ProtoMinor() -// } - -func (r *Request) ContentLength() int64 { - return r.request.ContentLength -} - -func (r *Request) UserAgent() string { - return r.request.UserAgent() -} - -func (r *Request) RemoteAddress() string { - return r.request.RemoteAddr -} - -func (r *Request) RealIP() string { - ra := r.RemoteAddress() - if ip := r.Header().Get("X-Forwarded-For"); ip != "" { - ra = ip - } else if ip := r.Header().Get("X-Real-IP"); ip != "" { - ra = ip - } else { - ra, _, _ = net.SplitHostPort(ra) - } - return ra -} - -func (r *Request) Method() string { - return r.request.Method -} - -func (r *Request) SetMethod(method string) { - r.request.Method = method -} - -func (r *Request) URI() string { - return r.request.RequestURI -} - -func (r *Request) SetURI(uri string) { - r.request.RequestURI = uri -} - -func (r *Request) Body() io.Reader { - return r.request.Body -} - -func (r *Request) SetBody(reader io.Reader) { - r.request.Body = ioutil.NopCloser(reader) -} - -func (r *Request) FormValue(name string) string { - return r.request.FormValue(name) -} - -func (r *Request) FormParams() map[string][]string { - if strings.HasPrefix(r.header.Get("Content-Type"), "multipart/form-data") { - if err := r.request.ParseMultipartForm(defaultMemory); err != nil { - panic(err) - } - } else { - if err := r.request.ParseForm(); err != nil { - panic(err) - } - } - return map[string][]string(r.request.Form) -} - -func (r *Request) FormFile(name string) (*multipart.FileHeader, error) { - _, fh, err := r.request.FormFile(name) - return fh, err -} - -func (r *Request) MultipartForm() (*multipart.Form, error) { - err := r.request.ParseMultipartForm(defaultMemory) - return r.request.MultipartForm, err -} - -func (r *Request) Cookie(name string) (engine.Cookie, error) { - c, err := r.request.Cookie(name) - if err != nil { - return nil, errors.New("cookie not found") - } - return &Cookie{c}, nil -} - -// Cookies implements `engine.Request#Cookies` function. -func (r *Request) Cookies() []engine.Cookie { - cs := r.request.Cookies() - cookies := make([]engine.Cookie, len(cs)) - for i, c := range cs { - cookies[i] = &Cookie{c} - } - return cookies -} - -func (r *Request) reset(req *http.Request, h engine.Header, u engine.URL) { - r.request = req - r.header = h - r.url = u -} diff --git a/test/response.go b/test/response.go deleted file mode 100644 index dc6146932..000000000 --- a/test/response.go +++ /dev/null @@ -1,103 +0,0 @@ -package test - -import ( - "bytes" - "io" - "net/http" - "net/http/httptest" - - "github.com/labstack/echo/engine" - "github.com/labstack/gommon/log" -) - -type ( - Response struct { - response http.ResponseWriter - header engine.Header - status int - size int64 - committed bool - writer io.Writer - logger *log.Logger - } - - ResponseRecorder struct { - engine.Response - Body *bytes.Buffer - } -) - -func NewResponseRecorder() *ResponseRecorder { - rec := httptest.NewRecorder() - return &ResponseRecorder{ - Response: &Response{ - response: rec, - header: &Header{rec.Header()}, - writer: rec, - logger: log.New("test"), - }, - Body: rec.Body, - } -} - -func (r *Response) Header() engine.Header { - return r.header -} - -func (r *Response) WriteHeader(code int) { - if r.committed { - r.logger.Warn("response already committed") - return - } - r.status = code - r.response.WriteHeader(code) - r.committed = true -} - -func (r *Response) Write(b []byte) (n int, err error) { - n, err = r.writer.Write(b) - r.size += int64(n) - return -} - -// SetCookie implements `engine.Response#SetCookie` function. -func (r *Response) SetCookie(c engine.Cookie) { - http.SetCookie(r.response, &http.Cookie{ - Name: c.Name(), - Value: c.Value(), - Path: c.Path(), - Domain: c.Domain(), - Expires: c.Expires(), - Secure: c.Secure(), - HttpOnly: c.HTTPOnly(), - }) -} - -func (r *Response) Status() int { - return r.status -} - -func (r *Response) Size() int64 { - return r.size -} - -func (r *Response) Committed() bool { - return r.committed -} - -func (r *Response) SetWriter(w io.Writer) { - r.writer = w -} - -func (r *Response) Writer() io.Writer { - return r.writer -} - -func (r *Response) reset(w http.ResponseWriter, h engine.Header) { - r.response = w - r.header = h - r.status = http.StatusOK - r.size = 0 - r.committed = false - r.writer = w -} diff --git a/test/server.go b/test/server.go deleted file mode 100644 index b6f2815f1..000000000 --- a/test/server.go +++ /dev/null @@ -1,129 +0,0 @@ -package test - -import ( - "net/http" - "sync" - - "github.com/labstack/echo/engine" - "github.com/labstack/gommon/log" -) - -type ( - Server struct { - *http.Server - config *engine.Config - handler engine.Handler - pool *Pool - logger *log.Logger - } - - Pool struct { - request sync.Pool - response sync.Pool - header sync.Pool - url sync.Pool - } -) - -func New(addr string) *Server { - c := &engine.Config{Address: addr} - return NewConfig(c) -} - -func NewTLS(addr, certFile, keyFile string) *Server { - c := &engine.Config{ - Address: addr, - TLSCertFile: certFile, - TLSKeyFile: keyFile, - } - return NewConfig(c) -} - -func NewConfig(c *engine.Config) (s *Server) { - s = &Server{ - Server: new(http.Server), - config: c, - pool: &Pool{ - request: sync.Pool{ - New: func() interface{} { - return &Request{} - }, - }, - response: sync.Pool{ - New: func() interface{} { - return &Response{logger: s.logger} - }, - }, - header: sync.Pool{ - New: func() interface{} { - return &Header{} - }, - }, - url: sync.Pool{ - New: func() interface{} { - return &URL{} - }, - }, - }, - handler: engine.HandlerFunc(func(req engine.Request, res engine.Response) { - panic("echo: handler not set, use `Server#SetHandler()` to set it.") - }), - logger: log.New("echo"), - } - return -} - -func (s *Server) SetHandler(h engine.Handler) { - s.handler = h -} - -func (s *Server) SetLogger(l *log.Logger) { - s.logger = l -} - -func (s *Server) Start() error { - if s.config.Listener == nil { - return s.startDefaultListener() - } - return s.startCustomListener() -} - -func (s *Server) Stop() error { - return nil -} - -func (s *Server) startDefaultListener() error { - c := s.config - if c.TLSCertFile != "" && c.TLSKeyFile != "" { - return s.ListenAndServeTLS(c.TLSCertFile, c.TLSKeyFile) - } - return s.ListenAndServe() -} - -func (s *Server) startCustomListener() error { - return s.Serve(s.config.Listener) -} - -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Request - req := s.pool.request.Get().(*Request) - reqHdr := s.pool.header.Get().(*Header) - reqURL := s.pool.url.Get().(*URL) - reqHdr.reset(r.Header) - reqURL.reset(r.URL) - req.reset(r, reqHdr, reqURL) - - // Response - res := s.pool.response.Get().(*Response) - resHdr := s.pool.header.Get().(*Header) - resHdr.reset(w.Header()) - res.reset(w, resHdr) - - s.handler.ServeHTTP(req, res) - - s.pool.request.Put(req) - s.pool.header.Put(reqHdr) - s.pool.url.Put(reqURL) - s.pool.response.Put(res) - s.pool.header.Put(resHdr) -} diff --git a/test/url.go b/test/url.go deleted file mode 100644 index 664fac305..000000000 --- a/test/url.go +++ /dev/null @@ -1,44 +0,0 @@ -package test - -import "net/url" - -type ( - URL struct { - url *url.URL - query url.Values - } -) - -func (u *URL) URL() *url.URL { - return u.url -} - -func (u *URL) SetPath(path string) { - u.url.Path = path -} - -func (u *URL) Path() string { - return u.url.Path -} - -func (u *URL) QueryParam(name string) string { - if u.query == nil { - u.query = u.url.Query() - } - return u.query.Get(name) -} - -func (u *URL) QueryParams() map[string][]string { - if u.query == nil { - u.query = u.url.Query() - } - return map[string][]string(u.query) -} - -func (u *URL) QueryString() string { - return u.url.RawQuery -} - -func (u *URL) reset(url *url.URL) { - u.url = url -}