Skip to content

Commit

Permalink
Add thrift mapper from internal sql types to and from thrift (cadence…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewjdawson2016 authored Dec 2, 2020
1 parent c324af8 commit 4318e53
Show file tree
Hide file tree
Showing 12 changed files with 1,459 additions and 214 deletions.
5 changes: 5 additions & 0 deletions common/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ func TimePtr(v time.Time) *time.Time {
return &v
}

// DurationPtr makes a copy and returns the pointer to a duration
func DurationPtr(v time.Duration) *time.Duration {
return &v
}

// TaskListPtr makes a copy and returns the pointer to a TaskList.
func TaskListPtr(v s.TaskList) *s.TaskList {
return &v
Expand Down
746 changes: 746 additions & 0 deletions common/persistence/serialization/thrift_mapper.go

Large diffs are not rendered by default.

482 changes: 482 additions & 0 deletions common/persistence/serialization/thrift_mapper_test.go

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package sqlplugin
package serialization

import (
"database/sql/driver"
Expand All @@ -44,6 +44,16 @@ func MustParseUUID(s string) UUID {
return u[:]
}

// MustParsePtrUUID returns a UUID parsed from the given string representation
// returns nil if the input is empty string
// panics if the given input is malformed
func MustParsePtrUUID(s *string) UUID {
if s == nil {
return nil
}
return MustParseUUID(*s)
}

// UUIDPtr simply returns a pointer for the given value type
func UUIDPtr(u UUID) *UUID {
return &u
Expand Down
76 changes: 38 additions & 38 deletions common/persistence/sql/sqlExecutionManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ func (m *sqlExecutionManager) createWorkflowExecutionTx(
startVersion := newWorkflow.StartVersion
lastWriteVersion := newWorkflow.LastWriteVersion
shardID := m.shardID
domainID := sqlplugin.MustParseUUID(executionInfo.DomainID)
domainID := serialization.MustParseUUID(executionInfo.DomainID)
workflowID := executionInfo.WorkflowID
runID := sqlplugin.MustParseUUID(executionInfo.RunID)
runID := serialization.MustParseUUID(executionInfo.RunID)

if err := p.ValidateCreateWorkflowModeState(
request.Mode,
Expand Down Expand Up @@ -174,13 +174,13 @@ func (m *sqlExecutionManager) createWorkflowExecutionTx(

case p.CreateWorkflowModeZombie:
// zombie workflow creation with existence of current record, this is a noop
if err := assertRunIDMismatch(sqlplugin.MustParseUUID(executionInfo.RunID), row.RunID); err != nil {
if err := assertRunIDMismatch(serialization.MustParseUUID(executionInfo.RunID), row.RunID); err != nil {
return nil, err
}

case p.CreateWorkflowModeContinueAsNew:
// continueAsNew mode expects a current run exists
if err := assertRunIDMismatch(sqlplugin.MustParseUUID(executionInfo.RunID), row.RunID); err != nil {
if err := assertRunIDMismatch(serialization.MustParseUUID(executionInfo.RunID), row.RunID); err != nil {
return nil, err
}

Expand Down Expand Up @@ -222,8 +222,8 @@ func (m *sqlExecutionManager) GetWorkflowExecution(
request *p.InternalGetWorkflowExecutionRequest,
) (*p.InternalGetWorkflowExecutionResponse, error) {

domainID := sqlplugin.MustParseUUID(request.DomainID)
runID := sqlplugin.MustParseUUID(*request.Execution.RunID)
domainID := serialization.MustParseUUID(request.DomainID)
runID := serialization.MustParseUUID(*request.Execution.RunID)
wfID := *request.Execution.WorkflowID
execution, err := m.db.SelectFromExecutions(ctx, &sqlplugin.ExecutionsFilter{
ShardID: m.shardID, DomainID: domainID, WorkflowID: wfID, RunID: runID})
Expand Down Expand Up @@ -314,9 +314,9 @@ func (m *sqlExecutionManager) GetWorkflowExecution(
}

if info.ParentDomainID != nil {
state.ExecutionInfo.ParentDomainID = sqlplugin.UUID(info.ParentDomainID).String()
state.ExecutionInfo.ParentDomainID = serialization.UUID(info.ParentDomainID).String()
state.ExecutionInfo.ParentWorkflowID = info.GetParentWorkflowID()
state.ExecutionInfo.ParentRunID = sqlplugin.UUID(info.ParentRunID).String()
state.ExecutionInfo.ParentRunID = serialization.UUID(info.ParentRunID).String()
state.ExecutionInfo.InitiatedID = info.GetInitiatedID()
if state.ExecutionInfo.CompletionEvent != nil {
state.ExecutionInfo.CompletionEvent = nil
Expand Down Expand Up @@ -482,9 +482,9 @@ func (m *sqlExecutionManager) updateWorkflowExecutionTx(
newWorkflow := request.NewWorkflowSnapshot

executionInfo := updateWorkflow.ExecutionInfo
domainID := sqlplugin.MustParseUUID(executionInfo.DomainID)
domainID := serialization.MustParseUUID(executionInfo.DomainID)
workflowID := executionInfo.WorkflowID
runID := sqlplugin.MustParseUUID(executionInfo.RunID)
runID := serialization.MustParseUUID(executionInfo.RunID)
shardID := m.shardID

if err := p.ValidateUpdateWorkflowModeState(
Expand Down Expand Up @@ -512,8 +512,8 @@ func (m *sqlExecutionManager) updateWorkflowExecutionTx(
newExecutionInfo := newWorkflow.ExecutionInfo
startVersion := newWorkflow.StartVersion
lastWriteVersion := newWorkflow.LastWriteVersion
newDomainID := sqlplugin.MustParseUUID(newExecutionInfo.DomainID)
newRunID := sqlplugin.MustParseUUID(newExecutionInfo.RunID)
newDomainID := serialization.MustParseUUID(newExecutionInfo.DomainID)
newRunID := serialization.MustParseUUID(newExecutionInfo.RunID)

if !bytes.Equal(domainID, newDomainID) {
return &types.InternalServiceError{
Expand Down Expand Up @@ -596,16 +596,16 @@ func (m *sqlExecutionManager) resetWorkflowExecutionTx(

shardID := m.shardID

domainID := sqlplugin.MustParseUUID(request.NewWorkflowSnapshot.ExecutionInfo.DomainID)
domainID := serialization.MustParseUUID(request.NewWorkflowSnapshot.ExecutionInfo.DomainID)
workflowID := request.NewWorkflowSnapshot.ExecutionInfo.WorkflowID

baseRunID := sqlplugin.MustParseUUID(request.BaseRunID)
baseRunID := serialization.MustParseUUID(request.BaseRunID)
baseRunNextEventID := request.BaseRunNextEventID

currentRunID := sqlplugin.MustParseUUID(request.CurrentRunID)
currentRunID := serialization.MustParseUUID(request.CurrentRunID)
currentRunNextEventID := request.CurrentRunNextEventID

newWorkflowRunID := sqlplugin.MustParseUUID(request.NewWorkflowSnapshot.ExecutionInfo.RunID)
newWorkflowRunID := serialization.MustParseUUID(request.NewWorkflowSnapshot.ExecutionInfo.RunID)
newExecutionInfo := request.NewWorkflowSnapshot.ExecutionInfo
startVersion := request.NewWorkflowSnapshot.StartVersion
lastWriteVersion := request.NewWorkflowSnapshot.LastWriteVersion
Expand Down Expand Up @@ -691,7 +691,7 @@ func (m *sqlExecutionManager) conflictResolveWorkflowExecutionTx(

shardID := m.shardID

domainID := sqlplugin.MustParseUUID(resetWorkflow.ExecutionInfo.DomainID)
domainID := serialization.MustParseUUID(resetWorkflow.ExecutionInfo.DomainID)
workflowID := resetWorkflow.ExecutionInfo.WorkflowID

if err := p.ValidateConflictResolveWorkflowModeState(
Expand All @@ -711,7 +711,7 @@ func (m *sqlExecutionManager) conflictResolveWorkflowExecutionTx(
shardID,
domainID,
workflowID,
sqlplugin.MustParseUUID(resetWorkflow.ExecutionInfo.RunID)); err != nil {
serialization.MustParseUUID(resetWorkflow.ExecutionInfo.RunID)); err != nil {
return err
}

Expand All @@ -724,13 +724,13 @@ func (m *sqlExecutionManager) conflictResolveWorkflowExecutionTx(
startVersion = newWorkflow.StartVersion
lastWriteVersion = newWorkflow.LastWriteVersion
}
runID := sqlplugin.MustParseUUID(executionInfo.RunID)
runID := serialization.MustParseUUID(executionInfo.RunID)
createRequestID := executionInfo.CreateRequestID
state := executionInfo.State
closeStatus := executionInfo.CloseStatus

if currentWorkflow != nil {
prevRunID := sqlplugin.MustParseUUID(currentWorkflow.ExecutionInfo.RunID)
prevRunID := serialization.MustParseUUID(currentWorkflow.ExecutionInfo.RunID)

if err := assertRunIDAndUpdateCurrentExecution(
ctx,
Expand All @@ -752,7 +752,7 @@ func (m *sqlExecutionManager) conflictResolveWorkflowExecutionTx(
}
} else {
// reset workflow is current
prevRunID := sqlplugin.MustParseUUID(resetWorkflow.ExecutionInfo.RunID)
prevRunID := serialization.MustParseUUID(resetWorkflow.ExecutionInfo.RunID)

if err := assertRunIDAndUpdateCurrentExecution(
ctx,
Expand Down Expand Up @@ -801,8 +801,8 @@ func (m *sqlExecutionManager) DeleteWorkflowExecution(
request *p.DeleteWorkflowExecutionRequest,
) error {

domainID := sqlplugin.MustParseUUID(request.DomainID)
runID := sqlplugin.MustParseUUID(request.RunID)
domainID := serialization.MustParseUUID(request.DomainID)
runID := serialization.MustParseUUID(request.RunID)
_, err := m.db.DeleteFromExecutions(ctx, &sqlplugin.ExecutionsFilter{
ShardID: m.shardID,
DomainID: domainID,
Expand All @@ -821,8 +821,8 @@ func (m *sqlExecutionManager) DeleteCurrentWorkflowExecution(
request *p.DeleteCurrentWorkflowExecutionRequest,
) error {

domainID := sqlplugin.MustParseUUID(request.DomainID)
runID := sqlplugin.MustParseUUID(request.RunID)
domainID := serialization.MustParseUUID(request.DomainID)
runID := serialization.MustParseUUID(request.RunID)
_, err := m.db.DeleteFromCurrentExecutions(ctx, &sqlplugin.CurrentExecutionsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -839,7 +839,7 @@ func (m *sqlExecutionManager) GetCurrentExecution(

row, err := m.db.SelectFromCurrentExecutions(ctx, &sqlplugin.CurrentExecutionsFilter{
ShardID: int64(m.shardID),
DomainID: sqlplugin.MustParseUUID(request.DomainID),
DomainID: serialization.MustParseUUID(request.DomainID),
WorkflowID: request.WorkflowID,
})
if err != nil {
Expand Down Expand Up @@ -902,13 +902,13 @@ func (m *sqlExecutionManager) GetTransferTasks(
}
resp.Tasks[i] = &p.TransferTaskInfo{
TaskID: row.TaskID,
DomainID: sqlplugin.UUID(info.DomainID).String(),
DomainID: serialization.UUID(info.DomainID).String(),
WorkflowID: info.GetWorkflowID(),
RunID: sqlplugin.UUID(info.RunID).String(),
RunID: serialization.UUID(info.RunID).String(),
VisibilityTimestamp: time.Unix(0, info.GetVisibilityTimestampNanos()),
TargetDomainID: sqlplugin.UUID(info.TargetDomainID).String(),
TargetDomainID: serialization.UUID(info.TargetDomainID).String(),
TargetWorkflowID: info.GetTargetWorkflowID(),
TargetRunID: sqlplugin.UUID(info.TargetRunID).String(),
TargetRunID: serialization.UUID(info.TargetRunID).String(),
TargetChildWorkflowOnly: info.GetTargetChildWorkflowOnly(),
TaskList: info.GetTaskList(),
TaskType: int(info.GetTaskType()),
Expand Down Expand Up @@ -1012,9 +1012,9 @@ func (m *sqlExecutionManager) populateGetReplicationTasksResponse(

tasks[i] = &p.InternalReplicationTaskInfo{
TaskID: row.TaskID,
DomainID: sqlplugin.UUID(info.DomainID).String(),
DomainID: serialization.UUID(info.DomainID).String(),
WorkflowID: info.GetWorkflowID(),
RunID: sqlplugin.UUID(info.RunID).String(),
RunID: serialization.UUID(info.RunID).String(),
TaskType: int(info.GetTaskType()),
FirstEventID: info.GetFirstEventID(),
NextEventID: info.GetNextEventID(),
Expand Down Expand Up @@ -1179,9 +1179,9 @@ func (m *sqlExecutionManager) CreateFailoverMarkerTasks(
tx,
t,
m.shardID,
sqlplugin.MustParseUUID(task.DomainID),
serialization.MustParseUUID(task.DomainID),
emptyWorkflowID,
sqlplugin.MustParseUUID(emptyReplicationRunID),
serialization.MustParseUUID(emptyReplicationRunID),
m.parser,
); err != nil {
rollBackErr := tx.Rollback()
Expand Down Expand Up @@ -1248,9 +1248,9 @@ func (m *sqlExecutionManager) GetTimerIndexTasks(
resp.Timers[i] = &p.TimerTaskInfo{
VisibilityTimestamp: row.VisibilityTimestamp,
TaskID: row.TaskID,
DomainID: sqlplugin.UUID(info.DomainID).String(),
DomainID: serialization.UUID(info.DomainID).String(),
WorkflowID: info.GetWorkflowID(),
RunID: sqlplugin.UUID(info.RunID).String(),
RunID: serialization.UUID(info.RunID).String(),
TaskType: int(info.GetTaskType()),
TimeoutType: int(info.GetTimeoutType()),
EventID: info.GetEventID(),
Expand Down Expand Up @@ -1319,9 +1319,9 @@ func (m *sqlExecutionManager) PutReplicationTaskToDLQ(
) error {
replicationTask := request.TaskInfo
blob, err := m.parser.ReplicationTaskInfoToBlob(&sqlblobs.ReplicationTaskInfo{
DomainID: sqlplugin.MustParseUUID(replicationTask.DomainID),
DomainID: serialization.MustParseUUID(replicationTask.DomainID),
WorkflowID: &replicationTask.WorkflowID,
RunID: sqlplugin.MustParseUUID(replicationTask.RunID),
RunID: serialization.MustParseUUID(replicationTask.RunID),
TaskType: common.Int16Ptr(int16(replicationTask.TaskType)),
FirstEventID: &replicationTask.FirstEventID,
NextEventID: &replicationTask.NextEventID,
Expand Down
Loading

0 comments on commit 4318e53

Please sign in to comment.