Skip to content

Commit

Permalink
executor: fix explicitly insert null value into timestamp column (pin…
Browse files Browse the repository at this point in the history
…gcap#3646)

create table t (ts timestamp);
insert into t values (null);
This should insert a null rather than default value.
  • Loading branch information
tiancaiamao authored Jul 12, 2017
1 parent a0dfa04 commit 1e1d01a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 53 deletions.
120 changes: 68 additions & 52 deletions executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -855,15 +855,13 @@ func (e *InsertValues) getRowsSelect(cols []*table.Column) ([][]types.Datum, err

func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum, ignoreErr bool) ([]types.Datum, error) {
row := make([]types.Datum, len(e.Table.Cols()))
marked := make(map[int]struct{}, len(vals))
hasValue := make([]bool, len(e.Table.Cols()))
for i, v := range vals {
offset := cols[i].Offset
row[offset] = v
if !ignoreErr {
marked[offset] = struct{}{}
}
hasValue[offset] = true
}
err := e.initDefaultValues(row, marked, ignoreErr)
err := e.initDefaultValues(row, hasValue, ignoreErr)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -889,64 +887,36 @@ func (e *InsertValues) filterErr(err error, ignoreErr bool) error {
return nil
}

func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struct{}, ignoreErr bool) error {
func (e *InsertValues) initDefaultValues(row []types.Datum, hasValue []bool, ignoreErr bool) error {
var defaultValueCols []*table.Column
sc := e.ctx.GetSessionVars().StmtCtx
for i, c := range e.Table.Cols() {
// It's used for retry.
if mysql.HasAutoIncrementFlag(c.Flag) && row[i].IsNull() &&
e.ctx.GetSessionVars().RetryInfo.Retrying {
id, err := e.ctx.GetSessionVars().RetryInfo.GetCurrAutoIncrementID()
if err != nil {
return errors.Trace(err)
}
row[i].SetInt64(id)
}
if !row[i].IsNull() {
// Column value isn't nil and column isn't auto-increment, continue.
if !mysql.HasAutoIncrementFlag(c.Flag) {
continue
}
val, err := row[i].ToInt64(sc)
if e.filterErr(errors.Trace(err), ignoreErr) != nil {
return errors.Trace(err)
}
row[i].SetInt64(val)
if val != 0 || (e.ctx.GetSessionVars().SQLMode&mysql.ModeNoAutoValueOnZero) > 0 {
e.ctx.GetSessionVars().InsertID = uint64(val)
e.Table.RebaseAutoID(val, true)
continue
}
}
strictSQL := e.ctx.GetSessionVars().StrictSQLMode

// If the nil value is evaluated in insert list, we will use nil except auto increment column.
if _, ok := marked[i]; ok && !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) {
continue
for i, c := range e.Table.Cols() {
var needDefaultValue bool
if !hasValue[i] {
needDefaultValue = true
} else if mysql.HasNotNullFlag(c.Flag) && row[i].IsNull() && !strictSQL {
needDefaultValue = true
// TODO: Append Warning ErrColumnCantNull.
}

if mysql.HasAutoIncrementFlag(c.Flag) {
recordID, err := e.Table.AllocAutoID()
if err != nil {
return errors.Trace(err)
}
row[i].SetInt64(recordID)
// It's compatible with mysql. So it sets last insert id to the first row.
if e.currRow == 0 {
e.lastInsertID = uint64(recordID)
}
// It's used for retry.
if !e.ctx.GetSessionVars().RetryInfo.Retrying {
e.ctx.GetSessionVars().RetryInfo.AddAutoIncrementID(recordID)
}
} else {
needDefaultValue = false
}
if needDefaultValue {
var err error
row[i], err = table.GetColDefaultValue(e.ctx, c.ToInfo())
if e.filterErr(err, ignoreErr) != nil {
return errors.Trace(err)
}
defaultValueCols = append(defaultValueCols, c)
}

defaultValueCols = append(defaultValueCols, c)
// Adjust the value if this column has auto increment flag.
if mysql.HasAutoIncrementFlag(c.Flag) {
if err := e.adjustAutoIncrementDatum(row, i, c, ignoreErr); err != nil {
return errors.Trace(err)
}
}
}
if err := table.CastValues(e.ctx, row, defaultValueCols, ignoreErr); err != nil {
return errors.Trace(err)
Expand All @@ -955,6 +925,52 @@ func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struc
return nil
}

func (e *InsertValues) adjustAutoIncrementDatum(row []types.Datum, i int, c *table.Column, ignoreErr bool) error {
retryInfo := e.ctx.GetSessionVars().RetryInfo
if retryInfo.Retrying {
id, err := retryInfo.GetCurrAutoIncrementID()
if err != nil {
return errors.Trace(err)
}
row[i].SetInt64(id)
return nil
}

var err error
var recordID int64
if !row[i].IsNull() {
recordID, err = row[i].ToInt64(e.ctx.GetSessionVars().StmtCtx)
if e.filterErr(errors.Trace(err), ignoreErr) != nil {
return errors.Trace(err)
}
}
// Use the value if it's not null and not 0.
if recordID != 0 {
e.Table.RebaseAutoID(recordID, true)
e.ctx.GetSessionVars().InsertID = uint64(recordID)
row[i].SetInt64(recordID)
retryInfo.AddAutoIncrementID(recordID)
return nil
}

// Change NULL to auto id.
// Change value 0 to auto id, if NoAutoValueOnZero SQL mode is not set.
if row[i].IsNull() || e.ctx.GetSessionVars().SQLMode&mysql.ModeNoAutoValueOnZero == 0 {
recordID, err = e.Table.AllocAutoID()
if e.filterErr(errors.Trace(err), ignoreErr) != nil {
return errors.Trace(err)
}
// It's compatible with mysql. So it sets last insert id to the first row.
if e.currRow == 0 {
e.lastInsertID = uint64(recordID)
}
}

row[i].SetInt64(recordID)
retryInfo.AddAutoIncrementID(recordID)
return nil
}

// onDuplicateUpdate updates the duplicate row.
// TODO: Report rows affected and last insert id.
func (e *InsertExec) onDuplicateUpdate(row []types.Datum, h int64, cols []*expression.Assignment) error {
Expand Down
16 changes: 16 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1006,3 +1006,19 @@ func (s *testSuite) TestBatchInsert(c *C) {
r = tk.MustQuery("select count(*) from batch_insert;")
r.Check(testkit.Rows("320"))
}

func (s *testSuite) TestNullDefault(c *C) {
defer func() {
s.cleanEnv(c)
testleak.AfterTest(c)()
}()
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test; drop table if exists test_null_default;")
tk.MustExec("set timestamp = 1234")
tk.MustExec("set time_zone = '+08:00'")
tk.MustExec("create table test_null_default (ts timestamp null default current_timestamp)")
tk.MustExec("insert into test_null_default values (null)")
tk.MustQuery("select * from test_null_default").Check(testkit.Rows("<nil>"))
tk.MustExec("insert into test_null_default values ()")
tk.MustQuery("select * from test_null_default").Check(testkit.Rows("<nil>", "1970-01-01 08:20:34"))
}
2 changes: 1 addition & 1 deletion table/tables/tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func countEntriesWithPrefix(ctx context.Context, prefix []byte) (int, error) {
}

func (ts *testSuite) TestTypes(c *C) {
_, err := ts.se.Execute("CREATE TABLE test.t (c1 tinyint, c2 smallint, c3 int, c4 bigint, c5 text, c6 blob, c7 varchar(64), c8 time, c9 timestamp not null default CURRENT_TIMESTAMP, c10 decimal(10,1))")
_, err := ts.se.Execute("CREATE TABLE test.t (c1 tinyint, c2 smallint, c3 int, c4 bigint, c5 text, c6 blob, c7 varchar(64), c8 time, c9 timestamp null default CURRENT_TIMESTAMP, c10 decimal(10,1))")
c.Assert(err, IsNil)
ctx := ts.se.(context.Context)
dom := sessionctx.GetDomain(ctx)
Expand Down

0 comments on commit 1e1d01a

Please sign in to comment.