Skip to content

Commit

Permalink
Add HTTPStatusMessageFunc to allow users to override status messages (a…
Browse files Browse the repository at this point in the history
…ppleboy#96)

Add HTTPStatusMessageFunc to allow users to override HTTP status messages triggered by the middleware
  • Loading branch information
jbfm authored and appleboy committed Dec 5, 2017
1 parent 4ef6a7a commit d16975b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 22 deletions.
83 changes: 67 additions & 16 deletions auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,53 @@ type GinJWTMiddleware struct {

// TimeFunc provides the current time. You can override it to use another time value. This is useful for testing or if your server uses a different time zone than your tokens.
TimeFunc func() time.Time

// HTTP Status messages for when something in the JWT middleware fails.
// Check error (e) to determine the appropriate error message.
HTTPStatusMessageFunc func(e error, c *gin.Context) string
}

var (
// ErrMissingRealm indicates Realm name is required
ErrMissingRealm = errors.New("realm is missing")

// ErrMissingSecretKey indicates Secret key is required
ErrMissingSecretKey = errors.New("secret key is required")

// ErrForbidden when HTTP status 403 is given
ErrForbidden = errors.New("you don't have permission to access this resource")

// ErrMissingAuthenticatorFunc indicates Authenticator is required
ErrMissingAuthenticatorFunc = errors.New("ginJWTMiddleware.Authenticator func is undefined")

// ErrMissingLoginValues indicates a user tried to authenticate without username or password
ErrMissingLoginValues = errors.New("missing Username or Password")

// ErrFailedAuthentication indicates authentication failed, could be faulty username or password
ErrFailedAuthentication = errors.New("incorrect Username or Password")

// ErrFailedTokenCreation indicates JWT Token failed to create, reason unknown
ErrFailedTokenCreation = errors.New("failed to create JWT Token")

// ErrExpiredToken indicates JWT token has expired. Can't refresh.
ErrExpiredToken = errors.New("token is expired")

// ErrEmptyAuthHeader can be thrown if authing with a HTTP header, the Auth header needs to be set
ErrEmptyAuthHeader = errors.New("auth header is empty")

// ErrInvalidAuthHeader indicates auth header is invalid, could for example have the wrong Realm name
ErrInvalidAuthHeader = errors.New("auth header is invalid")

// ErrEmptyQueryToken can be thrown if authing with URL Query, the query token variable is empty
ErrEmptyQueryToken = errors.New("query token is empty")

// ErrEmptyCookieToken can be thrown if authing with a cookie, the token cokie is empty
ErrEmptyCookieToken = errors.New("cookie token is empty")

// ErrInvalidSigningAlgorithm indicates signing algorithm is invalid, needs to be HS256, HS384, or HS512
ErrInvalidSigningAlgorithm = errors.New("invalid signing algorithm")
)

// Login form structure.
type Login struct {
Username string `form:"username" json:"username" binding:"required"`
Expand Down Expand Up @@ -127,12 +172,18 @@ func (mw *GinJWTMiddleware) MiddlewareInit() error {
}
}

if mw.HTTPStatusMessageFunc == nil {
mw.HTTPStatusMessageFunc = func(e error, c *gin.Context) string {
return e.Error()
}
}

if mw.Realm == "" {
return errors.New("realm is required")
return ErrMissingRealm
}

if mw.Key == nil {
return errors.New("secret key is required")
return ErrMissingSecretKey
}

return nil
Expand All @@ -142,7 +193,7 @@ func (mw *GinJWTMiddleware) MiddlewareInit() error {
func (mw *GinJWTMiddleware) MiddlewareFunc() gin.HandlerFunc {
if err := mw.MiddlewareInit(); err != nil {
return func(c *gin.Context) {
mw.unauthorized(c, http.StatusInternalServerError, err.Error())
mw.unauthorized(c, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(err, nil))
return
}
}
Expand All @@ -157,7 +208,7 @@ func (mw *GinJWTMiddleware) middlewareImpl(c *gin.Context) {
token, err := mw.parseToken(c)

if err != nil {
mw.unauthorized(c, http.StatusUnauthorized, err.Error())
mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, c))
return
}

Expand All @@ -168,7 +219,7 @@ func (mw *GinJWTMiddleware) middlewareImpl(c *gin.Context) {
c.Set("userID", id)

if !mw.Authorizator(id, c) {
mw.unauthorized(c, http.StatusForbidden, "You don't have permission to access.")
mw.unauthorized(c, http.StatusForbidden, mw.HTTPStatusMessageFunc(ErrForbidden, c))
return
}

Expand All @@ -186,19 +237,19 @@ func (mw *GinJWTMiddleware) LoginHandler(c *gin.Context) {
var loginVals Login

if c.ShouldBindWith(&loginVals, binding.JSON) != nil {
mw.unauthorized(c, http.StatusBadRequest, "Missing Username or Password")
mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrMissingLoginValues, c))
return
}

if mw.Authenticator == nil {
mw.unauthorized(c, http.StatusInternalServerError, "Missing define authenticator func")
mw.unauthorized(c, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(ErrMissingAuthenticatorFunc, c))
return
}

userID, ok := mw.Authenticator(loginVals.Username, loginVals.Password, c)

if !ok {
mw.unauthorized(c, http.StatusUnauthorized, "Incorrect Username / Password")
mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrFailedAuthentication, c))
return
}

Expand All @@ -224,7 +275,7 @@ func (mw *GinJWTMiddleware) LoginHandler(c *gin.Context) {
tokenString, err := token.SignedString(mw.Key)

if err != nil {
mw.unauthorized(c, http.StatusUnauthorized, "Create JWT Token faild")
mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrFailedTokenCreation, c))
return
}

Expand All @@ -244,7 +295,7 @@ func (mw *GinJWTMiddleware) RefreshHandler(c *gin.Context) {
origIat := int64(claims["orig_iat"].(float64))

if origIat < mw.TimeFunc().Add(-mw.MaxRefresh).Unix() {
mw.unauthorized(c, http.StatusUnauthorized, "Token is expired.")
mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, c))
return
}

Expand All @@ -264,7 +315,7 @@ func (mw *GinJWTMiddleware) RefreshHandler(c *gin.Context) {
tokenString, err := newToken.SignedString(mw.Key)

if err != nil {
mw.unauthorized(c, http.StatusUnauthorized, "Create JWT Token faild")
mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrFailedTokenCreation, c))
return
}

Expand Down Expand Up @@ -316,12 +367,12 @@ func (mw *GinJWTMiddleware) jwtFromHeader(c *gin.Context, key string) (string, e
authHeader := c.Request.Header.Get(key)

if authHeader == "" {
return "", errors.New("auth header empty")
return "", ErrEmptyAuthHeader
}

parts := strings.SplitN(authHeader, " ", 2)
if !(len(parts) == 2 && parts[0] == mw.TokenHeadName) {
return "", errors.New("invalid auth header")
return "", ErrInvalidAuthHeader
}

return parts[1], nil
Expand All @@ -331,7 +382,7 @@ func (mw *GinJWTMiddleware) jwtFromQuery(c *gin.Context, key string) (string, er
token := c.Query(key)

if token == "" {
return "", errors.New("Query token empty")
return "", ErrEmptyQueryToken
}

return token, nil
Expand All @@ -341,7 +392,7 @@ func (mw *GinJWTMiddleware) jwtFromCookie(c *gin.Context, key string) (string, e
cookie, _ := c.Cookie(key)

if cookie == "" {
return "", errors.New("Cookie token empty")
return "", ErrEmptyCookieToken
}

return cookie, nil
Expand All @@ -367,7 +418,7 @@ func (mw *GinJWTMiddleware) parseToken(c *gin.Context) (*jwt.Token, error) {

return jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
if jwt.GetSigningMethod(mw.SigningAlgorithm) != token.Method {
return nil, errors.New("invalid signing algorithm")
return nil, ErrInvalidSigningAlgorithm
}

return mw.Key, nil
Expand Down
46 changes: 40 additions & 6 deletions auth_jwt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt

import (
"errors"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -51,7 +52,7 @@ func TestMissingRealm(t *testing.T) {
err := authMiddleware.MiddlewareInit()

assert.Error(t, err)
assert.Equal(t, "realm is required", err.Error())
assert.Equal(t, ErrMissingRealm, err)
}

func TestMissingKey(t *testing.T) {
Expand All @@ -72,7 +73,7 @@ func TestMissingKey(t *testing.T) {
err := authMiddleware.MiddlewareInit()

assert.Error(t, err)
assert.Equal(t, "secret key is required", err.Error())
assert.Equal(t, ErrMissingSecretKey, err)
}

func TestMissingTimeOut(t *testing.T) {
Expand Down Expand Up @@ -149,7 +150,7 @@ func TestInternalServerError(t *testing.T) {

message, _ := jsonparser.GetString(data, "message")

assert.Equal(t, "realm is required", message)
assert.Equal(t, ErrMissingRealm.Error(), message)
assert.Equal(t, http.StatusInternalServerError, r.Code)
})
}
Expand All @@ -175,7 +176,7 @@ func TestMissingAuthenticatorForLoginHandler(t *testing.T) {
data := []byte(r.Body.String())
message, _ := jsonparser.GetString(data, "message")

assert.Equal(t, "Missing define authenticator func", message)
assert.Equal(t, ErrMissingAuthenticatorFunc.Error(), message)
assert.Equal(t, http.StatusInternalServerError, r.Code)
})
}
Expand Down Expand Up @@ -215,7 +216,7 @@ func TestLoginHandler(t *testing.T) {

message, _ := jsonparser.GetString(data, "message")

assert.Equal(t, "Missing Username or Password", message)
assert.Equal(t, ErrMissingLoginValues.Error(), message)
assert.Equal(t, http.StatusBadRequest, r.Code)
assert.Equal(t, "application/json; charset=utf-8", r.HeaderMap.Get("Content-Type"))
})
Expand All @@ -230,7 +231,7 @@ func TestLoginHandler(t *testing.T) {

message, _ := jsonparser.GetString(data, "message")

assert.Equal(t, "Incorrect Username / Password", message)
assert.Equal(t, ErrFailedAuthentication.Error(), message)
assert.Equal(t, http.StatusUnauthorized, r.Code)
})

Expand Down Expand Up @@ -755,3 +756,36 @@ func TestDefineTokenHeadName(t *testing.T) {
assert.Equal(t, http.StatusOK, r.Code)
})
}

func TestHTTPStatusMessageFunc(t *testing.T) {
var successError = errors.New("Successful test error")
var failedError = errors.New("Failed test error")
var successMessage = "Overwrite error message."

authMiddleware := &GinJWTMiddleware{
Key: key,
Timeout: time.Hour,
MaxRefresh: time.Hour * 24,
Authenticator: func(userId string, password string, c *gin.Context) (string, bool) {
if userId == "admin" && password == "admin" {
return "", true
}

return "", false
},

HTTPStatusMessageFunc: func(e error, c *gin.Context) string {
if e == successError {
return successMessage
}

return e.Error()
},
}

successString := authMiddleware.HTTPStatusMessageFunc(successError, nil)
failedString := authMiddleware.HTTPStatusMessageFunc(failedError, nil)

assert.Equal(t, successMessage, successString)
assert.NotEqual(t, successMessage, failedString)
}

0 comments on commit d16975b

Please sign in to comment.