Skip to content

Commit

Permalink
expression: remove type assertion on types.DatumRow. (pingcap#5005)
Browse files Browse the repository at this point in the history
  • Loading branch information
coocood authored and zz-jason committed Nov 7, 2017
1 parent 5a55d8b commit 1aa2a69
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 57 deletions.
2 changes: 1 addition & 1 deletion executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,7 @@ func (e *InsertExec) onDuplicateUpdate(row []types.Datum, h int64, cols []*expre
}

// See http://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
e.ctx.GetSessionVars().CurrInsertValues = row
e.ctx.GetSessionVars().CurrInsertValues = types.DatumRow(row)

// evaluate assignment
assignFlag := make([]bool, len(e.Table.WritableCols()))
Expand Down
72 changes: 35 additions & 37 deletions expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,12 @@ func (b *builtinValuesIntSig) evalInt(_ types.Row) (int64, bool, error) {
if values == nil {
return 0, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetInt64(), row[b.offset].IsNull(), nil
row := values.(types.Row)
if b.offset < row.Len() {
val, isNull := row.GetInt64(b.offset)
return val, isNull, nil
}
return 0, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
return 0, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", row.Len(), b.offset)
}

type builtinValuesRealSig struct {
Expand All @@ -478,11 +479,12 @@ func (b *builtinValuesRealSig) evalReal(_ types.Row) (float64, bool, error) {
if values == nil {
return 0, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetFloat64(), row[b.offset].IsNull(), nil
row := values.(types.Row)
if b.offset < row.Len() {
val, isNull := row.GetFloat64(b.offset)
return val, isNull, nil
}
return 0, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
return 0, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", row.Len(), b.offset)
}

type builtinValuesDecimalSig struct {
Expand All @@ -498,14 +500,12 @@ func (b *builtinValuesDecimalSig) evalDecimal(_ types.Row) (*types.MyDecimal, bo
if values == nil {
return nil, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
if row[b.offset].IsNull() {
return nil, true, nil
}
return row[b.offset].GetMysqlDecimal(), false, nil
row := values.(types.Row)
if b.offset < row.Len() {
val, isNull := row.GetMyDecimal(b.offset)
return val, isNull, nil
}
return nil, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
return nil, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", row.Len(), b.offset)
}

type builtinValuesStringSig struct {
Expand All @@ -521,11 +521,12 @@ func (b *builtinValuesStringSig) evalString(_ types.Row) (string, bool, error) {
if values == nil {
return "", true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetString(), row[b.offset].IsNull(), nil
row := values.(types.Row)
if b.offset < row.Len() {
val, isNull := row.GetString(b.offset)
return val, isNull, nil
}
return "", true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
return "", true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", row.Len(), b.offset)
}

type builtinValuesTimeSig struct {
Expand All @@ -541,14 +542,12 @@ func (b *builtinValuesTimeSig) evalTime(_ types.Row) (types.Time, bool, error) {
if values == nil {
return types.Time{}, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
if row[b.offset].IsNull() {
return types.Time{}, true, nil
}
return row[b.offset].GetMysqlTime(), false, nil
row := values.(types.Row)
if b.offset < row.Len() {
val, isNull := row.GetTime(b.offset)
return val, isNull, nil
}
return types.Time{}, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
return types.Time{}, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", row.Len(), b.offset)
}

type builtinValuesDurationSig struct {
Expand All @@ -564,11 +563,12 @@ func (b *builtinValuesDurationSig) evalDuration(_ types.Row) (types.Duration, bo
if values == nil {
return types.Duration{}, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetMysqlDuration(), row[b.offset].IsNull(), nil
row := values.(types.Row)
if b.offset < row.Len() {
val, isNull := row.GetDuration(b.offset)
return val, isNull, nil
}
return types.Duration{}, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
return types.Duration{}, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", row.Len(), b.offset)
}

type builtinValuesJSONSig struct {
Expand All @@ -584,14 +584,12 @@ func (b *builtinValuesJSONSig) evalJSON(_ types.Row) (json.JSON, bool, error) {
if values == nil {
return json.JSON{}, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
if row[b.offset].IsNull() {
return json.JSON{}, true, nil
}
return row[b.offset].GetMysqlJSON(), false, nil
row := values.(types.Row)
if b.offset < row.Len() {
val, isNull := row.GetJSON(b.offset)
return val, isNull, nil
}
return json.JSON{}, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
return json.JSON{}, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", row.Len(), b.offset)
}

type bitCountFunctionClass struct {
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ func (s *testEvaluatorSuite) TestValues(c *C) {
c.Assert(err, IsNil)
_, err = evalBuiltinFunc(sig, nil)
c.Assert(err.Error(), Equals, "Session current insert values is nil")
s.ctx.GetSessionVars().CurrInsertValues = types.MakeDatums("1")
s.ctx.GetSessionVars().CurrInsertValues = types.DatumRow(types.MakeDatums("1"))
_, err = evalBuiltinFunc(sig, nil)
c.Assert(err.Error(), Equals, fmt.Sprintf("Session current insert values len %d and column's offset %v don't match", 1, 1))
currInsertValues := types.MakeDatums("1", "2")
s.ctx.GetSessionVars().CurrInsertValues = currInsertValues
s.ctx.GetSessionVars().CurrInsertValues = types.DatumRow(currInsertValues)
ret, err := evalBuiltinFunc(sig, nil)
c.Assert(err, IsNil)
cmp, err := ret.CompareDatum(nil, &currInsertValues[1])
Expand Down
45 changes: 28 additions & 17 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -194,48 +195,58 @@ func (col *Column) GetType() *types.FieldType {

// Eval implements Expression interface.
func (col *Column) Eval(row types.Row) (types.Datum, error) {
return row.(types.DatumRow)[col.Index], nil
return row.GetDatum(col.Index, col.RetType), nil
}

// EvalInt returns int representation of Column.
func (col *Column) EvalInt(row types.Row, sc *variable.StatementContext) (int64, bool, error) {
val := &row.(types.DatumRow)[col.Index]
if val.IsNull() {
return 0, true, nil
}
if col.GetType().Hybrid() {
val := row.GetDatum(col.Index, col.RetType)
if val.IsNull() {
return 0, true, nil
}
res, err := val.ToInt64(sc)
return res, err != nil, errors.Trace(err)
}
return val.GetInt64(), false, nil
val, isNull := row.GetInt64(col.Index)
return val, isNull, nil
}

// EvalReal returns real representation of Column.
func (col *Column) EvalReal(row types.Row, sc *variable.StatementContext) (float64, bool, error) {
val := &row.(types.DatumRow)[col.Index]
if val.IsNull() {
return 0, true, nil
}
if col.GetType().Hybrid() {
val := row.GetDatum(col.Index, col.RetType)
if val.IsNull() {
return 0, true, nil
}
res, err := val.ToFloat64(sc)
return res, err != nil, errors.Trace(err)
}
return val.GetFloat64(), false, nil
if col.GetType().Tp == mysql.TypeFloat {
val, isNull := row.GetFloat32(col.Index)
return float64(val), isNull, nil
}
val, isNull := row.GetFloat64(col.Index)
return val, isNull, nil
}

// EvalString returns string representation of Column.
func (col *Column) EvalString(row types.Row, sc *variable.StatementContext) (string, bool, error) {
val := &row.(types.DatumRow)[col.Index]
if val.IsNull() {
return "", true, nil
if col.GetType().Hybrid() {
val := row.GetDatum(col.Index, col.RetType)
if val.IsNull() {
return "", true, nil
}
res, err := val.ToString()
return res, err != nil, errors.Trace(err)
}
res, err := val.ToString()
return res, err != nil, errors.Trace(err)
val, isNull := row.GetString(col.Index)
return val, isNull, nil
}

// EvalDecimal returns decimal representation of Column.
func (col *Column) EvalDecimal(row types.Row, sc *variable.StatementContext) (*types.MyDecimal, bool, error) {
val := &row.(types.DatumRow)[col.Index]
val := row.GetDatum(col.Index, col.RetType)
if val.IsNull() {
return nil, true, nil
}
Expand Down
16 changes: 16 additions & 0 deletions types/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import "github.com/pingcap/tidb/types/json"

// Row is an interface to read columns values.
type Row interface {
// Len returns the number of values in the row.
Len() int

// GetInt64 returns the int64 value and isNull with the colIdx.
GetInt64(colIdx int) (val int64, isNull bool)
Expand Down Expand Up @@ -53,13 +55,22 @@ type Row interface {

// GetJSON returns the JSON value and isNull with the colIdx.
GetJSON(colIdx int) (json.JSON, bool)

// GetDatum returns a Datum with the colIdx and field type.
// This method is provided for convenience, direct type methods are preferred for better performance.
GetDatum(colIdx int, tp *FieldType) Datum
}

var _ Row = DatumRow{}

// DatumRow is a slice of Datum, implements Row interface.
type DatumRow []Datum

// Len implements Row interface.
func (dr DatumRow) Len() int {
return len(dr)
}

// GetInt64 implements Row interface.
func (dr DatumRow) GetInt64(colIdx int) (int64, bool) {
datum := dr[colIdx]
Expand Down Expand Up @@ -167,3 +178,8 @@ func (dr DatumRow) GetJSON(colIdx int) (json.JSON, bool) {
}
return dr[colIdx].GetMysqlJSON(), false
}

// GetDatum implements Row interface.
func (dr DatumRow) GetDatum(colIdx int, tp *FieldType) Datum {
return dr[colIdx]
}
74 changes: 74 additions & 0 deletions util/chunk/chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ package chunk
import (
"unsafe"

"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/hack"
)

var _ types.Row = Row{}

// Chunk stores multiple rows of data in Apache Arrow format.
// See https://arrow.apache.org/docs/memory_layout.html
// Values are appended in compact format and can be directly accessed without decoding.
Expand Down Expand Up @@ -297,6 +300,11 @@ type Row struct {
idx int
}

// Len returns the number of values in the row.
func (r Row) Len() int {
return r.c.NumCols()
}

// GetInt64 returns the int64 value and isNull with the colIdx.
func (r Row) GetInt64(colIdx int) (int64, bool) {
col := r.c.columns[colIdx]
Expand Down Expand Up @@ -383,3 +391,69 @@ func (r Row) GetJSON(colIdx int) (json.JSON, bool) {
j, ok := col.ifaces[r.idx].(json.JSON)
return j, !ok
}

// GetDatum implements the types.Row interface.
func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum {
var d types.Datum
switch tp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
if mysql.HasUnsignedFlag(tp.Flag) {
val, isNull := r.GetUint64(colIdx)
if !isNull {
d.SetUint64(val)
}
} else {
val, isNull := r.GetInt64(colIdx)
if !isNull {
d.SetInt64(val)
}
}
case mysql.TypeFloat:
val, isNull := r.GetFloat32(colIdx)
if !isNull {
d.SetFloat32(val)
}
case mysql.TypeDouble:
val, isNull := r.GetFloat64(colIdx)
if !isNull {
d.SetFloat64(val)
}
case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString,
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
val, isNull := r.GetBytes(colIdx)
if !isNull {
d.SetBytes(val)
}
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
val, isNull := r.GetTime(colIdx)
if !isNull {
d.SetMysqlTime(val)
}
case mysql.TypeDuration:
val, isNull := r.GetDuration(colIdx)
if !isNull {
d.SetMysqlDuration(val)
}
case mysql.TypeNewDecimal:
val, IsNull := r.GetMyDecimal(colIdx)
if !IsNull {
d.SetMysqlDecimal(val)
}
case mysql.TypeEnum:
val, isNull := r.GetEnum(colIdx)
if !isNull {
d.SetMysqlEnum(val)
}
case mysql.TypeSet:
val, isNull := r.GetSet(colIdx)
if !isNull {
d.SetMysqlSet(val)
}
case mysql.TypeJSON:
val, isNull := r.GetJSON(colIdx)
if !isNull {
d.SetMysqlJSON(val)
}
}
return d
}

0 comments on commit 1aa2a69

Please sign in to comment.