Skip to content

Commit

Permalink
Fix misue of tae txn context for a single batch (matrixorigin#2775)
Browse files Browse the repository at this point in the history
* fix loading batch

* one txn for new relation

* fix txn

* fix ut
  • Loading branch information
daviszhen authored Jun 3, 2022
1 parent c8effab commit 8a4bc1a
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 18 deletions.
70 changes: 63 additions & 7 deletions pkg/frontend/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,11 @@ type SharePart struct {
simdCsvLineArray [][]string

//storage
storage engine.Engine
dbHandler engine.Database
tableHandler engine.Relation
dbName string
tableName string
txnHandler *TxnHandler
oneTxnPerBatch bool

Expand Down Expand Up @@ -502,8 +505,11 @@ func initWriteBatchHandler(handler *ParseLineHandler, wHandler *WriteBatchHandle
wHandler.dataColumnId2TableColumnId = handler.dataColumnId2TableColumnId
wHandler.batchSize = handler.batchSize
wHandler.attrName = handler.attrName
wHandler.storage = handler.storage
wHandler.dbName = handler.dbName
wHandler.dbHandler = handler.dbHandler
wHandler.tableHandler = handler.tableHandler
wHandler.tableName = handler.tableName
wHandler.txnHandler = handler.txnHandler
wHandler.oneTxnPerBatch = handler.oneTxnPerBatch
wHandler.timestamp = handler.timestamp
Expand Down Expand Up @@ -1335,23 +1341,35 @@ func writeBatchToStorage(handler *WriteBatchHandler, force bool) error {
wait_a := time.Now()
handler.ThreadInfo.SetTime(wait_a)
handler.ThreadInfo.SetCnt(1)
dbHandler := handler.dbHandler
txnHandler := handler.txnHandler
tableHandler := handler.tableHandler
if !handler.skipWriteBatch {
if handler.oneTxnPerBatch {
txnHandler = InitTxnHandler(config.StorageEngine)
_, err = txnHandler.StartByAutocommitIfNeeded()
if err != nil {
return err
goto handleError
}
dbHandler, err = handler.storage.Database(handler.dbName, txnHandler.GetTxn().GetCtx())
if err != nil {
goto handleError
}
tableHandler, err = dbHandler.Relation(handler.tableName, txnHandler.GetTxn().GetCtx())
if err != nil {
goto handleError
}
}
err = handler.tableHandler.Write(handler.timestamp, handler.batchData, txnHandler.GetTxn().GetCtx())
err = tableHandler.Write(handler.timestamp, handler.batchData, txnHandler.GetTxn().GetCtx())
if handler.oneTxnPerBatch {
err = txnHandler.CommitAfterAutocommitOnly()
if err != nil {
return err
goto handleError
}
}
}

handleError:
handler.ThreadInfo.SetCnt(0)
if err == nil {
handler.result.Records += uint64(handler.batchSize)
Expand All @@ -1365,6 +1383,13 @@ func writeBatchToStorage(handler *WriteBatchHandler, force bool) error {
handler.result.Skipped += uint64(handler.batchSize)
}

if handler.oneTxnPerBatch && err != nil {
err2 := txnHandler.RollbackAfterAutocommitOnly()
if err2 != nil {
logutil.Errorf("rollback failed.error:%v", err2)
}
}

handler.writeBatch += time.Since(wait_a)

wait_b := time.Now()
Expand Down Expand Up @@ -1450,22 +1475,34 @@ func writeBatchToStorage(handler *WriteBatchHandler, force bool) error {
handler.ThreadInfo.SetTime(wait_a)
handler.ThreadInfo.SetCnt(1)
txnHandler := handler.txnHandler
tableHandler := handler.tableHandler
dbHandler := handler.dbHandler
if !handler.skipWriteBatch {
if handler.oneTxnPerBatch {
txnHandler = InitTxnHandler(config.StorageEngine)
_, err = txnHandler.StartByAutocommitIfNeeded()
if err != nil {
return err
goto handleError2
}
dbHandler, err = handler.storage.Database(handler.dbName, txnHandler.GetTxn().GetCtx())
if err != nil {
goto handleError2
}
//new relation
tableHandler, err = dbHandler.Relation(handler.tableName, txnHandler.GetTxn().GetCtx())
if err != nil {
goto handleError2
}
}
err = handler.tableHandler.Write(handler.timestamp, handler.batchData, txnHandler.GetTxn().GetCtx())
err = tableHandler.Write(handler.timestamp, handler.batchData, txnHandler.GetTxn().GetCtx())
if handler.oneTxnPerBatch {
err = txnHandler.CommitAfterAutocommitOnly()
if err != nil {
return err
goto handleError2
}
}
}
handleError2:
handler.ThreadInfo.SetCnt(0)
if err == nil {
handler.result.Records += uint64(needLen)
Expand All @@ -1478,6 +1515,13 @@ func writeBatchToStorage(handler *WriteBatchHandler, force bool) error {
logutil.Errorf("write failed. err:%v \n", err)
handler.result.Skipped += uint64(needLen)
}

if handler.oneTxnPerBatch && err != nil {
err2 := txnHandler.RollbackAfterAutocommitOnly()
if err2 != nil {
logutil.Errorf("rollback failed.error:%v", err2)
}
}
}
}
}
Expand Down Expand Up @@ -1547,7 +1591,7 @@ func PrintThreadInfo(handler *ParseLineHandler, close *CloseFlag, a time.Duratio
/*
LoadLoop reads data from stream, extracts the fields, and saves into the table
*/
func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database, tableHandler engine.Relation) (*LoadResult, error) {
func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database, tableHandler engine.Relation, dbName string) (*LoadResult, error) {
ses := mce.GetSession()

var m sync.Mutex
Expand Down Expand Up @@ -1584,8 +1628,11 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
load: load,
lineIdx: 0,
simdCsvLineArray: make([][]string, curBatchSize),
storage: ses.Pu.StorageEngine,
dbHandler: dbHandler,
tableHandler: tableHandler,
tableName: string(load.Table.Name()),
dbName: dbName,
txnHandler: ses.GetTxnHandler(),
oneTxnPerBatch: ses.Pu.SV.GetOneTxnPerBatchDuringLoad(),
lineCount: 0,
Expand Down Expand Up @@ -1650,6 +1697,15 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
return nil, err
}

//TODO: remove it after tae is ready
if handler.oneTxnPerBatch {
txnHandler := ses.GetTxnHandler()
err = txnHandler.CommitAfterAutocommitOnly()
if err != nil {
return nil, err
}
}

wg := sync.WaitGroup{}

/*
Expand Down
9 changes: 8 additions & 1 deletion pkg/frontend/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,18 @@ func Test_load(t *testing.T) {
if i == 3 {
row2col = gostub.Stub(&row2colChoose, false)
}
_, err := mce.LoadLoop(cws[i], db, rel)
_, err = ses.txnHandler.StartByAutocommitIfNeeded()
convey.So(err, convey.ShouldBeNil)

_, err := mce.LoadLoop(cws[i], db, rel, "T")
if kases[i].fail {
convey.So(err, convey.ShouldBeError)
//err = ses.txnHandler.RollbackAfterAutocommitOnly()
//convey.So(err, convey.ShouldBeNil)
} else {
convey.So(err, convey.ShouldBeNil)
//err = ses.txnHandler.CommitAfterAutocommitOnly()
//convey.So(err, convey.ShouldBeNil)
}

if i == 3 {
Expand Down
13 changes: 9 additions & 4 deletions pkg/frontend/mysql_cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ func (mce *MysqlCmdExecutor) handleLoadData(load *tree.Load) error {
/*
execute load data
*/
result, err := mce.LoadLoop(load, dbHandler, tableHandler)
result, err := mce.LoadLoop(load, dbHandler, tableHandler, loadDb)
if err != nil {
return err
}
Expand Down Expand Up @@ -1620,6 +1620,7 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
var ret interface{}
var runner ComputationRunner
var selfHandle = false
var fromLoadData = false
var txnErr error

for _, cw := range cws {
Expand Down Expand Up @@ -1651,6 +1652,7 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
if err != nil {
goto handleFailed
}
logutil.Infof("start autocommit txn in default")
}

switch st := stmt.(type) {
Expand Down Expand Up @@ -1745,6 +1747,7 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
if err != nil {
goto handleFailed
}
fromLoadData = true
case *tree.SetVar:
selfHandle = true
err = mce.handleSetVar(st)
Expand Down Expand Up @@ -1964,9 +1967,11 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
}
}
handleSucceeded:
txnErr = txnHandler.CommitAfterAutocommitOnly()
if txnErr != nil {
return txnErr
if !fromLoadData {
txnErr = txnHandler.CommitAfterAutocommitOnly()
if txnErr != nil {
return txnErr
}
}
goto handleNext
handleFailed:
Expand Down
16 changes: 10 additions & 6 deletions pkg/frontend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func (ts *TxnState) isState(s int) bool {
}

func (ts *TxnState) switchToState(s int, err error) {
logutil.Infof("switch from %d to %d", ts.state, s)
ts.fromState = ts.state
ts.state = s
ts.err = err
Expand Down Expand Up @@ -96,7 +97,7 @@ var _ moengine.Txn = &TaeTxnDumpImpl{}
type TaeTxnDumpImpl struct {
}

func InitTaeTxnImpl() *TaeTxnDumpImpl {
func InitTaeTxnDumpImpl() *TaeTxnDumpImpl {
return &TaeTxnDumpImpl{}
}

Expand Down Expand Up @@ -136,7 +137,7 @@ type TxnHandler struct {

func InitTxnHandler(storage engine.Engine) *TxnHandler {
return &TxnHandler{
taeTxn: InitTaeTxnImpl(),
taeTxn: InitTaeTxnDumpImpl(),
txnState: InitTxnState(),
storage: storage,
}
Expand Down Expand Up @@ -345,6 +346,7 @@ func (th *TxnHandler) getTxnStateString() string {
// IsInTaeTxn checks the session executes a txn
func (th *TxnHandler) IsInTaeTxn() bool {
st := th.getTxnState()
logutil.Infof("current txn state %d", st)
if st == TxnAutocommit || st == TxnBegan {
return true
}
Expand Down Expand Up @@ -372,11 +374,12 @@ func (th *TxnHandler) createTxn(beganErr, autocommitErr error) (moengine.Txn, er
err = errorTaeTxnInIllegalState
}
if txn == nil {
txn = InitTaeTxnImpl()
txn = InitTaeTxnDumpImpl()
}
} else {
txn = InitTaeTxnImpl()
txn = InitTaeTxnDumpImpl()
}

return txn, err
}

Expand Down Expand Up @@ -412,6 +415,7 @@ func (th *TxnHandler) StartByAutocommitIfNeeded() (bool, error) {
if th.IsInTaeTxn() {
return false, nil
}
logutil.Infof("need create new txn")
err = th.StartByAutocommit()
return true, err
}
Expand Down Expand Up @@ -549,11 +553,11 @@ func (th *TxnHandler) CleanTxn() error {
logutil.Infof("clean tae txn")
switch th.txnState.getState() {
case TxnInit, TxnEnd:
th.taeTxn = InitTaeTxnImpl()
th.taeTxn = InitTaeTxnDumpImpl()
th.txnState.switchToState(TxnInit, nil)
case TxnErr:
logutil.Errorf("clean txn. Get error:%v txnError:%v", th.txnState.getError(), th.taeTxn.GetError())
th.taeTxn = InitTaeTxnImpl()
th.taeTxn = InitTaeTxnDumpImpl()
th.txnState.switchToState(TxnInit, nil)
}
return nil
Expand Down

0 comments on commit 8a4bc1a

Please sign in to comment.