Skip to content

Commit

Permalink
expression, sessionctx: support rand_seed1 and rand_seed2 sysvar (pin…
Browse files Browse the repository at this point in the history
  • Loading branch information
Defined2014 authored Nov 24, 2021
1 parent 791f59d commit 45836a6
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 24 deletions.
19 changes: 8 additions & 11 deletions expression/builtin_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ import (
"math"
"strconv"
"strings"
"sync"

"github.com/cznic/mathutil"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
utilMath "github.com/pingcap/tidb/util/math"
"github.com/pingcap/tipb/go-tipb"
)

Expand Down Expand Up @@ -1023,7 +1023,7 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
}
bt := bf
if len(args) == 0 {
sig = &builtinRandSig{bt, &sync.Mutex{}, NewWithTime()}
sig = &builtinRandSig{bt, ctx.GetSessionVars().Rng}
sig.setPbCode(tipb.ScalarFuncSig_Rand)
} else if _, isConstant := args[0].(*Constant); isConstant {
// According to MySQL manual:
Expand All @@ -1039,7 +1039,7 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
// The behavior same as MySQL.
seed = 0
}
sig = &builtinRandSig{bt, &sync.Mutex{}, NewWithSeed(seed)}
sig = &builtinRandSig{bt, utilMath.NewWithSeed(seed)}
sig.setPbCode(tipb.ScalarFuncSig_Rand)
} else {
sig = &builtinRandWithSeedFirstGenSig{bt}
Expand All @@ -1050,22 +1050,19 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio

type builtinRandSig struct {
baseBuiltinFunc
mu *sync.Mutex
mysqlRng *MysqlRng
mysqlRng *utilMath.MysqlRng
}

func (b *builtinRandSig) Clone() builtinFunc {
newSig := &builtinRandSig{mysqlRng: b.mysqlRng, mu: b.mu}
newSig := &builtinRandSig{mysqlRng: b.mysqlRng}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalReal evals RAND().
// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_rand
func (b *builtinRandSig) evalReal(row chunk.Row) (float64, bool, error) {
b.mu.Lock()
res := b.mysqlRng.Gen()
b.mu.Unlock()
return res, false, nil
}

Expand All @@ -1089,11 +1086,11 @@ func (b *builtinRandWithSeedFirstGenSig) evalReal(row chunk.Row) (float64, bool,
// b.args[0] is promised to be a non-constant(such as a column name) in
// builtinRandWithSeedFirstGenSig, the seed is initialized with the value for each
// invocation of RAND().
var rng *MysqlRng
var rng *utilMath.MysqlRng
if !isNull {
rng = NewWithSeed(seed)
rng = utilMath.NewWithSeed(seed)
} else {
rng = NewWithSeed(0)
rng = utilMath.NewWithSeed(0)
}
return rng.Gen(), false, nil
}
Expand Down
3 changes: 2 additions & 1 deletion expression/builtin_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/testkit/trequire"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
utilMath "github.com/pingcap/tidb/util/math"
"github.com/pingcap/tipb/go-tipb"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -382,7 +383,7 @@ func TestRand(t *testing.T) {
// issue 3211
f2, err := fc.getFunction(ctx, []Expression{&Constant{Value: types.NewIntDatum(20160101), RetType: types.NewFieldType(mysql.TypeLonglong)}})
require.NoError(t, err)
randGen := NewWithSeed(20160101)
randGen := utilMath.NewWithSeed(20160101)
for i := 0; i < 3; i++ {
v, err = evalBuiltinFunc(f2, chunk.Row{})
require.NoError(t, err)
Expand Down
8 changes: 4 additions & 4 deletions expression/builtin_math_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"math"
"strconv"

utilMath "github.com/pingcap/tidb/util/math"

"github.com/cznic/mathutil"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -709,11 +711,9 @@ func (b *builtinRandSig) vecEvalReal(input *chunk.Chunk, result *chunk.Column) e
n := input.NumRows()
result.ResizeFloat64(n, false)
f64s := result.Float64s()
b.mu.Lock()
for i := range f64s {
f64s[i] = b.mysqlRng.Gen()
}
b.mu.Unlock()
return nil
}

Expand All @@ -738,9 +738,9 @@ func (b *builtinRandWithSeedFirstGenSig) vecEvalReal(input *chunk.Chunk, result
for i := 0; i < n; i++ {
// When the seed is null we need to use 0 as the seed.
// The behavior same as MySQL.
rng := NewWithSeed(0)
rng := utilMath.NewWithSeed(0)
if !buf.IsNull(i) {
rng = NewWithSeed(i64s[i])
rng = utilMath.NewWithSeed(i64s[i])
}
f64s[i] = rng.Gen()
}
Expand Down
4 changes: 4 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,10 @@ func (s *testIntegrationSuite2) TestMathBuiltin(c *C) {
tk.MustQuery("select rand(1) from t").Sort().Check(testkit.Rows("0.1418603212962489", "0.40540353712197724", "0.8716141803857071"))
tk.MustQuery("select rand(a) from t").Check(testkit.Rows("0.40540353712197724", "0.6555866465490187", "0.9057697559760601"))
tk.MustQuery("select rand(1), rand(2), rand(3)").Check(testkit.Rows("0.40540353712197724 0.6555866465490187 0.9057697559760601"))
tk.MustQuery("set @@rand_seed1=10000000,@@rand_seed2=1000000")
tk.MustQuery("select rand()").Check(testkit.Rows("0.028870999839968048"))
tk.MustQuery("select rand(1)").Check(testkit.Rows("0.40540353712197724"))
tk.MustQuery("select rand()").Check(testkit.Rows("0.11641535266900002"))
}

func (s *testIntegrationSuite2) TestStringBuiltin(c *C) {
Expand Down
2 changes: 0 additions & 2 deletions sessionctx/variable/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ var noopSysVars = []*SysVar{
{Scope: ScopeGlobal | ScopeSession, Name: BigTables, Value: Off, Type: TypeBool},
{Scope: ScopeNone, Name: "skip_external_locking", Value: "1"},
{Scope: ScopeNone, Name: "innodb_sync_array_size", Value: "1"},
{Scope: ScopeSession, Name: "rand_seed2", Value: ""},
{Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: Off, Type: TypeBool},
{Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64},
{Scope: ScopeSession, Name: "gtid_next", Value: ""},
Expand Down Expand Up @@ -275,7 +274,6 @@ var noopSysVars = []*SysVar{
{Scope: ScopeNone, Name: "binlog_gtid_simple_recovery", Value: "1"},
{Scope: ScopeNone, Name: "performance_schema_digests_size", Value: "10000"},
{Scope: ScopeGlobal | ScopeSession, Name: Profiling, Value: Off, Type: TypeBool},
{Scope: ScopeSession, Name: "rand_seed1", Value: ""},
{Scope: ScopeGlobal, Name: "sha256_password_proxy_users", Value: ""},
{Scope: ScopeGlobal | ScopeSession, Name: SQLQuoteShowCreate, Value: On, Type: TypeBool},
{Scope: ScopeGlobal | ScopeSession, Name: "binlogging_impossible_mode", Value: "IGNORE_ERROR"},
Expand Down
6 changes: 6 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"sync/atomic"
"time"

utilMath "github.com/pingcap/tidb/util/math"

"github.com/pingcap/errors"
pumpcli "github.com/pingcap/tidb-tools/tidb-binlog/pump_client"
"github.com/pingcap/tidb/config"
Expand Down Expand Up @@ -955,6 +957,9 @@ type SessionVars struct {
curr int8
data [2]stmtctx.StatementContext
}

// Rng stores the rand_seed1 and rand_seed2 for Rand() function
Rng *utilMath.MysqlRng
}

// InitStatementContext initializes a StatementContext, the object is reused to reduce allocation.
Expand Down Expand Up @@ -1188,6 +1193,7 @@ func NewSessionVars() *SessionVars {
MPPStoreLastFailTime: make(map[string]time.Time),
MPPStoreFailTTL: DefTiDBMPPStoreFailTTL,
EnablePlacementChecks: DefEnablePlacementCheck,
Rng: utilMath.NewWithTime(),
}
vars.KVVars = tikvstore.NewVariables(&vars.Killed)
vars.Concurrency = Concurrency{
Expand Down
16 changes: 16 additions & 0 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -1861,6 +1861,18 @@ var defaultSysVars = []*SysVar{
{Scope: ScopeNone, Name: "version_compile_os", Value: runtime.GOOS},
{Scope: ScopeNone, Name: "version_compile_machine", Value: runtime.GOARCH},
{Scope: ScopeNone, Name: TiDBAllowFunctionForExpressionIndex, ReadOnly: true, Value: collectAllowFuncName4ExpressionIndex()},
{Scope: ScopeSession, Name: RandSeed1, Type: TypeInt, Value: "0", skipInit: true, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error {
s.Rng.SetSeed1(uint32(tidbOptPositiveInt32(val, 0)))
return nil
}, GetSession: func(s *SessionVars) (string, error) {
return "0", nil
}},
{Scope: ScopeSession, Name: RandSeed2, Type: TypeInt, Value: "0", skipInit: true, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error {
s.Rng.SetSeed2(uint32(tidbOptPositiveInt32(val, 0)))
return nil
}, GetSession: func(s *SessionVars) (string, error) {
return "0", nil
}},
}

func collectAllowFuncName4ExpressionIndex() string {
Expand Down Expand Up @@ -2183,6 +2195,10 @@ const (
Identity = "identity"
// TiDBAllowFunctionForExpressionIndex is the name of `TiDBAllowFunctionForExpressionIndex` system variable.
TiDBAllowFunctionForExpressionIndex = "tidb_allow_function_for_expression_index"
// RandSeed1 is the name of 'rand_seed1' system variable.
RandSeed1 = "rand_seed1"
// RandSeed2 is the name of 'rand_seed2' system variable.
RandSeed2 = "rand_seed2"
)

// GlobalVarAccessor is the interface for accessing global scope system and status variables.
Expand Down
30 changes: 27 additions & 3 deletions expression/rand.go → util/math/rand.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package expression
package math

import "time"
import (
"sync"
"time"
)

const maxRandValue = 0x3FFFFFFF

Expand All @@ -23,13 +26,18 @@ const maxRandValue = 0x3FFFFFFF
type MysqlRng struct {
seed1 uint32
seed2 uint32
mu *sync.Mutex
}

// NewWithSeed create a rng with random seed.
func NewWithSeed(seed int64) *MysqlRng {
seed1 := uint32(seed*0x10001+55555555) % maxRandValue
seed2 := uint32(seed*0x10000001) % maxRandValue
return &MysqlRng{seed1: seed1, seed2: seed2}
return &MysqlRng{
seed1: seed1,
seed2: seed2,
mu: &sync.Mutex{},
}
}

// NewWithTime create a rng with time stamp.
Expand All @@ -39,7 +47,23 @@ func NewWithTime() *MysqlRng {

// Gen will generate random number.
func (rng *MysqlRng) Gen() float64 {
rng.mu.Lock()
defer rng.mu.Unlock()
rng.seed1 = (rng.seed1*3 + rng.seed2) % maxRandValue
rng.seed2 = (rng.seed1 + rng.seed2 + 33) % maxRandValue
return float64(rng.seed1) / float64(maxRandValue)
}

// SetSeed1 is a interface to set seed1
func (rng *MysqlRng) SetSeed1(seed uint32) {
rng.mu.Lock()
defer rng.mu.Unlock()
rng.seed1 = seed
}

// SetSeed2 is a interface to set seed2
func (rng *MysqlRng) SetSeed2(seed uint32) {
rng.mu.Lock()
defer rng.mu.Unlock()
rng.seed2 = seed
}
21 changes: 18 additions & 3 deletions expression/rand_test.go → util/math/rand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package expression
package math

import (
"testing"
Expand Down Expand Up @@ -55,8 +55,23 @@ func TestRandWithSeed(t *testing.T) {
for _, test := range tests {
rng := NewWithSeed(test.seed)
got1 := rng.Gen()
require.True(t, got1 == test.once)
require.Equal(t, got1, test.once)
got2 := rng.Gen()
require.True(t, got2 == test.twice)
require.Equal(t, got2, test.twice)
}
}

func TestRandWithSeed1AndSeed2(t *testing.T) {
t.Parallel()

seed1 := uint32(10000000)
seed2 := uint32(1000000)

rng := NewWithTime()
rng.SetSeed1(seed1)
rng.SetSeed2(seed2)

require.Equal(t, rng.Gen(), 0.028870999839968048)
require.Equal(t, rng.Gen(), 0.11641535266900002)
require.Equal(t, rng.Gen(), 0.49546379455874096)
}

0 comments on commit 45836a6

Please sign in to comment.