Skip to content

Commit

Permalink
feat(jwt): make KeyFunc public in JWT middleware (labstack#1756)
Browse files Browse the repository at this point in the history
* feat(jwt): make KeyFunc public in JWT middleware

It allows a user-defined function to supply the key for a token
verification.
  • Loading branch information
antonindrawan authored May 8, 2021
1 parent 6430665 commit 76f186a
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 23 deletions.
65 changes: 42 additions & 23 deletions middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,19 @@ type (
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
ErrorHandlerWithContext JWTErrorHandlerWithContext

// Signing key to validate token. Used as fallback if SigningKeys has length 0.
// Required. This or SigningKeys.
// Signing key to validate token.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKeys is provided.
SigningKey interface{}

// Map of signing keys to validate token with kid field usage.
// Required. This or SigningKey.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKey is provided.
SigningKeys map[string]interface{}

// Signing method, used to check token signing method.
// Signing method used to check the token's signing algorithm.
// Optional. Default value HS256.
SigningMethod string

Expand All @@ -64,7 +68,16 @@ type (
// Optional. Default value "Bearer".
AuthScheme string

keyFunc jwt.Keyfunc
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
//
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither SigningKeys nor SigningKey is provided.
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
KeyFunc jwt.Keyfunc
}

// JWTSuccessHandler defines a function which is executed for a valid token.
Expand Down Expand Up @@ -99,6 +112,7 @@ var (
TokenLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
KeyFunc: nil,
}
)

Expand All @@ -123,7 +137,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper
}
if config.SigningKey == nil && len(config.SigningKeys) == 0 {
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil {
panic("echo: jwt middleware requires signing key")
}
if config.SigningMethod == "" {
Expand All @@ -141,21 +155,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.AuthScheme == "" {
config.AuthScheme = DefaultJWTConfig.AuthScheme
}
config.keyFunc = func(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(config.SigningKeys) > 0 {
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := config.SigningKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}

return config.SigningKey, nil
if config.KeyFunc == nil {
config.KeyFunc = config.defaultKeyFunc
}

// Initialize
Expand Down Expand Up @@ -196,11 +197,11 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
token := new(jwt.Token)
// Issue #647, #656
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.keyFunc)
token, err = jwt.Parse(auth, config.KeyFunc)
} else {
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc)
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
}
if err == nil && token.Valid {
// Store user information from token into context.
Expand All @@ -225,6 +226,24 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
}
}

// defaultKeyFunc returns a signing key of the given token.
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(config.SigningKeys) > 0 {
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := config.SigningKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}

return config.SigningKey, nil
}

// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) jwtExtractor {
return func(c echo.Context) (string, error) {
Expand Down
30 changes: 30 additions & 0 deletions middleware/jwt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"errors"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -220,6 +221,35 @@ func TestJWT(t *testing.T) {
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return validKey, nil
},
},
info: "Valid JWT with a valid key using a user-defined KeyFunc",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return invalidKey, nil
},
},
expErrCode: http.StatusUnauthorized,
info: "Valid JWT with an invalid key using a user-defined KeyFunc",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return nil, errors.New("faulty KeyFunc")
},
},
expErrCode: http.StatusUnauthorized,
info: "Token verification does not pass using a user-defined KeyFunc",
},
} {
if tc.reqURL == "" {
tc.reqURL = "/"
Expand Down

0 comments on commit 76f186a

Please sign in to comment.