Skip to content

Commit

Permalink
Added body limit middleware
Browse files Browse the repository at this point in the history
Signed-off-by: Vishal Rana <[email protected]>
  • Loading branch information
vishr committed May 1, 2016
1 parent 4fd9f14 commit 0edb17e
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 19 deletions.
13 changes: 7 additions & 6 deletions echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,13 @@ var (

// Errors
var (
ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
ErrNotFound = NewHTTPError(http.StatusNotFound)
ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
ErrNotFound = NewHTTPError(http.StatusNotFound)
ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
)

// Error handlers
Expand Down
3 changes: 3 additions & 0 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ type (
// Body returns request's body.
Body() io.Reader

// Body sets request's body.
SetBody(io.Reader)

// FormValue returns the form field value for the provided name.
FormValue(string) string

Expand Down
6 changes: 6 additions & 0 deletions engine/fasthttp/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"mime/multipart"

"github.com/labstack/echo"
"github.com/labstack/echo/engine"
"github.com/labstack/gommon/log"
"github.com/valyala/fasthttp"
Expand Down Expand Up @@ -97,6 +98,11 @@ func (r *Request) Body() io.Reader {
return bytes.NewBuffer(r.PostBody())
}

// SetBody implements `engine.Request#SetBody` function.
func (r *Request) SetBody(reader io.Reader) {
r.SetBodyStream(reader, r.header.Get(echo.HeaderContentType))
}

// FormValue implements `engine.Request#FormValue` function.
func (r *Request) FormValue(name string) string {
return string(r.RequestCtx.FormValue(name))
Expand Down
6 changes: 6 additions & 0 deletions engine/standard/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package standard

import (
"io"
"io/ioutil"
"mime/multipart"
"net/http"

Expand Down Expand Up @@ -109,6 +110,11 @@ func (r *Request) Body() io.Reader {
return r.Request.Body
}

// SetBody implements `engine.Request#SetBody` function.
func (r *Request) SetBody(reader io.Reader) {
r.Request.Body = ioutil.NopCloser(reader)
}

// FormValue implements `engine.Request#FormValue` function.
func (r *Request) FormValue(name string) string {
return r.Request.FormValue(name)
Expand Down
10 changes: 5 additions & 5 deletions glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 1 addition & 8 deletions middleware/basic_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,13 @@ const (
basic = "Basic"
)

var (
// DefaultBasicAuthConfig is the default basic auth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{}
)

// BasicAuth returns an HTTP basic auth middleware.
//
// For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response.
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig
c.Validator = fn
return BasicAuthWithConfig(c)
return BasicAuthWithConfig(BasicAuthConfig{fn})
}

// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
Expand Down
84 changes: 84 additions & 0 deletions middleware/body_limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package middleware

import (
"fmt"
"io"
"net/http"
"sync"

"github.com/labstack/echo"
"github.com/labstack/gommon/bytes"
)

type (
// BodyLimitConfig defines the config for body limit middleware.
BodyLimitConfig struct {
Limit string `json:"limit"`
limit int
}

limitedReader struct {
BodyLimitConfig
reader io.Reader
read int
context echo.Context
}
)

// BodyLimit returns a body limit middleware.
//
// BodyLimit middleware sets the maximum allowed size for a request body, if the
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
// response. The body limit is determined based on the actually read and not `Content-Length`
// request header, which makes it super secure.
// Limit can be specifed as `4x` or `4xB`, where x is one of the multple from K, M,
// G, T or P.
func BodyLimit(limit string) echo.MiddlewareFunc {
return BodyLimitWithConfig(BodyLimitConfig{Limit: limit})
}

// BodyLimitWithConfig returns a body limit middleware from config.
// See `BodyLimit()`.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
limit, err := bytes.Parse(config.Limit)
if err != nil {
panic(fmt.Errorf("invalid body-limit=%s", config.Limit))
}
config.limit = limit
pool := limitedReaderPool(config)

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
r := pool.Get().(*limitedReader)
r.Reset(req.Body(), c)
defer pool.Put(r)
req.SetBody(r)
return next(c)
}
}
}

func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b)
r.read += n
if r.read > r.limit {
s := http.StatusRequestEntityTooLarge
r.context.String(s, http.StatusText(s))
return n, echo.ErrStatusRequestEntityTooLarge
}
return
}

func (r *limitedReader) Reset(reader io.Reader, context echo.Context) {
r.reader = reader
r.context = context
}

func limitedReaderPool(c BodyLimitConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: c}
},
}
}
33 changes: 33 additions & 0 deletions middleware/body_limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package middleware

import (
"io/ioutil"
"net/http"
"testing"

"bytes"

"github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert"
)

func TestBodyLimit(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!")))
rec := test.NewResponseRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
body, _ := ioutil.ReadAll(c.Request().Body())
return c.String(http.StatusOK, string(body))
}

// Within limit
BodyLimit("2M")(h)(c)
assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, "Hello, World!", rec.Body.String())

// Overlimit
// BodyLimit("2B")(h)(c)
// assert.Equal(t, "Hello, World!", rec.Body.String())
}
5 changes: 5 additions & 0 deletions test/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package test

import (
"io"
"io/ioutil"
"mime/multipart"
"net/http"

Expand Down Expand Up @@ -93,6 +94,10 @@ func (r *Request) Body() io.Reader {
return r.request.Body
}

func (r *Request) SetBody(reader io.Reader) {
r.request.Body = ioutil.NopCloser(reader)
}

func (r *Request) FormValue(name string) string {
return r.request.FormValue(name)
}
Expand Down

0 comments on commit 0edb17e

Please sign in to comment.