Skip to content

Commit

Permalink
expression, types, plan: remove the usage of "TypeClass" completely (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
zz-jason authored and coocood committed Sep 27, 2017
1 parent 5200745 commit baee1bb
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 122 deletions.
26 changes: 12 additions & 14 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1285,23 +1285,21 @@ func (b *builtinCastJSONAsDurationSig) evalDuration(row []types.Datum) (res type
// BuildCastFunction builds a CAST ScalarFunction from the Expression.
func BuildCastFunction(expr Expression, tp *types.FieldType, ctx context.Context) *ScalarFunction {
var fc functionClass
switch tp.ToClass() {
case types.ClassInt:
switch tp.EvalType() {
case types.ETInt:
fc = &castAsIntFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ClassDecimal:
case types.ETDecimal:
fc = &castAsDecimalFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ClassReal:
case types.ETReal:
fc = &castAsRealFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ClassString:
if types.IsTypeTime(tp.Tp) {
fc = &castAsTimeFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
} else if tp.Tp == mysql.TypeDuration {
fc = &castAsDurationFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
} else if tp.Tp == mysql.TypeJSON {
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
} else {
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
}
case types.ETDatetime, types.ETTimestamp:
fc = &castAsTimeFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ETDuration:
fc = &castAsDurationFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ETJson:
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ETString:
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
}
f, _ := fc.getFunction(ctx, []Expression{expr})
return &ScalarFunction{
Expand Down
20 changes: 10 additions & 10 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1089,40 +1089,40 @@ func (s *testEvaluatorSuite) TestWrapWithCastAsTypesClasses(c *C) {
for _, t := range cases {
// Test wrapping with CastAsInt.
intExpr := WrapWithCastAsInt(t.expr, ctx)
c.Assert(intExpr.GetTypeClass(), Equals, types.ClassInt)
c.Assert(intExpr.GetType().EvalType(), Equals, types.ETInt)
_, ok := intExpr.(*ScalarFunction)
c.Assert(ok, Equals, t.expr.GetTypeClass() != types.ClassInt)
c.Assert(ok, Equals, t.expr.GetType().EvalType() != types.ETInt)
intRes, isNull, err := intExpr.EvalInt(t.row, sc)
c.Assert(err, IsNil)
c.Assert(isNull, Equals, false)
c.Assert(intRes, Equals, t.intRes)

// Test wrapping with CastAsReal.
realExpr := WrapWithCastAsReal(t.expr, ctx)
c.Assert(realExpr.GetTypeClass(), Equals, types.ClassReal)
c.Assert(realExpr.GetType().EvalType(), Equals, types.ETReal)
_, ok = realExpr.(*ScalarFunction)
c.Assert(ok, Equals, t.expr.GetTypeClass() != types.ClassReal)
c.Assert(ok, Equals, t.expr.GetType().EvalType() != types.ETReal)
realRes, isNull, err := realExpr.EvalReal(t.row, sc)
c.Assert(err, IsNil)
c.Assert(isNull, Equals, false)
c.Assert(realRes, Equals, t.realRes)

// Test wrapping with CastAsDecimal.
decExpr := WrapWithCastAsDecimal(t.expr, ctx)
c.Assert(decExpr.GetTypeClass(), Equals, types.ClassDecimal)
c.Assert(decExpr.GetType().EvalType(), Equals, types.ETDecimal)
_, ok = decExpr.(*ScalarFunction)
c.Assert(ok, Equals, t.expr.GetTypeClass() != types.ClassDecimal)
c.Assert(ok, Equals, t.expr.GetType().EvalType() != types.ETDecimal)
decRes, isNull, err := decExpr.EvalDecimal(t.row, sc)
c.Assert(err, IsNil)
c.Assert(isNull, Equals, false)
c.Assert(decRes.Compare(t.decRes), Equals, 0)

// Test wrapping with CastAsString.
strExpr := WrapWithCastAsString(t.expr, ctx)
c.Assert(strExpr.GetTypeClass(), Equals, types.ClassString)
c.Assert(strExpr.GetType().EvalType().IsStringKind(), IsTrue)
_, ok = strExpr.(*ScalarFunction)
exprTp := t.expr.GetType().Tp
c.Assert(ok, Equals, t.expr.GetTypeClass() != types.ClassString || types.IsTypeTime(exprTp) || exprTp == mysql.TypeDuration)
c.Assert(ok, Equals, !t.expr.GetType().EvalType().IsStringKind() || types.IsTypeTime(exprTp) || exprTp == mysql.TypeDuration)
strRes, isNull, err := strExpr.EvalString(t.row, sc)
c.Assert(err, IsNil)
c.Assert(isNull, Equals, false)
Expand All @@ -1133,7 +1133,7 @@ func (s *testEvaluatorSuite) TestWrapWithCastAsTypesClasses(c *C) {

// test cast unsigned int as string.
strExpr := WrapWithCastAsString(unsignedIntExpr, ctx)
c.Assert(strExpr.GetTypeClass(), Equals, types.ClassString)
c.Assert(strExpr.GetType().EvalType().IsStringKind(), IsTrue)
strRes, isNull, err := strExpr.EvalString([]types.Datum{types.NewUintDatum(math.MaxUint64)}, sc)
c.Assert(err, IsNil)
c.Assert(strRes, Equals, strconv.FormatUint(math.MaxUint64, 10))
Expand All @@ -1146,7 +1146,7 @@ func (s *testEvaluatorSuite) TestWrapWithCastAsTypesClasses(c *C) {

// test cast unsigned int as decimal.
decExpr := WrapWithCastAsDecimal(unsignedIntExpr, ctx)
c.Assert(decExpr.GetTypeClass(), Equals, types.ClassDecimal)
c.Assert(decExpr.GetType().EvalType(), Equals, types.ETDecimal)
decRes, isNull, err := decExpr.EvalDecimal([]types.Datum{types.NewUintDatum(uint64(1234))}, sc)
c.Assert(err, IsNil)
c.Assert(isNull, Equals, false)
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (c *caseWhenFunctionClass) getFunction(ctx context.Context, args []Expressi
decimal = 0
}
fieldTp.Decimal, fieldTp.Flen = decimal, flen
if fieldTp.ToClass() == types.ClassString && !isBinaryStr {
if fieldTp.EvalType().IsStringKind() && !isBinaryStr {
fieldTp.Charset, fieldTp.Collate = mysql.DefaultCharset, mysql.DefaultCollationName
}
// Set retType to BINARY(0) if all arguments are of type NULL.
Expand Down
8 changes: 4 additions & 4 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -3043,12 +3043,12 @@ func (c *timestampFunctionClass) getDefaultFsp(tp *types.FieldType) int {
tp.Tp == mysql.TypeTimestamp || tp.Tp == mysql.TypeNewDate {
return tp.Decimal
}
switch cls := tp.ToClass(); cls {
case types.ClassInt:
switch cls := tp.EvalType(); cls {
case types.ETInt:
return types.MinFsp
case types.ClassReal, types.ClassString:
case types.ETReal, types.ETDatetime, types.ETTimestamp, types.ETDuration, types.ETJson, types.ETString:
return types.MaxFsp
case types.ClassDecimal:
case types.ETDecimal:
if tp.Decimal < types.MaxFsp {
return tp.Decimal
}
Expand Down
5 changes: 0 additions & 5 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,6 @@ func (col *Column) GetType() *types.FieldType {
return col.RetType
}

// GetTypeClass implements Expression interface.
func (col *Column) GetTypeClass() types.TypeClass {
return col.RetType.ToClass()
}

// Eval implements Expression interface.
func (col *Column) Eval(row []types.Datum) (types.Datum, error) {
return row[col.Index], nil
Expand Down
5 changes: 0 additions & 5 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,6 @@ func (c *Constant) GetType() *types.FieldType {
return c.RetType
}

// GetTypeClass implements Expression interface.
func (c *Constant) GetTypeClass() types.TypeClass {
return c.RetType.ToClass()
}

// Eval implements Expression interface.
func (c *Constant) Eval(_ []types.Datum) (types.Datum, error) {
return c.Value, nil
Expand Down
3 changes: 0 additions & 3 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ type Expression interface {
// GetType gets the type that the expression returns.
GetType() *types.FieldType

// GetTypeClass gets the TypeClass that the expression returns.
GetTypeClass() types.TypeClass

// Clone copies an expression totally.
Clone() Expression

Expand Down
33 changes: 12 additions & 21 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,6 @@ func (sf *ScalarFunction) GetType() *types.FieldType {
return sf.RetType
}

// GetTypeClass implements Expression interface.
func (sf *ScalarFunction) GetTypeClass() types.TypeClass {
return sf.RetType.ToClass()
}

// Equal implements Expression interface.
func (sf *ScalarFunction) Equal(e Expression, ctx context.Context) bool {
fun, ok := e.(*ScalarFunction)
Expand Down Expand Up @@ -184,31 +179,27 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) {
res interface{}
isNull bool
)
tp := sf.GetType()
switch sf.GetTypeClass() {
case types.ClassInt:
switch tp, evalType := sf.GetType(), sf.GetType().EvalType(); evalType {
case types.ETInt:
var intRes int64
intRes, isNull, err = sf.EvalInt(row, sc)
if mysql.HasUnsignedFlag(tp.Flag) {
res = uint64(intRes)
} else {
res = intRes
}
case types.ClassReal:
case types.ETReal:
res, isNull, err = sf.EvalReal(row, sc)
case types.ClassDecimal:
case types.ETDecimal:
res, isNull, err = sf.EvalDecimal(row, sc)
case types.ClassString:
switch x := sf.GetType().Tp; x {
case mysql.TypeDatetime, mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeNewDate:
res, isNull, err = sf.EvalTime(row, sc)
case mysql.TypeDuration:
res, isNull, err = sf.EvalDuration(row, sc)
case mysql.TypeJSON:
res, isNull, err = sf.EvalJSON(row, sc)
default:
res, isNull, err = sf.EvalString(row, sc)
}
case types.ETDatetime, types.ETTimestamp:
res, isNull, err = sf.EvalTime(row, sc)
case types.ETDuration:
res, isNull, err = sf.EvalDuration(row, sc)
case types.ETJson:
res, isNull, err = sf.EvalJSON(row, sc)
case types.ETString:
res, isNull, err = sf.EvalString(row, sc)
}

if isNull || err != nil {
Expand Down
3 changes: 1 addition & 2 deletions plan/physical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,8 +794,7 @@ func compareTypeForOrder(lhs *types.FieldType, rhs *types.FieldType) bool {
if lhs.Tp != rhs.Tp {
return false
}
if lhs.ToClass() == types.ClassString &&
(lhs.Charset != rhs.Charset || lhs.Collate != rhs.Collate) {
if lhs.EvalType().IsStringKind() && (lhs.Charset != rhs.Charset || lhs.Collate != rhs.Collate) {
return false
}
return true
Expand Down
57 changes: 0 additions & 57 deletions util/types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,6 @@ const (
UnspecifiedLength int = -1
)

// TypeClass classifies field types, used for type inference.
type TypeClass byte

// TypeClass values.
const (
ClassString TypeClass = 0
ClassReal TypeClass = 1
ClassInt TypeClass = 2
ClassRow TypeClass = 3
ClassDecimal TypeClass = 4
)

// FieldType records field type information.
type FieldType struct {
Tp byte
Expand Down Expand Up @@ -129,20 +117,6 @@ func setTypeFlag(flag *uint, flagItem uint, on bool) {
}
}

// ToClass maps the field type to a type class.
func (ft *FieldType) ToClass() TypeClass {
switch ft.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear, mysql.TypeBit:
return ClassInt
case mysql.TypeNewDecimal:
return ClassDecimal
case mysql.TypeFloat, mysql.TypeDouble:
return ClassReal
default:
return ClassString
}
}

// EvalType gets the type in evaluation.
func (ft *FieldType) EvalType() EvalType {
switch ft.Tp {
Expand All @@ -165,37 +139,6 @@ func (ft *FieldType) EvalType() EvalType {
return ETString
}

func (tc TypeClass) String() string {
switch tc {
case ClassString:
return "ClassString"
case ClassReal:
return "ClassReal"
case ClassInt:
return "ClassInt"
case ClassDecimal:
return "ClassDecimal"
default:
return "ClassRow"
}
}

// ToType maps the type class to a type.
func (tc TypeClass) ToType() byte {
switch tc {
case ClassString:
return mysql.TypeVarString
case ClassReal:
return mysql.TypeDouble
case ClassInt:
return mysql.TypeLonglong
case ClassDecimal:
return mysql.TypeNewDecimal
default:
return mysql.TypeUnspecified
}
}

// Init initializes the FieldType data.
func (ft *FieldType) Init(tp byte) {
ft.Tp = tp
Expand Down

0 comments on commit baee1bb

Please sign in to comment.