Skip to content

Commit

Permalink
*: add GetTypeClass() function for Expression interface (pingcap#3321)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored and hanfei1991 committed May 25, 2017
1 parent a630a9a commit ed2555a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
5 changes: 5 additions & 0 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ func (col *Column) GetType() *types.FieldType {
return col.RetType
}

// GetTypeClass implements Expression interface.
func (col *Column) GetTypeClass() types.TypeClass {
return col.RetType.ToClass()
}

// Eval implements Expression interface.
func (col *Column) Eval(row []types.Datum) (types.Datum, error) {
return row[col.Index], nil
Expand Down
26 changes: 18 additions & 8 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ type Expression interface {
// GetType gets the type that the expression returns.
GetType() *types.FieldType

// GetTypeClass gets the TypeClass that the expression returns.
GetTypeClass() types.TypeClass

// Clone copies an expression totally.
Clone() Expression

Expand Down Expand Up @@ -134,8 +137,7 @@ func evalExprToInt(expr Expression, row []types.Datum, sc *variable.StatementCon
if val.IsNull() || err != nil {
return res, val.IsNull(), errors.Trace(err)
}
tc := expr.GetType().ToClass()
if tc == types.ClassInt {
if expr.GetTypeClass() == types.ClassInt {
return val.GetInt64(), false, nil
} else if IsHybridType(expr) {
res, err = val.ToInt64(sc)
Expand All @@ -150,8 +152,7 @@ func evalExprToReal(expr Expression, row []types.Datum, sc *variable.StatementCo
if val.IsNull() || err != nil {
return res, val.IsNull(), errors.Trace(err)
}
tc := expr.GetType().ToClass()
if tc == types.ClassReal {
if expr.GetTypeClass() == types.ClassReal {
return val.GetFloat64(), false, nil
} else if IsHybridType(expr) {
res, err = val.ToFloat64(sc)
Expand All @@ -166,8 +167,7 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen
if val.IsNull() || err != nil {
return res, val.IsNull(), errors.Trace(err)
}
tc := expr.GetType().ToClass()
if tc == types.ClassDecimal {
if expr.GetTypeClass() == types.ClassDecimal {
return val.GetMysqlDecimal(), false, nil
} else if IsHybridType(expr) {
res, err = val.ToDecimal(sc)
Expand All @@ -182,8 +182,7 @@ func evalExprToString(expr Expression, row []types.Datum, _ *variable.StatementC
if val.IsNull() || err != nil {
return res, val.IsNull(), errors.Trace(err)
}
tc := expr.GetType().ToClass()
if tc == types.ClassString {
if expr.GetTypeClass() == types.ClassString {
// We cannot use val.GetString() directly.
// For example, `Bit` is regarded as ClassString,
// while we can not use val.GetString() to get the value of a Bit variable,
Expand All @@ -196,6 +195,9 @@ func evalExprToString(expr Expression, row []types.Datum, _ *variable.StatementC

// evalExprToTime evaluates `expr` to TIME type.
func evalExprToTime(expr Expression, row []types.Datum, _ *variable.StatementContext) (res types.Time, isNull bool, err error) {
if IsHybridType(expr) {
return res, true, nil
}
val, err := expr.Eval(row)
if val.IsNull() || err != nil {
return res, val.IsNull(), errors.Trace(err)
Expand All @@ -210,6 +212,9 @@ func evalExprToTime(expr Expression, row []types.Datum, _ *variable.StatementCon

// evalExprToDuration evaluates `expr` to DURATION type.
func evalExprToDuration(expr Expression, row []types.Datum, _ *variable.StatementContext) (res types.Duration, isNull bool, err error) {
if IsHybridType(expr) {
return res, true, nil
}
val, err := expr.Eval(row)
if val.IsNull() || err != nil {
return res, val.IsNull(), errors.Trace(err)
Expand Down Expand Up @@ -266,6 +271,11 @@ func (c *Constant) GetType() *types.FieldType {
return c.RetType
}

// GetTypeClass implements Expression interface.
func (c *Constant) GetTypeClass() types.TypeClass {
return c.RetType.ToClass()
}

// Eval implements Expression interface.
func (c *Constant) Eval(_ []types.Datum) (types.Datum, error) {
return c.Value, nil
Expand Down
5 changes: 5 additions & 0 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ func (sf *ScalarFunction) GetType() *types.FieldType {
return sf.RetType
}

// GetTypeClass implements Expression interface.
func (sf *ScalarFunction) GetTypeClass() types.TypeClass {
return sf.RetType.ToClass()
}

// Equal implements Expression interface.
func (sf *ScalarFunction) Equal(e Expression, ctx context.Context) bool {
fun, ok := e.(*ScalarFunction)
Expand Down
2 changes: 1 addition & 1 deletion util/types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const (
UnspecifiedLength int = -1
)

// TypeClass classifies types, used for type inference.
// TypeClass classifies field types, used for type inference.
type TypeClass byte

// TypeClass values.
Expand Down

0 comments on commit ed2555a

Please sign in to comment.