Skip to content

Commit

Permalink
expression: fix wrong result type for greatest/least (pingcap#29408)
Browse files Browse the repository at this point in the history
  • Loading branch information
guo-shaoge authored Nov 7, 2021
1 parent 2a6d43a commit 4b11003
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 46 deletions.
169 changes: 135 additions & 34 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,16 @@ var (
_ builtinFunc = &builtinGreatestRealSig{}
_ builtinFunc = &builtinGreatestDecimalSig{}
_ builtinFunc = &builtinGreatestStringSig{}
_ builtinFunc = &builtinGreatestDurationSig{}
_ builtinFunc = &builtinGreatestTimeSig{}
_ builtinFunc = &builtinGreatestCmpStringAsTimeSig{}
_ builtinFunc = &builtinLeastIntSig{}
_ builtinFunc = &builtinLeastRealSig{}
_ builtinFunc = &builtinLeastDecimalSig{}
_ builtinFunc = &builtinLeastStringSig{}
_ builtinFunc = &builtinLeastTimeSig{}
_ builtinFunc = &builtinLeastDurationSig{}
_ builtinFunc = &builtinLeastCmpStringAsTimeSig{}
_ builtinFunc = &builtinIntervalIntSig{}
_ builtinFunc = &builtinIntervalRealSig{}

Expand Down Expand Up @@ -412,7 +416,7 @@ func ResolveType4Between(args [3]Expression) types.EvalType {
}

// resolveType4Extremum gets compare type for GREATEST and LEAST and BETWEEN (mainly for datetime).
func resolveType4Extremum(args []Expression) types.EvalType {
func resolveType4Extremum(args []Expression) (_ types.EvalType, cmpStringAsDatetime bool) {
aggType := aggregateType(args)

var temporalItem *types.FieldType
Expand All @@ -429,10 +433,11 @@ func resolveType4Extremum(args []Expression) types.EvalType {

if !types.IsTypeTemporal(aggType.Tp) && temporalItem != nil {
aggType.Tp = temporalItem.Tp
cmpStringAsDatetime = true
}
// TODO: String charset, collation checking are needed.
}
return aggType.EvalType()
return aggType.EvalType(), cmpStringAsDatetime
}

// unsupportedJSONComparison reports warnings while there is a JSON type in least/greatest function's arguments
Expand All @@ -454,12 +459,9 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime || tp == types.ETTimestamp {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETDuration {
tp, cmpStringAsDatetime := resolveType4Extremum(args)
if cmpStringAsDatetime {
// Args are temporal and string mixed, we cast all args as string and parse it to temporal mannualy to compare.
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
Expand All @@ -473,9 +475,6 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}
if cmpAsDatetime {
tp = types.ETDatetime
}
switch tp {
case types.ETInt:
sig = &builtinGreatestIntSig{bf}
Expand All @@ -487,8 +486,16 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
sig = &builtinGreatestDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestDecimal)
case types.ETString:
sig = &builtinGreatestStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestString)
if cmpStringAsDatetime {
sig = &builtinGreatestCmpStringAsTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestCmpStringAsTime)
} else {
sig = &builtinGreatestStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestString)
}
case types.ETDuration:
sig = &builtinGreatestDurationSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestDuration)
case types.ETDatetime, types.ETTimestamp:
sig = &builtinGreatestTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestTime)
Expand Down Expand Up @@ -630,19 +637,19 @@ func (b *builtinGreatestStringSig) evalString(row chunk.Row) (max string, isNull
return
}

type builtinGreatestTimeSig struct {
type builtinGreatestCmpStringAsTimeSig struct {
baseBuiltinFunc
}

func (b *builtinGreatestTimeSig) Clone() builtinFunc {
newSig := &builtinGreatestTimeSig{}
func (b *builtinGreatestCmpStringAsTimeSig) Clone() builtinFunc {
newSig := &builtinGreatestCmpStringAsTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalString evals a builtinGreatestTimeSig.
// evalString evals a builtinGreatestCmpStringAsTimeSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_greatest
func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (strRes string, isNull bool, err error) {
func (b *builtinGreatestCmpStringAsTimeSig) evalString(row chunk.Row) (strRes string, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalString(b.ctx, row)
Expand All @@ -665,6 +672,52 @@ func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (strRes string, isNul
return strRes, false, nil
}

type builtinGreatestTimeSig struct {
baseBuiltinFunc
}

func (b *builtinGreatestTimeSig) Clone() builtinFunc {
newSig := &builtinGreatestTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinGreatestTimeSig) evalTime(row chunk.Row) (res types.Time, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalTime(b.ctx, row)
if isNull || err != nil {
return types.ZeroTime, true, err
}
if i == 0 || v.Compare(res) > 0 {
res = v
}
}
return res, false, nil
}

type builtinGreatestDurationSig struct {
baseBuiltinFunc
}

func (b *builtinGreatestDurationSig) Clone() builtinFunc {
newSig := &builtinGreatestDurationSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinGreatestDurationSig) evalDuration(row chunk.Row) (res types.Duration, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalDuration(b.ctx, row)
if isNull || err != nil {
return types.Duration{}, true, err
}
if i == 0 || v.Compare(res) > 0 {
res = v
}
}
return res, false, nil
}

type leastFunctionClass struct {
baseFunctionClass
}
Expand All @@ -673,12 +726,9 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime || tp == types.ETTimestamp {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETDuration {
tp, cmpStringAsDatetime := resolveType4Extremum(args)
if cmpStringAsDatetime {
// Args are temporal and string mixed, we cast all args as string and parse it to temporal mannualy to compare.
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
Expand All @@ -692,9 +742,6 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if err != nil {
return nil, err
}
if cmpAsDatetime {
tp = types.ETDatetime
}
switch tp {
case types.ETInt:
sig = &builtinLeastIntSig{bf}
Expand All @@ -706,8 +753,16 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
sig = &builtinLeastDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastDecimal)
case types.ETString:
sig = &builtinLeastStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastString)
if cmpStringAsDatetime {
sig = &builtinLeastCmpStringAsTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastCmpStringAsTime)
} else {
sig = &builtinLeastStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastString)
}
case types.ETDuration:
sig = &builtinLeastDurationSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastDuration)
case types.ETDatetime, types.ETTimestamp:
sig = &builtinLeastTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastTime)
Expand Down Expand Up @@ -836,19 +891,19 @@ func (b *builtinLeastStringSig) evalString(row chunk.Row) (min string, isNull bo
return
}

type builtinLeastTimeSig struct {
type builtinLeastCmpStringAsTimeSig struct {
baseBuiltinFunc
}

func (b *builtinLeastTimeSig) Clone() builtinFunc {
newSig := &builtinLeastTimeSig{}
func (b *builtinLeastCmpStringAsTimeSig) Clone() builtinFunc {
newSig := &builtinLeastCmpStringAsTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalString evals a builtinLeastTimeSig.
// evalString evals a builtinLeastCmpStringAsTimeSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#functionleast
func (b *builtinLeastTimeSig) evalString(row chunk.Row) (strRes string, isNull bool, err error) {
func (b *builtinLeastCmpStringAsTimeSig) evalString(row chunk.Row) (strRes string, isNull bool, err error) {
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalString(b.ctx, row)
Expand All @@ -871,6 +926,52 @@ func (b *builtinLeastTimeSig) evalString(row chunk.Row) (strRes string, isNull b
return strRes, false, nil
}

type builtinLeastTimeSig struct {
baseBuiltinFunc
}

func (b *builtinLeastTimeSig) Clone() builtinFunc {
newSig := &builtinLeastTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinLeastTimeSig) evalTime(row chunk.Row) (res types.Time, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalTime(b.ctx, row)
if isNull || err != nil {
return types.ZeroTime, true, err
}
if i == 0 || v.Compare(res) < 0 {
res = v
}
}
return res, false, nil
}

type builtinLeastDurationSig struct {
baseBuiltinFunc
}

func (b *builtinLeastDurationSig) Clone() builtinFunc {
newSig := &builtinLeastDurationSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinLeastDurationSig) evalDuration(row chunk.Row) (res types.Duration, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalDuration(b.ctx, row)
if isNull || err != nil {
return types.Duration{}, true, err
}
if i == 0 || v.Compare(res) < 0 {
res = v
}
}
return res, false, nil
}

type intervalFunctionClass struct {
baseFunctionClass
}
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ func TestGreatestLeastFunc(t *testing.T) {
},
{
[]interface{}{duration, duration},
"12:59:59", "12:59:59", false, false,
duration, duration, false, false,
},
{
[]interface{}{"123", nil, "123"},
Expand Down
Loading

0 comments on commit 4b11003

Please sign in to comment.