Skip to content

Commit

Permalink
*: load privilege in a goroutine when server initialize (pingcap#2489)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaiamao authored and hanfei1991 committed Jan 24, 2017
1 parent c6cc2f7 commit 1da75c5
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 1 deletion.
35 changes: 35 additions & 0 deletions domain/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ import (

"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/ddl"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/perfschema"
"github.com/pingcap/tidb/privilege/privileges"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/terror"
)
Expand All @@ -34,6 +36,7 @@ import (
type Domain struct {
store kv.Storage
infoHandle *infoschema.Handle
privHandle *privileges.Handle
ddl ddl.DDL
m sync.Mutex
SchemaValidator SchemaValidator
Expand Down Expand Up @@ -363,6 +366,38 @@ func NewDomain(store kv.Storage, lease time.Duration) (d *Domain, err error) {
return d, nil
}

// LoadPrivilegeLoop create a goroutine loads privilege tables in a loop, it
// should be called only once in BootstrapSession.
func (do *Domain) LoadPrivilegeLoop(ctx context.Context) error {
do.privHandle = &privileges.Handle{}
err := do.privHandle.Update(ctx)
if err != nil {
return errors.Trace(err)
}

go func(do *Domain) {
ticker := time.NewTicker(5 * time.Minute)
for {
select {
case <-ticker.C:
err := do.privHandle.Update(ctx)
if err != nil {
log.Error(errors.ErrorStack(err))
}
case <-do.exit:
return
}
}
}(do)

return nil
}

// Privilege returns the MySQLPrivilege.
func (do *Domain) Privilege() *privileges.MySQLPrivilege {
return do.privHandle.Get()
}

// Domain error codes.
const (
codeInfoSchemaExpired terror.ErrCode = 1
Expand Down
142 changes: 142 additions & 0 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package privileges

import (
"strings"
"sync/atomic"
"time"

"github.com/juju/errors"
Expand Down Expand Up @@ -261,6 +262,9 @@ func (p *MySQLPrivilege) decodeColumnsPrivTableRow(row *ast.Row, fs []*ast.Resul

func decodeSetToPrivilege(s types.Set) (mysql.PrivilegeType, error) {
var ret mysql.PrivilegeType
if s.Name == "" {
return ret, nil
}
for _, str := range strings.Split(s.Name, ",") {
priv, ok := mysql.SetStr2Priv[str]
if !ok {
Expand All @@ -270,3 +274,141 @@ func decodeSetToPrivilege(s types.Set) (mysql.PrivilegeType, error) {
}
return ret, nil
}

func (record *userRecord) match(user, host string) bool {
return record.User == user && patternMatch(record.Host, host)
}

func (record *dbRecord) match(user, host, db string) bool {
return record.User == user && record.DB == db &&
patternMatch(record.Host, host)
}

func (record *tablesPrivRecord) match(user, host, db, table string) bool {
return record.User == user && record.DB == db &&
record.TableName == table && patternMatch(record.Host, host)
}

func (record *columnsPrivRecord) match(user, host, db, table, col string) bool {
return record.User == user && record.DB == db &&
record.TableName == table && record.ColumnName == col &&
patternMatch(record.Host, host)
}

// patternMatch matches "%" the same way as ".*" in regular expression, for example,
// "10.0.%" would match "10.0.1" "10.0.1.118" ...
// TODO: patternMatch's behaviour is actual LIKE expression, so we should reuse the code.
func patternMatch(pattern, str string) bool {
for i := 0; i < len(pattern); i++ {
p := pattern[i]
if p == '%' {
return true
}
if i >= len(str) || p != str[i] {
return false
}
}
return len(pattern) == len(str)
}

// ConnectionVerification verifies the connection have access to TiDB server.
func (p *MySQLPrivilege) ConnectionVerification(user, host string) bool {
for _, record := range p.User {
if record.match(user, host) {
return true
}
}
return false
}

func (p *MySQLPrivilege) matchUser(user, host string) *userRecord {
for i := 0; i < len(p.User); i++ {
record := &p.User[i]
if record.match(user, host) {
return record
}
}
return nil
}

func (p *MySQLPrivilege) matchDB(user, host, db string) *dbRecord {
for i := 0; i < len(p.DB); i++ {
record := &p.DB[i]
if record.match(user, host, db) {
return record
}
}
return nil
}

func (p *MySQLPrivilege) matchTables(user, host, db, table string) *tablesPrivRecord {
for i := 0; i < len(p.TablesPriv); i++ {
record := &p.TablesPriv[i]
if record.match(user, host, db, table) {
return record
}
}
return nil
}

func (p *MySQLPrivilege) matchColumns(user, host, db, table, column string) *columnsPrivRecord {
for i := 0; i < len(p.ColumnsPriv); i++ {
record := &p.ColumnsPriv[i]
if record.match(user, host, db, table, column) {
return record
}
}
return nil
}

// RequestVerification checks whether the user have sufficient privileges to do the operation.
func (p *MySQLPrivilege) RequestVerification(user, host, db, table, column string, priv mysql.PrivilegeType) bool {
record1 := p.matchUser(user, host)
if record1 != nil && record1.Privileges&priv > 0 {
return true
}

record2 := p.matchDB(user, host, db)
if record2 != nil && record2.Privileges&priv > 0 {
return true
}

record3 := p.matchTables(user, host, db, table)
if record3 != nil {
if record3.TablePriv&priv > 0 {
return true
}
if column != "" && record3.ColumnPriv&priv > 0 {
return true
}
}

record4 := p.matchColumns(user, host, db, table, column)
if record4 != nil && record4.ColumnPriv&priv > 0 {
return true
}

return false
}

// Handle wraps MySQLPrivilege providing thread safe access.
type Handle struct {
priv atomic.Value
}

// Get the MySQLPrivilege for read.
func (h *Handle) Get() *MySQLPrivilege {
return h.priv.Load().(*MySQLPrivilege)
}

// Update the MySQLPrivilege.
func (h *Handle) Update(ctx context.Context) error {
var priv MySQLPrivilege
err := priv.LoadAll(ctx)
if err != nil {
return errors.Trace(err)
}

h.priv.Store(&priv)
return nil
}
15 changes: 14 additions & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,13 @@ func (s *session) Auth(user string, auth []byte, salt []byte) bool {
// Get user password.
name := strs[0]
host := strs[1]

// TODO: Use the new privilege implementation.
domain := sessionctx.GetDomain(s)
checker := domain.Privilege()
succ := checker.ConnectionVerification(name, host)
log.Debug("RequestVerification result:", succ)

pwd, err := s.getPassword(name, host)
if err != nil {
if terror.ExecResultIsEmpty.Equal(err) {
Expand All @@ -685,6 +692,7 @@ func (s *session) Auth(user string, auth []byte, salt []byte) bool {
return false
}
s.sessionVars.User = user

return true
}

Expand Down Expand Up @@ -723,7 +731,12 @@ func BootstrapSession(store kv.Storage) error {
runInBootstrapSession(store, upgrade)
}

_, err := domap.Get(store)
se, err := createSession(store)
if err != nil {
return errors.Trace(err)
}
err = sessionctx.GetDomain(se).LoadPrivilegeLoop(se)

return errors.Trace(err)
}

Expand Down
1 change: 1 addition & 0 deletions tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (dm *domainMap) Get(store kv.Storage) (d *domain.Domain, err error) {
return nil, errors.Trace(err)
}
dm.domains[key] = d

return
}

Expand Down

0 comments on commit 1da75c5

Please sign in to comment.