Skip to content

Commit

Permalink
Merged tx_execute with runTransaction (cadence-workflow#1181)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfateev authored Oct 20, 2018
1 parent 81c0089 commit f2ebf3c
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 39 deletions.
31 changes: 1 addition & 30 deletions common/persistence/sql/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/uber-common/bark"
workflow "github.com/uber/cadence/.gen/go/shared"
"github.com/uber/cadence/common/persistence"
p "github.com/uber/cadence/common/persistence"
)

// TODO: Rename all SQL Managers to Stores
Expand Down Expand Up @@ -59,6 +58,7 @@ func (m *sqlStore) txExecute(operation string, f func(tx *sqlx.Tx) error) error
*persistence.CurrentWorkflowConditionFailedError,
*workflow.InternalServiceError,
*persistence.WorkflowExecutionAlreadyStartedError,
*workflow.DomainAlreadyExistsError,
*persistence.ShardOwnershipLostError:
return err
default:
Expand Down Expand Up @@ -141,35 +141,6 @@ func dereferenceIfNotNil(a *[]byte) []byte {
return nil
}

func runTransaction(name string, db *sqlx.DB, txFunc func(tx *sqlx.Tx) error) error {
convertErr := func(err error) error {
switch err.(type) {
case *workflow.InternalServiceError, *workflow.DomainAlreadyExistsError:
return err
case *p.ShardOwnershipLostError, *p.ConditionFailedError:
return err
default:
return &workflow.InternalServiceError{
Message: fmt.Sprintf("%v: %v", name, err),
}
}
}
tx, err := db.Beginx()
if err != nil {
return &workflow.InternalServiceError{
Message: fmt.Sprintf("%v: failed to begin transaction: %v", name, err),
}
}
if err := txFunc(tx); err != nil {
tx.Rollback()
return convertErr(err)
}
if err := tx.Commit(); err != nil {
return convertErr(err)
}
return nil
}

func serializePageToken(offset int64) []byte {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, uint64(offset))
Expand Down
2 changes: 1 addition & 1 deletion common/persistence/sql/sqlHistoryManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func (m *sqlHistoryManager) Close() {
}

func (m *sqlHistoryManager) overWriteHistoryEvents(request *p.InternalAppendHistoryEventsRequest, row *eventsRow) error {
return runTransaction("AppendHistoryEvents", m.db, func(tx *sqlx.Tx) error {
return m.txExecute("AppendHistoryEvents", func(tx *sqlx.Tx) error {
if err := lockEventForUpdate(tx, request); err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions common/persistence/sql/sqlMetadataManagerV2.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (m *sqlMetadataManagerV2) CreateDomain(request *persistence.CreateDomainReq
}

var resp *persistence.CreateDomainResponse
err = runTransaction("CreateDomain", m.db, func(tx *sqlx.Tx) error {
err = m.txExecute("CreateDomain", func(tx *sqlx.Tx) error {
if _, err1 := tx.NamedExec(createDomainSQLQuery, &domainRow{
domainCommon: domainCommon{
Name: request.Info.Name,
Expand Down Expand Up @@ -394,7 +394,7 @@ func (m *sqlMetadataManagerV2) UpdateDomain(request *persistence.UpdateDomainReq
}
}

return runTransaction("UpdateDomain", m.db, func(tx *sqlx.Tx) error {
return m.txExecute("UpdateDomain", func(tx *sqlx.Tx) error {
result, err := tx.NamedExec(updateDomainSQLQuery, &flatUpdateDomainRequest{
domainCommon: domainCommon{
Name: request.Info.Name,
Expand Down Expand Up @@ -432,14 +432,14 @@ func (m *sqlMetadataManagerV2) UpdateDomain(request *persistence.UpdateDomainReq
}

func (m *sqlMetadataManagerV2) DeleteDomain(request *persistence.DeleteDomainRequest) error {
return runTransaction("DeleteDomain", m.db, func(tx *sqlx.Tx) error {
return m.txExecute("DeleteDomain", func(tx *sqlx.Tx) error {
_, err := tx.NamedExec(deleteDomainByIDSQLQuery, request)
return err
})
}

func (m *sqlMetadataManagerV2) DeleteDomainByName(request *persistence.DeleteDomainByNameRequest) error {
return runTransaction("DeleteDomainByName", m.db, func(tx *sqlx.Tx) error {
return m.txExecute("DeleteDomainByName", func(tx *sqlx.Tx) error {
_, err := m.db.NamedExec(deleteDomainByNameSQLQuery, request)
return err
})
Expand Down
2 changes: 1 addition & 1 deletion common/persistence/sql/sqlShardManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (m *sqlShardManager) UpdateShard(request *persistence.UpdateShardRequest) e
Message: fmt.Sprintf("UpdateShard operation failed. Error: %v", err),
}
}
return runTransaction("UpdateShard", m.db, func(tx *sqlx.Tx) error {
return m.txExecute("UpdateShard", func(tx *sqlx.Tx) error {
if err := lockShard(tx, request.ShardInfo.ShardID, request.PreviousRangeID); err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions common/persistence/sql/sqlTaskManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (m *sqlTaskManager) LeaseTaskList(request *persistence.LeaseTaskListRequest
}

var resp *persistence.LeaseTaskListResponse
err := runTransaction("LeaseTaskList", m.db, func(tx *sqlx.Tx) error {
err := m.txExecute("LeaseTaskList", func(tx *sqlx.Tx) error {
rangeID = row.RangeID
ackLevel = row.AckLevel
// We need to separately check the condition and do the
Expand Down Expand Up @@ -217,7 +217,7 @@ func (m *sqlTaskManager) UpdateTaskList(request *persistence.UpdateTaskListReque
}
}
var resp *persistence.UpdateTaskListResponse
err := runTransaction("UpdateTaskList", m.db, func(tx *sqlx.Tx) error {
err := m.txExecute("UpdateTaskList", func(tx *sqlx.Tx) error {
err1 := lockTaskList(
tx, request.TaskListInfo.DomainID, request.TaskListInfo.Name, request.TaskListInfo.TaskType, request.TaskListInfo.RangeID)
if err1 != nil {
Expand Down Expand Up @@ -271,7 +271,7 @@ func (m *sqlTaskManager) CreateTasks(request *persistence.CreateTasksRequest) (*
}
}
var resp *persistence.CreateTasksResponse
err := runTransaction("CreateTasks", m.db, func(tx *sqlx.Tx) error {
err := m.txExecute("CreateTasks", func(tx *sqlx.Tx) error {
query, args, err1 := m.db.BindNamed(createTaskSQLQuery, tasksRows)
if err1 != nil {
return err1
Expand Down

0 comments on commit f2ebf3c

Please sign in to comment.