diff --git a/executor/executor_test.go b/executor/executor_test.go index c4efffb8479d8..2c40b47382440 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -3731,10 +3731,13 @@ func (s *testSuite3) TestTSOFail(c *C) { tk.MustExec(`drop table if exists t`) tk.MustExec(`create table t(a int)`) - ctx := context.Background() - ctx = context.WithValue(ctx, "mockGetTSFail", struct{}{}) + c.Assert(failpoint.Enable("github.com/pingcap/tidb/session/mockGetTSFail", "return"), IsNil) + ctx := failpoint.WithHook(context.Background(), func(ctx context.Context, fpname string) bool { + return fpname == "github.com/pingcap/tidb/session/mockGetTSFail" + }) _, err := tk.Se.Execute(ctx, `select * from t`) c.Assert(err, NotNil) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/session/mockGetTSFail"), IsNil) } func (s *testSuite3) TestSelectHashPartitionTable(c *C) { diff --git a/session/session_fail_test.go b/session/session_fail_test.go index 44fa7be1173a0..65f1b9ff45365 100644 --- a/session/session_fail_test.go +++ b/session/session_fail_test.go @@ -82,14 +82,18 @@ func (s *testSessionSuite) TestGetTSFailDirtyState(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("create table t (id int)") - ctx := context.Background() - ctx = context.WithValue(ctx, "mockGetTSFail", struct{}{}) - tk.Se.Execute(ctx, "select * from t") + c.Assert(failpoint.Enable("github.com/pingcap/tidb/session/mockGetTSFail", "return"), IsNil) + ctx := failpoint.WithHook(context.Background(), func(ctx context.Context, fpname string) bool { + return fpname == "github.com/pingcap/tidb/session/mockGetTSFail" + }) + _, err := tk.Se.Execute(ctx, "select * from t") + c.Assert(err, NotNil) // Fix a bug that active txn fail set TxnState.fail to error, and then the following write // affected by this fail flag. tk.MustExec("insert into t values (1)") tk.MustQuery(`select * from t`).Check(testkit.Rows("1")) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/session/mockGetTSFail"), IsNil) } func (s *testSessionSuite) TestGetTSFailDirtyStateInretry(c *C) { diff --git a/session/txn.go b/session/txn.go index b4eedb521079b..eb40cb902c7c4 100755 --- a/session/txn.go +++ b/session/txn.go @@ -372,22 +372,24 @@ func mergeToDirtyDB(dirtyDB *executor.DirtyDB, op dirtyTableOperation) { } } +type txnFailFuture struct{} + +func (txnFailFuture) Wait() (uint64, error) { + return 0, errors.New("mock get timestamp fail") +} + // txnFuture is a promise, which promises to return a txn in future. type txnFuture struct { future oracle.Future store kv.Storage - - mockFail bool } func (tf *txnFuture) wait() (kv.Transaction, error) { - if tf.mockFail { - return nil, errors.New("mock get timestamp fail") - } - startTS, err := tf.future.Wait() if err == nil { return tf.store.BeginWithStartTS(startTS) + } else if _, ok := tf.future.(txnFailFuture); ok { + return nil, err } // It would retry get timestamp. @@ -409,9 +411,9 @@ func (s *session) getTxnFuture(ctx context.Context) *txnFuture { tsFuture = oracleStore.GetTimestampAsync(ctx) } ret := &txnFuture{future: tsFuture, store: s.store} - if x := ctx.Value("mockGetTSFail"); x != nil { - ret.mockFail = true - } + failpoint.InjectContext(ctx, "mockGetTSFail", func() { + ret.future = txnFailFuture{} + }) return ret }