Skip to content

Commit

Permalink
expression: separated arithmeticModIntSig (pingcap#22137)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tjianke authored Mar 29, 2021
1 parent 647ba7f commit 32a9c3d
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 54 deletions.
148 changes: 122 additions & 26 deletions expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ var (
_ builtinFunc = &builtinArithmeticMultiplyIntSig{}
_ builtinFunc = &builtinArithmeticIntDivideIntSig{}
_ builtinFunc = &builtinArithmeticIntDivideDecimalSig{}
_ builtinFunc = &builtinArithmeticModIntSig{}
_ builtinFunc = &builtinArithmeticModIntUnsignedUnsignedSig{}
_ builtinFunc = &builtinArithmeticModIntUnsignedSignedSig{}
_ builtinFunc = &builtinArithmeticModIntSignedUnsignedSig{}
_ builtinFunc = &builtinArithmeticModIntSignedSignedSig{}
_ builtinFunc = &builtinArithmeticModRealSig{}
_ builtinFunc = &builtinArithmeticModDecimalSig{}
)
Expand Down Expand Up @@ -931,9 +934,27 @@ func (c *arithmeticModFunctionClass) getFunction(ctx sessionctx.Context, args []
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_ModInt)
return sig, nil
isLHSUnsigned := mysql.HasUnsignedFlag(args[0].GetType().Flag)
isRHSUnsigned := mysql.HasUnsignedFlag(args[1].GetType().Flag)

switch {
case isLHSUnsigned && isRHSUnsigned:
sig := &builtinArithmeticModIntUnsignedUnsignedSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_ModIntUnsignedUnsigned)
return sig, nil
case isLHSUnsigned && !isRHSUnsigned:
sig := &builtinArithmeticModIntUnsignedSignedSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_ModIntUnsignedSigned)
return sig, nil
case !isLHSUnsigned && isRHSUnsigned:
sig := &builtinArithmeticModIntSignedUnsignedSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_ModIntSignedUnsigned)
return sig, nil
default:
sig := &builtinArithmeticModIntSignedSignedSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_ModIntSignedSigned)
return sig, nil
}
}
}

Expand Down Expand Up @@ -992,17 +1013,17 @@ func (s *builtinArithmeticModDecimalSig) evalDecimal(row chunk.Row) (*types.MyDe
return c, err != nil, err
}

type builtinArithmeticModIntSig struct {
type builtinArithmeticModIntUnsignedUnsignedSig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticModIntSig) Clone() builtinFunc {
newSig := &builtinArithmeticModIntSig{}
func (s *builtinArithmeticModIntUnsignedUnsignedSig) Clone() builtinFunc {
newSig := &builtinArithmeticModIntUnsignedUnsignedSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticModIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
func (s *builtinArithmeticModIntUnsignedUnsignedSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, err
Expand All @@ -1017,28 +1038,103 @@ func (s *builtinArithmeticModIntSig) evalInt(row chunk.Row) (val int64, isNull b
return 0, isNull, err
}

ret := int64(uint64(a) % uint64(b))

return ret, false, nil
}

type builtinArithmeticModIntUnsignedSignedSig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticModIntUnsignedSignedSig) Clone() builtinFunc {
newSig := &builtinArithmeticModIntUnsignedSignedSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticModIntUnsignedSignedSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, err
}
if b == 0 {
return 0, true, handleDivisionByZeroError(s.ctx)
}
a, isNull, err := s.args[0].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, err
}

var ret int64
isLHSUnsigned := mysql.HasUnsignedFlag(s.args[0].GetType().Flag)
isRHSUnsigned := mysql.HasUnsignedFlag(s.args[1].GetType().Flag)
if b < 0 {
ret = int64(uint64(a) % uint64(-b))
} else {
ret = int64(uint64(a) % uint64(b))
}

switch {
case isLHSUnsigned && isRHSUnsigned:
return ret, false, nil
}

type builtinArithmeticModIntSignedUnsignedSig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticModIntSignedUnsignedSig) Clone() builtinFunc {
newSig := &builtinArithmeticModIntSignedUnsignedSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticModIntSignedUnsignedSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, err
}

if b == 0 {
return 0, true, handleDivisionByZeroError(s.ctx)
}

a, isNull, err := s.args[0].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, err
}

var ret int64
if a < 0 {
ret = -int64(uint64(-a) % uint64(b))
} else {
ret = int64(uint64(a) % uint64(b))
case isLHSUnsigned && !isRHSUnsigned:
if b < 0 {
ret = int64(uint64(a) % uint64(-b))
} else {
ret = int64(uint64(a) % uint64(b))
}
case !isLHSUnsigned && isRHSUnsigned:
if a < 0 {
ret = -int64(uint64(-a) % uint64(b))
} else {
ret = int64(uint64(a) % uint64(b))
}
case !isLHSUnsigned && !isRHSUnsigned:
ret = a % b
}

return ret, false, nil
}

type builtinArithmeticModIntSignedSignedSig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticModIntSignedSignedSig) Clone() builtinFunc {
newSig := &builtinArithmeticModIntSignedSignedSig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticModIntSignedSignedSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) {
b, isNull, err := s.args[1].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, err
}

if b == 0 {
return 0, true, handleDivisionByZeroError(s.ctx)
}

a, isNull, err := s.args[0].EvalInt(s.ctx, row)
if isNull || err != nil {
return 0, isNull, err
}

return a % b, false, nil
}
39 changes: 37 additions & 2 deletions expression/builtin_arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package expression

import (
"math"
"time"

. "github.com/pingcap/check"
Expand Down Expand Up @@ -500,6 +501,34 @@ func (s *testEvaluatorSuite) TestArithmeticMod(c *C) {
args: []interface{}{int64(13), int64(11)},
expect: int64(2),
},
{
args: []interface{}{int64(13), int64(11)},
expect: int64(2),
},
{
args: []interface{}{int64(13), int64(0)},
expect: nil,
},
{
args: []interface{}{uint64(13), int64(0)},
expect: nil,
},
{
args: []interface{}{int64(13), uint64(0)},
expect: nil,
},
{
args: []interface{}{uint64(math.MaxInt64 + 1), int64(math.MinInt64)},
expect: int64(0),
},
{
args: []interface{}{int64(-22), uint64(10)},
expect: int64(-2),
},
{
args: []interface{}{int64(math.MinInt64), uint64(3)},
expect: int64(-2),
},
{
args: []interface{}{int64(-13), int64(11)},
expect: int64(-2),
Expand Down Expand Up @@ -598,8 +627,14 @@ func (s *testEvaluatorSuite) TestArithmeticMod(c *C) {
switch sig.(type) {
case *builtinArithmeticModRealSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModReal)
case *builtinArithmeticModIntSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModInt)
case *builtinArithmeticModIntUnsignedUnsignedSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModIntUnsignedUnsigned)
case *builtinArithmeticModIntUnsignedSignedSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModIntUnsignedSigned)
case *builtinArithmeticModIntSignedUnsignedSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModIntSignedUnsigned)
case *builtinArithmeticModIntSignedSignedSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModIntSignedSigned)
case *builtinArithmeticModDecimalSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModDecimal)
}
Expand Down
91 changes: 69 additions & 22 deletions expression/builtin_arithmetic_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,17 @@ func (b *builtinArithmeticDivideDecimalSig) vecEvalDecimal(input *chunk.Chunk, r
return nil
}

func (b *builtinArithmeticModIntSig) vectorized() bool {
func (b *builtinArithmeticModIntUnsignedUnsignedSig) vectorized() bool {
return true
}

func (b *builtinArithmeticModIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
func (b *builtinArithmeticModIntUnsignedUnsignedSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
lh, err := b.bufAllocator.get(types.ETInt, n)
if err != nil {
return err
}
defer b.bufAllocator.put(lh)

if err := b.args[0].VecEvalInt(b.ctx, input, lh); err != nil {
return err
}
Expand All @@ -130,23 +129,8 @@ func (b *builtinArithmeticModIntSig) vecEvalInt(input *chunk.Chunk, result *chun
return err
}

isLHSUnsigned := mysql.HasUnsignedFlag(b.args[0].GetType().Flag)
isRHSUnsigned := mysql.HasUnsignedFlag(b.args[1].GetType().Flag)

rh := result
switch {
case isLHSUnsigned && isRHSUnsigned:
err = b.modUU(lh, rh)
case isLHSUnsigned && !isRHSUnsigned:
err = b.modUS(lh, rh)
case !isLHSUnsigned && isRHSUnsigned:
err = b.modSU(lh, rh)
case !isLHSUnsigned && !isRHSUnsigned:
err = b.modSS(lh, rh)
}
return err
}
func (b *builtinArithmeticModIntSig) modUU(lh, rh *chunk.Column) error {

lhi64s := lh.Int64s()
rhi64s := rh.Int64s()

Expand All @@ -170,7 +154,28 @@ func (b *builtinArithmeticModIntSig) modUU(lh, rh *chunk.Column) error {
}
return nil
}
func (b *builtinArithmeticModIntSig) modUS(lh, rh *chunk.Column) error {

func (b *builtinArithmeticModIntUnsignedSignedSig) vectorized() bool {
return true
}

func (b *builtinArithmeticModIntUnsignedSignedSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
lh, err := b.bufAllocator.get(types.ETInt, n)
if err != nil {
return err
}
defer b.bufAllocator.put(lh)

if err := b.args[0].VecEvalInt(b.ctx, input, lh); err != nil {
return err
}
// reuse result as rh to avoid buf allocate
if err := b.args[1].VecEvalInt(b.ctx, input, result); err != nil {
return err
}
rh := result

lhi64s := lh.Int64s()
rhi64s := rh.Int64s()

Expand Down Expand Up @@ -198,7 +203,28 @@ func (b *builtinArithmeticModIntSig) modUS(lh, rh *chunk.Column) error {
}
return nil
}
func (b *builtinArithmeticModIntSig) modSU(lh, rh *chunk.Column) error {

func (b *builtinArithmeticModIntSignedUnsignedSig) vectorized() bool {
return true
}

func (b *builtinArithmeticModIntSignedUnsignedSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
lh, err := b.bufAllocator.get(types.ETInt, n)
if err != nil {
return err
}
defer b.bufAllocator.put(lh)

if err := b.args[0].VecEvalInt(b.ctx, input, lh); err != nil {
return err
}
// reuse result as rh to avoid buf allocate
if err := b.args[1].VecEvalInt(b.ctx, input, result); err != nil {
return err
}
rh := result

lhi64s := lh.Int64s()
rhi64s := rh.Int64s()

Expand Down Expand Up @@ -226,7 +252,28 @@ func (b *builtinArithmeticModIntSig) modSU(lh, rh *chunk.Column) error {
}
return nil
}
func (b *builtinArithmeticModIntSig) modSS(lh, rh *chunk.Column) error {

func (b *builtinArithmeticModIntSignedSignedSig) vectorized() bool {
return true
}

func (b *builtinArithmeticModIntSignedSignedSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
lh, err := b.bufAllocator.get(types.ETInt, n)
if err != nil {
return err
}
defer b.bufAllocator.put(lh)

if err := b.args[0].VecEvalInt(b.ctx, input, lh); err != nil {
return err
}
// reuse result as rh to avoid buf allocate
if err := b.args[1].VecEvalInt(b.ctx, input, result); err != nil {
return err
}
rh := result

lhi64s := lh.Int64s()
rhi64s := rh.Int64s()

Expand Down
10 changes: 8 additions & 2 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,14 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
f = &builtinArithmeticModRealSig{base}
case tipb.ScalarFuncSig_ModDecimal:
f = &builtinArithmeticModDecimalSig{base}
case tipb.ScalarFuncSig_ModInt:
f = &builtinArithmeticModIntSig{base}
case tipb.ScalarFuncSig_ModIntUnsignedUnsigned:
f = &builtinArithmeticModIntUnsignedUnsignedSig{base}
case tipb.ScalarFuncSig_ModIntUnsignedSigned:
f = &builtinArithmeticModIntUnsignedSignedSig{base}
case tipb.ScalarFuncSig_ModIntSignedUnsigned:
f = &builtinArithmeticModIntSignedUnsignedSig{base}
case tipb.ScalarFuncSig_ModIntSignedSigned:
f = &builtinArithmeticModIntSignedSignedSig{base}
case tipb.ScalarFuncSig_MultiplyIntUnsigned:
f = &builtinArithmeticMultiplyIntUnsignedSig{base}
case tipb.ScalarFuncSig_AbsInt:
Expand Down
Loading

0 comments on commit 32a9c3d

Please sign in to comment.