Skip to content

Commit

Permalink
feat: improve TokenGenerator func (appleboy#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
appleboy authored May 31, 2018
1 parent 7518a3b commit 3ae2a3e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
11 changes: 7 additions & 4 deletions auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ import (
"gopkg.in/dgrijalva/jwt-go.v3"
)

// MapClaims type that uses the map[string]interface{} for JSON decoding
// This is the default claims type if you don't supply one
type MapClaims map[string]interface{}

// GinJWTMiddleware provides a Json-Web-Token authentication implementation. On failure, a 401 HTTP response
// is returned. On success, the wrapped middleware is called, and the userID is made available as
// c.Get("userID").(string).
Expand Down Expand Up @@ -53,7 +57,7 @@ type GinJWTMiddleware struct {
// Note that the payload is not encrypted.
// The attributes mentioned on jwt.io can't be used as keys for the map.
// Optional, by default no additional data will be set.
PayloadFunc func(data interface{}) map[string]interface{}
PayloadFunc func(data interface{}) MapClaims

// User can define own Unauthorized func.
Unauthorized func(*gin.Context, int, string)
Expand Down Expand Up @@ -430,7 +434,6 @@ func (mw *GinJWTMiddleware) RefreshHandler(c *gin.Context) {

// ExtractClaims help to extract the JWT claims
func ExtractClaims(c *gin.Context) jwt.MapClaims {

claims, exists := c.Get("JWT_PAYLOAD")
if !exists {
return make(jwt.MapClaims)
Expand All @@ -440,12 +443,12 @@ func ExtractClaims(c *gin.Context) jwt.MapClaims {
}

// TokenGenerator method that clients can use to get a jwt token.
func (mw *GinJWTMiddleware) TokenGenerator(userID string) (string, time.Time, error) {
func (mw *GinJWTMiddleware) TokenGenerator(userID string, data MapClaims) (string, time.Time, error) {
token := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm))
claims := token.Claims.(jwt.MapClaims)

if mw.PayloadFunc != nil {
for key, value := range mw.PayloadFunc(userID) {
for key, value := range mw.PayloadFunc(data) {
claims[key] = value
}
}
Expand Down
18 changes: 9 additions & 9 deletions auth_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ func TestLoginHandler(t *testing.T) {
authMiddleware := &GinJWTMiddleware{
Realm: "test zone",
Key: key,
PayloadFunc: func(userId interface{}) map[string]interface{} {
PayloadFunc: func(userId interface{}) MapClaims {
// Set custom claim, to be checked in Authorizator method
return map[string]interface{}{"testkey": "testval", "exp": 0}
return MapClaims{"testkey": "testval", "exp": 0}
},
Authenticator: func(userId string, password string, c *gin.Context) (interface{}, bool) {
if userId == "admin" && password == "admin" {
Expand Down Expand Up @@ -643,7 +643,7 @@ func TestClaimsDuringAuthorization(t *testing.T) {
Key: key,
Timeout: time.Hour,
MaxRefresh: time.Hour * 24,
PayloadFunc: func(data interface{}) map[string]interface{} {
PayloadFunc: func(data interface{}) MapClaims {
var testkey string
switch data {
case "Administrator":
Expand All @@ -654,7 +654,7 @@ func TestClaimsDuringAuthorization(t *testing.T) {
testkey = ""
}
// Set custom claim, to be checked in Authorizator method
return map[string]interface{}{"testkey": testkey, "exp": 0}
return MapClaims{"testkey": testkey, "exp": 0}
},
Authenticator: func(userId string, password string, c *gin.Context) (interface{}, bool) {
if userId == "admin" && password == "admin" {
Expand Down Expand Up @@ -689,7 +689,7 @@ func TestClaimsDuringAuthorization(t *testing.T) {
r := gofight.New()
handler := ginHandler(authMiddleware)

userToken, _, _ := authMiddleware.TokenGenerator("administrator")
userToken, _, _ := authMiddleware.TokenGenerator("administrator", MapClaims{})

r.GET("/auth/hello").
SetHeader(gofight.H{
Expand Down Expand Up @@ -740,7 +740,7 @@ func TestClaimsDuringAuthorization(t *testing.T) {

func TestEmptyClaims(t *testing.T) {

var jwtClaims map[string]interface{}
var jwtClaims jwt.MapClaims

// the middleware to test
authMiddleware := &GinJWTMiddleware{
Expand Down Expand Up @@ -832,7 +832,7 @@ func TestTokenExpire(t *testing.T) {

r := gofight.New()

userToken, _, _ := authMiddleware.TokenGenerator("admin")
userToken, _, _ := authMiddleware.TokenGenerator("admin", MapClaims{})

r.GET("/auth/refresh_token").
SetHeader(gofight.H{
Expand Down Expand Up @@ -865,7 +865,7 @@ func TestTokenFromQueryString(t *testing.T) {

r := gofight.New()

userToken, _, _ := authMiddleware.TokenGenerator("admin")
userToken, _, _ := authMiddleware.TokenGenerator("admin", MapClaims{})

r.GET("/auth/refresh_token").
SetHeader(gofight.H{
Expand Down Expand Up @@ -906,7 +906,7 @@ func TestTokenFromCookieString(t *testing.T) {

r := gofight.New()

userToken, _, _ := authMiddleware.TokenGenerator("admin")
userToken, _, _ := authMiddleware.TokenGenerator("admin", MapClaims{})

r.GET("/auth/refresh_token").
SetHeader(gofight.H{
Expand Down

0 comments on commit 3ae2a3e

Please sign in to comment.