Skip to content

Commit

Permalink
parser: remove YYParse function, use Parse and ParseOne instead.
Browse files Browse the repository at this point in the history
  • Loading branch information
coocood committed Dec 23, 2015
1 parent 56ae5f3 commit bfd188d
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 71 deletions.
10 changes: 5 additions & 5 deletions ast/flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ func (ts *testFlagSuite) TestFlag(c *C) {
},
}
for _, ca := range cases {
lexer := parser.NewLexer("select " + ca.expr)
parser.YYParse(lexer)
stmt := lexer.Stmts()[0].(*ast.SelectStmt)
ast.SetFlag(stmt)
expr := stmt.Fields.Fields[0].Expr
stmt, err := parser.ParseOne("select "+ca.expr, "", "")
c.Assert(err, IsNil)
selectStmt := stmt.(*ast.SelectStmt)
ast.SetFlag(selectStmt)
expr := selectStmt.Fields.Fields[0].Expr
c.Assert(expr.GetFlag(), Equals, ca.flag, Commentf("For %s", ca.expr))
}
}
5 changes: 2 additions & 3 deletions ddl/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,9 @@ func (ts *testSuite) TestAlterTableColumn(c *C) {

func statement(ctx context.Context, sql string) stmt.Statement {
log.Debug("[ddl] Compile", sql)
lexer := parser.NewLexer(sql)
parser.YYParse(lexer)
s, _ := parser.ParseOne(sql, "", "")
compiler := &executor.Compiler{}
stm, _ := compiler.Compile(ctx, lexer.Stmts()[0])
stm, _ := compiler.Compile(ctx, s)
return stm
}

Expand Down
12 changes: 6 additions & 6 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,17 @@ func (e *PrepareExec) DoPrepare() {
return
}
}
l := parser.NewLexer(e.SQLText)
l.SetCharsetInfo(variable.GetCharsetInfo(e.Ctx))
if parser.YYParse(l) != 0 {
e.Err = errors.Trace(l.Errors()[0])
charset, collation := variable.GetCharsetInfo(e.Ctx)
stmts, err := parser.Parse(e.SQLText, charset, collation)
if err != nil {
e.Err = errors.Trace(err)
return
}
if len(l.Stmts()) != 1 {
if len(stmts) != 1 {
e.Err = ErrPrepareMulti
return
}
stmt := l.Stmts()[0]
stmt := stmts[0]
var extractor paramMarkerExtractor
stmt.Accept(&extractor)

Expand Down
9 changes: 3 additions & 6 deletions optimizer/evaluator/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,9 @@ type testEvaluatorSuite struct {
}

func parseExpr(c *C, expr string) ast.ExprNode {
lexer := parser.NewLexer("select " + expr)
parser.YYParse(lexer)
if parser.YYParse(lexer) != 0 || len(lexer.Errors()) != 0 {
c.Fatal(lexer.Errors()[0], expr)
}
stmt := lexer.Stmts()[0].(*ast.SelectStmt)
s, err := parser.ParseOne("select "+expr, "", "")
c.Assert(err, IsNil)
stmt := s.(*ast.SelectStmt)
return stmt.Fields.Fields[0].Expr
}

Expand Down
22 changes: 9 additions & 13 deletions optimizer/plan/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,9 @@ func (s *testPlanSuite) TestRangeBuilder(c *C) {

for _, ca := range cases {
sql := "select 1 from dual where " + ca.exprStr
lexer := parser.NewLexer(sql)

rc := parser.YYParse(lexer)
c.Assert(rc, Equals, 0, Commentf("error %v, for expr %s", lexer.Errors(), ca.exprStr))
stmt := lexer.Stmts()[0].(*ast.SelectStmt)
stmts, err := parser.Parse(sql, "", "")
c.Assert(err, IsNil, Commentf("error %v, for expr %s", err, ca.exprStr))
stmt := stmts[0].(*ast.SelectStmt)
result := rb.build(stmt.Where)
c.Assert(rb.err, IsNil)
got := fmt.Sprintf("%v", result)
Expand Down Expand Up @@ -233,10 +231,9 @@ func (s *testPlanSuite) TestBuilder(c *C) {
},
}
for _, ca := range cases {
lexer := parser.NewLexer(ca.sqlStr)
rc := parser.YYParse(lexer)
c.Assert(rc, Equals, 0, Commentf("error %v for expr %s", lexer.Errors(), ca.sqlStr))
stmt := lexer.Stmts()[0].(*ast.SelectStmt)
s, err := parser.ParseOne(ca.sqlStr, "", "")
c.Assert(err, IsNil, Commentf("for expr %s", ca.sqlStr))
stmt := s.(*ast.SelectStmt)
mockResolve(stmt)
p, err := BuildPlan(stmt)
c.Assert(err, IsNil)
Expand Down Expand Up @@ -293,10 +290,9 @@ func (s *testPlanSuite) TestBestPlan(c *C) {
},
}
for _, ca := range cases {
lexer := parser.NewLexer(ca.sql)
rc := parser.YYParse(lexer)
c.Assert(rc, Equals, 0, Commentf("error %v for sql %s", lexer.Errors(), ca.sql))
stmt := lexer.Stmts()[0].(*ast.SelectStmt)
s, err := parser.ParseOne(ca.sql, "", "")
c.Assert(err, IsNil, Commentf("for expr %s", ca.sql))
stmt := s.(*ast.SelectStmt)
ast.SetFlag(stmt)
mockResolve(stmt)
p, err := BuildPlan(stmt)
Expand Down
7 changes: 2 additions & 5 deletions optimizer/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,8 @@ func (ts *testNameResolverSuite) TestNameResolver(c *C) {
domain := sessionctx.GetDomain(ctx)
db.BindCurrentSchema(ctx, "test")
for _, tc := range resolverTestCases {
l := parser.NewLexer(tc.src)
c.Assert(parser.YYParse(l), Equals, 0)
stmts := l.Stmts()
c.Assert(len(stmts), Equals, 1)
node := stmts[0]
node, err := parser.ParseOne(tc.src, "", "")
c.Assert(err, IsNil)
resolveErr := optimizer.ResolveName(node, domain.InfoSchema(), ctx)
if tc.valid {
c.Assert(resolveErr, IsNil)
Expand Down
40 changes: 17 additions & 23 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,25 @@ func (s *testParserSuite) TestSimple(c *C) {
}
for _, kw := range unreservedKws {
src := fmt.Sprintf("SELECT %s FROM tbl;", kw)
l := NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
c.Assert(l.errs, HasLen, 0, Commentf("source %s", src))
_, err := ParseOne(src, "", "")
c.Assert(err, IsNil, Commentf("source %s", src))
}

// Testcase for prepared statement
src := "SELECT id+?, id+? from t;"
l := NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
c.Assert(len(l.Stmts()), Equals, 1)
_, err := ParseOne(src, "", "")
c.Assert(err, IsNil)

// Testcase for -- Comment and unary -- operator
src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;"
l = NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
c.Assert(len(l.Stmts()), Equals, 2)
stmts, err := Parse(src, "", "")
c.Assert(err, IsNil)
c.Assert(stmts, HasLen, 2)

// Testcase for CONVERT(expr,type)
src = "SELECT CONVERT('111', SIGNED);"
l = NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
st := l.Stmts()[0]
st, err := ParseOne(src, "", "")
c.Assert(err, IsNil)
ss, ok := st.(*ast.SelectStmt)
c.Assert(ok, IsTrue)
c.Assert(len(ss.Fields.Fields), Equals, 1)
Expand All @@ -83,9 +80,8 @@ func (s *testParserSuite) TestSimple(c *C) {
"SELECT CONVERT('111', SIGNED) /*comment*/;",
}
for _, src := range srcs {
l = NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
st = l.Stmts()[0]
st, err = ParseOne(src, "", "")
c.Assert(err, IsNil)
ss, ok = st.(*ast.SelectStmt)
c.Assert(ok, IsTrue)
}
Expand All @@ -98,14 +94,12 @@ type testCase struct {

func (s *testParserSuite) RunTest(c *C, table []testCase) {
for _, t := range table {
l := NewLexer(t.src)
ok := yyParse(l) == 0
c.Assert(ok, Equals, t.ok, Commentf("source %v %v", t.src, l.errs))
switch ok {
case true:
c.Assert(l.errs, HasLen, 0, Commentf("src: %s", t.src))
case false:
c.Assert(len(l.errs), Not(Equals), 0, Commentf("src: %s", t.src))
_, err := Parse(t.src, "", "")
comment := Commentf("source %v", t.src)
if t.ok {
c.Assert(err, IsNil, comment)
} else {
c.Assert(err, NotNil, comment)
}
}
}
Expand Down
39 changes: 35 additions & 4 deletions parser/yy_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

package parser

import "github.com/pingcap/tidb/terror"
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/terror"
)

// Error instances.
var (
Expand All @@ -25,7 +30,33 @@ const (
CodeSyntaxErr terror.ErrCode = iota + 1
)

// YYParse is an wrapper of `yyParse` to make it exported.
func YYParse(yylex yyLexer) int {
return yyParse(yylex)
// Parse parses a query string to raw ast.StmtNode.
// If charset and collation is "", default charset and collation will be used.
func Parse(sql, charset, collation string) ([]ast.StmtNode, error) {
if charset == "" {
charset = mysql.DefaultCharset
}
if collation == "" {
collation = mysql.DefaultCollationName
}
l := NewLexer(sql)
l.SetCharsetInfo(charset, collation)
yyParse(l)
if len(l.Errors()) != 0 {
return nil, errors.Trace(l.Errors()[0])
}
return l.Stmts(), nil
}

// ParseOne parses a query and return the ast.StmtNode.
// The query must has one statement, otherwise ErrSyntax is returned.
func ParseOne(sql, charset, collation string) (ast.StmtNode, error) {
stmts, err := Parse(sql, charset, collation)
if err != nil {
return nil, errors.Trace(err)
}
if len(stmts) != 1 {
return nil, ErrSyntax
}
return stmts[0], nil
}
12 changes: 6 additions & 6 deletions tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ func getCtxCharsetInfo(ctx context.Context) (string, string) {

// Parse parses a query string to raw ast.StmtNode.
func Parse(ctx context.Context, src string) ([]ast.StmtNode, error) {
l := parser.NewLexer(src)
l.SetCharsetInfo(getCtxCharsetInfo(ctx))
if parser.YYParse(l) != 0 {
log.Warnf("compiling %s, error: %v", src, l.Errors()[0])
return nil, errors.Trace(l.Errors()[0])
charset, collation := getCtxCharsetInfo(ctx)
stmts, err := parser.Parse(src, charset, collation)
if err != nil {
log.Warnf("compiling %s, error: %v", src, err)
return nil, errors.Trace(err)
}
return l.Stmts(), nil
return stmts, nil
}

// Compile is safe for concurrent use by multiple goroutines.
Expand Down

0 comments on commit bfd188d

Please sign in to comment.