Skip to content

Commit

Permalink
Refactored Echo.HandlerFunc, added WebSocket support.
Browse files Browse the repository at this point in the history
Signed-off-by: Vishal Rana <[email protected]>
  • Loading branch information
vishr committed May 20, 2015
1 parent 60a377a commit 13ac746
Show file tree
Hide file tree
Showing 16 changed files with 186 additions and 136 deletions.
38 changes: 17 additions & 21 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package echo
import (
"encoding/json"
"net/http"

"golang.org/x/net/websocket"
)

type (
Expand All @@ -11,6 +13,7 @@ type (
Context struct {
Request *http.Request
Response *Response
Socket *websocket.Conn
pnames []string
pvalues []string
store store
Expand Down Expand Up @@ -53,60 +56,53 @@ func (c *Context) Param(name string) (value string) {

// Bind binds the request body into specified type v. Default binder does it
// based on Content-Type header.
func (c *Context) Bind(i interface{}) *HTTPError {
func (c *Context) Bind(i interface{}) error {
return c.echo.binder(c.Request, i)
}

// Render invokes the registered HTML template renderer and sends a text/html
// response with status code.
func (c *Context) Render(code int, name string, data interface{}) *HTTPError {
func (c *Context) Render(code int, name string, data interface{}) error {
if c.echo.renderer == nil {
return &HTTPError{Error: RendererNotRegistered}
return RendererNotRegistered
}
c.Response.Header().Set(ContentType, TextHTML+"; charset=utf-8")
c.Response.WriteHeader(code)
return c.echo.renderer.Render(c.Response, name, data)
}

// JSON sends an application/json response with status code.
func (c *Context) JSON(code int, i interface{}) *HTTPError {
func (c *Context) JSON(code int, i interface{}) error {
c.Response.Header().Set(ContentType, ApplicationJSON+"; charset=utf-8")
c.Response.WriteHeader(code)
if err := json.NewEncoder(c.Response).Encode(i); err != nil {
return &HTTPError{Error: err}
}
return nil
return json.NewEncoder(c.Response).Encode(i)
}

// String sends a text/plain response with status code.
func (c *Context) String(code int, s string) *HTTPError {
func (c *Context) String(code int, s string) error {
c.Response.Header().Set(ContentType, TextPlain+"; charset=utf-8")
c.Response.WriteHeader(code)
if _, err := c.Response.Write([]byte(s)); err != nil {
return &HTTPError{Error: err}
}
return nil
_, err := c.Response.Write([]byte(s))
return err
}

// HTML sends a text/html response with status code.
func (c *Context) HTML(code int, html string) *HTTPError {
func (c *Context) HTML(code int, html string) error {
c.Response.Header().Set(ContentType, TextHTML+"; charset=utf-8")
c.Response.WriteHeader(code)
if _, err := c.Response.Write([]byte(html)); err != nil {
return &HTTPError{Error: err}
}
return nil
_, err := c.Response.Write([]byte(html))
return err
}

// NoContent sends a response with no body and a status code.
func (c *Context) NoContent(code int) *HTTPError {
func (c *Context) NoContent(code int) error {
c.Response.WriteHeader(code)
return nil
}

// Error invokes the registered HTTP error handler.
func (c *Context) Error(he *HTTPError) {
c.echo.httpErrorHandler(he, c)
func (c *Context) Error(err error) {
c.echo.httpErrorHandler(err, c)
}

// Get retrieves data from the context.
Expand Down
9 changes: 3 additions & 6 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@ type (
}
)

func (t *Template) Render(w io.Writer, name string, data interface{}) *HTTPError {
if err := t.templates.ExecuteTemplate(w, name, data); err != nil {
return &HTTPError{Error: err}
}
return nil
func (t *Template) Render(w io.Writer, name string, data interface{}) error {
return t.templates.ExecuteTemplate(w, name, data)
}

func TestContext(t *testing.T) {
b, _ := json.Marshal(u1)
r, _ := http.NewRequest(POST, "/users/1", bytes.NewReader(b))
c := NewContext(r, &Response{Writer: httptest.NewRecorder()}, New())
c := NewContext(r, NewResponse(httptest.NewRecorder()), New())

//------
// Bind
Expand Down
86 changes: 51 additions & 35 deletions echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync"

"github.com/mattn/go-colorable"
"golang.org/x/net/websocket"
)

type (
Expand All @@ -33,24 +34,23 @@ type (
HTTPError struct {
Code int
Message string
Error error
}
Middleware interface{}
MiddlewareFunc func(HandlerFunc) HandlerFunc
Handler interface{}
HandlerFunc func(*Context) *HTTPError
HandlerFunc func(*Context) error

// HTTPErrorHandler is a centralized HTTP error handler.
HTTPErrorHandler func(*HTTPError, *Context)
HTTPErrorHandler func(error, *Context)

BindFunc func(*http.Request, interface{}) *HTTPError
BindFunc func(*http.Request, interface{}) error

// Renderer is the interface that wraps the Render method.
//
// Render renders the HTML template with given name and specified data.
// It writes the output to w.
Renderer interface {
Render(w io.Writer, name string, data interface{}) *HTTPError
Render(w io.Writer, name string, data interface{}) error
}
)

Expand Down Expand Up @@ -120,6 +120,10 @@ var (
RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
)

func (e *HTTPError) Error() string {
return e.Message
}

// New creates an Echo instance.
func New() (e *Echo) {
e = &Echo{
Expand All @@ -135,33 +139,30 @@ func New() (e *Echo) {
//----------

e.SetMaxParam(5)
e.notFoundHandler = func(c *Context) *HTTPError {
e.notFoundHandler = func(c *Context) error {
return &HTTPError{Code: http.StatusNotFound}
}
e.SetHTTPErrorHandler(func(he *HTTPError, c *Context) {
if he.Code == 0 {
he.Code = http.StatusInternalServerError
}
if he.Message == "" {
he.Message = http.StatusText(he.Code)
e.SetHTTPErrorHandler(func(err error, c *Context) {
code := http.StatusInternalServerError
msg := http.StatusText(code)
if he, ok := err.(*HTTPError); ok {
code = he.Code
msg = he.Message
}
if e.debug && he.Error != nil {
he.Message = he.Error.Error()
if e.Debug() {
msg = err.Error()
}
http.Error(c.Response, he.Message, he.Code)
http.Error(c.Response, msg, code)
})
e.SetBinder(func(r *http.Request, v interface{}) *HTTPError {
e.SetBinder(func(r *http.Request, v interface{}) error {
ct := r.Header.Get(ContentType)
err := UnsupportedMediaType
if strings.HasPrefix(ct, ApplicationJSON) {
err = json.NewDecoder(r.Body).Decode(v)
} else if strings.HasPrefix(ct, ApplicationForm) {
err = nil
}
if err != nil {
return &HTTPError{Error: err}
}
return nil
return err
})
return
}
Expand Down Expand Up @@ -261,6 +262,21 @@ func (e *Echo) Trace(path string, h Handler) {
e.add(TRACE, path, h)
}

// WebSocket adds a WebSocket route > handler to the router.
func (e *Echo) WebSocket(path string, h HandlerFunc) {
e.Get(path, func(c *Context) *HTTPError {
wss := websocket.Server{
Handler: func(ws *websocket.Conn) {
c.Socket = ws
c.Response.status = http.StatusSwitchingProtocols
h(c)
},
}
wss.ServeHTTP(c.Response.writer, c.Request)
return nil
})
}

func (e *Echo) add(method, path string, h Handler) {
key := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
e.uris[key] = path
Expand All @@ -280,15 +296,15 @@ func (e *Echo) Favicon(file string) {
// Static serves static files.
func (e *Echo) Static(path, root string) {
fs := http.StripPrefix(path, http.FileServer(http.Dir(root)))
e.Get(path+"*", func(c *Context) *HTTPError {
e.Get(path+"*", func(c *Context) error {
fs.ServeHTTP(c.Response, c.Request)
return nil
})
}

// ServeFile serves a file.
func (e *Echo) ServeFile(path, file string) {
e.Get(path, func(c *Context) *HTTPError {
e.Get(path, func(c *Context) error {
http.ServeFile(c.Response, c.Request, file)
return nil
})
Expand Down Expand Up @@ -376,16 +392,16 @@ func wrapMiddleware(m Middleware) MiddlewareFunc {
return m
case HandlerFunc:
return wrapHandlerFuncMW(m)
case func(*Context) *HTTPError:
case func(*Context) error:
return wrapHandlerFuncMW(m)
case func(http.Handler) http.Handler:
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) (he *HTTPError) {
return func(c *Context) (err error) {
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.Response.Writer = w
c.Response.writer = w
c.Request = r
he = h(c)
})).ServeHTTP(c.Response.Writer, c.Request)
err = h(c)
})).ServeHTTP(c.Response.writer, c.Request)
return
}
}
Expand All @@ -403,9 +419,9 @@ func wrapMiddleware(m Middleware) MiddlewareFunc {
// Wraps HandlerFunc middleware
func wrapHandlerFuncMW(m HandlerFunc) MiddlewareFunc {
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError {
if he := m(c); he != nil {
return he
return func(c *Context) error {
if err := m(c); err != nil {
return err
}
return h(c)
}
Expand All @@ -415,9 +431,9 @@ func wrapHandlerFuncMW(m HandlerFunc) MiddlewareFunc {
// Wraps http.HandlerFunc middleware
func wrapHTTPHandlerFuncMW(m http.HandlerFunc) MiddlewareFunc {
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError {
return func(c *Context) error {
if !c.Response.committed {
m.ServeHTTP(c.Response.Writer, c.Request)
m.ServeHTTP(c.Response.writer, c.Request)
}
return h(c)
}
Expand All @@ -429,15 +445,15 @@ func wrapHandler(h Handler) HandlerFunc {
switch h := h.(type) {
case HandlerFunc:
return h
case func(*Context) *HTTPError:
case func(*Context) error:
return h
case http.Handler, http.HandlerFunc:
return func(c *Context) *HTTPError {
return func(c *Context) error {
h.(http.Handler).ServeHTTP(c.Response, c.Request)
return nil
}
case func(http.ResponseWriter, *http.Request):
return func(c *Context) *HTTPError {
return func(c *Context) error {
h(c.Response, c.Request)
return nil
}
Expand Down
Loading

0 comments on commit 13ac746

Please sign in to comment.