Skip to content

Commit

Permalink
*: Add privilege checker and unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
shenli committed Oct 20, 2015
1 parent fd9bcc6 commit 138497a
Show file tree
Hide file tree
Showing 4 changed files with 382 additions and 42 deletions.
30 changes: 30 additions & 0 deletions mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ var Priv2UserCol = map[PrivilegeType]string{
IndexPriv: "Index_priv",
}

// Col2PrivType is the privilege tables column name to privilege type.
var Col2PrivType = map[string]PrivilegeType{
"Create_priv": CreatePriv,
"Select_priv": SelectPriv,
"Insert_priv": InsertPriv,
"Update_priv": UpdatePriv,
"Delete_priv": DeletePriv,
"Show_db_priv": ShowDBPriv,
"Create_user_priv": CreateUserPriv,
"Drop_priv": DropPriv,
"Grant_priv": GrantPriv,
"Alter_priv": AlterPriv,
"Execute_priv": ExecutePriv,
"Index_priv": IndexPriv,
}

// AllGlobalPrivs is all the privileges in global scope.
var AllGlobalPrivs = []PrivilegeType{SelectPriv, InsertPriv, UpdatePriv, DeletePriv, CreatePriv, DropPriv, GrantPriv, AlterPriv, ShowDBPriv, ExecutePriv, IndexPriv, CreateUserPriv}

Expand Down Expand Up @@ -212,6 +228,20 @@ var Priv2SetStr = map[PrivilegeType]string{
IndexPriv: "Index",
}

// SetStr2Priv is the map for privilege set string to privilege type.
var SetStr2Priv = map[string]PrivilegeType{
"Create": CreatePriv,
"Select": SelectPriv,
"Insert": InsertPriv,
"Update": UpdatePriv,
"Delete": DeletePriv,
"Drop": DropPriv,
"Grant": GrantPriv,
"Alter": AlterPriv,
"Execute": ExecutePriv,
"Index": IndexPriv,
}

// AllDBPrivs is all the privileges in database scope.
var AllDBPrivs = []PrivilegeType{SelectPriv, InsertPriv, UpdatePriv, DeletePriv, CreatePriv, DropPriv, GrantPriv, AlterPriv, ExecutePriv, IndexPriv}

Expand Down
18 changes: 10 additions & 8 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,22 @@ func (k keyType) String() string {
return "privilege-key"
}

type PrivilegeChecker interface {
SetUser(user string)
CheckPrivilege(ctx context.Context, db *model.DBInfo, privilege mysql.PrivilegeType) (bool, error)
CheckPrivilege(ctx context.Context, db *model.DBInfo, tbl *model.TableInfo, privilege mysql.PrivilegeType) (bool, error)
// Checker is the interface for check privileges.
type Checker interface {
CheckDBPrivilege(ctx context.Context, db *model.DBInfo, privilege mysql.PrivilegeType) (bool, error)
CheckTablePrivilege(ctx context.Context, db *model.DBInfo, tbl *model.TableInfo, privilege mysql.PrivilegeType) (bool, error)
}

const key keyType = 0

// BindPrivilegeChecker binds domain to context.
func BindPrivilegeChecker(ctx context.Context, pc PrivilegeChecker) {
ctx.SetValue(keyType, pc)
func BindPrivilegeChecker(ctx context.Context, pc Checker) {
ctx.SetValue(key, pc)
}

// GetPrivilegeChecker gets domain from context.
func GetPrivilegeChecker(ctx context.Context) Privilege {
v, ok := ctx.Value(keyType).(PrivilegeChecker)
func GetPrivilegeChecker(ctx context.Context) Checker {
v, ok := ctx.Value(key).(Checker)
if !ok {
return nil
}
Expand Down
204 changes: 170 additions & 34 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,99 +17,235 @@ import (
"fmt"
"strings"

"github.com/pingcap/juju/errors"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/coldef"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/sqlexec"
)

type PrivilegeInfo struct {
mysql.PrivilegeType
Wildcard bool
var _ privilege.Checker = (*PrivilegeCheck)(nil)

type userPrivilege struct {
GlobalPrivs map[mysql.PrivilegeType]bool
DBPrivs map[string]map[mysql.PrivilegeType]bool
TablePrivs map[string]map[string]map[mysql.PrivilegeType]bool
}

// PrivilegeCheck implements privilege.Checker interface.
// This is used to check privilege for the current user.
type PrivilegeCheck struct {
username string
host string
privs map[int]map[mysql.PrivilegeType]bool
privs *userPrivilege
}

func (p *PrivilegeCheck) SetUser(user string) {
strs := strings.Split(user, "@")
p.username, p.host = strs[0], strs[1]
}

func (p *PrivilegeCheck) CheckPrivilege(ctx context.Context, db *model.DBInfo, privilege mysql.PrivilegeType) (bool, error) {
if privs == nil {
// CheckDBPrivilege implements PrivilegeChecker.CheckDBPrivilege interface.
func (p *PrivilegeCheck) CheckDBPrivilege(ctx context.Context, db *model.DBInfo, privilege mysql.PrivilegeType) (bool, error) {
if p.privs == nil {
// Lazy load
err := loadPrivileges()
err := p.loadPrivileges(ctx)
if err != nil {
return false, errors.Trace(err)
}
}
// Check global scope privileges
// Check db scope privileges
// Check global scope privileges.
_, ok := p.privs.GlobalPrivs[privilege]
if ok {
return true, nil
}
// Check db scope privileges.
dbp, ok := p.privs.DBPrivs[db.Name.O]
if !ok {
return false, nil
}
_, ok = dbp[privilege]
return ok, nil
}
func (p *PrivilegeCheck) CheckPrivilege(ctx context.Context, db *model.DBInfo, tbl *model.TableInfo, privilege mysql.PrivilegeType) (bool, error) {
if privs == nil {

// CheckTablePrivilege implements PrivilegeChecker.CheckTablePrivilege interface.
func (p *PrivilegeCheck) CheckTablePrivilege(ctx context.Context, db *model.DBInfo, tbl *model.TableInfo, privilege mysql.PrivilegeType) (bool, error) {
if p.privs == nil {
// Lazy load
err := loadPrivileges()
err := p.loadPrivileges(ctx)
if err != nil {
return false, errors.Trace(err)
}
}
// Check global scope privileges
// Check db scope privileges
// Check table scope privileges
// Check global scope privileges.
_, ok := p.privs.GlobalPrivs[privilege]
if ok {
return true, nil
}
// Check db scope privileges.
dbp, ok := p.privs.DBPrivs[db.Name.O]
if ok {
_, ok = dbp[privilege]
if ok {
return true, nil
}
}
// Check table scope privileges.
tblp, ok := p.privs.TablePrivs[db.Name.O]
if !ok {
return false, nil
}
_, ok = tblp[tbl.Name.O]
return ok, nil
}

func (p *PrivilegeCheck) loadGlobalPrivileges() error {
// TODO: or Host="%"
sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND Host="%s";`, mysql.SystemDB, mysql.UserTable, p.username, p.host)
func (p *PrivilegeCheck) loadGlobalPrivileges(ctx context.Context) error {
sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, mysql.SystemDB, mysql.UserTable, p.username, p.host)
rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
if err != nil {
return errors.Trace(err)
}
defer rs.Close()
ps := make(map[mysql.PrivilegeType]bool)
fs, err := rs.Fields()
if err != nil {
return errors.Trace(err)
}
for {
row, err := rs.Next()
if err != nil {
return errors.Trace(err)
}
fs := rs.GetFields()
if row == nil {
break
}
for i := 3; i < len(fs); i++ {
d := row.Data[i]
ed, ok := d.(mysql.Enum)
if !ok {
return fmt.Errorf("Privilege should be mysql.Enum: %v(%T)", d, d)
}
if ed.String() != "Y" {
continue
}
f := fs[i]
// check each priv field
p, ok := mysql.Col2PrivType[f.Name]
if !ok {
panic("This should be never happened!")
}
ps[p] = true
}
}
p.privs[coldef.GrantLevelGlobal] = ps
p.privs.GlobalPrivs = ps
return nil
}

func (p *PrivilegeCheck) loadDBScopePrivileges() error {
func (p *PrivilegeCheck) loadDBScopePrivileges(ctx context.Context) error {
sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, mysql.SystemDB, mysql.DBTable, p.username, p.host)
rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
if err != nil {
return errors.Trace(err)
}
defer rs.Close()
ps := make(map[string]map[mysql.PrivilegeType]bool)
fs, err := rs.Fields()
if err != nil {
return errors.Trace(err)
}
for {
row, err := rs.Next()
if err != nil {
return errors.Trace(err)
}
if row == nil {
break
}
db, ok := row.Data[1].(string)
if !ok {
panic("This should be never happened!")
}
ps[db] = make(map[mysql.PrivilegeType]bool)
for i := 3; i < len(fs); i++ {
d := row.Data[i]
ed, ok := d.(mysql.Enum)
if !ok {
return fmt.Errorf("Privilege should be mysql.Enum: %v(%T)", d, d)
}
if ed.String() != "Y" {
continue
}
f := fs[i]
// check each priv field
p, ok := mysql.Col2PrivType[f.Name]
if !ok {
panic("This should be never happened!")
}
ps[db][p] = true
}
}
p.privs.DBPrivs = ps
return nil
}

func (p *PrivilegeCheck) loadTableScopePrivileges() error {
func (p *PrivilegeCheck) loadTableScopePrivileges(ctx context.Context) error {
sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, mysql.SystemDB, mysql.TablePrivTable, p.username, p.host)
rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
if err != nil {
return errors.Trace(err)
}
defer rs.Close()
ps := make(map[string]map[string]map[mysql.PrivilegeType]bool)
for {
row, err := rs.Next()
if err != nil {
return errors.Trace(err)
}
if row == nil {
break
}
db, ok := row.Data[1].(string)
if !ok {
panic("This should be never happened!")
}
tbl, ok := row.Data[3].(string)
if !ok {
panic("This should be never happened!")
}
_, ok = ps[db]
if !ok {
ps[db] = make(map[string]map[mysql.PrivilegeType]bool)
}
ps[db][tbl] = make(map[mysql.PrivilegeType]bool)
tblPrivs, ok := row.Data[6].(mysql.Set)
if !ok {
panic("This should be never happened!")
}
pvs := strings.Split(tblPrivs.Name, ",")
for _, d := range pvs {
p, ok := mysql.SetStr2Priv[d]
if !ok {
panic("This should be never happened!")
}
ps[db][tbl][p] = true
}
}
p.privs.TablePrivs = ps
return nil
}

func (p *PrivilegeCheck) loadPrivileges() error {
p.privs = make(map[int]map[mysql.PrivilegeType]bool)
func (p *PrivilegeCheck) loadPrivileges(ctx context.Context) error {
user := variable.GetSessionVars(ctx).User
strs := strings.Split(user, "@")
p.username, p.host = strs[0], strs[1]
p.privs = &userPrivilege{}
// Load privileges from mysql.User/DB/Table_privs/Column_privs table
err := p.loadGlobalPrivileges()
err := p.loadGlobalPrivileges(ctx)
if err != nil {
return errors.Trace(err)
}
err = p.loadDBScopePrivileges()
err = p.loadDBScopePrivileges(ctx)
if err != nil {
return errors.Trace(err)
}
err = p.loadTableScopePrivileges()
err = p.loadTableScopePrivileges(ctx)
if err != nil {
return errors.Trace(err)
}
Expand Down
Loading

0 comments on commit 138497a

Please sign in to comment.