Skip to content

Commit

Permalink
Middleware interface
Browse files Browse the repository at this point in the history
Signed-off-by: Vishal Rana <[email protected]>
  • Loading branch information
vishr committed Feb 8, 2016
1 parent f27de9a commit 65fcca2
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 54 deletions.
8 changes: 4 additions & 4 deletions glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 22 additions & 20 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,40 @@ type (
)

const (
Basic = "Basic"
basic = "Basic"
)

// BasicAuth returns an HTTP basic authentication middleware.
//
// For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Skip WebSocket
if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket {
return nil
}
func BasicAuth(fn BasicValidateFunc) MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Skip WebSocket
if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket {
return nil
}

auth := c.Request().Header().Get(echo.Authorization)
l := len(Basic)
auth := c.Request().Header().Get(echo.Authorization)
l := len(basic)

if len(auth) > l+1 && auth[:l] == Basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err == nil {
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
if fn(cred[:i], cred[i+1:]) {
return nil
if len(auth) > l+1 && auth[:l] == basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err == nil {
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
if fn(cred[:i], cred[i+1:]) {
return nil
}
}
}
}
}
c.Response().Header().Set(echo.WWWAuthenticate, basic+" realm=Restricted")
return echo.NewHTTPError(http.StatusUnauthorized)
}
c.Response().Header().Set(echo.WWWAuthenticate, Basic+" realm=Restricted")
return echo.NewHTTPError(http.StatusUnauthorized)
}
}
25 changes: 14 additions & 11 deletions middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,41 @@ func TestBasicAuth(t *testing.T) {
}
return false
}
ba := BasicAuth(fn)
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw := BasicAuth(fn)(h)

// Valid credentials
auth := Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header().Set(echo.Authorization, auth)
assert.NoError(t, ba(c))
assert.NoError(t, mw(c))

//---------------------
// Invalid credentials
//---------------------

// Incorrect password
auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header().Set(echo.Authorization, auth)
he := ba(c).(*echo.HTTPError)
he := mw(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))

// Empty Authorization header
req.Header().Set(echo.Authorization, "")
he = ba(c).(*echo.HTTPError)
he = mw(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))

// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header().Set(echo.Authorization, auth)
he = ba(c).(*echo.HTTPError)
he = mw(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))

// WebSocket
c.Request().Header().Set(echo.Upgrade, echo.WebSocket)
assert.NoError(t, ba(c))
assert.NoError(t, mw(c))
}
6 changes: 3 additions & 3 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ var writerPool = sync.Pool{

// Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme.
func Gzip() echo.MiddlewareFunc {
scheme := "gzip"

func Gzip() MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc {
scheme := "gzip"

return func(c echo.Context) error {
c.Response().Header().Add(echo.Vary, echo.AcceptEncoding)
if strings.Contains(c.Request().Header().Get(echo.AcceptEncoding), scheme) {
Expand Down
2 changes: 1 addition & 1 deletion middleware/logger.go → middleware/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/labstack/gommon/color"
)

func Logger() echo.MiddlewareFunc {
func Log() MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
Expand Down
26 changes: 13 additions & 13 deletions middleware/logger_test.go → middleware/log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,35 @@ import (
"github.com/stretchr/testify/assert"
)

func TestLogger(t *testing.T) {
func TestLog(t *testing.T) {
// Note: Just for the test coverage, not a real test.
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder()
c := echo.NewContext(req, rec, e)

// Status 2xx
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
Logger()(h)(c)
mw := Log()(h)

// Status 2xx
mw(c)

// Status 3xx
rec = test.NewResponseRecorder()
c = echo.NewContext(req, rec, e)
h = func(c echo.Context) error {
return c.String(http.StatusTemporaryRedirect, "test")
}
Logger()(h)(c)
mw(c)

// Status 4xx
rec = test.NewResponseRecorder()
c = echo.NewContext(req, rec, e)
h = func(c echo.Context) error {
return c.String(http.StatusNotFound, "test")
}
Logger()(h)(c)
mw(c)

// Status 5xx with empty path
req = test.NewRequest(echo.GET, "", nil)
Expand All @@ -48,10 +49,10 @@ func TestLogger(t *testing.T) {
h = func(c echo.Context) error {
return errors.New("error")
}
Logger()(h)(c)
mw(c)
}

func TestLoggerIPAddress(t *testing.T) {
func TestLogIPAddress(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder()
Expand All @@ -62,23 +63,22 @@ func TestLoggerIPAddress(t *testing.T) {
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}

mw := Logger()
mw := Log()(h)

// With X-Real-IP
req.Header().Add(echo.XRealIP, ip)
mw(h)(c)
mw(c)
assert.Contains(t, buf.String(), ip)

// With X-Forwarded-For
buf.Reset()
req.Header().Del(echo.XRealIP)
req.Header().Add(echo.XForwardedFor, ip)
mw(h)(c)
mw(c)
assert.Contains(t, buf.String(), ip)

// with req.RemoteAddr
buf.Reset()
mw(h)(c)
mw(c)
assert.Contains(t, buf.String(), ip)
}
15 changes: 15 additions & 0 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package middleware

import "github.com/labstack/echo"

type (
Middleware interface {
Process(echo.HandlerFunc) echo.HandlerFunc
}

MiddlewareFunc func(echo.HandlerFunc) echo.HandlerFunc
)

func (f MiddlewareFunc) Process(h echo.HandlerFunc) echo.HandlerFunc {
return f(h)
}
4 changes: 2 additions & 2 deletions middleware/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import (

// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
func Recover() echo.MiddlewareFunc {
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
func Recover() MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc {
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
return func(c echo.Context) error {
defer func() {
if err := recover(); err != nil {
Expand Down

0 comments on commit 65fcca2

Please sign in to comment.