Skip to content

Commit

Permalink
Trailing slash middleware with option to redirect
Browse files Browse the repository at this point in the history
Signed-off-by: Vishal Rana <[email protected]>
  • Loading branch information
vishr committed Apr 13, 2016
1 parent 6c27cff commit b9aa218
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 5 deletions.
2 changes: 1 addition & 1 deletion context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ func TestContext(t *testing.T) {
rec = test.NewResponseRecorder()
c = NewContext(rq, rec, e)
assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
assert.Equal(t, http.StatusMovedPermanently, rec.Status())
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))

// Error
rec = test.NewResponseRecorder()
Expand Down
3 changes: 3 additions & 0 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ type (
// URI returns the unmodified `Request-URI` sent by the client.
URI() string

// SetURI sets the URI of the request.
SetURI(string)

// URL returns `engine.URL`.
URL() URL

Expand Down
5 changes: 5 additions & 0 deletions engine/fasthttp/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ func (r *Request) URI() string {
return string(r.RequestURI())
}

// SetURI implements `engine.Request#SetURI` function.
func (r *Request) SetURI(uri string) {
r.Request.Header.SetRequestURI(uri)
}

// Body implements `engine.Request#Body` function.
func (r *Request) Body() io.Reader {
return bytes.NewBuffer(r.PostBody())
Expand Down
5 changes: 5 additions & 0 deletions engine/standard/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ func (r *Request) URI() string {
return r.RequestURI
}

// SetURI implements `engine.Request#SetURI` function.
func (r *Request) SetURI(uri string) {
r.RequestURI = uri
}

// Body implements `engine.Request#Body` function.
func (r *Request) Body() io.Reader {
return r.Request.Body
Expand Down
43 changes: 39 additions & 4 deletions middleware/slash.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,39 @@ import (
"github.com/labstack/echo"
)

type (
// TrailingSlashConfig defines the config for TrailingSlash middleware.
TrailingSlashConfig struct {
// RedirectCode is the status code used when redirecting the request.
// Optional but when provided the request is redirected using this code.
RedirectCode int
}
)

// AddTrailingSlash returns a root level (before router) middleware which adds a
// trailing slash to the request `URL#Path`.
//
// Usage `Echo#Pre(AddTrailingSlash())`
func AddTrailingSlash() echo.MiddlewareFunc {
return AddTrailingSlashWithConfig(TrailingSlashConfig{})
}

// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware from config.
// See `AddTrailingSlash()`.
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
url := c.Request().URL()
rq := c.Request()
url := rq.URL()
path := url.Path()
if path != "/" && path[len(path)-1] != '/' {
url.SetPath(path + "/")
path += "/"
uri := path + "?" + url.QueryString()
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
}
rq.SetURI(uri)
url.SetPath(path)
}
return next(c)
}
Expand All @@ -26,13 +48,26 @@ func AddTrailingSlash() echo.MiddlewareFunc {
//
// Usage `Echo#Pre(RemoveTrailingSlash())`
func RemoveTrailingSlash() echo.MiddlewareFunc {
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
}

// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware from config.
// See `RemoveTrailingSlash()`.
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
url := c.Request().URL()
rq := c.Request()
url := rq.URL()
path := url.Path()
l := len(path) - 1
if path != "/" && path[l] == '/' {
url.SetPath(path[:l])
path = path[:l]
uri := path + "?" + url.QueryString()
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
}
rq.SetURI(uri)
url.SetPath(path)
}
return next(c)
}
Expand Down
23 changes: 23 additions & 0 deletions middleware/slash_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"net/http"
"testing"

"github.com/labstack/echo"
Expand All @@ -18,6 +19,17 @@ func TestAddTrailingSlash(t *testing.T) {
})
h(c)
assert.Equal(t, "/add-slash/", rq.URL().Path())

// With config
rq = test.NewRequest(echo.GET, "/add-slash?key=value", nil)
rc = test.NewResponseRecorder()
c = echo.NewContext(rq, rc, e)
h = AddTrailingSlashWithConfig(TrailingSlashConfig{RedirectCode: http.StatusMovedPermanently})(func(c echo.Context) error {
return nil
})
h(c)
assert.Equal(t, http.StatusMovedPermanently, rc.Status())
assert.Equal(t, "/add-slash/?key=value", rc.Header().Get(echo.HeaderLocation))
}

func TestRemoveTrailingSlash(t *testing.T) {
Expand All @@ -30,4 +42,15 @@ func TestRemoveTrailingSlash(t *testing.T) {
})
h(c)
assert.Equal(t, "/remove-slash", rq.URL().Path())

// With config
rq = test.NewRequest(echo.GET, "/remove-slash/?key=value", nil)
rc = test.NewResponseRecorder()
c = echo.NewContext(rq, rc, e)
h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{RedirectCode: http.StatusMovedPermanently})(func(c echo.Context) error {
return nil
})
h(c)
assert.Equal(t, http.StatusMovedPermanently, rc.Status())
assert.Equal(t, "/remove-slash?key=value", rc.Header().Get(echo.HeaderLocation))
}
5 changes: 5 additions & 0 deletions test/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type (

func NewRequest(method, url string, body io.Reader) engine.Request {
r, _ := http.NewRequest(method, url, body)
r.RequestURI = url
return &Request{
request: r,
url: &URL{url: r.URL},
Expand Down Expand Up @@ -84,6 +85,10 @@ func (r *Request) URI() string {
return r.request.RequestURI
}

func (r *Request) SetURI(uri string) {
r.request.RequestURI = uri
}

func (r *Request) Body() io.Reader {
return r.request.Body
}
Expand Down

0 comments on commit b9aa218

Please sign in to comment.