Skip to content

Commit

Permalink
*: refine the attribute definition of types.Time and types.Dur… (ping…
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored and alivxxx committed Aug 14, 2019
1 parent 44292d0 commit adb3071
Show file tree
Hide file tree
Showing 33 changed files with 333 additions and 323 deletions.
6 changes: 3 additions & 3 deletions bindinfo/bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (s *testSuite) TestGlobalBinding(c *C) {
metrics.BindTotalGauge.WithLabelValues(metrics.ScopeGlobal, bindinfo.Using).Write(pb)
c.Assert(pb.GetGauge().GetValue(), Equals, float64(1))
metrics.BindMemoryUsage.WithLabelValues(metrics.ScopeGlobal, bindinfo.Using).Write(pb)
c.Assert(pb.GetGauge().GetValue(), Equals, float64(161))
c.Assert(pb.GetGauge().GetValue(), Equals, float64(121))

sql, hash := parser.NormalizeDigest("select * from t where i > ?")

Expand Down Expand Up @@ -221,7 +221,7 @@ func (s *testSuite) TestGlobalBinding(c *C) {
c.Assert(pb.GetGauge().GetValue(), Equals, float64(0))
metrics.BindMemoryUsage.WithLabelValues(metrics.ScopeGlobal, bindinfo.Using).Write(pb)
// From newly created global bind handle.
c.Assert(pb.GetGauge().GetValue(), Equals, float64(161))
c.Assert(pb.GetGauge().GetValue(), Equals, float64(121))

bindHandle = bindinfo.NewBindHandle(tk.Se)
err = bindHandle.Update(true)
Expand Down Expand Up @@ -268,7 +268,7 @@ func (s *testSuite) TestSessionBinding(c *C) {
metrics.BindTotalGauge.WithLabelValues(metrics.ScopeSession, bindinfo.Using).Write(pb)
c.Assert(pb.GetGauge().GetValue(), Equals, float64(1))
metrics.BindMemoryUsage.WithLabelValues(metrics.ScopeSession, bindinfo.Using).Write(pb)
c.Assert(pb.GetGauge().GetValue(), Equals, float64(161))
c.Assert(pb.GetGauge().GetValue(), Equals, float64(121))

handle := tk.Se.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle)
bindData := handle.GetBindRecord("select * from t where i > ?", "test")
Expand Down
8 changes: 4 additions & 4 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value in
if value != nil && ctx.GetSessionVars().SQLMode.HasNoZeroDateMode() &&
ctx.GetSessionVars().SQLMode.HasStrictMode() && types.IsTypeTime(col.Tp) {
if vv, ok := value.(string); ok {
timeValue, err := expression.GetTimeValue(ctx, vv, col.Tp, col.Decimal)
timeValue, err := expression.GetTimeValue(ctx, vv, col.Tp, int8(col.Decimal))
if err != nil {
return hasDefaultValue, value, errors.Trace(err)
}
Expand All @@ -455,7 +455,7 @@ func convertTimestampDefaultValToUTC(ctx sessionctx.Context, defaultVal interfac
}
if vv, ok := defaultVal.(string); ok {
if vv != types.ZeroDatetimeStr && strings.ToUpper(vv) != strings.ToUpper(ast.CurrentTimestamp) {
t, err := types.ParseTime(ctx.GetSessionVars().StmtCtx, vv, col.Tp, col.Decimal)
t, err := types.ParseTime(ctx.GetSessionVars().StmtCtx, vv, col.Tp, int8(col.Decimal))
if err != nil {
return defaultVal, errors.Trace(err)
}
Expand Down Expand Up @@ -636,7 +636,7 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption
}
}
}
vd, err := expression.GetTimeValue(ctx, c.Expr, tp, fsp)
vd, err := expression.GetTimeValue(ctx, c.Expr, tp, int8(fsp))
value := vd.GetValue()
if err != nil {
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(colName)
Expand Down Expand Up @@ -951,7 +951,7 @@ func checkColumnAttributes(colName string, tp *types.FieldType) error {
return types.ErrMBiggerThanD.GenWithStackByArgs(colName)
}
case mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp:
if tp.Decimal != types.UnspecifiedFsp && (tp.Decimal < types.MinFsp || tp.Decimal > types.MaxFsp) {
if tp.Decimal != int(types.UnspecifiedFsp) && (tp.Decimal < int(types.MinFsp) || tp.Decimal > int(types.MaxFsp)) {
return types.ErrTooBigPrecision.GenWithStackByArgs(tp.Decimal, colName, types.MaxFsp)
}
}
Expand Down
2 changes: 1 addition & 1 deletion executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h int64, oldData
// 4. Fill values into on-update-now fields, only if they are really changed.
for i, col := range t.Cols() {
if mysql.HasOnUpdateNowFlag(col.Flag) && !modified[i] && !onUpdateSpecified[i] {
if v, err := expression.GetTimeValue(sctx, strings.ToUpper(ast.CurrentTimestamp), col.Tp, col.Decimal); err == nil {
if v, err := expression.GetTimeValue(sctx, strings.ToUpper(ast.CurrentTimestamp), col.Tp, int8(col.Decimal)); err == nil {
newData[i] = v
modified[i] = true
} else {
Expand Down
6 changes: 3 additions & 3 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,21 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, args []Expression, retType
fieldType = &types.FieldType{
Tp: mysql.TypeDatetime,
Flen: mysql.MaxDatetimeWidthWithFsp,
Decimal: types.MaxFsp,
Decimal: int(types.MaxFsp),
Flag: mysql.BinaryFlag,
}
case types.ETTimestamp:
fieldType = &types.FieldType{
Tp: mysql.TypeTimestamp,
Flen: mysql.MaxDatetimeWidthWithFsp,
Decimal: types.MaxFsp,
Decimal: int(types.MaxFsp),
Flag: mysql.BinaryFlag,
}
case types.ETDuration:
fieldType = &types.FieldType{
Tp: mysql.TypeDuration,
Flen: mysql.MaxDurationWidthWithFsp,
Decimal: types.MaxFsp,
Decimal: int(types.MaxFsp),
Flag: mysql.BinaryFlag,
}
case types.ETJson:
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ func setFlenDecimal4RealOrDecimal(retTp, a, b *types.FieldType, isReal bool) {

func (c *arithmeticDivideFunctionClass) setType4DivDecimal(retTp, a, b *types.FieldType) {
var deca, decb = a.Decimal, b.Decimal
if deca == types.UnspecifiedFsp {
if deca == int(types.UnspecifiedFsp) {
deca = 0
}
if decb == types.UnspecifiedFsp {
if decb == int(types.UnspecifiedFsp) {
decb = 0
}
retTp.Decimal = deca + precIncrement
Expand Down
34 changes: 17 additions & 17 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func (b *builtinCastIntAsTimeSig) evalTime(row chunk.Row) (res types.Time, isNul
if isNull || err != nil {
return res, isNull, err
}
res, err = types.ParseTimeFromNum(b.ctx.GetSessionVars().StmtCtx, val, b.tp.Tp, b.tp.Decimal)
res, err = types.ParseTimeFromNum(b.ctx.GetSessionVars().StmtCtx, val, b.tp.Tp, int8(b.tp.Decimal))
if err != nil {
return types.Time{}, true, handleInvalidTimeError(b.ctx, err)
}
Expand All @@ -566,7 +566,7 @@ func (b *builtinCastIntAsDurationSig) evalDuration(row chunk.Row) (res types.Dur
if isNull || err != nil {
return res, isNull, err
}
dur, err := types.NumberToDuration(val, b.tp.Decimal)
dur, err := types.NumberToDuration(val, int8(b.tp.Decimal))
if err != nil {
if types.ErrOverflow.Equal(err) {
err = b.ctx.GetSessionVars().StmtCtx.HandleOverflow(err, err)
Expand Down Expand Up @@ -821,7 +821,7 @@ func (b *builtinCastRealAsTimeSig) evalTime(row chunk.Row) (types.Time, bool, er
return types.Time{}, true, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err := types.ParseTime(sc, strconv.FormatFloat(val, 'f', -1, 64), b.tp.Tp, b.tp.Decimal)
res, err := types.ParseTime(sc, strconv.FormatFloat(val, 'f', -1, 64), b.tp.Tp, int8(b.tp.Decimal))
if err != nil {
return types.Time{}, true, handleInvalidTimeError(b.ctx, err)
}
Expand All @@ -847,7 +847,7 @@ func (b *builtinCastRealAsDurationSig) evalDuration(row chunk.Row) (res types.Du
if isNull || err != nil {
return res, isNull, err
}
res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, strconv.FormatFloat(val, 'f', -1, 64), b.tp.Decimal)
res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, strconv.FormatFloat(val, 'f', -1, 64), int8(b.tp.Decimal))
return res, false, err
}

Expand Down Expand Up @@ -978,7 +978,7 @@ func (b *builtinCastDecimalAsTimeSig) evalTime(row chunk.Row) (res types.Time, i
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ParseTimeFromFloatString(sc, string(val.ToString()), b.tp.Tp, b.tp.Decimal)
res, err = types.ParseTimeFromFloatString(sc, string(val.ToString()), b.tp.Tp, int8(b.tp.Decimal))
if err != nil {
return types.Time{}, true, handleInvalidTimeError(b.ctx, err)
}
Expand All @@ -1004,7 +1004,7 @@ func (b *builtinCastDecimalAsDurationSig) evalDuration(row chunk.Row) (res types
if isNull || err != nil {
return res, true, err
}
res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(val.ToString()), b.tp.Decimal)
res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(val.ToString()), int8(b.tp.Decimal))
if types.ErrTruncatedWrongVal.Equal(err) {
err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err)
// ZeroDuration of error ErrTruncatedWrongVal needs to be considered NULL.
Expand Down Expand Up @@ -1186,7 +1186,7 @@ func (b *builtinCastStringAsTimeSig) evalTime(row chunk.Row) (res types.Time, is
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ParseTime(sc, val, b.tp.Tp, b.tp.Decimal)
res, err = types.ParseTime(sc, val, b.tp.Tp, int8(b.tp.Decimal))
if err != nil {
return types.Time{}, true, handleInvalidTimeError(b.ctx, err)
}
Expand All @@ -1212,7 +1212,7 @@ func (b *builtinCastStringAsDurationSig) evalDuration(row chunk.Row) (res types.
if isNull || err != nil {
return res, isNull, err
}
res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, val, b.tp.Decimal)
res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, val, int8(b.tp.Decimal))
if types.ErrTruncatedWrongVal.Equal(err) {
sc := b.ctx.GetSessionVars().StmtCtx
err = sc.HandleTruncate(err)
Expand Down Expand Up @@ -1244,7 +1244,7 @@ func (b *builtinCastTimeAsTimeSig) evalTime(row chunk.Row) (res types.Time, isNu
if res, err = res.Convert(sc, b.tp.Tp); err != nil {
return types.Time{}, true, handleInvalidTimeError(b.ctx, err)
}
res, err = res.RoundFrac(sc, b.tp.Decimal)
res, err = res.RoundFrac(sc, int8(b.tp.Decimal))
if b.tp.Tp == mysql.TypeDate {
// Truncate hh:mm:ss part if the type is Date.
res.Time = types.FromDate(res.Time.Year(), res.Time.Month(), res.Time.Day(), 0, 0, 0, 0)
Expand Down Expand Up @@ -1358,7 +1358,7 @@ func (b *builtinCastTimeAsDurationSig) evalDuration(row chunk.Row) (res types.Du
if err != nil {
return res, false, err
}
res, err = res.RoundFrac(b.tp.Decimal)
res, err = res.RoundFrac(int8(b.tp.Decimal))
return res, false, err
}

Expand All @@ -1377,7 +1377,7 @@ func (b *builtinCastDurationAsDurationSig) evalDuration(row chunk.Row) (res type
if isNull || err != nil {
return res, isNull, err
}
res, err = res.RoundFrac(b.tp.Decimal)
res, err = res.RoundFrac(int8(b.tp.Decimal))
return res, false, err
}

Expand Down Expand Up @@ -1505,7 +1505,7 @@ func (b *builtinCastDurationAsTimeSig) evalTime(row chunk.Row) (res types.Time,
if err != nil {
return types.Time{}, true, handleInvalidTimeError(b.ctx, err)
}
res, err = res.RoundFrac(sc, b.tp.Decimal)
res, err = res.RoundFrac(sc, int8(b.tp.Decimal))
return res, false, err
}

Expand Down Expand Up @@ -1625,7 +1625,7 @@ func (b *builtinCastJSONAsTimeSig) evalTime(row chunk.Row) (res types.Time, isNu
return res, false, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ParseTime(sc, s, b.tp.Tp, b.tp.Decimal)
res, err = types.ParseTime(sc, s, b.tp.Tp, int8(b.tp.Decimal))
if err != nil {
return types.Time{}, true, handleInvalidTimeError(b.ctx, err)
}
Expand Down Expand Up @@ -1655,7 +1655,7 @@ func (b *builtinCastJSONAsDurationSig) evalDuration(row chunk.Row) (res types.Du
if err != nil {
return res, false, err
}
res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, s, b.tp.Decimal)
res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, s, int8(b.tp.Decimal))
if types.ErrTruncatedWrongVal.Equal(err) {
sc := b.ctx.GetSessionVars().StmtCtx
err = sc.HandleTruncate(err)
Expand Down Expand Up @@ -1793,7 +1793,7 @@ func WrapWithCastAsString(ctx sessionctx.Context, expr Expression) Expression {
// into consideration, so we set `expr.GetType().Flen + 2` as the `argLen`.
// Since the length of float and double is not accurate, we do not handle
// them.
if exprTp.Tp == mysql.TypeNewDecimal && argLen != types.UnspecifiedFsp {
if exprTp.Tp == mysql.TypeNewDecimal && argLen != int(types.UnspecifiedFsp) {
argLen += 2
}
if exprTp.EvalType() == types.ETInt {
Expand All @@ -1818,7 +1818,7 @@ func WrapWithCastAsTime(ctx sessionctx.Context, expr Expression, tp *types.Field
case mysql.TypeDatetime, mysql.TypeTimestamp, mysql.TypeDate, mysql.TypeDuration:
tp.Decimal = x.Decimal
default:
tp.Decimal = types.MaxFsp
tp.Decimal = int(types.MaxFsp)
}
switch tp.Tp {
case mysql.TypeDate:
Expand All @@ -1844,7 +1844,7 @@ func WrapWithCastAsDuration(ctx sessionctx.Context, expr Expression) Expression
case mysql.TypeDatetime, mysql.TypeTimestamp, mysql.TypeDate:
tp.Decimal = x.Decimal
default:
tp.Decimal = types.MaxFsp
tp.Decimal = int(types.MaxFsp)
}
tp.Flen = mysql.MaxDurationWidthNoFsp
if tp.Decimal > 0 {
Expand Down
14 changes: 7 additions & 7 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
var sig builtinFunc

durationColumn := &Column{RetType: types.NewFieldType(mysql.TypeDuration), Index: 0}
durationColumn.RetType.Decimal = types.DefaultFsp
durationColumn.RetType.Decimal = int(types.DefaultFsp)
// Test cast as Decimal.
castToDecCases := []struct {
before *Column
Expand Down Expand Up @@ -805,7 +805,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
for i, t := range castToTimeCases {
args := []Expression{t.before}
tp := types.NewFieldType(mysql.TypeDatetime)
tp.Decimal = types.DefaultFsp
tp.Decimal = int(types.DefaultFsp)
timeFunc := newBaseBuiltinFunc(ctx, args)
timeFunc.tp = tp
switch i {
Expand Down Expand Up @@ -834,7 +834,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
before *Column
after types.Time
row chunk.MutRow
fsp int
fsp int8
tp byte
}{
// cast real as Time(0).
Expand Down Expand Up @@ -889,7 +889,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
for i, t := range castToTimeCases2 {
args := []Expression{t.before}
tp := types.NewFieldType(t.tp)
tp.Decimal = t.fsp
tp.Decimal = int(t.fsp)
timeFunc := newBaseBuiltinFunc(ctx, args)
timeFunc.tp = tp
switch i {
Expand All @@ -912,7 +912,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
resAfter := t.after.String()
if t.fsp > 0 {
resAfter += "."
for i := 0; i < t.fsp; i++ {
for i := 0; i < int(t.fsp); i++ {
resAfter += "0"
}
}
Expand Down Expand Up @@ -970,7 +970,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
for i, t := range castToDurationCases {
args := []Expression{t.before}
tp := types.NewFieldType(mysql.TypeDuration)
tp.Decimal = types.DefaultFsp
tp.Decimal = int(types.DefaultFsp)
durationFunc := newBaseBuiltinFunc(ctx, args)
durationFunc.tp = tp
switch i {
Expand Down Expand Up @@ -1144,7 +1144,7 @@ func (s *testEvaluatorSuite) TestWrapWithCastAsTypesClasses(c *C) {
ctx := s.ctx

durationColumn0 := &Column{RetType: types.NewFieldType(mysql.TypeDuration), Index: 0}
durationColumn0.RetType.Decimal = types.DefaultFsp
durationColumn0.RetType.Decimal = int(types.DefaultFsp)
durationColumn3 := &Column{RetType: types.NewFieldType(mysql.TypeDuration), Index: 0}
durationColumn3.RetType.Decimal = 3
cases := []struct {
Expand Down
Loading

0 comments on commit adb3071

Please sign in to comment.