diff --git a/service/history/replication/poller_manager.go b/service/history/replication/poller_manager.go index 7e392077d2f..90fde64536c 100644 --- a/service/history/replication/poller_manager.go +++ b/service/history/replication/poller_manager.go @@ -25,6 +25,7 @@ package replication import ( + "errors" "fmt" "go.temporal.io/server/common/cluster" @@ -32,7 +33,7 @@ import ( type ( pollerManager interface { - getSourceClusterShardIDs(sourceClusterName string) []int32 + getSourceClusterShardIDs(sourceClusterName string) ([]int32, error) } pollerManagerImpl struct { @@ -53,18 +54,27 @@ func newPollerManager( } } -func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) []int32 { +func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) ([]int32, error) { currentCluster := p.clusterMetadata.GetCurrentClusterName() allClusters := p.clusterMetadata.GetAllClusterInfo() currentClusterInfo, ok := allClusters[currentCluster] if !ok { - panic("Cannot get current cluster info from cluster metadata cache") + return nil, errors.New("cannot get current cluster info from cluster metadata cache") } remoteClusterInfo, ok := allClusters[sourceClusterName] if !ok { - panic(fmt.Sprintf("Cannot get source cluster %s info from cluster metadata cache", sourceClusterName)) + return nil, errors.New(fmt.Sprintf("cannot get source cluster %s info from cluster metadata cache", sourceClusterName)) } - return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount) + + // The remote shard count and local shard count must be multiples. + large, small := remoteClusterInfo.ShardCount, currentClusterInfo.ShardCount + if small > large { + large, small = small, large + } + if large%small != 0 { + return nil, errors.New(fmt.Sprintf("remote shard count %d and local shard count %d are not multiples.", remoteClusterInfo.ShardCount, currentClusterInfo.ShardCount)) + } + return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount), nil } func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCount int32) []int32 { @@ -75,12 +85,7 @@ func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCoun } return pollingShards } - // remoteShardCount > localShardCount, replication poller will poll from multiple remote shard. - // The remote shard count and local shard count must be multiples. - if remoteShardCount%localShardCount != 0 { - panic(fmt.Sprintf("Remote shard count %d and local shard count %d are not multiples.", remoteShardCount, localShardCount)) - } for i := localShardId; i <= remoteShardCount; i += localShardCount { pollingShards = append(pollingShards, i) } diff --git a/service/history/replication/poller_manager_test.go b/service/history/replication/poller_manager_test.go index abe69d9f79f..1166ad65dc0 100644 --- a/service/history/replication/poller_manager_test.go +++ b/service/history/replication/poller_manager_test.go @@ -36,60 +36,53 @@ func TestGetPollingShardIds(t *testing.T) { shardID int32 remoteShardCount int32 localShardCount int32 - expectedPanic bool expectedShardIDs []int32 }{ { 1, 4, 4, - false, []int32{1}, }, { 1, 2, 4, - false, []int32{1}, }, { 3, 2, 4, - false, - []int32{}, + nil, }, { 1, 16, 4, - false, []int32{1, 5, 9, 13}, }, { 4, 16, 4, - false, []int32{4, 8, 12, 16}, }, { 4, 17, 4, - true, - []int32{}, + []int32{4, 8, 12, 16}, + }, + { + 1, + 17, + 4, + []int32{1, 5, 9, 13, 17}, }, } for idx, tt := range testCases { t.Run(fmt.Sprintf("Testcase %d", idx), func(t *testing.T) { - t.Parallel() - defer func() { - if r := recover(); tt.expectedPanic && r == nil { - t.Errorf("The code did not panic") - } - }() shardIDs := generateShardIDs(tt.shardID, tt.localShardCount, tt.remoteShardCount) assert.Equal(t, tt.expectedShardIDs, shardIDs) }) diff --git a/service/history/replication/task_processor_manager.go b/service/history/replication/task_processor_manager.go index 8da536a875e..22a1b3d5214 100644 --- a/service/history/replication/task_processor_manager.go +++ b/service/history/replication/task_processor_manager.go @@ -26,7 +26,6 @@ package replication import ( "context" - "fmt" "sync" "sync/atomic" "time" @@ -72,7 +71,7 @@ type ( logger log.Logger taskProcessorLock sync.RWMutex - taskProcessors map[string]TaskProcessor + taskProcessors map[string][]TaskProcessor // cluster name - processor minTxAckedTaskID int64 shutdownChan chan struct{} } @@ -114,7 +113,7 @@ func NewTaskProcessorManager( ), logger: shard.GetLogger(), metricsHandler: shard.GetMetricsHandler(), - taskProcessors: make(map[string]TaskProcessor), + taskProcessors: make(map[string][]TaskProcessor), taskExecutorProvider: taskExecutorProvider, taskPollerManager: newPollerManager(shard.GetShardID(), shard.GetClusterMetadata()), minTxAckedTaskID: persistence.EmptyQueueMessageID, @@ -149,8 +148,10 @@ func (r *taskProcessorManagerImpl) Stop() { r.shard.GetClusterMetadata().UnRegisterMetadataChangeCallback(r) r.taskProcessorLock.Lock() - for _, replicationTaskProcessor := range r.taskProcessors { - replicationTaskProcessor.Stop() + for _, taskProcessors := range r.taskProcessors { + for _, processor := range taskProcessors { + processor.Stop() + } } r.taskProcessorLock.Unlock() } @@ -170,44 +171,57 @@ func (r *taskProcessorManagerImpl) handleClusterMetadataUpdate( r.taskProcessorLock.Lock() defer r.taskProcessorLock.Unlock() currentClusterName := r.shard.GetClusterMetadata().GetCurrentClusterName() + // The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address + // The callback covers three cases: + // Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata(1 + 2). + + // Case 1 and Case 3 for clusterName := range oldClusterMetadata { if clusterName == currentClusterName { continue } - sourceShardIds := r.taskPollerManager.getSourceClusterShardIDs(clusterName) + for _, processor := range r.taskProcessors[clusterName] { + processor.Stop() + delete(r.taskProcessors, clusterName) + } + } + + // Case 2 and Case 3 + for clusterName := range newClusterMetadata { + if clusterName == currentClusterName { + continue + } + if clusterInfo := newClusterMetadata[clusterName]; clusterInfo == nil || !clusterInfo.Enabled { + continue + } + sourceShardIds, err := r.taskPollerManager.getSourceClusterShardIDs(clusterName) + if err != nil { + r.logger.Error("Failed to get source shard id list", tag.Error(err), tag.ClusterName(clusterName)) + continue + } + var processors []TaskProcessor for _, sourceShardId := range sourceShardIds { - perShardTaskProcessorKey := fmt.Sprintf(clusterCallbackKey, clusterName, sourceShardId) - // The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address - // The callback covers three cases: - // Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata. - if processor, ok := r.taskProcessors[perShardTaskProcessorKey]; ok { - // Case 1 and Case 3 - processor.Stop() - delete(r.taskProcessors, perShardTaskProcessorKey) - } - if clusterInfo := newClusterMetadata[clusterName]; clusterInfo != nil && clusterInfo.Enabled { - // Case 2 and Case 3 - fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName) - replicationTaskProcessor := NewTaskProcessor( - sourceShardId, - r.shard, - r.engine, - r.config, - r.shard.GetMetricsHandler(), - fetcher, - r.taskExecutorProvider(TaskExecutorParams{ - RemoteCluster: clusterName, - Shard: r.shard, - HistoryResender: r.resender, - DeleteManager: r.deleteMgr, - WorkflowCache: r.workflowCache, - }), - r.eventSerializer, - ) - replicationTaskProcessor.Start() - r.taskProcessors[perShardTaskProcessorKey] = replicationTaskProcessor - } + fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName) + replicationTaskProcessor := NewTaskProcessor( + sourceShardId, + r.shard, + r.engine, + r.config, + r.shard.GetMetricsHandler(), + fetcher, + r.taskExecutorProvider(TaskExecutorParams{ + RemoteCluster: clusterName, + Shard: r.shard, + HistoryResender: r.resender, + DeleteManager: r.deleteMgr, + WorkflowCache: r.workflowCache, + }), + r.eventSerializer, + ) + replicationTaskProcessor.Start() + processors = append(processors, replicationTaskProcessor) } + r.taskProcessors[clusterName] = processors } }