From d60041fffc239ee2f2392a67fedee528fc45512d Mon Sep 17 00:00:00 2001 From: Zijian Date: Mon, 18 Mar 2024 15:24:39 -0700 Subject: [PATCH] Fix checksum validation for SQL implementation (#5790) What changed? Add a check for the SQL implementation of GetWorkflowExecution operation to exclude false positive checksum validation failure cases. Why? To make sure the checksum validation result is true, the data we read from GetWorkflowExecution operation are from a consistent view. In the NoSQL implementation, the operation is a single read, so the data is from a consistent view. However, in the SQL implementation, the operation is multiple reads from different table. If there is a concurrent update, the data we read isn't from a consistent view and the checksum validation could fail. Normally, we don't have concurrent updates with reads. But when the shard ownership changed, it might not be the case. --- common/persistence/data_manager_interfaces.go | 1 + common/persistence/data_store_interfaces.go | 1 + common/persistence/executionManager.go | 1 + .../persistence-tests/persistenceTestBase.go | 2 + common/persistence/serializer.go | 3 + common/persistence/serializer_test.go | 1 + common/persistence/sql/sql_execution_store.go | 35 +++-- .../sql/sql_execution_store_test.go | 127 ++++++++++++++++++ .../history/ndc/activity_replicator_test.go | 1 + .../history/ndc/transaction_manager_test.go | 2 + service/history/shard/context.go | 1 + service/history/shard/context_test.go | 3 + 12 files changed, 169 insertions(+), 9 deletions(-) diff --git a/common/persistence/data_manager_interfaces.go b/common/persistence/data_manager_interfaces.go index d9a2e395df4..66cdec42b54 100644 --- a/common/persistence/data_manager_interfaces.go +++ b/common/persistence/data_manager_interfaces.go @@ -864,6 +864,7 @@ type ( DomainID string Execution types.WorkflowExecution DomainName string + RangeID int64 } // GetWorkflowExecutionResponse is the response to GetworkflowExecutionRequest diff --git a/common/persistence/data_store_interfaces.go b/common/persistence/data_store_interfaces.go index 03ccaf4ba1d..ab2527bcdae 100644 --- a/common/persistence/data_store_interfaces.go +++ b/common/persistence/data_store_interfaces.go @@ -531,6 +531,7 @@ type ( InternalGetWorkflowExecutionRequest struct { DomainID string Execution types.WorkflowExecution + RangeID int64 } // InternalGetWorkflowExecutionResponse is the response to GetWorkflowExecution for Persistence Interface diff --git a/common/persistence/executionManager.go b/common/persistence/executionManager.go index bd8b2ea932c..acd4e54d77c 100644 --- a/common/persistence/executionManager.go +++ b/common/persistence/executionManager.go @@ -73,6 +73,7 @@ func (m *executionManagerImpl) GetWorkflowExecution( internalRequest := &InternalGetWorkflowExecutionRequest{ DomainID: request.DomainID, Execution: request.Execution, + RangeID: request.RangeID, } response, err := m.persistence.GetWorkflowExecution(ctx, internalRequest) if err != nil { diff --git a/common/persistence/persistence-tests/persistenceTestBase.go b/common/persistence/persistence-tests/persistenceTestBase.go index 2fb672107e0..1c6faa79163 100644 --- a/common/persistence/persistence-tests/persistenceTestBase.go +++ b/common/persistence/persistence-tests/persistenceTestBase.go @@ -476,6 +476,7 @@ func (s *TestBase) GetWorkflowExecutionInfoWithStats(ctx context.Context, domain response, err := s.ExecutionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{ DomainID: domainID, Execution: workflowExecution, + RangeID: s.ShardInfo.RangeID, }) if err != nil { return nil, nil, err @@ -490,6 +491,7 @@ func (s *TestBase) GetWorkflowExecutionInfo(ctx context.Context, domainID string response, err := s.ExecutionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{ DomainID: domainID, Execution: workflowExecution, + RangeID: s.ShardInfo.RangeID, }) if err != nil { return nil, err diff --git a/common/persistence/serializer.go b/common/persistence/serializer.go index be20c54ceac..b59ca9baf39 100644 --- a/common/persistence/serializer.go +++ b/common/persistence/serializer.go @@ -304,6 +304,9 @@ func (t *serializerImpl) DeserializeAsyncWorkflowsConfig(data *DataBlob) (*types } func (t *serializerImpl) SerializeChecksum(sum checksum.Checksum, encodingType common.EncodingType) (*DataBlob, error) { + if len(sum.Value) == 0 { + return nil, nil + } return t.serialize(sum, encodingType) } diff --git a/common/persistence/serializer_test.go b/common/persistence/serializer_test.go index 4b861c21a5c..164fb359074 100644 --- a/common/persistence/serializer_test.go +++ b/common/persistence/serializer_test.go @@ -214,6 +214,7 @@ func TestSerializers(t *testing.T) { { name: "checksum", payloads: map[string]any{ + "empty": checksum.Checksum{}, "normal": generateChecksum(), }, serializeFn: func(payload any, encoding common.EncodingType) (*DataBlob, error) { diff --git a/common/persistence/sql/sql_execution_store.go b/common/persistence/sql/sql_execution_store.go index 459bcd3a9aa..b5b4274c4d4 100644 --- a/common/persistence/sql/sql_execution_store.go +++ b/common/persistence/sql/sql_execution_store.go @@ -307,60 +307,60 @@ func (m *sqlExecutionStore) GetWorkflowExecution( var bufferedEvents []*p.DataBlob var signalsRequested map[string]struct{} - g, ctx := errgroup.WithContext(ctx) + g, childCtx := errgroup.WithContext(ctx) g.Go(func() (e error) { defer func() { recoverPanic(recover(), &e) }() - executions, e = m.getExecutions(ctx, request, domainID, wfID, runID) + executions, e = m.getExecutions(childCtx, request, domainID, wfID, runID) return e }) g.Go(func() (e error) { defer func() { recoverPanic(recover(), &e) }() activityInfos, e = getActivityInfoMap( - ctx, m.db, m.shardID, domainID, wfID, runID, m.parser) + childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser) return e }) g.Go(func() (e error) { defer func() { recoverPanic(recover(), &e) }() timerInfos, e = getTimerInfoMap( - ctx, m.db, m.shardID, domainID, wfID, runID, m.parser) + childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser) return e }) g.Go(func() (e error) { defer func() { recoverPanic(recover(), &e) }() childExecutionInfos, e = getChildExecutionInfoMap( - ctx, m.db, m.shardID, domainID, wfID, runID, m.parser) + childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser) return e }) g.Go(func() (e error) { defer func() { recoverPanic(recover(), &e) }() requestCancelInfos, e = getRequestCancelInfoMap( - ctx, m.db, m.shardID, domainID, wfID, runID, m.parser) + childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser) return e }) g.Go(func() (e error) { defer func() { recoverPanic(recover(), &e) }() signalInfos, e = getSignalInfoMap( - ctx, m.db, m.shardID, domainID, wfID, runID, m.parser) + childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser) return e }) g.Go(func() (e error) { defer func() { recoverPanic(recover(), &e) }() bufferedEvents, e = getBufferedEvents( - ctx, m.db, m.shardID, domainID, wfID, runID) + childCtx, m.db, m.shardID, domainID, wfID, runID) return e }) g.Go(func() (e error) { defer func() { recoverPanic(recover(), &e) }() signalsRequested, e = getSignalsRequested( - ctx, m.db, m.shardID, domainID, wfID, runID) + childCtx, m.db, m.shardID, domainID, wfID, runID) return e }) @@ -375,6 +375,23 @@ func (m *sqlExecutionStore) GetWorkflowExecution( Message: fmt.Sprintf("GetWorkflowExecution: failed. Error: %v", err), } } + // if we have checksum, we need to make sure the rangeID did not change + // if the rangeID changed, it means the shard ownership might have changed + // and the workflow might have been updated when we read the data, so the data + // we read might not be from a consistent view, the checksum validation might fail + // in that case, we clear the checksum data so that we will not perform the validation + if state.ChecksumData != nil { + row, err := m.db.SelectFromShards(ctx, &sqlplugin.ShardsFilter{ShardID: int64(m.shardID)}) + if err != nil { + return nil, convertCommonErrors(m.db, "GetWorkflowExecution", "", err) + } + if row.RangeID != request.RangeID { + // The GetWorkflowExecution operation will not be impacted by this. ChecksumData is purely for validation purposes. + m.logger.Warn("GetWorkflowExecution's checksum is discarded. The shard might have changed owner.") + state.ChecksumData = nil + } + } + state.ActivityInfos = activityInfos state.TimerInfos = timerInfos state.ChildExecutionInfos = childExecutionInfos diff --git a/common/persistence/sql/sql_execution_store_test.go b/common/persistence/sql/sql_execution_store_test.go index 4df8cab37e8..514d10a9e12 100644 --- a/common/persistence/sql/sql_execution_store_test.go +++ b/common/persistence/sql/sql_execution_store_test.go @@ -2975,6 +2975,7 @@ func TestGetWorkflowExecution(t *testing.T) { mockSetup func(*sqlplugin.MockDB, *serialization.MockParser) want *persistence.InternalGetWorkflowExecutionResponse wantErr bool + assertErr func(t *testing.T, err error) }{ { name: "Success case", @@ -2984,6 +2985,7 @@ func TestGetWorkflowExecution(t *testing.T) { WorkflowID: "test-workflow-id", RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f", }, + RangeID: 1, }, mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) { db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{ @@ -3204,6 +3206,9 @@ func TestGetWorkflowExecution(t *testing.T) { Control: []byte("test control"), RequestID: "test-signal-request-id", }, nil) + db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(&sqlplugin.ShardsRow{ + RangeID: 1, + }, nil) }, want: &persistence.InternalGetWorkflowExecutionResponse{ State: &persistence.InternalWorkflowMutableState{ @@ -3366,6 +3371,125 @@ func TestGetWorkflowExecution(t *testing.T) { }, wantErr: false, }, + { + name: "Error - Shard owner changed", + req: &persistence.InternalGetWorkflowExecutionRequest{ + DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d", + Execution: types.WorkflowExecution{ + WorkflowID: "test-workflow-id", + RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f", + }, + }, + mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) { + db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{ + { + ShardID: 0, + DomainID: serialization.MustParseUUID("ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d"), + WorkflowID: "test-workflow-id", + RunID: serialization.MustParseUUID("ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f"), + NextEventID: 101, + LastWriteVersion: 11, + Data: []byte("test data"), + DataEncoding: "thriftrw", + }, + }, nil) + db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + parser.EXPECT().WorkflowExecutionInfoFromBlob(gomock.Any(), gomock.Any()).Return(&serialization.WorkflowExecutionInfo{ + Checksum: []byte("test-checksum"), + ChecksumEncoding: "test-checksum-encoding", + }, nil) + db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(&sqlplugin.ShardsRow{ + RangeID: 1, + }, nil) + }, + want: &persistence.InternalGetWorkflowExecutionResponse{ + State: &persistence.InternalWorkflowMutableState{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{ + DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d", + WorkflowID: "test-workflow-id", + RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f", + NextEventID: 101, + CompletionEventBatchID: -23, + }, + ActivityInfos: map[int64]*persistence.InternalActivityInfo{}, + TimerInfos: map[string]*persistence.TimerInfo{}, + ChildExecutionInfos: map[int64]*persistence.InternalChildExecutionInfo{}, + RequestCancelInfos: map[int64]*persistence.RequestCancelInfo{}, + SignalInfos: map[int64]*persistence.SignalInfo{}, + SignalRequestedIDs: map[string]struct{}{}, + ChecksumData: nil, + }, + }, + wantErr: false, + }, + { + name: "Error - failed to get shard", + req: &persistence.InternalGetWorkflowExecutionRequest{ + DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d", + Execution: types.WorkflowExecution{ + WorkflowID: "test-workflow-id", + RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f", + }, + }, + mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) { + db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{ + { + ShardID: 0, + DomainID: serialization.MustParseUUID("ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d"), + WorkflowID: "test-workflow-id", + RunID: serialization.MustParseUUID("ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f"), + NextEventID: 101, + LastWriteVersion: 11, + Data: []byte("test data"), + DataEncoding: "thriftrw", + }, + }, nil) + db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + parser.EXPECT().WorkflowExecutionInfoFromBlob(gomock.Any(), gomock.Any()).Return(&serialization.WorkflowExecutionInfo{ + Checksum: []byte("test-checksum"), + ChecksumEncoding: "test-checksum-encoding", + }, nil) + db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes() + }, + wantErr: true, + }, + { + name: "Error - SelectFromExecutions no row", + req: &persistence.InternalGetWorkflowExecutionRequest{ + DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d", + Execution: types.WorkflowExecution{ + WorkflowID: "test-workflow-id", + RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f", + }, + }, + mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) { + db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, &types.EntityNotExistsError{}, err) + }, + }, { name: "Error - SelectFromExecutions failed", req: &persistence.InternalGetWorkflowExecutionRequest{ @@ -3562,6 +3686,9 @@ func TestGetWorkflowExecution(t *testing.T) { resp, err := s.GetWorkflowExecution(context.Background(), tc.req) if tc.wantErr { assert.Error(t, err, "Expected an error for test case") + if tc.assertErr != nil { + tc.assertErr(t, err) + } } else { assert.NoError(t, err, "Did not expect an error for test case") assert.Equal(t, tc.want, resp, "Response mismatch") diff --git a/service/history/ndc/activity_replicator_test.go b/service/history/ndc/activity_replicator_test.go index ee1c26651c6..528bcb6ac88 100644 --- a/service/history/ndc/activity_replicator_test.go +++ b/service/history/ndc/activity_replicator_test.go @@ -141,6 +141,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_WorkflowNotFound() { RunID: runID, }, DomainName: domainName, + RangeID: 1, }).Return(nil, &types.EntityNotExistsError{}) s.mockDomainCache.EXPECT().GetDomainByID(domainID).Return( cache.NewGlobalDomainCacheEntryForTest( diff --git a/service/history/ndc/transaction_manager_test.go b/service/history/ndc/transaction_manager_test.go index 5f0747d3707..a51c147f546 100644 --- a/service/history/ndc/transaction_manager_test.go +++ b/service/history/ndc/transaction_manager_test.go @@ -443,6 +443,7 @@ func (s *transactionManagerSuite) TestCheckWorkflowExists_DoesNotExists() { RunID: runID, }, DomainName: domainName, + RangeID: 1, }).Return(nil, &types.EntityNotExistsError{}).Once() exists, err := s.transactionManager.checkWorkflowExists(ctx, domainID, workflowID, runID) @@ -465,6 +466,7 @@ func (s *transactionManagerSuite) TestCheckWorkflowExists_DoesExists() { RunID: runID, }, DomainName: domainName, + RangeID: 1, }).Return(&persistence.GetWorkflowExecutionResponse{}, nil).Once() exists, err := s.transactionManager.checkWorkflowExists(ctx, domainID, workflowID, runID) diff --git a/service/history/shard/context.go b/service/history/shard/context.go index 26f9c4531bc..f3f223e360d 100644 --- a/service/history/shard/context.go +++ b/service/history/shard/context.go @@ -585,6 +585,7 @@ func (s *contextImpl) GetWorkflowExecution( ctx context.Context, request *persistence.GetWorkflowExecutionRequest, ) (*persistence.GetWorkflowExecutionResponse, error) { + request.RangeID = atomic.LoadInt64(&s.rangeID) // This is to make sure read is not blocked by write, s.rangeID is synced with s.shardInfo.RangeID if s.isClosed() { return nil, ErrShardClosed } diff --git a/service/history/shard/context_test.go b/service/history/shard/context_test.go index 42d0975e873..855434c129f 100644 --- a/service/history/shard/context_test.go +++ b/service/history/shard/context_test.go @@ -297,6 +297,9 @@ func TestGetWorkflowExecution(t *testing.T) { mockExecutionMgr := &mocks.ExecutionManager{} shardContext := &contextImpl{ executionManager: mockExecutionMgr, + shardInfo: &persistence.ShardInfo{ + RangeID: 12, + }, } if tc.isClosed { shardContext.closed = 1