Skip to content

Commit

Permalink
Fix checksum validation for SQL implementation (cadence-workflow#5790)
Browse files Browse the repository at this point in the history
What changed?
Add a check for the SQL implementation of GetWorkflowExecution operation to exclude false positive checksum validation failure cases.

Why?
To make sure the checksum validation result is true, the data we read from GetWorkflowExecution operation are from a consistent view. In the NoSQL implementation, the operation is a single read, so the data is from a consistent view. However, in the SQL implementation, the operation is multiple reads from different table. If there is a concurrent update, the data we read isn't from a consistent view and the checksum validation could fail. Normally, we don't have concurrent updates with reads. But when the shard ownership changed, it might not be the case.
  • Loading branch information
Shaddoll authored Mar 18, 2024
1 parent d43b582 commit d60041f
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 9 deletions.
1 change: 1 addition & 0 deletions common/persistence/data_manager_interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ type (
DomainID string
Execution types.WorkflowExecution
DomainName string
RangeID int64
}

// GetWorkflowExecutionResponse is the response to GetworkflowExecutionRequest
Expand Down
1 change: 1 addition & 0 deletions common/persistence/data_store_interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ type (
InternalGetWorkflowExecutionRequest struct {
DomainID string
Execution types.WorkflowExecution
RangeID int64
}

// InternalGetWorkflowExecutionResponse is the response to GetWorkflowExecution for Persistence Interface
Expand Down
1 change: 1 addition & 0 deletions common/persistence/executionManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (m *executionManagerImpl) GetWorkflowExecution(
internalRequest := &InternalGetWorkflowExecutionRequest{
DomainID: request.DomainID,
Execution: request.Execution,
RangeID: request.RangeID,
}
response, err := m.persistence.GetWorkflowExecution(ctx, internalRequest)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions common/persistence/persistence-tests/persistenceTestBase.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ func (s *TestBase) GetWorkflowExecutionInfoWithStats(ctx context.Context, domain
response, err := s.ExecutionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{
DomainID: domainID,
Execution: workflowExecution,
RangeID: s.ShardInfo.RangeID,
})
if err != nil {
return nil, nil, err
Expand All @@ -490,6 +491,7 @@ func (s *TestBase) GetWorkflowExecutionInfo(ctx context.Context, domainID string
response, err := s.ExecutionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{
DomainID: domainID,
Execution: workflowExecution,
RangeID: s.ShardInfo.RangeID,
})
if err != nil {
return nil, err
Expand Down
3 changes: 3 additions & 0 deletions common/persistence/serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ func (t *serializerImpl) DeserializeAsyncWorkflowsConfig(data *DataBlob) (*types
}

func (t *serializerImpl) SerializeChecksum(sum checksum.Checksum, encodingType common.EncodingType) (*DataBlob, error) {
if len(sum.Value) == 0 {
return nil, nil
}
return t.serialize(sum, encodingType)
}

Expand Down
1 change: 1 addition & 0 deletions common/persistence/serializer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ func TestSerializers(t *testing.T) {
{
name: "checksum",
payloads: map[string]any{
"empty": checksum.Checksum{},
"normal": generateChecksum(),
},
serializeFn: func(payload any, encoding common.EncodingType) (*DataBlob, error) {
Expand Down
35 changes: 26 additions & 9 deletions common/persistence/sql/sql_execution_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,60 +307,60 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
var bufferedEvents []*p.DataBlob
var signalsRequested map[string]struct{}

g, ctx := errgroup.WithContext(ctx)
g, childCtx := errgroup.WithContext(ctx)

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
executions, e = m.getExecutions(ctx, request, domainID, wfID, runID)
executions, e = m.getExecutions(childCtx, request, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
activityInfos, e = getActivityInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
timerInfos, e = getTimerInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
childExecutionInfos, e = getChildExecutionInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
requestCancelInfos, e = getRequestCancelInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
signalInfos, e = getSignalInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
bufferedEvents, e = getBufferedEvents(
ctx, m.db, m.shardID, domainID, wfID, runID)
childCtx, m.db, m.shardID, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
signalsRequested, e = getSignalsRequested(
ctx, m.db, m.shardID, domainID, wfID, runID)
childCtx, m.db, m.shardID, domainID, wfID, runID)
return e
})

Expand All @@ -375,6 +375,23 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
Message: fmt.Sprintf("GetWorkflowExecution: failed. Error: %v", err),
}
}
// if we have checksum, we need to make sure the rangeID did not change
// if the rangeID changed, it means the shard ownership might have changed
// and the workflow might have been updated when we read the data, so the data
// we read might not be from a consistent view, the checksum validation might fail
// in that case, we clear the checksum data so that we will not perform the validation
if state.ChecksumData != nil {
row, err := m.db.SelectFromShards(ctx, &sqlplugin.ShardsFilter{ShardID: int64(m.shardID)})
if err != nil {
return nil, convertCommonErrors(m.db, "GetWorkflowExecution", "", err)
}
if row.RangeID != request.RangeID {
// The GetWorkflowExecution operation will not be impacted by this. ChecksumData is purely for validation purposes.
m.logger.Warn("GetWorkflowExecution's checksum is discarded. The shard might have changed owner.")
state.ChecksumData = nil
}
}

state.ActivityInfos = activityInfos
state.TimerInfos = timerInfos
state.ChildExecutionInfos = childExecutionInfos
Expand Down
127 changes: 127 additions & 0 deletions common/persistence/sql/sql_execution_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2975,6 +2975,7 @@ func TestGetWorkflowExecution(t *testing.T) {
mockSetup func(*sqlplugin.MockDB, *serialization.MockParser)
want *persistence.InternalGetWorkflowExecutionResponse
wantErr bool
assertErr func(t *testing.T, err error)
}{
{
name: "Success case",
Expand All @@ -2984,6 +2985,7 @@ func TestGetWorkflowExecution(t *testing.T) {
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
RangeID: 1,
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{
Expand Down Expand Up @@ -3204,6 +3206,9 @@ func TestGetWorkflowExecution(t *testing.T) {
Control: []byte("test control"),
RequestID: "test-signal-request-id",
}, nil)
db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(&sqlplugin.ShardsRow{
RangeID: 1,
}, nil)
},
want: &persistence.InternalGetWorkflowExecutionResponse{
State: &persistence.InternalWorkflowMutableState{
Expand Down Expand Up @@ -3366,6 +3371,125 @@ func TestGetWorkflowExecution(t *testing.T) {
},
wantErr: false,
},
{
name: "Error - Shard owner changed",
req: &persistence.InternalGetWorkflowExecutionRequest{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
Execution: types.WorkflowExecution{
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{
{
ShardID: 0,
DomainID: serialization.MustParseUUID("ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d"),
WorkflowID: "test-workflow-id",
RunID: serialization.MustParseUUID("ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f"),
NextEventID: 101,
LastWriteVersion: 11,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}, nil)
db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
parser.EXPECT().WorkflowExecutionInfoFromBlob(gomock.Any(), gomock.Any()).Return(&serialization.WorkflowExecutionInfo{
Checksum: []byte("test-checksum"),
ChecksumEncoding: "test-checksum-encoding",
}, nil)
db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(&sqlplugin.ShardsRow{
RangeID: 1,
}, nil)
},
want: &persistence.InternalGetWorkflowExecutionResponse{
State: &persistence.InternalWorkflowMutableState{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
NextEventID: 101,
CompletionEventBatchID: -23,
},
ActivityInfos: map[int64]*persistence.InternalActivityInfo{},
TimerInfos: map[string]*persistence.TimerInfo{},
ChildExecutionInfos: map[int64]*persistence.InternalChildExecutionInfo{},
RequestCancelInfos: map[int64]*persistence.RequestCancelInfo{},
SignalInfos: map[int64]*persistence.SignalInfo{},
SignalRequestedIDs: map[string]struct{}{},
ChecksumData: nil,
},
},
wantErr: false,
},
{
name: "Error - failed to get shard",
req: &persistence.InternalGetWorkflowExecutionRequest{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
Execution: types.WorkflowExecution{
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{
{
ShardID: 0,
DomainID: serialization.MustParseUUID("ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d"),
WorkflowID: "test-workflow-id",
RunID: serialization.MustParseUUID("ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f"),
NextEventID: 101,
LastWriteVersion: 11,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}, nil)
db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
parser.EXPECT().WorkflowExecutionInfoFromBlob(gomock.Any(), gomock.Any()).Return(&serialization.WorkflowExecutionInfo{
Checksum: []byte("test-checksum"),
ChecksumEncoding: "test-checksum-encoding",
}, nil)
db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes()
},
wantErr: true,
},
{
name: "Error - SelectFromExecutions no row",
req: &persistence.InternalGetWorkflowExecutionRequest{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
Execution: types.WorkflowExecution{
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
},
wantErr: true,
assertErr: func(t *testing.T, err error) {
assert.IsType(t, &types.EntityNotExistsError{}, err)
},
},
{
name: "Error - SelectFromExecutions failed",
req: &persistence.InternalGetWorkflowExecutionRequest{
Expand Down Expand Up @@ -3562,6 +3686,9 @@ func TestGetWorkflowExecution(t *testing.T) {
resp, err := s.GetWorkflowExecution(context.Background(), tc.req)
if tc.wantErr {
assert.Error(t, err, "Expected an error for test case")
if tc.assertErr != nil {
tc.assertErr(t, err)
}
} else {
assert.NoError(t, err, "Did not expect an error for test case")
assert.Equal(t, tc.want, resp, "Response mismatch")
Expand Down
1 change: 1 addition & 0 deletions service/history/ndc/activity_replicator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_WorkflowNotFound() {
RunID: runID,
},
DomainName: domainName,
RangeID: 1,
}).Return(nil, &types.EntityNotExistsError{})
s.mockDomainCache.EXPECT().GetDomainByID(domainID).Return(
cache.NewGlobalDomainCacheEntryForTest(
Expand Down
2 changes: 2 additions & 0 deletions service/history/ndc/transaction_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ func (s *transactionManagerSuite) TestCheckWorkflowExists_DoesNotExists() {
RunID: runID,
},
DomainName: domainName,
RangeID: 1,
}).Return(nil, &types.EntityNotExistsError{}).Once()

exists, err := s.transactionManager.checkWorkflowExists(ctx, domainID, workflowID, runID)
Expand All @@ -465,6 +466,7 @@ func (s *transactionManagerSuite) TestCheckWorkflowExists_DoesExists() {
RunID: runID,
},
DomainName: domainName,
RangeID: 1,
}).Return(&persistence.GetWorkflowExecutionResponse{}, nil).Once()

exists, err := s.transactionManager.checkWorkflowExists(ctx, domainID, workflowID, runID)
Expand Down
1 change: 1 addition & 0 deletions service/history/shard/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ func (s *contextImpl) GetWorkflowExecution(
ctx context.Context,
request *persistence.GetWorkflowExecutionRequest,
) (*persistence.GetWorkflowExecutionResponse, error) {
request.RangeID = atomic.LoadInt64(&s.rangeID) // This is to make sure read is not blocked by write, s.rangeID is synced with s.shardInfo.RangeID
if s.isClosed() {
return nil, ErrShardClosed
}
Expand Down
3 changes: 3 additions & 0 deletions service/history/shard/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ func TestGetWorkflowExecution(t *testing.T) {
mockExecutionMgr := &mocks.ExecutionManager{}
shardContext := &contextImpl{
executionManager: mockExecutionMgr,
shardInfo: &persistence.ShardInfo{
RangeID: 12,
},
}
if tc.isClosed {
shardContext.closed = 1
Expand Down

0 comments on commit d60041f

Please sign in to comment.