Skip to content

Commit

Permalink
expression: rewrite builtin function: NULLIF (pingcap#4170)
Browse files Browse the repository at this point in the history
* rewrite nullif() to if()
  • Loading branch information
breezewish authored and zimulala committed Aug 15, 2017
1 parent 3eb558e commit 6733b6f
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 71 deletions.
3 changes: 1 addition & 2 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ type baseFunctionClass struct {
func (b *baseFunctionClass) verifyArgs(args []Expression) error {
l := len(args)
if l < b.minArgs || (b.maxArgs != -1 && l > b.maxArgs) {
return errIncorrectParameterCount.GenByArgs(b.funcName)
return ErrIncorrectParameterCount.GenByArgs(b.funcName)
}
return nil
}
Expand Down Expand Up @@ -845,7 +845,6 @@ var funcs = map[string]functionClass{
// control functions
ast.If: &ifFunctionClass{baseFunctionClass{ast.If, 3, 3}},
ast.Ifnull: &ifNullFunctionClass{baseFunctionClass{ast.Ifnull, 2, 2}},
ast.Nullif: &nullIfFunctionClass{baseFunctionClass{ast.Nullif, 2, 2}},

// miscellaneous functions
ast.Sleep: &sleepFunctionClass{baseFunctionClass{ast.Sleep, 1, 1}},
Expand Down
42 changes: 0 additions & 42 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ var (
_ functionClass = &caseWhenFunctionClass{}
_ functionClass = &ifFunctionClass{}
_ functionClass = &ifNullFunctionClass{}
_ functionClass = &nullIfFunctionClass{}
)

var (
_ builtinFunc = &builtinCaseWhenSig{}
_ builtinFunc = &builtinNullIfSig{}
_ builtinFunc = &builtinIfNullIntSig{}
_ builtinFunc = &builtinIfNullRealSig{}
_ builtinFunc = &builtinIfNullDecimalSig{}
Expand Down Expand Up @@ -479,43 +477,3 @@ func (b *builtinIfNullDurationSig) evalDuration(row []types.Datum) (types.Durati
arg1, isNull, err := b.args[1].EvalDuration(row, sc)
return arg1, isNull, errors.Trace(err)
}

type nullIfFunctionClass struct {
baseFunctionClass
}

func (c *nullIfFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
sig := &builtinNullIfSig{newBaseBuiltinFunc(args, ctx)}
return sig.setSelf(sig), nil
}

type builtinNullIfSig struct {
baseBuiltinFunc
}

// eval evals a builtinNullIfSig.
// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#function_nullif
func (b *builtinNullIfSig) eval(row []types.Datum) (types.Datum, error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
// nullif(expr1, expr2)
// returns null if expr1 = expr2 is true, otherwise returns expr1
v1 := args[0]
v2 := args[1]

if v1.IsNull() || v2.IsNull() {
return v1, nil
}

if n, err1 := v1.CompareDatum(b.ctx.GetSessionVars().StmtCtx, v2); err1 != nil || n == 0 {
d := types.Datum{}
return d, errors.Trace(err1)
}

return v1, nil
}
23 changes: 0 additions & 23 deletions expression/builtin_control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,26 +138,3 @@ func (s *testEvaluatorSuite) TestIfNull(c *C) {
f, err = funcs[ast.Ifnull].getFunction([]Expression{Zero}, s.ctx)
c.Assert(err, NotNil)
}

func (s *testEvaluatorSuite) TestNullIf(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
Arg1 interface{}
Arg2 interface{}
Ret interface{}
}{
{1, 1, nil},
{nil, 2, nil},
{1, nil, 1},
{1, 2, 1},
}

for _, t := range tbl {
fc := funcs[ast.Nullif]
f, err := fc.getFunction(datumsToConstants(types.MakeDatums(t.Arg1, t.Arg2)), s.ctx)
c.Assert(err, IsNil)
d, err := f.eval(nil)
c.Assert(err, IsNil)
c.Assert(d, testutil.DatumEquals, types.NewDatum(t.Ret))
}
}
2 changes: 1 addition & 1 deletion expression/builtin_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func JSONMerge(args []types.Datum, sc *variable.StatementContext) (d types.Datum
// JSONObject creates a json from an ordered key-value slice. It retrieves 2 arguments at least.
func JSONObject(args []types.Datum, sc *variable.StatementContext) (d types.Datum, err error) {
if len(args)&1 == 1 {
err = errIncorrectParameterCount.GenByArgs(ast.JSONObject)
err = ErrIncorrectParameterCount.GenByArgs(ast.JSONObject)
return
}
var jsonMap = make(map[string]json.JSON, len(args)>>1)
Expand Down
1 change: 0 additions & 1 deletion expression/expr_to_pb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ func (s *testEvaluatorSuite) TestControlFunc2Pb(c *C) {
ast.Case,
ast.If,
ast.Ifnull,
ast.Nullif,
}
for i, funcName := range funcNames {
args := []Expression{dg.genColumn(mysql.TypeLong, 1)}
Expand Down
2 changes: 1 addition & 1 deletion expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ import (
// Error instances.
var (
errInvalidOperation = terror.ClassExpression.New(codeInvalidOperation, "invalid operation")
errIncorrectParameterCount = terror.ClassExpression.New(codeIncorrectParameterCount, "Incorrect parameter count in the call to native function '%s'")
errFunctionNotExists = terror.ClassExpression.New(codeFunctionNotExists, "FUNCTION %s does not exist")
errZlibZData = terror.ClassTypes.New(codeZlibZData, "ZLIB: Input data corrupted")
errIncorrectArgs = terror.ClassExpression.New(codeIncorrectArgs, mysql.MySQLErrName[mysql.ErrWrongArguments])
ErrIncorrectParameterCount = terror.ClassExpression.New(codeIncorrectParameterCount, "Incorrect parameter count in the call to native function '%s'")
)

// Error codes.
Expand Down
22 changes: 22 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,28 @@ func (s *testIntegrationSuite) TestCompareBuiltin(c *C) {
"26 0 0 0 1 <nil> <nil> 0 0 0 0 0 <nil> <nil> 1 0 0 0 <nil>",
"27 0 0 0 1 <nil> <nil> 0 0 0 0 0 <nil> <nil> 0 0 1 0 <nil>",
"28 0 0 0 1 <nil> <nil> 0 0 0 0 0 <nil> <nil> 0 0 0 0 <nil>"))

// nullif
result = tk.MustQuery(`SELECT NULLIF(NULL, 1), NULLIF(1, NULL), NULLIF(1, 1), NULLIF(NULL, NULL);`)
result.Check(testkit.Rows("<nil> 1 <nil> <nil>"))

result = tk.MustQuery(`SELECT NULLIF(1, 1.0), NULLIF(1, "1.0");`)
result.Check(testkit.Rows("<nil> <nil>"))

result = tk.MustQuery(`SELECT NULLIF("abc", 1);`)
result.Check(testkit.Rows("abc"))

result = tk.MustQuery(`SELECT NULLIF(1+2, 1);`)
result.Check(testkit.Rows("3"))

result = tk.MustQuery(`SELECT NULLIF(1, 1+2);`)
result.Check(testkit.Rows("1"))

result = tk.MustQuery(`SELECT NULLIF(2+3, 1+2);`)
result.Check(testkit.Rows("5"))

result = tk.MustQuery(`SELECT HEX(NULLIF("abc", 1));`)
result.Check(testkit.Rows("616263"))
}

func (s *testIntegrationSuite) TestAggregationBuiltin(c *C) {
Expand Down
36 changes: 36 additions & 0 deletions plan/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1062,13 +1062,49 @@ func (er *expressionRewriter) checkArgsOneColumn(args ...expression.Expression)
}
}

// rewriteFuncCall handles a FuncCallExpr and generates a customized function.
// It should return true if for the given FuncCallExpr a rewrite is performed so that original behavior is skipped.
// Otherwise it should return false to indicate (the caller) that original behavior needs to be performed.
func (er *expressionRewriter) rewriteFuncCall(v *ast.FuncCallExpr) bool {
switch v.FnName.L {
case ast.Nullif:
if len(v.Args) != 2 {
er.err = expression.ErrIncorrectParameterCount.GenByArgs(v.FnName.O)
return true
}
stackLen := len(er.ctxStack)
param1 := er.ctxStack[stackLen-2]
param2 := er.ctxStack[stackLen-1]
// param1 = param2
funcCompare, err := er.constructBinaryOpFunction(param1, param2, ast.EQ)
if err != nil {
er.err = err
return true
}
// if(param1 = param2, null, param1)
funcIf, err := expression.NewFunction(er.ctx, ast.If, &v.Type, funcCompare, expression.Null, param1)
if err != nil {
er.err = err
return true
}
er.ctxStack = er.ctxStack[:stackLen-len(v.Args)]
er.ctxStack = append(er.ctxStack, funcIf)
return true
default:
return false
}
}

func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) {
stackLen := len(er.ctxStack)
args := er.ctxStack[stackLen-len(v.Args):]
er.checkArgsOneColumn(args...)
if er.err != nil {
return
}
if er.rewriteFuncCall(v) {
return
}
var function expression.Expression
function, er.err = expression.NewFunction(er.ctx, v.FnName.L, &v.Type, args...)
er.ctxStack = er.ctxStack[:stackLen-len(v.Args)]
Expand Down
2 changes: 1 addition & 1 deletion plan/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func (s *testPlanSuite) TestPushDownExpression(c *C) {
// nullif
{
sql: "a = nullif(a, 1)",
cond: "eq(test.t.a, nullif(test.t.a, 1))",
cond: "eq(test.t.a, if(eq(test.t.a, 1), <nil>, test.t.a))",
},
// ifnull
// TODO: ifnull(null, a) will be wrapped with cast which can not be pushed down.
Expand Down
15 changes: 15 additions & 0 deletions plan/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,21 @@ func (s *testPlanSuite) createTestCase4CompareFuncs() []typeInferTestCase {
{"isnull(c_blob )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_set )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_enum )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},

{"nullif(c_int , 123)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, types.UnspecifiedLength}, // TODO: tp should be TypeLonglong, decimal should be 0
{"nullif(c_bigint , 123)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 21, types.UnspecifiedLength}, // TODO: flen should be 20, decimal should be 0
{"nullif(c_float , 123)", mysql.TypeFloat, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, // TODO: tp should be TypeDouble
{"nullif(c_double , 123)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength},
{"nullif(c_decimal , 123)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 3},
{"nullif(c_datetime , 123)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, 2}, // TODO: tp should be TypeVarString, flen is incorrect, no binary flag
{"nullif(c_time , 123)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, 0}, // TODO: tp should be TypeVarString, no binary flag
{"nullif(c_timestamp, 123)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, // TODO: tp should be TypeVarString, flen is incorrect, decimal should be 0, no binary flag
{"nullif(c_char , 123)", mysql.TypeString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"nullif(c_varchar , 123)", mysql.TypeVarchar, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
{"nullif(c_text , 123)", mysql.TypeBlob, charset.CharsetUTF8, 0, types.UnspecifiedLength, types.UnspecifiedLength}, // TODO: tp should be TypeMediumBlob, flen should be 589815
{"nullif(c_binary , 123)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
{"nullif(c_varbinary, 123)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
{"nullif(c_blob , 123)", mysql.TypeBlob, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, // TODO: tp should be TypeVarString, flen should be 65535
}
}

Expand Down

0 comments on commit 6733b6f

Please sign in to comment.