Skip to content

Commit

Permalink
Return error if cluster metadata is invalid (temporalio#3879)
Browse files Browse the repository at this point in the history
* Return error if cluster metadata is invalid
  • Loading branch information
yux0 authored Feb 1, 2023
1 parent af00719 commit d9bac92
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 62 deletions.
25 changes: 15 additions & 10 deletions service/history/replication/poller_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
package replication

import (
"errors"
"fmt"

"go.temporal.io/server/common/cluster"
)

type (
pollerManager interface {
getSourceClusterShardIDs(sourceClusterName string) []int32
getSourceClusterShardIDs(sourceClusterName string) ([]int32, error)
}

pollerManagerImpl struct {
Expand All @@ -53,18 +54,27 @@ func newPollerManager(
}
}

func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) []int32 {
func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) ([]int32, error) {
currentCluster := p.clusterMetadata.GetCurrentClusterName()
allClusters := p.clusterMetadata.GetAllClusterInfo()
currentClusterInfo, ok := allClusters[currentCluster]
if !ok {
panic("Cannot get current cluster info from cluster metadata cache")
return nil, errors.New("cannot get current cluster info from cluster metadata cache")
}
remoteClusterInfo, ok := allClusters[sourceClusterName]
if !ok {
panic(fmt.Sprintf("Cannot get source cluster %s info from cluster metadata cache", sourceClusterName))
return nil, errors.New(fmt.Sprintf("cannot get source cluster %s info from cluster metadata cache", sourceClusterName))
}
return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount)

// The remote shard count and local shard count must be multiples.
large, small := remoteClusterInfo.ShardCount, currentClusterInfo.ShardCount
if small > large {
large, small = small, large
}
if large%small != 0 {
return nil, errors.New(fmt.Sprintf("remote shard count %d and local shard count %d are not multiples.", remoteClusterInfo.ShardCount, currentClusterInfo.ShardCount))
}
return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount), nil
}

func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCount int32) []int32 {
Expand All @@ -75,12 +85,7 @@ func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCoun
}
return pollingShards
}

// remoteShardCount > localShardCount, replication poller will poll from multiple remote shard.
// The remote shard count and local shard count must be multiples.
if remoteShardCount%localShardCount != 0 {
panic(fmt.Sprintf("Remote shard count %d and local shard count %d are not multiples.", remoteShardCount, localShardCount))
}
for i := localShardId; i <= remoteShardCount; i += localShardCount {
pollingShards = append(pollingShards, i)
}
Expand Down
23 changes: 8 additions & 15 deletions service/history/replication/poller_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,60 +36,53 @@ func TestGetPollingShardIds(t *testing.T) {
shardID int32
remoteShardCount int32
localShardCount int32
expectedPanic bool
expectedShardIDs []int32
}{
{
1,
4,
4,
false,
[]int32{1},
},
{
1,
2,
4,
false,
[]int32{1},
},
{
3,
2,
4,
false,
[]int32{},
nil,
},
{
1,
16,
4,
false,
[]int32{1, 5, 9, 13},
},
{
4,
16,
4,
false,
[]int32{4, 8, 12, 16},
},
{
4,
17,
4,
true,
[]int32{},
[]int32{4, 8, 12, 16},
},
{
1,
17,
4,
[]int32{1, 5, 9, 13, 17},
},
}
for idx, tt := range testCases {
t.Run(fmt.Sprintf("Testcase %d", idx), func(t *testing.T) {
t.Parallel()
defer func() {
if r := recover(); tt.expectedPanic && r == nil {
t.Errorf("The code did not panic")
}
}()
shardIDs := generateShardIDs(tt.shardID, tt.localShardCount, tt.remoteShardCount)
assert.Equal(t, tt.expectedShardIDs, shardIDs)
})
Expand Down
88 changes: 51 additions & 37 deletions service/history/replication/task_processor_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ package replication

import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -72,7 +71,7 @@ type (
logger log.Logger

taskProcessorLock sync.RWMutex
taskProcessors map[string]TaskProcessor
taskProcessors map[string][]TaskProcessor // cluster name - processor
minTxAckedTaskID int64
shutdownChan chan struct{}
}
Expand Down Expand Up @@ -114,7 +113,7 @@ func NewTaskProcessorManager(
),
logger: shard.GetLogger(),
metricsHandler: shard.GetMetricsHandler(),
taskProcessors: make(map[string]TaskProcessor),
taskProcessors: make(map[string][]TaskProcessor),
taskExecutorProvider: taskExecutorProvider,
taskPollerManager: newPollerManager(shard.GetShardID(), shard.GetClusterMetadata()),
minTxAckedTaskID: persistence.EmptyQueueMessageID,
Expand Down Expand Up @@ -149,8 +148,10 @@ func (r *taskProcessorManagerImpl) Stop() {

r.shard.GetClusterMetadata().UnRegisterMetadataChangeCallback(r)
r.taskProcessorLock.Lock()
for _, replicationTaskProcessor := range r.taskProcessors {
replicationTaskProcessor.Stop()
for _, taskProcessors := range r.taskProcessors {
for _, processor := range taskProcessors {
processor.Stop()
}
}
r.taskProcessorLock.Unlock()
}
Expand All @@ -170,44 +171,57 @@ func (r *taskProcessorManagerImpl) handleClusterMetadataUpdate(
r.taskProcessorLock.Lock()
defer r.taskProcessorLock.Unlock()
currentClusterName := r.shard.GetClusterMetadata().GetCurrentClusterName()
// The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address
// The callback covers three cases:
// Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata(1 + 2).

// Case 1 and Case 3
for clusterName := range oldClusterMetadata {
if clusterName == currentClusterName {
continue
}
sourceShardIds := r.taskPollerManager.getSourceClusterShardIDs(clusterName)
for _, processor := range r.taskProcessors[clusterName] {
processor.Stop()
delete(r.taskProcessors, clusterName)
}
}

// Case 2 and Case 3
for clusterName := range newClusterMetadata {
if clusterName == currentClusterName {
continue
}
if clusterInfo := newClusterMetadata[clusterName]; clusterInfo == nil || !clusterInfo.Enabled {
continue
}
sourceShardIds, err := r.taskPollerManager.getSourceClusterShardIDs(clusterName)
if err != nil {
r.logger.Error("Failed to get source shard id list", tag.Error(err), tag.ClusterName(clusterName))
continue
}
var processors []TaskProcessor
for _, sourceShardId := range sourceShardIds {
perShardTaskProcessorKey := fmt.Sprintf(clusterCallbackKey, clusterName, sourceShardId)
// The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address
// The callback covers three cases:
// Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata.
if processor, ok := r.taskProcessors[perShardTaskProcessorKey]; ok {
// Case 1 and Case 3
processor.Stop()
delete(r.taskProcessors, perShardTaskProcessorKey)
}
if clusterInfo := newClusterMetadata[clusterName]; clusterInfo != nil && clusterInfo.Enabled {
// Case 2 and Case 3
fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName)
replicationTaskProcessor := NewTaskProcessor(
sourceShardId,
r.shard,
r.engine,
r.config,
r.shard.GetMetricsHandler(),
fetcher,
r.taskExecutorProvider(TaskExecutorParams{
RemoteCluster: clusterName,
Shard: r.shard,
HistoryResender: r.resender,
DeleteManager: r.deleteMgr,
WorkflowCache: r.workflowCache,
}),
r.eventSerializer,
)
replicationTaskProcessor.Start()
r.taskProcessors[perShardTaskProcessorKey] = replicationTaskProcessor
}
fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName)
replicationTaskProcessor := NewTaskProcessor(
sourceShardId,
r.shard,
r.engine,
r.config,
r.shard.GetMetricsHandler(),
fetcher,
r.taskExecutorProvider(TaskExecutorParams{
RemoteCluster: clusterName,
Shard: r.shard,
HistoryResender: r.resender,
DeleteManager: r.deleteMgr,
WorkflowCache: r.workflowCache,
}),
r.eventSerializer,
)
replicationTaskProcessor.Start()
processors = append(processors, replicationTaskProcessor)
}
r.taskProcessors[clusterName] = processors
}
}

Expand Down

0 comments on commit d9bac92

Please sign in to comment.