Skip to content

Commit

Permalink
Added panic recover middleware
Browse files Browse the repository at this point in the history
Signed-off-by: Vishal Rana <[email protected]>
  • Loading branch information
vishr committed May 18, 2015
1 parent 609879b commit 73fa05f
Show file tree
Hide file tree
Showing 18 changed files with 166 additions and 99 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,21 @@ func main() {
// Echo instance
e := echo.New()

//------------
// Middleware
//------------

// Recover
e.Use(mw.Recover())

// Logger
e.Use(mw.Logger())

// Routes
e.Get("/", hello)

// Start server
e.Run(":1323)
e.Run(":1323")
}
```

Expand Down
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

type (
// Context represents context for the current request. It holds request and
// response references, path parameters, data and registered handler.
// response objects, path parameters, data and registered handler.
Context struct {
Request *http.Request
Response *Response
Expand Down
53 changes: 23 additions & 30 deletions echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ type (
prefix string
middleware []MiddlewareFunc
maxParam byte
notFoundHandler HandlerFunc
httpErrorHandler HTTPErrorHandler
binder BindFunc
renderer Renderer
uris map[Handler]string
pool sync.Pool
debug bool
}
HTTPError struct {
Code int
Expand Down Expand Up @@ -115,8 +115,8 @@ var (
// Errors
//--------

UnsupportedMediaType = errors.New("echo: unsupported media type")
RendererNotRegistered = errors.New("echo: renderer not registered")
UnsupportedMediaType = errors.New("echo unsupported media type")
RendererNotRegistered = errors.New("echo renderer not registered")
)

// New creates an Echo instance.
Expand All @@ -134,19 +134,14 @@ func New() (e *Echo) {
//----------

e.MaxParam(5)
e.NotFoundHandler(func(c *Context) *HTTPError {
http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return nil
})
e.HTTPErrorHandler(func(he *HTTPError, c *Context) {
if he.Code == 0 {
he.Code = http.StatusInternalServerError
}
if he.Message == "" {
if he.Error != nil {
he.Message = http.StatusText(he.Code)
if e.debug {
he.Message = he.Error.Error()
} else {
he.Message = http.StatusText(he.Code)
}
}
http.Error(c.Response, he.Message, he.Code)
Expand Down Expand Up @@ -185,12 +180,6 @@ func (e *Echo) MaxParam(n uint8) {
e.maxParam = n
}

// NotFoundHandler registers a custom NotFound handler used by router in case it
// doesn't find any registered handler for HTTP method and path.
func (e *Echo) NotFoundHandler(h Handler) {
e.notFoundHandler = wrapHandler(h)
}

// HTTPErrorHandler registers an HTTP error handler.
func (e *Echo) HTTPErrorHandler(h HTTPErrorHandler) {
e.httpErrorHandler = h
Expand All @@ -207,6 +196,11 @@ func (e *Echo) Renderer(r Renderer) {
e.renderer = r
}

// Debug runs the application in debug mode.
func (e *Echo) Debug(on bool) {
e.debug = on
}

// Use adds handler to the middleware chain.
func (e *Echo) Use(m ...Middleware) {
for _, h := range m {
Expand Down Expand Up @@ -325,21 +319,20 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if echo != nil {
e = echo
}
if h == nil {
h = e.notFoundHandler
}
c.reset(w, r, e)
if h == nil {
c.Error(&HTTPError{Code: http.StatusNotFound})
} else {
// Chain middleware with handler in the end
for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h)
}

// Chain middleware with handler in the end
for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h)
}

// Execute chain
if he := h(c); he != nil {
e.httpErrorHandler(he, c)
// Execute chain
if he := h(c); he != nil {
e.httpErrorHandler(he, c)
}
}

e.pool.Put(c)
}

Expand Down Expand Up @@ -394,7 +387,7 @@ func wrapMiddleware(m Middleware) MiddlewareFunc {
case func(http.ResponseWriter, *http.Request):
return wrapHTTPHandlerFuncMW(m)
default:
panic("echo: unknown middleware")
panic("echo unknown middleware")
}
}

Expand Down Expand Up @@ -440,7 +433,7 @@ func wrapHandler(h Handler) HandlerFunc {
return nil
}
default:
panic("echo: unknown handler")
panic("echo unknown handler")
}
}

Expand Down
10 changes: 0 additions & 10 deletions echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,6 @@ func TestEchoNotFound(t *testing.T) {
if w.Code != http.StatusNotFound {
t.Errorf("status code should be 404, found %d", w.Code)
}

// Customized NotFound handler
e.NotFoundHandler(func(c *Context) *HTTPError {
return c.String(http.StatusNotFound, "not found")
})
w = httptest.NewRecorder()
e.ServeHTTP(w, r)
if w.Body.String() != "not found" {
t.Errorf("body should be `not found`")
}
}

func verifyUser(u2 *user, t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions examples/crud/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func main() {
e := echo.New()

// Middleware
e.Use(mw.Recover())
e.Use(mw.Logger())

// Routes
Expand Down
7 changes: 7 additions & 0 deletions examples/hello/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ func main() {
// Echo instance
e := echo.New()

//------------
// Middleware
//------------

// Recover
e.Use(mw.Recover())

// Logger
e.Use(mw.Logger())

// Routes
Expand Down
6 changes: 6 additions & 0 deletions examples/middleware/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ func main() {
// Echo instance
e := echo.New()

// Debug mode
e.Debug(true)

//------------
// Middleware
//------------

// Recover
e.Use(mw.Recover())

// Logger
e.Use(mw.Logger())

Expand Down
1 change: 1 addition & 0 deletions examples/web/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func main() {
e := echo.New()

// Middleware
e.Use(mw.Recover())
e.Use(mw.Logger())

//------------------------
Expand Down
2 changes: 1 addition & 1 deletion middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const (
Basic = "Basic"
)

// BasicAuth provides HTTP basic authentication.
// BasicAuth returns an HTTP basic authentication middleware.
func BasicAuth(fn AuthFunc) echo.HandlerFunc {
return func(c *echo.Context) (he *echo.HTTPError) {
auth := c.Request.Header.Get(echo.Authorization)
Expand Down
29 changes: 15 additions & 14 deletions middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ package middleware

import (
"encoding/base64"
"github.com/labstack/echo"
"net/http"
"net/http/httptest"
"testing"

"github.com/labstack/echo"
)

func TestBasicAuth(t *testing.T) {
req, _ := http.NewRequest(echo.POST, "/", nil)
res := &echo.Response{Writer: httptest.NewRecorder()}
res := &echo.Response{}
c := echo.NewContext(req, res, echo.New())
fn := func(u, p string) bool {
if u == "joe" && p == "secret" {
Expand All @@ -34,7 +34,7 @@ func TestBasicAuth(t *testing.T) {
auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.Authorization, auth)
if ba(c) != nil {
t.Error("expected `pass` with case insensitive header")
t.Error("expected `pass`, with case insensitive header.")
}

//---------------------
Expand All @@ -46,29 +46,30 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail` with incorrect password")
t.Error("expected `fail`, with incorrect password.")
}

// Empty Authorization header
req.Header.Set(echo.Authorization, "")
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail`, with empty Authorization header.")
}

// Invalid header
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte(" :secret"))
req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail` with invalid auth header")
t.Error("expected `fail`, with invalid Authorization header.")
}

// Invalid scheme
auth = "Base " + base64.StdEncoding.EncodeToString([]byte(" :secret"))
req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail` with invalid scheme")
t.Error("expected `fail`, with invalid scheme.")
}

// Empty auth header
req.Header.Set(echo.Authorization, "")
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail` with empty auth header")
}
}
22 changes: 9 additions & 13 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,21 @@ func (g gzipWriter) Write(b []byte) (int, error) {
return g.Writer.Write(b)
}

// Gzip compresses HTTP response using gzip compression scheme.
// Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme.
func Gzip() echo.MiddlewareFunc {
scheme := "gzip"

return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) *echo.HTTPError {
if !strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) {
return nil
if strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) {
w := gzip.NewWriter(c.Response.Writer)
defer w.Close()
gw := gzipWriter{Writer: w, Response: c.Response}
c.Response.Header().Set(echo.ContentEncoding, scheme)
c.Response = &echo.Response{Writer: gw}
}

w := gzip.NewWriter(c.Response.Writer)
defer w.Close()
gw := gzipWriter{Writer: w, Response: c.Response}
c.Response.Header().Set(echo.ContentEncoding, scheme)
c.Response = &echo.Response{Writer: gw}
if he := h(c); he != nil {
c.Error(he)
}
return nil
return h(c)
}
}
}
32 changes: 21 additions & 11 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
@@ -1,42 +1,52 @@
package middleware

import (
"compress/gzip"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"

"compress/gzip"
"github.com/labstack/echo"
"io/ioutil"
)

func TestGzip(t *testing.T) {
// Empty Accept-Encoding header
req, _ := http.NewRequest(echo.GET, "/", nil)
req.Header.Set(echo.AcceptEncoding, "gzip")
w := httptest.NewRecorder()
res := &echo.Response{Writer: w}
c := echo.NewContext(req, res, echo.New())
Gzip()(func(c *echo.Context) *echo.HTTPError {
h := func(c *echo.Context) *echo.HTTPError {
return c.String(http.StatusOK, "test")
})(c)
}
Gzip()(h)(c)
s := w.Body.String()
if s != "test" {
t.Errorf("expected `test`, with empty Accept-Encoding header, got %s.", s)
}

if w.Header().Get(echo.ContentEncoding) != "gzip" {
t.Errorf("expected Content-Encoding header `gzip`, got %d.", w.Header().Get(echo.ContentEncoding))
// Content-Encoding header
req.Header.Set(echo.AcceptEncoding, "gzip")
w = httptest.NewRecorder()
c.Response = &echo.Response{Writer: w}
Gzip()(h)(c)
ce := w.Header().Get(echo.ContentEncoding)
if ce != "gzip" {
t.Errorf("expected Content-Encoding header `gzip`, got %d.", ce)
}

// Body
r, err := gzip.NewReader(w.Body)
defer r.Close()
if err != nil {
t.Error(err)
}

b, err := ioutil.ReadAll(r)
if err != nil {
t.Error(err)
}
s := string(b)

s = string(b)
if s != "test" {
t.Errorf("expected `test`, got %s.", s)
t.Errorf("expected body `test`, got %s.", s)
}
}
Loading

0 comments on commit 73fa05f

Please sign in to comment.