Skip to content

Commit

Permalink
enforcer invalidate & cache dump API exposed (devtron-labs#2098)
Browse files Browse the repository at this point in the history
* enforcer invalidate & cache dump API exposed

* cache not enabled check
  • Loading branch information
kripanshdevtron authored Jul 28, 2022
1 parent d9d2f92 commit eef3a9c
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 6 deletions.
46 changes: 46 additions & 0 deletions api/user/UserRestHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package user
import (
"encoding/json"
"errors"
"fmt"
"github.com/devtron-labs/devtron/api/restHandler/common"
"github.com/devtron-labs/devtron/pkg/user/casbin"
"net/http"
Expand Down Expand Up @@ -52,6 +53,8 @@ type UserRestHandler interface {
CheckUserRoles(w http.ResponseWriter, r *http.Request)
SyncOrchestratorToCasbin(w http.ResponseWriter, r *http.Request)
UpdateTriggerPolicyForTerminalAccess(w http.ResponseWriter, r *http.Request)
GetRoleCacheDump(w http.ResponseWriter, r *http.Request)
InvalidateRoleCache(w http.ResponseWriter, r *http.Request)
}

type userNamePassword struct {
Expand Down Expand Up @@ -715,3 +718,46 @@ func (handler UserRestHandlerImpl) UpdateTriggerPolicyForTerminalAccess(w http.R
}
common.WriteJsonResp(w, nil, "Trigger policy updated successfully.", http.StatusOK)
}

func (handler UserRestHandlerImpl) GetRoleCacheDump(w http.ResponseWriter, r *http.Request) {

userId, err := handler.userService.GetLoggedInUser(r)
if userId == 0 || err != nil {
handler.logger.Errorw("unauthorized user, GetRoleCacheDump", "userId", userId)
common.WriteJsonResp(w, err, "Unauthorized User", http.StatusUnauthorized)
return
}
isSuperAdmin, err := handler.userService.IsSuperAdmin(int(userId))
if err != nil {
common.WriteJsonResp(w, err, "Failed to check is super admin", http.StatusInternalServerError)
return
}
if !isSuperAdmin {
common.WriteJsonResp(w, fmt.Errorf("unauthorized user"), "Unauthorized User", http.StatusForbidden)
} else {
cacheDump := handler.enforcer.GetCacheDump()
common.WriteJsonResp(w, nil, cacheDump, http.StatusOK)
}
}

func (handler UserRestHandlerImpl) InvalidateRoleCache(w http.ResponseWriter, r *http.Request) {

userId, err := handler.userService.GetLoggedInUser(r)
if userId == 0 || err != nil {
handler.logger.Errorw("unauthorized user, InvalidateRoleCache", "userId", userId)
common.WriteJsonResp(w, err, "Unauthorized User", http.StatusUnauthorized)
return
}
isSuperAdmin, err := handler.userService.IsSuperAdmin(int(userId))
if err != nil {
common.WriteJsonResp(w, err, "Failed to check is super admin", http.StatusInternalServerError)
return
}
if !isSuperAdmin {
common.WriteJsonResp(w, fmt.Errorf("unauthorized user"), "Unauthorized User", http.StatusForbidden)
} else {
handler.enforcer.InvalidateCompleteCache()
common.WriteJsonResp(w, nil, "Cache Cleaned Successfully", http.StatusOK)
}

}
4 changes: 4 additions & 0 deletions api/user/UserRouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,8 @@ func (router UserRouterImpl) InitUserRouter(userAuthRouter *mux.Router) {
HandlerFunc(router.userRestHandler.SyncOrchestratorToCasbin).Methods("GET")
userAuthRouter.Path("/update/trigger/terminal").
HandlerFunc(router.userRestHandler.UpdateTriggerPolicyForTerminalAccess).Methods("PUT")
userAuthRouter.Path("/role/cache").
HandlerFunc(router.userRestHandler.GetRoleCacheDump).Methods("GET")
userAuthRouter.Path("/role/cache/invalidate").
HandlerFunc(router.userRestHandler.InvalidateRoleCache).Methods("GET")
}
1 change: 1 addition & 0 deletions pkg/user/casbin/Adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func Create() *casbin.SyncedEnforcer {
}
e = auth
err = e.LoadPolicy()
log.Println("casbin Policies Loaded Successfully")
if err != nil {
log.Fatal(err)
}
Expand Down
30 changes: 24 additions & 6 deletions pkg/user/casbin/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package casbin

import (
"encoding/json"
"fmt"
"github.com/caarlos0/env"
"github.com/casbin/casbin"
Expand All @@ -43,6 +44,7 @@ type Enforcer interface {
InvalidateCache(emailId string) bool
InvalidateCompleteCache()
ReloadPolicy() error
GetCacheDump() string
}

func NewEnforcerImpl(
Expand Down Expand Up @@ -139,7 +141,10 @@ func (e *EnforcerImpl) enforceByEmailInBatchSync(wg *sync.WaitGroup, mutex *sync
start := time.Now()
batchResult := make(map[string]bool)
for _, resourceItem := range vals {
batchResult[resourceItem] = e.enforcerEnforce(strings.ToLower(emailId), resource, action, resourceItem)
data, err := e.enforcerEnforce(strings.ToLower(emailId), resource, action, resourceItem)
if err == nil {
batchResult[resourceItem] = data
}
}
duration := time.Since(start)
mutex.Lock()
Expand Down Expand Up @@ -333,6 +338,19 @@ func (e *EnforcerImpl) InvalidateCompleteCache() {
}
}

func (e *EnforcerImpl) GetCacheDump() string {
if e.Cache == nil {
return "not-enabled"
}
items := e.Cache.Items()
cacheData, err := json.Marshal(items)
if err != nil {
e.logger.Infow("error occurred while taking cache dump", "reason", err)
return ""
}
return string(cacheData)
}

// enforce is a helper to additionally check a default role and invoke a custom claims enforcement function
func (e *EnforcerImpl) enforce(token string, resource string, action string, resourceItem string) bool {
// check the default role
Expand All @@ -346,15 +364,15 @@ func (e *EnforcerImpl) enforce(token string, resource string, action string, res
func (e *EnforcerImpl) enforceAndUpdateCache(email string, resource string, action string, resourceItem string) bool {
cacheData := e.getEnforcerCacheLock(email)
cacheData.lock.Lock()
enforcedStatus := e.enforcerEnforce(email, resource, action, resourceItem)
defer cacheData.lock.Unlock()
enforcedStatus, err := e.enforcerEnforce(email, resource, action, resourceItem)
returnVal := atomic.AddInt64(&cacheData.enforceReqCounter, -1)
if cacheData.cacheCleaningFlag {
if cacheData.cacheCleaningFlag || err != nil {
if returnVal == 0 {
cacheData.cacheCleaningFlag = false
}
e.logger.Debugw("not updating enforcer status for cache", "email", email, "resource", resource,
"action", action, "resourceItem", resourceItem, "enforceReqCounter", cacheData.enforceReqCounter)
"action", action, "resourceItem", resourceItem, "enforceReqCounter", cacheData.enforceReqCounter, "err", err == nil)
return enforcedStatus
}
enforceData := e.getCacheData(email, resource, action)
Expand All @@ -363,15 +381,15 @@ func (e *EnforcerImpl) enforceAndUpdateCache(email string, resource string, acti
return enforcedStatus
}

func (e *EnforcerImpl) enforcerEnforce(email string, resource string, action string, resourceItem string) bool {
func (e *EnforcerImpl) enforcerEnforce(email string, resource string, action string, resourceItem string) (bool, error) {
//e.enforcerRWLock.RLock()
//defer e.enforcerRWLock.RUnlock()
response, err := e.SyncedEnforcer.EnforceSafe(email, resource, action, resourceItem)
if err != nil {
e.logger.Errorw("error occurred while enforcing safe", "email", email,
"resource", resource, "action", action, "resourceItem", resourceItem, "reason", err)
}
return response
return response, err
}

func (e *EnforcerImpl) verifyTokenAndGetEmail(tokenString string) (string, bool) {
Expand Down
24 changes: 24 additions & 0 deletions pkg/user/casbin/rbac_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package casbin

import (
"encoding/json"
"fmt"
"github.com/patrickmn/go-cache"
"math/rand"
Expand Down Expand Up @@ -32,6 +33,29 @@ func TestEnforcerCache(t *testing.T) {
invalidateCache_123(lock, cache123)
})

t.Run("CacheDump", func(t *testing.T) {
for i := 0; i < 100_000; i++ {
emailId := GetRandomStringOfGivenLength(rand.Intn(50)) + "@yopmail.com"
getAndSet(lock, emailId, cache123)
cache123.GetWithExpiration(emailId)
//result, expiration, b := cache123.GetWithExpiration(emailId)
//fmt.Println("result", result, "expiration", expiration, "found", b)
}
//invalidateCache_123(lock, cache123)

fmt.Println("dump: ", GetCacheDump(cache123))
})

}

func GetCacheDump(cache *cache.Cache) string {
items := cache.Items()
cacheData, err := json.Marshal(items)
if err != nil {
fmt.Println("error occurred while taking cache dump", "reason", err)
return ""
}
return string(cacheData)
}

func GetRandomStringOfGivenLength(length int) string {
Expand Down

0 comments on commit eef3a9c

Please sign in to comment.