From 7c7473969f99aa21143bd4a674a0a69de480a12c Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 29 Oct 2015 20:29:02 +0800 Subject: [PATCH] tidb: switch to use ast parser. --- ast/cloner.go | 7 +- ast/dml.go | 8 ++ ast/expressions.go | 5 +- ast/parser/parser.y | 101 +++++++++--------- ast/parser/scanner.l | 13 ++- optimizer/convert_expr.go | 4 +- optimizer/convert_stmt.go | 208 ++++++++++++++++++++++++++++---------- optimizer/optimizer.go | 80 +++++++++------ tidb.go | 23 +++-- tidb_test.go | 1 + 10 files changed, 297 insertions(+), 153 deletions(-) diff --git a/ast/cloner.go b/ast/cloner.go index f08ef0947671c..bf7b3c961d274 100644 --- a/ast/cloner.go +++ b/ast/cloner.go @@ -13,6 +13,8 @@ package ast +import "fmt" + // Cloner is a ast visitor that clones a node. type Cloner struct { } @@ -57,6 +59,9 @@ func cloneStruct(in Node) (out Node) { case *ColumnName: nv := *v out = &nv + case *ColumnNameExpr: + nv := *v + out = &nv case *DefaultExpr: nv := *v out = &nv @@ -162,7 +167,7 @@ func cloneStruct(in Node) (out Node) { default: // We currently only handle expression and select statement. // Will add more when we need to. - panic("unknown ast Node type") + panic("unknown ast Node type " + fmt.Sprintf("%T", v)) } return } diff --git a/ast/dml.go b/ast/dml.go index 0225af046255e..b19e60bfc659e 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -488,6 +488,7 @@ type UnionStmt struct { Distinct bool Selects []*SelectStmt + OrderBy *OrderByClause Limit *Limit } @@ -505,6 +506,13 @@ func (nod *UnionStmt) Accept(v Visitor) (Node, bool) { } nod.Selects[i] = node.(*SelectStmt) } + if nod.OrderBy != nil { + node, ok := nod.OrderBy.Accept(v) + if !ok { + return nod, false + } + nod.OrderBy = node.(*OrderByClause) + } if nod.Limit != nil { node, ok := nod.Limit.Accept(v) if !ok { diff --git a/ast/expressions.go b/ast/expressions.go index 43fc1764989c1..e32b14752a821 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -14,6 +14,7 @@ package ast import ( + "fmt" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" @@ -58,6 +59,8 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Data = value // TODO: make it more precise. switch value.(type) { + case nil: + ve.Type = types.NewFieldType(mysql.TypeNull) case bool, int64: ve.Type = types.NewFieldType(mysql.TypeLonglong) case uint64: @@ -80,7 +83,7 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Type.Charset = "binary" ve.Type.Collate = "binary" default: - panic("illegal literal value type") + panic("illegal literal value type:" + fmt.Sprintf("T", value)) } return ve } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 6f096b86873d2..643195111a6ad 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -773,13 +773,9 @@ ColumnNameListOpt: { $$ = []*ast.ColumnName{} } -| '(' ')' - { - $$ = []*ast.ColumnName{} - } -| '(' ColumnNameList ')' +| ColumnNameList { - $$ = $2.([]*ast.ColumnName) + $$ = $1.([]*ast.ColumnName) } CommitStmt: @@ -1144,12 +1140,13 @@ DeleteFromStmt: LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), - Order: $8.([]*ast.ByItem), } if $7 != nil { x.Where = $7.(ast.ExprNode) } - + if $8 != nil { + x.Order = $8.([]*ast.ByItem) + } if $9 != nil { x.Limit = $9.(*ast.Limit) } @@ -1501,13 +1498,13 @@ Field: } | Expression FieldAsNameOpt { - $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: $2.(model.CIStr)} + $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: model.NewCIStr($2.(string))} } FieldAsNameOpt: /* EMPTY */ { - $$ = model.CIStr{} + $$ = "" } | FieldAsName { @@ -2213,7 +2210,7 @@ TrimDirection: FunctionCallAgg: "AVG" '(' DistinctOpt ExpressionList ')' { - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "COUNT" '(' DistinctOpt ExpressionList ')' { @@ -2550,14 +2547,32 @@ RollbackStmt: } SelectStmt: - "SELECT" SelectStmtOpts SelectStmtFieldList FromDual SelectStmtLimit SelectLockOpt + "SELECT" SelectStmtOpts SelectStmtFieldList SelectStmtLimit SelectLockOpt + { + st := &ast.SelectStmt { + Distinct: $2.(bool), + Fields: $3.(*ast.FieldList), + LockTp: $5.(ast.SelectLockType), + } + if $4 != nil { + st.Limit = $4.(*ast.Limit) + } + $$ = st + } +| "SELECT" SelectStmtOpts SelectStmtFieldList FromDual WhereClauseOptional SelectStmtLimit SelectLockOpt { - $$ = &ast.SelectStmt { + st := &ast.SelectStmt { Distinct: $2.(bool), Fields: $3.(*ast.FieldList), - From: nil, - LockTp: $6.(ast.SelectLockType), + LockTp: $7.(ast.SelectLockType), } + if $5 != nil { + st.Where = $5.(ast.ExprNode) + } + if $6 != nil { + st.Limit = $6.(*ast.Limit) + } + $$ = st } | "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional @@ -2594,8 +2609,7 @@ SelectStmt: } FromDual: - /* Empty */ -| "FROM" "DUAL" + "FROM" "DUAL" TableRefsClause: @@ -2836,13 +2850,7 @@ UnionStmt: { union := $1.(*ast.UnionStmt) if $2 != nil { - // push union order by into select statements. - orderBy := $2.(*ast.OrderByClause) - cloner := &ast.Cloner{} - for _, s := range union.Selects { - node, _ := orderBy.Accept(cloner) - s.OrderBy = node.(*ast.OrderByClause) - } + union.OrderBy = $2.(*ast.OrderByClause) } if $3 != nil { union.Limit = $3.(*ast.Limit) @@ -3052,12 +3060,15 @@ ShowStmt: } | "SHOW" OptFull "TABLES" ShowDatabaseNameOpt ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowTables, DBName: $4.(string), Full: $2.(bool), - Where: $5.(ast.ExprNode), } + if $5 != nil { + stmt.Where = $5.(ast.ExprNode) + } + $$ = stmt } | "SHOW" OptFull "COLUMNS" ShowTableAliasOpt ShowDatabaseNameOpt { @@ -3074,18 +3085,24 @@ ShowStmt: } | "SHOW" GlobalScope "VARIABLES" ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowVariables, GlobalScope: $2.(bool), - Where: $4.(ast.ExprNode), } + if $4 != nil { + stmt.Where = $4.(ast.ExprNode) + } + $$ = stmt } | "SHOW" "COLLATION" ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowCollation, - Where: $3.(ast.ExprNode), } + if $3 != nil { + stmt.Where = $3.(ast.ExprNode) + } + $$ = stmt } | "SHOW" "CREATE" "TABLE" TableName { @@ -3174,6 +3191,7 @@ Statement: | PreparedStmt | RollbackStmt | SelectStmt +| UnionStmt | SetStmt | ShowStmt | TruncateTableStmt @@ -3807,13 +3825,11 @@ StringName: * See: https://dev.mysql.com/doc/refman/5.7/en/update.html ***********************************************************************************/ UpdateStmt: - "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause + "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause { - // Single-table syntax - join := &ast.Join{Left: &ast.TableSource{Source:$4.(ast.ResultSetNode)}, Right: nil} st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: &ast.TableRefsClause{TableRefs: join}, + TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, List: $6.([]*ast.Assignment), } if $7 != nil { @@ -3830,23 +3846,6 @@ UpdateStmt: break } } -| "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional - { - // Multiple-table syntax - st := &ast.UpdateStmt{ - LowPriority: $2.(bool), - TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, - List: $6.([]*ast.Assignment), - MultipleTable: true, - } - if $7 != nil { - st.Where = $7.(ast.ExprNode) - } - $$ = st - if yylex.(*lexer).root { - break - } - } UseStmt: "USE" DBName diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index a23a68a72ac4a..cdbe765347ac0 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/stringutil" ) type lexer struct { @@ -46,6 +47,7 @@ type lexer struct { val []byte ungetBuf []byte root bool + prepare bool stmtStartPos int stringLit []byte @@ -86,6 +88,14 @@ func (l *lexer) SetInj(inj int) { l.inj = inj } +func (l *lexer) SetPrepare() { + l.prepare = true +} + +func (l *lexer) IsPrepare() bool { + return l.prepare +} + func (l *lexer) Root() bool { return l.root } @@ -1001,7 +1011,8 @@ func (l *lexer) str(lval *yySymType, pref string) int { s = strings.TrimSuffix(s, "'") + "\"" pref = "\"" } - v, err := strconv.Unquote(pref + s) + v := stringutil.RemoveUselessBackslash(pref+s) + v, err := strconv.Unquote(v) if err != nil { v = strings.TrimSuffix(s, pref) } diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 41c78a815e04c..7a62c1ee96d4c 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -166,14 +166,14 @@ func (c *expressionConverter) subquery(v *ast.SubqueryExpr) { oldSubquery := &subquery.SubQuery{} switch x := v.Query.(type) { case *ast.SelectStmt: - oldSelect, err := convertSelect(x) + oldSelect, err := convertSelect(c, x) if err != nil { c.err = err return } oldSubquery.Stmt = oldSelect case *ast.UnionStmt: - oldUnion, err := convertUnion(x) + oldUnion, err := convertUnion(c, x) if err != nil { c.err = err return diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 036bfebefdc1e..b58f2d240d8a8 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -16,8 +16,10 @@ package optimizer import ( "github.com/juju/errors" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" + "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/rset/rsets" "github.com/pingcap/tidb/stmt" @@ -37,17 +39,16 @@ func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expr return oldAssign, nil } -func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { +func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { oldInsert := &stmts.InsertIntoStmt{ Priority: v.Priority, Text: v.Text(), } - tableName := v.Table.TableRefs.Left.(*ast.TableName) + tableName := v.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName) oldInsert.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} for _, val := range v.Columns { oldInsert.ColNames = append(oldInsert.ColNames, joinColumnName(val)) } - converter := newExpressionConverter() for _, row := range v.Lists { var oldRow []expression.Expression for _, val := range row { @@ -77,9 +78,9 @@ func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { var err error switch x := v.Select.(type) { case *ast.SelectStmt: - oldInsert.Sel, err = convertSelect(x) + oldInsert.Sel, err = convertSelect(converter, x) case *ast.UnionStmt: - oldInsert.Sel, err = convertUnion(x) + oldInsert.Sel, err = convertUnion(converter, x) } if err != nil { return nil, errors.Trace(err) @@ -88,7 +89,7 @@ func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { return oldInsert, nil } -func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { +func convertDelete(converter *expressionConverter, v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { oldDelete := &stmts.DeleteStmt{ BeforeFrom: v.BeforeFrom, Ignore: v.Ignore, @@ -97,7 +98,6 @@ func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { Quick: v.Quick, Text: v.Text(), } - converter := newExpressionConverter() oldRefs, err := convertJoin(converter, v.TableRefs.TableRefs) if err != nil { return nil, errors.Trace(err) @@ -131,14 +131,13 @@ func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { return oldDelete, nil } -func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { +func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { oldUpdate := &stmts.UpdateStmt{ Ignore: v.Ignore, MultipleTable: v.MultipleTable, LowPriority: v.LowPriority, Text: v.Text(), } - converter := newExpressionConverter() var err error oldUpdate.TableRefs, err = convertJoin(converter, v.TableRefs.TableRefs) if err != nil { @@ -176,10 +175,7 @@ func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { return oldUpdate, nil } -func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { - converter := &expressionConverter{ - exprMap: map[ast.Node]expression.Expression{}, - } +func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.SelectStmt, error) { oldSelect := &stmts.SelectStmt{} oldSelect.Distinct = s.Distinct oldSelect.Fields = make([]*field.Field, len(s.Fields.Fields)) @@ -187,9 +183,13 @@ func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { oldField := &field.Field{} oldField.AsName = val.AsName.O var err error - oldField.Expr, err = convertExpr(converter, val.Expr) - if err != nil { - return nil, errors.Trace(err) + if val.Expr != nil { + oldField.Expr, err = convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + } else if val.WildCard != nil { + oldField.Expr = &expression.Ident{CIStr: model.NewCIStr("*")} } oldSelect.Fields[i] = oldField } @@ -245,7 +245,7 @@ func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { return oldSelect, nil } -func convertUnion(u *ast.UnionStmt) (*stmts.UnionStmt, error) { +func convertUnion(converter *expressionConverter, u *ast.UnionStmt) (*stmts.UnionStmt, error) { oldUnion := &stmts.UnionStmt{} oldUnion.Selects = make([]*stmts.SelectStmt, len(u.Selects)) oldUnion.Distincts = make([]bool, len(u.Selects)-1) @@ -255,12 +255,19 @@ func convertUnion(u *ast.UnionStmt) (*stmts.UnionStmt, error) { } } for i, val := range u.Selects { - oldSelect, err := convertSelect(val) + oldSelect, err := convertSelect(converter, val) if err != nil { return nil, errors.Trace(err) } oldUnion.Selects[i] = oldSelect } + if u.OrderBy != nil { + oldOrderBy, err := convertOrderBy(converter, u.OrderBy) + if err != nil { + return nil, errors.Trace(err) + } + oldUnion.OrderBy = oldOrderBy + } if u.Limit != nil { if u.Limit.Offset > 0 { oldUnion.Offset = &rsets.OffsetRset{Count: u.Limit.Offset} @@ -298,7 +305,7 @@ func convertJoin(converter *expressionConverter, join *ast.Join) (*rsets.JoinRse oldJoin.Left = oldLeft } - switch r := join.Left.(type) { + switch r := join.Right.(type) { case *ast.Join: oldRight, err := convertJoin(converter, r) if err != nil { @@ -331,13 +338,13 @@ func convertTableSource(converter *expressionConverter, ts *ast.TableSource) (*r case *ast.TableName: oldTs.Source = table.Ident{Schema: src.Schema, Name: src.Name} case *ast.SelectStmt: - oldSelect, err := convertSelect(src) + oldSelect, err := convertSelect(converter, src) if err != nil { return nil, errors.Trace(err) } oldTs.Source = oldSelect case *ast.UnionStmt: - oldUnion, err := convertUnion(src) + oldUnion, err := convertUnion(converter, src) if err != nil { return nil, errors.Trace(err) } @@ -384,7 +391,7 @@ func convertOrderBy(converter *expressionConverter, orderBy *ast.OrderByClause) return oldOrderBy, nil } -func convertCreateDatabase(v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { +func convertCreateDatabase(converter *expressionConverter, v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { oldCreateDatabase := &stmts.CreateDatabaseStmt{ IfNotExists: v.IfNotExists, Name: v.Name, @@ -404,7 +411,7 @@ func convertCreateDatabase(v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt return oldCreateDatabase, nil } -func convertDropDatabase(v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { +func convertDropDatabase(converter *expressionConverter, v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { return &stmts.DropDatabaseStmt{ IfExists: v.IfExists, Name: v.Name, @@ -512,12 +519,12 @@ func convertConstraint(converter *expressionConverter, v *ast.Constraint) (*cold return oldConstraint, nil } -func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { +func convertCreateTable(converter *expressionConverter, v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { oldCreateTable := &stmts.CreateTableStmt{ - Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, - Text: v.Text(), + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + IfNotExists: v.IfNotExists, + Text: v.Text(), } - converter := newExpressionConverter() for _, val := range v.Cols { oldColDef, err := convertColumnDef(converter, val) if err != nil { @@ -551,7 +558,7 @@ func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) return oldCreateTable, nil } -func convertDropTable(v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { +func convertDropTable(converter *expressionConverter, v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { oldDropTable := &stmts.DropTableStmt{ IfExists: v.IfExists, Text: v.Text(), @@ -566,7 +573,7 @@ func convertDropTable(v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { return oldDropTable, nil } -func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { +func convertCreateIndex(converter *expressionConverter, v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { oldCreateIndex := &stmts.CreateIndexStmt{ IndexName: v.IndexName, Unique: v.Unique, @@ -587,7 +594,7 @@ func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) return oldCreateIndex, nil } -func convertDropIndex(v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { +func convertDropIndex(converter *expressionConverter, v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { return &stmts.DropIndexStmt{ IfExists: v.IfExists, IndexName: v.IndexName, @@ -595,14 +602,110 @@ func convertDropIndex(v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { }, nil } -func convertAlterTable(v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { +func convertAlterTableSpec(converter *expressionConverter, v *ast.AlterTableSpec) (*ddl.AlterSpecification, error) { + oldAlterSpec := &ddl.AlterSpecification{ + Name: v.Name, + } + switch v.Tp { + case ast.AlterTableAddConstraint: + oldAlterSpec.Action = ddl.AlterAddConstr + case ast.AlterTableAddColumn: + oldAlterSpec.Action = ddl.AlterAddColumn + case ast.AlterTableDropColumn: + oldAlterSpec.Action = ddl.AlterDropColumn + case ast.AlterTableDropForeignKey: + oldAlterSpec.Action = ddl.AlterDropForeignKey + case ast.AlterTableDropIndex: + oldAlterSpec.Action = ddl.AlterDropIndex + case ast.AlterTableDropPrimaryKey: + oldAlterSpec.Action = ddl.AlterDropPrimaryKey + case ast.AlterTableOption: + oldAlterSpec.Action = ddl.AlterTableOpt + } + if v.Column != nil { + oldColDef, err := convertColumnDef(converter, v.Column) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Column = oldColDef + } + if v.Position != nil { + oldAlterSpec.Position = &ddl.ColumnPosition{} + switch v.Position.Tp { + case ast.ColumnPositionNone: + oldAlterSpec.Position.Type = ddl.ColumnPositionNone + case ast.ColumnPositionFirst: + oldAlterSpec.Position.Type = ddl.ColumnPositionFirst + case ast.ColumnPositionAfter: + oldAlterSpec.Position.Type = ddl.ColumnPositionAfter + } + if v.ColumnName != nil { + oldAlterSpec.Position.RelativeColumn = joinColumnName(v.ColumnName) + } + } + if v.Constraint != nil { + oldConstraint, err := convertConstraint(converter, v.Constraint) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Constraint = oldConstraint + } + for _, val := range v.Options { + oldOpt := &coldef.TableOpt{ + StrValue: val.StrValue, + UintValue: val.UintValue, + } + switch val.Tp { + case ast.TableOptionNone: + oldOpt.Tp = coldef.TblOptNone + case ast.TableOptionEngine: + oldOpt.Tp = coldef.TblOptEngine + case ast.TableOptionCharset: + oldOpt.Tp = coldef.TblOptCharset + case ast.TableOptionCollate: + oldOpt.Tp = coldef.TblOptCollate + case ast.TableOptionAutoIncrement: + oldOpt.Tp = coldef.TblOptAutoIncrement + case ast.TableOptionComment: + oldOpt.Tp = coldef.TblOptComment + case ast.TableOptionAvgRowLength: + oldOpt.Tp = coldef.TblOptAvgRowLength + case ast.TableOptionCheckSum: + oldOpt.Tp = coldef.TblOptCheckSum + case ast.TableOptionCompression: + oldOpt.Tp = coldef.TblOptCompression + case ast.TableOptionConnection: + oldOpt.Tp = coldef.TblOptConnection + case ast.TableOptionPassword: + oldOpt.Tp = coldef.TblOptPassword + case ast.TableOptionKeyBlockSize: + oldOpt.Tp = coldef.TblOptKeyBlockSize + case ast.TableOptionMaxRows: + oldOpt.Tp = coldef.TblOptMaxRows + case ast.TableOptionMinRows: + oldOpt.Tp = coldef.TblOptMinRows + } + oldAlterSpec.TableOpts = append(oldAlterSpec.TableOpts, oldOpt) + } + return oldAlterSpec, nil +} + +func convertAlterTable(converter *expressionConverter, v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { oldAlterTable := &stmts.AlterTableStmt{ - Text: v.Text(), + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + Text: v.Text(), + } + for _, val := range v.Specs { + oldSpec, err := convertAlterTableSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterTable.Specs = append(oldAlterTable.Specs, oldSpec) } return oldAlterTable, nil } -func convertTruncateTable(v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { +func convertTruncateTable(converter *expressionConverter, v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { return &stmts.TruncateTableStmt{ TableIdent: table.Ident{ Schema: v.Table.Schema, @@ -612,20 +715,20 @@ func convertTruncateTable(v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, e }, nil } -func convertExplain(v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { +func convertExplain(converter *expressionConverter, v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { oldExplain := &stmts.ExplainStmt{ Text: v.Text(), } var err error switch x := v.Stmt.(type) { case *ast.SelectStmt: - oldExplain.S, err = convertSelect(x) + oldExplain.S, err = convertSelect(converter, x) case *ast.UpdateStmt: - oldExplain.S, err = convertUpdate(x) + oldExplain.S, err = convertUpdate(converter, x) case *ast.DeleteStmt: - oldExplain.S, err = convertDelete(x) + oldExplain.S, err = convertDelete(converter, x) case *ast.InsertStmt: - oldExplain.S, err = convertInsert(x) + oldExplain.S, err = convertInsert(converter, x) } if err != nil { return nil, errors.Trace(err) @@ -633,7 +736,7 @@ func convertExplain(v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { return oldExplain, nil } -func convertPrepare(v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { +func convertPrepare(converter *expressionConverter, v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { oldPrepare := &stmts.PreparedStmt{ InPrepare: true, Name: v.Name, @@ -651,7 +754,7 @@ func convertPrepare(v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { return oldPrepare, nil } -func convertDeallocate(v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { +func convertDeallocate(converter *expressionConverter, v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { return &stmts.DeallocateStmt{ ID: v.ID, Name: v.Name, @@ -659,14 +762,13 @@ func convertDeallocate(v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { }, nil } -func convertExecute(v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { +func convertExecute(converter *expressionConverter, v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { oldExec := &stmts.ExecuteStmt{ ID: v.ID, Name: v.Name, Text: v.Text(), } oldExec.UsingVars = make([]expression.Expression, len(v.UsingVars)) - converter := newExpressionConverter() for i, val := range v.UsingVars { oldVar, err := convertExpr(converter, val) if err != nil { @@ -677,7 +779,7 @@ func convertExecute(v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { return oldExec, nil } -func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { +func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowStmt, error) { oldShow := &stmts.ShowStmt{ DBName: v.DBName, Flag: v.Flag, @@ -735,25 +837,25 @@ func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { return oldShow, nil } -func convertBegin(v *ast.BeginStmt) (*stmts.BeginStmt, error) { +func convertBegin(converter *expressionConverter, v *ast.BeginStmt) (*stmts.BeginStmt, error) { return &stmts.BeginStmt{ Text: v.Text(), }, nil } -func convertCommit(v *ast.CommitStmt) (*stmts.CommitStmt, error) { +func convertCommit(converter *expressionConverter, v *ast.CommitStmt) (*stmts.CommitStmt, error) { return &stmts.CommitStmt{ Text: v.Text(), }, nil } -func convertRollback(v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { +func convertRollback(converter *expressionConverter, v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { return &stmts.RollbackStmt{ Text: v.Text(), }, nil } -func convertUse(v *ast.UseStmt) (*stmts.UseStmt, error) { +func convertUse(converter *expressionConverter, v *ast.UseStmt) (*stmts.UseStmt, error) { return &stmts.UseStmt{ DBName: v.DBName, Text: v.Text(), @@ -775,12 +877,11 @@ func convertVariableAssignment(converter *expressionConverter, v *ast.VariableAs }, nil } -func convertSet(v *ast.SetStmt) (*stmts.SetStmt, error) { +func convertSet(converter *expressionConverter, v *ast.SetStmt) (*stmts.SetStmt, error) { oldSet := &stmts.SetStmt{ Text: v.Text(), Variables: make([]*stmts.VariableAssignment, len(v.Variables)), } - converter := newExpressionConverter() for i, val := range v.Variables { oldAssign, err := convertVariableAssignment(converter, val) if err != nil { @@ -791,7 +892,7 @@ func convertSet(v *ast.SetStmt) (*stmts.SetStmt, error) { return oldSet, nil } -func convertSetCharset(v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { +func convertSetCharset(converter *expressionConverter, v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { return &stmts.SetCharsetStmt{ Charset: v.Charset, Collate: v.Collate, @@ -799,7 +900,7 @@ func convertSetCharset(v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { }, nil } -func convertSetPwd(v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { +func convertSetPwd(converter *expressionConverter, v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { return &stmts.SetPwdStmt{ User: v.User, Password: v.Password, @@ -807,14 +908,13 @@ func convertSetPwd(v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { }, nil } -func convertDo(v *ast.DoStmt) (*stmts.DoStmt, error) { - exprConverter := newExpressionConverter() +func convertDo(converter *expressionConverter, v *ast.DoStmt) (*stmts.DoStmt, error) { oldDo := &stmts.DoStmt{ Text: v.Text(), Exprs: make([]expression.Expression, len(v.Exprs)), } for i, val := range v.Exprs { - oldExpr, err := convertExpr(exprConverter, val) + oldExpr, err := convertExpr(converter, val) if err != nil { return nil, errors.Trace(err) } diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index cec754d164e3e..c4d3966cd8a6d 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -16,79 +16,95 @@ package optimizer import ( "github.com/juju/errors" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/stmt" ) +type Compiler struct { + paramMarkers []*expression.ParamMarker +} + // Compile compiles a ast.Node into a executable statement. -func Compile(node ast.Node) (stmt.Statement, error) { +func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { validator := &validator{} if _, ok := node.Accept(validator); !ok { return nil, errors.Trace(validator.err) } - binder := &InfoBinder{} - if _, ok := node.Accept(validator); !ok { - return nil, errors.Trace(binder.Err) - } + // binder := &InfoBinder{} + // if _, ok := node.Accept(validator); !ok { + // return nil, errors.Trace(binder.Err) + // } tpComputer := &typeComputer{} if _, ok := node.Accept(tpComputer); !ok { return nil, errors.Trace(tpComputer.err) } - + c := newExpressionConverter() + defer func() { + for _, v := range c.exprMap { + if x, ok := v.(*expression.ParamMarker); ok { + com.paramMarkers = append(com.paramMarkers, x) + } + } + }() switch v := node.(type) { case *ast.InsertStmt: - return convertInsert(v) + return convertInsert(c, v) case *ast.DeleteStmt: - return convertDelete(v) + return convertDelete(c, v) case *ast.UpdateStmt: - return convertUpdate(v) + return convertUpdate(c, v) case *ast.SelectStmt: - return convertSelect(v) + return convertSelect(c, v) case *ast.UnionStmt: - return convertUnion(v) + return convertUnion(c, v) case *ast.CreateDatabaseStmt: - return convertCreateDatabase(v) + return convertCreateDatabase(c, v) case *ast.DropDatabaseStmt: - return convertDropDatabase(v) + return convertDropDatabase(c, v) case *ast.CreateTableStmt: - return convertCreateTable(v) + return convertCreateTable(c, v) case *ast.DropTableStmt: - return convertDropTable(v) + return convertDropTable(c, v) case *ast.CreateIndexStmt: - return convertCreateIndex(v) + return convertCreateIndex(c, v) case *ast.DropIndexStmt: - return convertDropIndex(v) + return convertDropIndex(c, v) case *ast.AlterTableStmt: - return convertAlterTable(v) + return convertAlterTable(c, v) case *ast.TruncateTableStmt: - return convertTruncateTable(v) + return convertTruncateTable(c, v) case *ast.ExplainStmt: - return convertExplain(v) + return convertExplain(c, v) case *ast.PrepareStmt: - return convertPrepare(v) + return convertPrepare(c, v) case *ast.DeallocateStmt: - return convertDeallocate(v) + return convertDeallocate(c, v) case *ast.ExecuteStmt: - return convertExecute(v) + return convertExecute(c, v) case *ast.ShowStmt: - return convertShow(v) + return convertShow(c, v) case *ast.BeginStmt: - return convertBegin(v) + return convertBegin(c, v) case *ast.CommitStmt: - return convertCommit(v) + return convertCommit(c, v) case *ast.RollbackStmt: - return convertRollback(v) + return convertRollback(c, v) case *ast.UseStmt: - return convertUse(v) + return convertUse(c, v) case *ast.SetStmt: - return convertSet(v) + return convertSet(c, v) case *ast.SetCharsetStmt: - return convertSetCharset(v) + return convertSetCharset(c, v) case *ast.SetPwdStmt: - return convertSetPwd(v) + return convertSetPwd(c, v) case *ast.DoStmt: - return convertDo(v) + return convertDo(c, v) } return nil, nil } + +func (com *Compiler) ParamMarkers() []*expression.ParamMarker { + return com.paramMarkers +} diff --git a/tidb.go b/tidb.go index 55581af7822fb..27fcf5167ab47 100644 --- a/tidb.go +++ b/tidb.go @@ -26,14 +26,13 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/optimizer" - "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/autocommit" "github.com/pingcap/tidb/sessionctx/variable" @@ -132,15 +131,12 @@ func Compile(ctx context.Context, src string) ([]stmt.Statement, error) { rawStmt := l.Stmts() stmts := make([]stmt.Statement, len(rawStmt)) for i, v := range rawStmt { - if node, ok := v.(ast.Node); ok { - stm, err := optimizer.Compile(node) - if err != nil { - return nil, errors.Trace(err) - } - stmts[i] = stm - } else { - stmts[i] = v.(stmt.Statement) + compiler := &optimizer.Compiler{} + stm, err := compiler.Compile(v) + if err != nil { + return nil, errors.Trace(err) } + stmts[i] = stm } return stmts, nil } @@ -162,7 +158,12 @@ func CompilePrepare(ctx context.Context, src string) (stmt.Statement, []*express return nil, nil, nil } sm := sms[0] - return sm.(stmt.Statement), l.ParamList, nil + compiler := &optimizer.Compiler{} + statement, err := compiler.Compile(sm) + if err != nil { + return nil, nil, errors.Trace(err) + } + return statement, compiler.ParamMarkers(), nil } func prepareStmt(ctx context.Context, sqlText string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) { diff --git a/tidb_test.go b/tidb_test.go index 0cb22ed701200..8249ba21aaa26 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -640,6 +640,7 @@ func (s *testSessionSuite) TestSelectForUpdate(c *C) { // conflict mustExecSQL(c, se1, "begin") rs, err := exec(c, se1, "select * from t where c1=11 for update") + c.Assert(err, IsNil) _, err = rs.Rows(-1, 0) mustExecSQL(c, se2, "begin")