Skip to content

Commit

Permalink
privileges: fix create temporary tables privilege (pingcap#29279)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Nov 1, 2021
1 parent db60f12 commit 8d9647d
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 7 deletions.
2 changes: 1 addition & 1 deletion parser/mysql/privs.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ const (
CreateRolePriv
// DropRolePriv is the privilege to drop a role.
DropRolePriv

// CreateTMPTablePriv is the privilege to create a temporary table.
CreateTMPTablePriv
LockTablesPriv
CreateRoutinePriv
Expand Down
3 changes: 2 additions & 1 deletion planner/core/common_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ func (e *Execute) handleExecuteBuilderOption(sctx sessionctx.Context,
func (e *Execute) checkPreparedPriv(ctx context.Context, sctx sessionctx.Context,
preparedObj *CachedPrepareStmt, is infoschema.InfoSchema) error {
if pm := privilege.GetPrivilegeManager(sctx); pm != nil {
if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, preparedObj.VisitInfos); err != nil {
visitInfo := VisitInfo4PrivCheck(is, preparedObj.PreparedAst.Stmt, preparedObj.VisitInfos)
if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, visitInfo); err != nil {
return err
}
}
Expand Down
71 changes: 71 additions & 0 deletions planner/core/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/lock"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/planner/property"
"github.com/pingcap/tidb/privilege"
Expand Down Expand Up @@ -120,6 +121,76 @@ func CheckPrivilege(activeRoles []*auth.RoleIdentity, pm privilege.Manager, vs [
return nil
}

// VisitInfo4PrivCheck generates privilege check infos because privilege check of local temporary tables is different
// with normal tables. `CREATE` statement needs `CREATE TEMPORARY TABLE` privilege from the database, and subsequent
// statements do not need any privileges.
func VisitInfo4PrivCheck(is infoschema.InfoSchema, node ast.Node, vs []visitInfo) (privVisitInfo []visitInfo) {
if node == nil {
return vs
}

switch stmt := node.(type) {
case *ast.CreateTableStmt:
privVisitInfo = make([]visitInfo, 0, len(vs))
for _, v := range vs {
if v.privilege == mysql.CreatePriv {
if stmt.TemporaryKeyword == ast.TemporaryLocal {
// `CREATE TEMPORARY TABLE` privilege is required from the database, not the table.
newVisitInfo := v
newVisitInfo.privilege = mysql.CreateTMPTablePriv
newVisitInfo.table = ""
privVisitInfo = append(privVisitInfo, newVisitInfo)
} else {
// If both the normal table and temporary table already exist, we need to check the privilege.
privVisitInfo = append(privVisitInfo, v)
}
} else {
// `CREATE TABLE LIKE tmp` or `CREATE TABLE FROM SELECT tmp` in the future.
if needCheckTmpTablePriv(is, v) {
privVisitInfo = append(privVisitInfo, v)
}
}
}
case *ast.DropTableStmt:
// Dropping a local temporary table doesn't need any privileges.
if stmt.IsView {
privVisitInfo = vs
} else {
privVisitInfo = make([]visitInfo, 0, len(vs))
if stmt.TemporaryKeyword != ast.TemporaryLocal {
for _, v := range vs {
if needCheckTmpTablePriv(is, v) {
privVisitInfo = append(privVisitInfo, v)
}
}
}
}
case *ast.GrantStmt, *ast.DropSequenceStmt, *ast.DropPlacementPolicyStmt:
// Some statements ignore local temporary tables, so they should check the privileges on normal tables.
privVisitInfo = vs
default:
privVisitInfo = make([]visitInfo, 0, len(vs))
for _, v := range vs {
if needCheckTmpTablePriv(is, v) {
privVisitInfo = append(privVisitInfo, v)
}
}
}
return
}

func needCheckTmpTablePriv(is infoschema.InfoSchema, v visitInfo) bool {
if v.db != "" && v.table != "" {
// Other statements on local temporary tables except `CREATE` do not check any privileges.
tb, err := is.TableByName(model.NewCIStr(v.db), model.NewCIStr(v.table))
// If the table doesn't exist, we do not report errors to avoid leaking the existence of the table.
if err == nil && tb.Meta().TempTableType == model.TempTableLocal {
return false
}
}
return true
}

// CheckTableLock checks the table lock.
func CheckTableLock(ctx sessionctx.Context, is infoschema.InfoSchema, vs []visitInfo) error {
if !config.TableLockEnabled() {
Expand Down
13 changes: 10 additions & 3 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3868,8 +3868,15 @@ func (b *PlanBuilder) buildDDL(ctx context.Context, node ast.DDLNode) (Plan, err
}
}
if b.ctx.GetSessionVars().User != nil {
authErr = ErrTableaccessDenied.GenWithStackByArgs("CREATE", b.ctx.GetSessionVars().User.AuthUsername,
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L)
// This is tricky here: we always need the visitInfo because it's not only used in privilege checks, and we
// must pass the table name. However, the privilege check is towards the database. We'll deal with it later.
if v.TemporaryKeyword == ast.TemporaryLocal {
authErr = ErrDBaccessDenied.GenWithStackByArgs(b.ctx.GetSessionVars().User.AuthUsername,
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Schema.L)
} else {
authErr = ErrTableaccessDenied.GenWithStackByArgs("CREATE", b.ctx.GetSessionVars().User.AuthUsername,
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L)
}
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreatePriv, v.Table.Schema.L,
v.Table.Name.L, "", authErr)
Expand Down Expand Up @@ -3936,7 +3943,7 @@ func (b *PlanBuilder) buildDDL(ctx context.Context, node ast.DDLNode) (Plan, err
"", "", authErr)
case *ast.DropIndexStmt:
if b.ctx.GetSessionVars().User != nil {
authErr = ErrTableaccessDenied.GenWithStackByArgs("INDEx", b.ctx.GetSessionVars().User.AuthUsername,
authErr = ErrTableaccessDenied.GenWithStackByArgs("INDEX", b.ctx.GetSessionVars().User.AuthUsername,
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L)
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.IndexPriv, v.Table.Schema.L,
Expand Down
3 changes: 2 additions & 1 deletion planner/optimize.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in
// we need the table information to check privilege, which is collected
// into the visitInfo in the logical plan builder.
if pm := privilege.GetPrivilegeManager(sctx); pm != nil {
if err := plannercore.CheckPrivilege(activeRoles, pm, builder.GetVisitInfo()); err != nil {
visitInfo := plannercore.VisitInfo4PrivCheck(is, node, builder.GetVisitInfo())
if err := plannercore.CheckPrivilege(activeRoles, pm, visitInfo); err != nil {
return nil, nil, 0, err
}
}
Expand Down
161 changes: 160 additions & 1 deletion privilege/privileges/privileges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"strings"
"testing"

"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/auth"
Expand Down Expand Up @@ -2666,7 +2667,7 @@ func TestGrantCreateTmpTables(t *testing.T) {
tk.MustExec("CREATE TABLE create_tmp_table_table (a int)")
tk.MustExec("GRANT CREATE TEMPORARY TABLES on create_tmp_table_db.* to u1")
tk.MustExec("GRANT CREATE TEMPORARY TABLES on *.* to u1")
// Must set a session user to avoid null pointer dereferencing
// Must set a session user to avoid null pointer dereference
tk.Session().Auth(&auth.UserIdentity{
Username: "root",
Hostname: "localhost",
Expand All @@ -2678,6 +2679,164 @@ func TestGrantCreateTmpTables(t *testing.T) {
tk.MustExec("DROP DATABASE create_tmp_table_db")
}

func TestCreateTmpTablesPriv(t *testing.T) {
t.Parallel()
store, clean := newStore(t)
defer clean()

createStmt := "CREATE TEMPORARY TABLE test.tmp(id int)"
dropStmt := "DROP TEMPORARY TABLE IF EXISTS test.tmp"

tk := testkit.NewTestKit(t, store)
tk.MustExec(dropStmt)
tk.MustExec("CREATE TABLE test.t(id int primary key)")
tk.MustExec("CREATE SEQUENCE test.tmp")
tk.MustExec("CREATE USER vcreate, vcreate_tmp, vcreate_tmp_all")
tk.MustExec("GRANT CREATE, USAGE ON test.* TO vcreate")
tk.MustExec("GRANT CREATE TEMPORARY TABLES, USAGE ON test.* TO vcreate_tmp")
tk.MustExec("GRANT CREATE TEMPORARY TABLES, USAGE ON *.* TO vcreate_tmp_all")

tk.Session().Auth(&auth.UserIdentity{Username: "vcreate", Hostname: "localhost"}, nil, nil)
err := tk.ExecToErr(createStmt)
require.EqualError(t, err, "[planner:1044]Access denied for user 'vcreate'@'%' to database 'test'")
tk.Session().Auth(&auth.UserIdentity{Username: "vcreate_tmp", Hostname: "localhost"}, nil, nil)
tk.MustExec(createStmt)
tk.MustExec(dropStmt)
tk.Session().Auth(&auth.UserIdentity{Username: "vcreate_tmp_all", Hostname: "localhost"}, nil, nil)
// TODO: issue #29280 to be fixed.
//err = tk.ExecToErr(createStmt)
//require.EqualError(t, err, "[planner:1044]Access denied for user 'vcreate_tmp_all'@'%' to database 'test'")

tests := []struct {
sql string
errcode int
}{
{
sql: "create temporary table tmp(id int primary key)",
},
{
sql: "insert into tmp value(1)",
},
{
sql: "insert into tmp value(1) on duplicate key update id=1",
},
{
sql: "replace tmp values(1)",
},
{
sql: "insert into tmp select * from t",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "update tmp set id=1 where id=1",
},
{
sql: "update tmp t1, t t2 set t1.id=t2.id where t1.id=t2.id",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "delete from tmp where id=1",
},
{
sql: "delete t1 from tmp t1 join t t2 where t1.id=t2.id",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "select * from tmp where id=1",
},
{
sql: "select * from tmp where id in (1,2)",
},
{
sql: "select * from tmp",
},
{
sql: "select * from tmp join t where tmp.id=t.id",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "(select * from tmp) union (select * from t)",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "create temporary table tmp1 like t",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "create table tmp(id int primary key)",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "create table t(id int primary key)",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "analyze table tmp",
},
{
sql: "analyze table tmp, t",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "show create table tmp",
},
// TODO: issue #29281 to be fixed.
//{
// sql: "show create table t",
// errcode: mysql.ErrTableaccessDenied,
//},
{
sql: "drop sequence tmp",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "alter table tmp add column c1 char(10)",
errcode: errno.ErrUnsupportedDDLOperation,
},
{
sql: "truncate table tmp",
},
{
sql: "drop temporary table t",
errcode: mysql.ErrBadTable,
},
{
sql: "drop table t",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "drop table t, tmp",
errcode: mysql.ErrTableaccessDenied,
},
{
sql: "drop temporary table tmp",
},
}

tk.Session().Auth(&auth.UserIdentity{Username: "vcreate_tmp", Hostname: "localhost"}, nil, nil)
tk.MustExec("use test")
tk.MustExec(dropStmt)
for _, test := range tests {
if test.errcode == 0 {
tk.MustExec(test.sql)
} else {
tk.MustGetErrCode(test.sql, test.errcode)
}
}

// TODO: issue #29282 to be fixed.
//for i, test := range tests {
// preparedStmt := fmt.Sprintf("prepare stmt%d from '%s'", i, test.sql)
// executeStmt := fmt.Sprintf("execute stmt%d", i)
// tk.MustExec(preparedStmt)
// if test.errcode == 0 {
// tk.MustExec(executeStmt)
// } else {
// tk.MustGetErrCode(executeStmt, test.errcode)
// }
//}
}

func TestRevokeSecondSyntax(t *testing.T) {
t.Parallel()
store, clean := newStore(t)
Expand Down

0 comments on commit 8d9647d

Please sign in to comment.