Skip to content

Commit

Permalink
expression: remove the usage of "TypeClass" in "builtin_op.go" (pingc…
Browse files Browse the repository at this point in the history
  • Loading branch information
zz-jason authored and XuHuaiyu committed Sep 18, 2017
1 parent f00206f commit ae00bc2
Showing 1 changed file with 17 additions and 35 deletions.
52 changes: 17 additions & 35 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,11 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx context.Context, args []Exp
return nil, errors.Trace(err)
}

argTp := tpInt
switch args[0].GetTypeClass() {
case types.ClassReal:
argTp = tpReal
case types.ClassDecimal:
argTp = tpDecimal
argTp := fieldTp2EvalTp(args[0].GetType())
if argTp != tpReal && argTp != tpDecimal {
argTp = tpInt
}

bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, argTp)
bf.tp.Flen = 1

Expand Down Expand Up @@ -549,19 +547,16 @@ func (c *unaryMinusFunctionClass) handleIntOverflow(arg *Constant) (overflow boo
// typeInfer infers unaryMinus function return type. when the arg is an int constant and overflow,
// typerInfer will infers the return type as tpDecimal, not tpInt.
func (c *unaryMinusFunctionClass) typeInfer(argExpr Expression, ctx context.Context) (evalTp, bool) {
tp := tpInt
switch argExpr.GetTypeClass() {
case types.ClassString, types.ClassReal:
tp := fieldTp2EvalTp(argExpr.GetType())
if tp != tpInt && tp != tpDecimal {
tp = tpReal
case types.ClassDecimal:
tp = tpDecimal
}

sc := ctx.GetSessionVars().StmtCtx
overflow := false
// TODO: Handle float overflow.
if arg, ok := argExpr.(*Constant); sc.InSelectStmt && ok &&
arg.GetTypeClass() == types.ClassInt {
tp == tpInt {
overflow = c.handleIntOverflow(arg)
if overflow {
tp = tpDecimal
Expand All @@ -579,8 +574,8 @@ func (c *unaryMinusFunctionClass) getFunction(ctx context.Context, args []Expres
retTp, intOverflow := c.typeInfer(argExpr, ctx)

var bf baseBuiltinFunc
switch argExpr.GetTypeClass() {
case types.ClassInt:
switch fieldTp2EvalTp(argExprTp) {
case tpInt:
if intOverflow {
bf = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, true}
Expand All @@ -591,16 +586,16 @@ func (c *unaryMinusFunctionClass) getFunction(ctx context.Context, args []Expres
sig.setPbCode(tipb.ScalarFuncSig_UnaryMinusInt)
}
bf.tp.Decimal = 0
case types.ClassDecimal:
case tpDecimal:
bf = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
bf.tp.Decimal = argExprTp.Decimal
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
sig.setPbCode(tipb.ScalarFuncSig_UnaryMinusDecimal)
case types.ClassReal:
case tpReal:
bf = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal)
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}}
sig.setPbCode(tipb.ScalarFuncSig_UnaryMinusReal)
case types.ClassString:
default:
tp := argExpr.GetType().Tp
if types.IsTypeTime(tp) || tp == mysql.TypeDuration {
bf = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
Expand Down Expand Up @@ -678,24 +673,11 @@ func (c *isNullFunctionClass) getFunction(ctx context.Context, args []Expression
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
tc := args[0].GetType().ToClass()
var argTp evalTp
switch tc {
case types.ClassInt:
argTp = tpInt
case types.ClassDecimal:
argTp = tpDecimal
case types.ClassReal:
argTp = tpReal
default:
tp := args[0].GetType().Tp
if types.IsTypeTime(tp) {
argTp = tpDatetime
} else if tp == mysql.TypeDuration {
argTp = tpDuration
} else {
argTp = tpString
}
argTp := fieldTp2EvalTp(args[0].GetType())
if argTp == tpTimestamp {
argTp = tpDatetime
} else if argTp == tpJSON {
argTp = tpString
}
bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, argTp)
bf.tp.Flen = 1
Expand Down

0 comments on commit ae00bc2

Please sign in to comment.