Skip to content

Commit

Permalink
feat: add UpdateFilteredPolicies method
Browse files Browse the repository at this point in the history
Signed-off-by: tangyang9464 <[email protected]>
  • Loading branch information
tangyang9464 committed Aug 29, 2021
1 parent 6fc8c43 commit 5678dab
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 1 deletion.
116 changes: 116 additions & 0 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,119 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules []

return session.Commit()
}

func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) {
// UpdateFilteredPolicies deletes old rules and adds new rules.
line := &CasbinRule{}

line.PType = ptype
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
line.V0 = fieldValues[0-fieldIndex]
}
if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
line.V1 = fieldValues[1-fieldIndex]
}
if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
line.V2 = fieldValues[2-fieldIndex]
}
if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
line.V3 = fieldValues[3-fieldIndex]
}
if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
line.V4 = fieldValues[4-fieldIndex]
}
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line.V5 = fieldValues[5-fieldIndex]
}

newP := make([]CasbinRule, 0, len(newPolicies))
oldP := make([]CasbinRule, 0)
for _, newRule := range newPolicies {
newP = append(newP, *a.genPolicyLine(ptype, newRule))
}
tx := a.engine.NewSession()
defer tx.Close()

if err := tx.Begin(); err != nil {
return nil, err
}

for i := range newP {
str, args := line.queryString()
if err := tx.Where(str, args...).Find(&oldP); err != nil {
return nil, tx.Rollback()
}
if _, err := tx.Where(str.(string), args...).Delete(CasbinRule{}); err != nil {
return nil, tx.Rollback()
}
if _, err := tx.Insert(&newP[i]); err != nil {
return nil, tx.Rollback()
}
}

// return deleted rulues
oldPolicies := make([][]string, 0)
for _, v := range oldP {
oldPolicy := v.toStringPolicy()
oldPolicies = append(oldPolicies, oldPolicy)
}
return oldPolicies, tx.Commit()
}

func (c *CasbinRule) toStringPolicy() []string {
policy := make([]string, 0)
if c.PType != "" {
policy = append(policy, c.PType)
}
if c.V0 != "" {
policy = append(policy, c.V0)
}
if c.V1 != "" {
policy = append(policy, c.V1)
}
if c.V2 != "" {
policy = append(policy, c.V2)
}
if c.V3 != "" {
policy = append(policy, c.V3)
}
if c.V4 != "" {
policy = append(policy, c.V4)
}
if c.V5 != "" {
policy = append(policy, c.V5)
}
return policy
}

func (c *CasbinRule) queryString() (interface{}, []interface{}) {
queryArgs := []interface{}{c.PType}

queryStr := "p_type = ?"
if c.V0 != "" {
queryStr += " and v0 = ?"
queryArgs = append(queryArgs, c.V0)
}
if c.V1 != "" {
queryStr += " and v1 = ?"
queryArgs = append(queryArgs, c.V1)
}
if c.V2 != "" {
queryStr += " and v2 = ?"
queryArgs = append(queryArgs, c.V2)
}
if c.V3 != "" {
queryStr += " and v3 = ?"
queryArgs = append(queryArgs, c.V3)
}
if c.V4 != "" {
queryStr += " and v4 = ?"
queryArgs = append(queryArgs, c.V4)
}
if c.V5 != "" {
queryStr += " and v5 = ?"
queryArgs = append(queryArgs, c.V5)
}

return queryStr, queryArgs
}
69 changes: 69 additions & 0 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package xormadapter

import (
"github.com/casbin/casbin/v2/util"
"log"
"strings"
"testing"
Expand Down Expand Up @@ -300,6 +301,71 @@ func testUpdatePolicies(t *testing.T, driverName string, dataSourceName string,
testGetPolicy(t, e, [][]string{{"bob", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}

func testUpdateFilteredPolicies(t *testing.T, driverName string, dataSourceName string, dbSpecified ...bool) {
// Initialize some policy in DB.
initPolicy(t, driverName, dataSourceName, dbSpecified...)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.

// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := NewAdapter(driverName, dataSourceName, dbSpecified...)
e, _ := casbin.NewEnforcer("examples/rbac_model.conf")

// Now set the adapter
e.SetAdapter(a)

e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read")
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write")
e.LoadPolicy()
testGetPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
}

func testGetPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) {
myRes := e.GetPolicy()
log.Print("Policy: ", myRes)

if !arrayEqualsWithoutOrder(myRes, res) {
t.Error("Policy: ", myRes, ", supposed to be ", res)
}
}

func arrayEqualsWithoutOrder(a [][]string, b [][]string) bool {
if len(a) != len(b) {
return false
}

mapA := make(map[int]string)
mapB := make(map[int]string)
order := make(map[int]struct{})
l := len(a)

for i := 0; i < l; i++ {
mapA[i] = util.ArrayToString(a[i])
mapB[i] = util.ArrayToString(b[i])
}

for i := 0; i < l; i++ {
for j := 0; j < l; j++ {
if _, ok := order[j]; ok {
if j == l-1 {
return false
} else {
continue
}
}
if mapA[i] == mapB[j] {
order[j] = struct{}{}
break
} else if j == l-1 {
return false
}
}
}
return true
}

func TestAdapters(t *testing.T) {
// You can also use the following way to use an existing DB "abc":
// testSaveLoad(t, "mysql", "root:@tcp(127.0.0.1:3306)/abc", true)
Expand All @@ -320,4 +386,7 @@ func TestAdapters(t *testing.T) {

testUpdatePolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
testUpdatePolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")

testUpdateFilteredPolicies(t, "mysql", "root:@tcp(127.0.0.1:3306)/")
testUpdateFilteredPolicies(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/casbin/xorm-adapter/v2
go 1.12

require (
github.com/casbin/casbin/v2 v2.25.5
github.com/casbin/casbin/v2 v2.28.3
github.com/go-sql-driver/mysql v1.5.0
github.com/lib/pq v1.8.0
xorm.io/xorm v1.0.3
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBK
github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y=
github.com/casbin/casbin/v2 v2.25.5 h1:TPKaoGu1gqAVJtQ2MaTfdHn2zgnCaulLylbNXbY6TYo=
github.com/casbin/casbin/v2 v2.25.5/go.mod h1:wUgota0cQbTXE6Vd+KWpg41726jFRi7upxio0sR+Xd0=
github.com/casbin/casbin/v2 v2.28.3 h1:iHxxEsNHwSciRoYh+54etVUA8AXKS9OKzNy6/39UWvY=
github.com/casbin/casbin/v2 v2.28.3/go.mod h1:vByNa/Fchek0KZUgG5wEsl7iFsiviAYKRtgrQfcJqHg=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
Expand Down

0 comments on commit 5678dab

Please sign in to comment.