Skip to content

Commit

Permalink
tidb: switch to use ast parser.
Browse files Browse the repository at this point in the history
  • Loading branch information
coocood committed Oct 30, 2015
1 parent edcbbc8 commit 7c74739
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 153 deletions.
7 changes: 6 additions & 1 deletion ast/cloner.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

package ast

import "fmt"

// Cloner is a ast visitor that clones a node.
type Cloner struct {
}
Expand Down Expand Up @@ -57,6 +59,9 @@ func cloneStruct(in Node) (out Node) {
case *ColumnName:
nv := *v
out = &nv
case *ColumnNameExpr:
nv := *v
out = &nv
case *DefaultExpr:
nv := *v
out = &nv
Expand Down Expand Up @@ -162,7 +167,7 @@ func cloneStruct(in Node) (out Node) {
default:
// We currently only handle expression and select statement.
// Will add more when we need to.
panic("unknown ast Node type")
panic("unknown ast Node type " + fmt.Sprintf("%T", v))
}
return
}
8 changes: 8 additions & 0 deletions ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ type UnionStmt struct {

Distinct bool
Selects []*SelectStmt
OrderBy *OrderByClause
Limit *Limit
}

Expand All @@ -505,6 +506,13 @@ func (nod *UnionStmt) Accept(v Visitor) (Node, bool) {
}
nod.Selects[i] = node.(*SelectStmt)
}
if nod.OrderBy != nil {
node, ok := nod.OrderBy.Accept(v)
if !ok {
return nod, false
}
nod.OrderBy = node.(*OrderByClause)
}
if nod.Limit != nil {
node, ok := nod.Limit.Accept(v)
if !ok {
Expand Down
5 changes: 4 additions & 1 deletion ast/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package ast

import (
"fmt"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
Expand Down Expand Up @@ -58,6 +59,8 @@ func NewValueExpr(value interface{}) *ValueExpr {
ve.Data = value
// TODO: make it more precise.
switch value.(type) {
case nil:
ve.Type = types.NewFieldType(mysql.TypeNull)
case bool, int64:
ve.Type = types.NewFieldType(mysql.TypeLonglong)
case uint64:
Expand All @@ -80,7 +83,7 @@ func NewValueExpr(value interface{}) *ValueExpr {
ve.Type.Charset = "binary"
ve.Type.Collate = "binary"
default:
panic("illegal literal value type")
panic("illegal literal value type:" + fmt.Sprintf("T", value))
}
return ve
}
Expand Down
101 changes: 50 additions & 51 deletions ast/parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -773,13 +773,9 @@ ColumnNameListOpt:
{
$$ = []*ast.ColumnName{}
}
| '(' ')'
{
$$ = []*ast.ColumnName{}
}
| '(' ColumnNameList ')'
| ColumnNameList
{
$$ = $2.([]*ast.ColumnName)
$$ = $1.([]*ast.ColumnName)
}

CommitStmt:
Expand Down Expand Up @@ -1144,12 +1140,13 @@ DeleteFromStmt:
LowPriority: $2.(bool),
Quick: $3.(bool),
Ignore: $4.(bool),
Order: $8.([]*ast.ByItem),
}
if $7 != nil {
x.Where = $7.(ast.ExprNode)
}

if $8 != nil {
x.Order = $8.([]*ast.ByItem)
}
if $9 != nil {
x.Limit = $9.(*ast.Limit)
}
Expand Down Expand Up @@ -1501,13 +1498,13 @@ Field:
}
| Expression FieldAsNameOpt
{
$$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: $2.(model.CIStr)}
$$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: model.NewCIStr($2.(string))}
}

FieldAsNameOpt:
/* EMPTY */
{
$$ = model.CIStr{}
$$ = ""
}
| FieldAsName
{
Expand Down Expand Up @@ -2213,7 +2210,7 @@ TrimDirection:
FunctionCallAgg:
"AVG" '(' DistinctOpt ExpressionList ')'
{
$$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)}
$$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)}
}
| "COUNT" '(' DistinctOpt ExpressionList ')'
{
Expand Down Expand Up @@ -2550,14 +2547,32 @@ RollbackStmt:
}

SelectStmt:
"SELECT" SelectStmtOpts SelectStmtFieldList FromDual SelectStmtLimit SelectLockOpt
"SELECT" SelectStmtOpts SelectStmtFieldList SelectStmtLimit SelectLockOpt
{
st := &ast.SelectStmt {
Distinct: $2.(bool),
Fields: $3.(*ast.FieldList),
LockTp: $5.(ast.SelectLockType),
}
if $4 != nil {
st.Limit = $4.(*ast.Limit)
}
$$ = st
}
| "SELECT" SelectStmtOpts SelectStmtFieldList FromDual WhereClauseOptional SelectStmtLimit SelectLockOpt
{
$$ = &ast.SelectStmt {
st := &ast.SelectStmt {
Distinct: $2.(bool),
Fields: $3.(*ast.FieldList),
From: nil,
LockTp: $6.(ast.SelectLockType),
LockTp: $7.(ast.SelectLockType),
}
if $5 != nil {
st.Where = $5.(ast.ExprNode)
}
if $6 != nil {
st.Limit = $6.(*ast.Limit)
}
$$ = st
}
| "SELECT" SelectStmtOpts SelectStmtFieldList "FROM"
TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional
Expand Down Expand Up @@ -2594,8 +2609,7 @@ SelectStmt:
}

FromDual:
/* Empty */
| "FROM" "DUAL"
"FROM" "DUAL"


TableRefsClause:
Expand Down Expand Up @@ -2836,13 +2850,7 @@ UnionStmt:
{
union := $1.(*ast.UnionStmt)
if $2 != nil {
// push union order by into select statements.
orderBy := $2.(*ast.OrderByClause)
cloner := &ast.Cloner{}
for _, s := range union.Selects {
node, _ := orderBy.Accept(cloner)
s.OrderBy = node.(*ast.OrderByClause)
}
union.OrderBy = $2.(*ast.OrderByClause)
}
if $3 != nil {
union.Limit = $3.(*ast.Limit)
Expand Down Expand Up @@ -3052,12 +3060,15 @@ ShowStmt:
}
| "SHOW" OptFull "TABLES" ShowDatabaseNameOpt ShowLikeOrWhereOpt
{
$$ = &ast.ShowStmt{
stmt := &ast.ShowStmt{
Tp: ast.ShowTables,
DBName: $4.(string),
Full: $2.(bool),
Where: $5.(ast.ExprNode),
}
if $5 != nil {
stmt.Where = $5.(ast.ExprNode)
}
$$ = stmt
}
| "SHOW" OptFull "COLUMNS" ShowTableAliasOpt ShowDatabaseNameOpt
{
Expand All @@ -3074,18 +3085,24 @@ ShowStmt:
}
| "SHOW" GlobalScope "VARIABLES" ShowLikeOrWhereOpt
{
$$ = &ast.ShowStmt{
stmt := &ast.ShowStmt{
Tp: ast.ShowVariables,
GlobalScope: $2.(bool),
Where: $4.(ast.ExprNode),
}
if $4 != nil {
stmt.Where = $4.(ast.ExprNode)
}
$$ = stmt
}
| "SHOW" "COLLATION" ShowLikeOrWhereOpt
{
$$ = &ast.ShowStmt{
stmt := &ast.ShowStmt{
Tp: ast.ShowCollation,
Where: $3.(ast.ExprNode),
}
if $3 != nil {
stmt.Where = $3.(ast.ExprNode)
}
$$ = stmt
}
| "SHOW" "CREATE" "TABLE" TableName
{
Expand Down Expand Up @@ -3174,6 +3191,7 @@ Statement:
| PreparedStmt
| RollbackStmt
| SelectStmt
| UnionStmt
| SetStmt
| ShowStmt
| TruncateTableStmt
Expand Down Expand Up @@ -3807,13 +3825,11 @@ StringName:
* See: https://dev.mysql.com/doc/refman/5.7/en/update.html
***********************************************************************************/
UpdateStmt:
"UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause
"UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause
{
// Single-table syntax
join := &ast.Join{Left: &ast.TableSource{Source:$4.(ast.ResultSetNode)}, Right: nil}
st := &ast.UpdateStmt{
LowPriority: $2.(bool),
TableRefs: &ast.TableRefsClause{TableRefs: join},
TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)},
List: $6.([]*ast.Assignment),
}
if $7 != nil {
Expand All @@ -3830,23 +3846,6 @@ UpdateStmt:
break
}
}
| "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional
{
// Multiple-table syntax
st := &ast.UpdateStmt{
LowPriority: $2.(bool),
TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)},
List: $6.([]*ast.Assignment),
MultipleTable: true,
}
if $7 != nil {
st.Where = $7.(ast.ExprNode)
}
$$ = st
if yylex.(*lexer).root {
break
}
}

UseStmt:
"USE" DBName
Expand Down
13 changes: 12 additions & 1 deletion ast/parser/scanner.l
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/stringutil"
)

type lexer struct {
Expand All @@ -46,6 +47,7 @@ type lexer struct {
val []byte
ungetBuf []byte
root bool
prepare bool
stmtStartPos int
stringLit []byte

Expand Down Expand Up @@ -86,6 +88,14 @@ func (l *lexer) SetInj(inj int) {
l.inj = inj
}

func (l *lexer) SetPrepare() {
l.prepare = true
}

func (l *lexer) IsPrepare() bool {
return l.prepare
}

func (l *lexer) Root() bool {
return l.root
}
Expand Down Expand Up @@ -1001,7 +1011,8 @@ func (l *lexer) str(lval *yySymType, pref string) int {
s = strings.TrimSuffix(s, "'") + "\""
pref = "\""
}
v, err := strconv.Unquote(pref + s)
v := stringutil.RemoveUselessBackslash(pref+s)
v, err := strconv.Unquote(v)
if err != nil {
v = strings.TrimSuffix(s, pref)
}
Expand Down
4 changes: 2 additions & 2 deletions optimizer/convert_expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,14 @@ func (c *expressionConverter) subquery(v *ast.SubqueryExpr) {
oldSubquery := &subquery.SubQuery{}
switch x := v.Query.(type) {
case *ast.SelectStmt:
oldSelect, err := convertSelect(x)
oldSelect, err := convertSelect(c, x)
if err != nil {
c.err = err
return
}
oldSubquery.Stmt = oldSelect
case *ast.UnionStmt:
oldUnion, err := convertUnion(x)
oldUnion, err := convertUnion(c, x)
if err != nil {
c.err = err
return
Expand Down
Loading

0 comments on commit 7c74739

Please sign in to comment.