From 41f1acc662adf262bc56b26eb03c9737a43c361b Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Tue, 26 Nov 2019 14:09:10 -0800 Subject: [PATCH] Use new resource struct for frontend (#2851) --- cmd/server/server.go | 2 +- common/resource/resourceImpl.go | 22 + common/resource/resourceTest.go | 37 +- host/dynamicconfig.go | 3 +- host/onebox.go | 196 ++-- host/testcluster.go | 41 +- host/xdc/elasticsearch_test.go | 2 +- service/frontend/adminHandler.go | 93 +- service/frontend/adminHandler_test.go | 61 +- service/frontend/dcRedirectionHandler.go | 97 +- service/frontend/dcRedirectionHandler_test.go | 64 +- service/frontend/service.go | 189 ++-- service/frontend/workflowHandler.go | 304 +++--- service/frontend/workflowHandler_test.go | 905 +++++++----------- service/matching/service.go | 4 + service/worker/service.go | 31 +- 16 files changed, 822 insertions(+), 1229 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 7bf693e371b..c05f001fd2e 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -210,7 +210,7 @@ func (s *server) startService() common.Daemon { switch s.name { case frontendService: - daemon = frontend.NewService(¶ms) + daemon, err = frontend.NewService(¶ms) case historyService: daemon = history.NewService(¶ms) case matchingService: diff --git a/common/resource/resourceImpl.go b/common/resource/resourceImpl.go index c03f17310c8..04f5b65c312 100644 --- a/common/resource/resourceImpl.go +++ b/common/resource/resourceImpl.go @@ -236,6 +236,27 @@ func New( common.IsWhitelistServiceTransientError, ) + historyArchiverBootstrapContainer := &archiver.HistoryBootstrapContainer{ + HistoryV2Manager: persistenceBean.GetHistoryManager(), + Logger: logger, + MetricsClient: params.MetricsClient, + ClusterMetadata: params.ClusterMetadata, + DomainCache: domainCache, + } + visibilityArchiverBootstrapContainer := &archiver.VisibilityBootstrapContainer{ + Logger: logger, + MetricsClient: params.MetricsClient, + ClusterMetadata: params.ClusterMetadata, + DomainCache: domainCache, + } + if err := params.ArchiverProvider.RegisterBootstrapContainer( + serviceName, + historyArchiverBootstrapContainer, + visibilityArchiverBootstrapContainer, + ); err != nil { + return nil, err + } + impl = &Impl{ status: common.DaemonStatusInitialized, @@ -350,6 +371,7 @@ func (h *Impl) Stop() { } h.runtimeMetricsReporter.Stop() h.persistenceBean.Close() + h.visibilityMgr.Close() } // GetServiceName return service name diff --git a/common/resource/resourceTest.go b/common/resource/resourceTest.go index 08db8446736..70d5a7f500b 100644 --- a/common/resource/resourceTest.go +++ b/common/resource/resourceTest.go @@ -27,6 +27,7 @@ import ( "go.uber.org/cadence/.gen/go/cadence/workflowserviceclient" publicservicetest "go.uber.org/cadence/.gen/go/cadence/workflowservicetest" + "github.com/uber/cadence/.gen/go/admin/adminservicetest" "github.com/uber/cadence/.gen/go/cadence/workflowservicetest" "github.com/uber/cadence/.gen/go/history/historyservicetest" "github.com/uber/cadence/.gen/go/matching/matchingservicetest" @@ -78,11 +79,13 @@ type ( // internal services clients - SDKClient *publicservicetest.MockClient - FrontendClient *workflowservicetest.MockClient - MatchingClient *matchingservicetest.MockClient - HistoryClient *historyservicetest.MockClient - ClientBean *client.MockBean + SDKClient *publicservicetest.MockClient + FrontendClient *workflowservicetest.MockClient + MatchingClient *matchingservicetest.MockClient + HistoryClient *historyservicetest.MockClient + RemoteAdminClient *adminservicetest.MockClient + RemoteFrontendClient *workflowservicetest.MockClient + ClientBean *client.MockBean // persistence clients @@ -124,10 +127,14 @@ func NewTest( frontendClient := workflowservicetest.NewMockClient(controller) matchingClient := matchingservicetest.NewMockClient(controller) historyClient := historyservicetest.NewMockClient(controller) + remoteFrontendClient := workflowservicetest.NewMockClient(controller) + remoteAdminClient := adminservicetest.NewMockClient(controller) clientBean := client.NewMockBean(controller) clientBean.EXPECT().GetFrontendClient().Return(frontendClient).AnyTimes() clientBean.EXPECT().GetMatchingClient(gomock.Any()).Return(matchingClient, nil).AnyTimes() clientBean.EXPECT().GetHistoryClient().Return(historyClient).AnyTimes() + clientBean.EXPECT().GetRemoteAdminClient(gomock.Any()).Return(remoteAdminClient).AnyTimes() + clientBean.EXPECT().GetRemoteFrontendClient(gomock.Any()).Return(remoteFrontendClient).AnyTimes() metadataMgr := &mocks.MetadataManager{} taskMgr := &mocks.TaskManager{} @@ -166,11 +173,13 @@ func NewTest( // internal services clients - SDKClient: publicservicetest.NewMockClient(controller), - FrontendClient: frontendClient, - MatchingClient: matchingClient, - HistoryClient: historyClient, - ClientBean: clientBean, + SDKClient: publicservicetest.NewMockClient(controller), + FrontendClient: frontendClient, + MatchingClient: matchingClient, + HistoryClient: historyClient, + RemoteAdminClient: remoteAdminClient, + RemoteFrontendClient: remoteFrontendClient, + ClientBean: clientBean, // persistence clients @@ -314,12 +323,12 @@ func (s *Test) GetMatchingClient() matching.Client { // GetHistoryRawClient for testing func (s *Test) GetHistoryRawClient() history.Client { - return s.ClientBean.GetHistoryClient() + return s.HistoryClient } // GetHistoryClient for testing func (s *Test) GetHistoryClient() history.Client { - return s.ClientBean.GetHistoryClient() + return s.HistoryClient } // GetRemoteAdminClient for testing @@ -327,7 +336,7 @@ func (s *Test) GetRemoteAdminClient( cluster string, ) admin.Client { - return s.ClientBean.GetRemoteAdminClient(cluster) + return s.RemoteAdminClient } // GetRemoteFrontendClient for testing @@ -335,7 +344,7 @@ func (s *Test) GetRemoteFrontendClient( cluster string, ) frontend.Client { - return s.ClientBean.GetRemoteFrontendClient(cluster) + return s.RemoteFrontendClient } // GetClientBean for testing diff --git a/host/dynamicconfig.go b/host/dynamicconfig.go index 32cf7ef3c66..adcb56b4636 100644 --- a/host/dynamicconfig.go +++ b/host/dynamicconfig.go @@ -24,7 +24,6 @@ import ( "time" "github.com/uber/cadence/common" - "github.com/uber/cadence/common/service/dynamicconfig" ) @@ -32,6 +31,8 @@ var ( // Override value for integer keys for dynamic config intKeys = map[dynamicconfig.Key]int{ dynamicconfig.FrontendRPS: 3000, + dynamicconfig.FrontendVisibilityListMaxQPS: 100, + dynamicconfig.FrontendESIndexMaxResultWindow: defaultTestValueOfESIndexMaxResultWindow, dynamicconfig.MatchingNumTasklistWritePartitions: 3, dynamicconfig.MatchingNumTasklistReadPartitions: 3, } diff --git a/host/onebox.go b/host/onebox.go index 1c5db5b0cc8..a4c3c5aab46 100644 --- a/host/onebox.go +++ b/host/onebox.go @@ -27,7 +27,6 @@ import ( "time" "github.com/pborman/uuid" - "github.com/stretchr/testify/mock" "github.com/uber-go/tally" cwsc "go.uber.org/cadence/.gen/go/cadence/workflowserviceclient" @@ -53,7 +52,6 @@ import ( "github.com/uber/cadence/common/membership" "github.com/uber/cadence/common/messaging" "github.com/uber/cadence/common/metrics" - "github.com/uber/cadence/common/mocks" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/service" "github.com/uber/cadence/common/service/config" @@ -74,47 +72,45 @@ type Cadence interface { GetAdminClient() adminserviceclient.Interface GetFrontendClient() workflowserviceclient.Interface FrontendAddress() string - GetFrontendService() service.Service GetHistoryClient() historyserviceclient.Interface GetExecutionManagerFactory() persistence.ExecutionManagerFactory } type ( cadenceImpl struct { + frontendService common.Daemon matchingService common.Daemon workerService common.Daemon - adminHandler *frontend.AdminHandler - frontendHandler *frontend.WorkflowHandler - historyHandlers []*history.Handler - logger log.Logger - clusterMetadata cluster.Metadata - persistenceConfig config.Persistence - dispatcherProvider client.DispatcherProvider - messagingClient messaging.Client - metadataMgr persistence.MetadataManager - shardMgr persistence.ShardManager - historyV2Mgr persistence.HistoryManager - taskMgr persistence.TaskManager - visibilityMgr persistence.VisibilityManager - executionMgrFactory persistence.ExecutionManagerFactory - domainReplicationQueue persistence.DomainReplicationQueue - shutdownCh chan struct{} - shutdownWG sync.WaitGroup - frontEndService service.Service - historyService service.Service - clusterNo int // cluster number - replicator *replicator.Replicator - clientWorker archiver.ClientWorker - indexer *indexer.Indexer - enbaleNDC bool - archiverMetadata carchiver.ArchivalMetadata - archiverProvider provider.ArchiverProvider - historyConfig *HistoryConfig - esConfig *elasticsearch.Config - esClient elasticsearch.Client - workerConfig *WorkerConfig - mockFrontendClient map[string]frontendclient.Client + adminClient adminserviceclient.Interface + frontendClient workflowserviceclient.Interface + historyHandlers []*history.Handler + logger log.Logger + clusterMetadata cluster.Metadata + persistenceConfig config.Persistence + dispatcherProvider client.DispatcherProvider + messagingClient messaging.Client + metadataMgr persistence.MetadataManager + shardMgr persistence.ShardManager + historyV2Mgr persistence.HistoryManager + taskMgr persistence.TaskManager + visibilityMgr persistence.VisibilityManager + executionMgrFactory persistence.ExecutionManagerFactory + shutdownCh chan struct{} + shutdownWG sync.WaitGroup + historyService service.Service + clusterNo int // cluster number + replicator *replicator.Replicator + clientWorker archiver.ClientWorker + indexer *indexer.Indexer + enbaleNDC bool + archiverMetadata carchiver.ArchivalMetadata + archiverProvider provider.ArchiverProvider + historyConfig *HistoryConfig + esConfig *elasticsearch.Config + esClient elasticsearch.Client + workerConfig *WorkerConfig + mockFrontendClient map[string]frontendclient.Client } // HistoryConfig contains configs for history service @@ -148,7 +144,6 @@ type ( ESClient elasticsearch.Client WorkerConfig *WorkerConfig MockFrontendClient map[string]frontendclient.Client - DomainReplicationQueue persistence.DomainReplicationQueue } membershipFactoryImpl struct { @@ -160,28 +155,27 @@ type ( // NewCadence returns an instance that hosts full cadence in one process func NewCadence(params *CadenceParams) Cadence { return &cadenceImpl{ - logger: params.Logger, - clusterMetadata: params.ClusterMetadata, - persistenceConfig: params.PersistenceConfig, - dispatcherProvider: params.DispatcherProvider, - messagingClient: params.MessagingClient, - metadataMgr: params.MetadataMgr, - visibilityMgr: params.VisibilityMgr, - shardMgr: params.ShardMgr, - historyV2Mgr: params.HistoryV2Mgr, - taskMgr: params.TaskMgr, - executionMgrFactory: params.ExecutionMgrFactory, - domainReplicationQueue: params.DomainReplicationQueue, - shutdownCh: make(chan struct{}), - clusterNo: params.ClusterNo, - enbaleNDC: params.EnableNDC, - esConfig: params.ESConfig, - esClient: params.ESClient, - archiverMetadata: params.ArchiverMetadata, - archiverProvider: params.ArchiverProvider, - historyConfig: params.HistoryConfig, - workerConfig: params.WorkerConfig, - mockFrontendClient: params.MockFrontendClient, + logger: params.Logger, + clusterMetadata: params.ClusterMetadata, + persistenceConfig: params.PersistenceConfig, + dispatcherProvider: params.DispatcherProvider, + messagingClient: params.MessagingClient, + metadataMgr: params.MetadataMgr, + visibilityMgr: params.VisibilityMgr, + shardMgr: params.ShardMgr, + historyV2Mgr: params.HistoryV2Mgr, + taskMgr: params.TaskMgr, + executionMgrFactory: params.ExecutionMgrFactory, + shutdownCh: make(chan struct{}), + clusterNo: params.ClusterNo, + enbaleNDC: params.EnableNDC, + esConfig: params.ESConfig, + esClient: params.ESClient, + archiverMetadata: params.ArchiverMetadata, + archiverProvider: params.ArchiverProvider, + historyConfig: params.HistoryConfig, + workerConfig: params.WorkerConfig, + mockFrontendClient: params.MockFrontendClient, } } @@ -229,8 +223,7 @@ func (c *cadenceImpl) Stop() { } else { c.shutdownWG.Add(3) } - c.frontendHandler.Stop() - c.adminHandler.Stop() + c.frontendService.Stop() for _, historyHandler := range c.historyHandlers { historyHandler.Stop() } @@ -384,16 +377,11 @@ func (c *cadenceImpl) WorkerPProfPort() int { } func (c *cadenceImpl) GetAdminClient() adminserviceclient.Interface { - return NewAdminClient(c.frontEndService.GetDispatcher()) + return c.adminClient } func (c *cadenceImpl) GetFrontendClient() workflowserviceclient.Interface { - return NewFrontendClient(c.frontEndService.GetDispatcher()) -} - -// For integration tests to get hold of FE instance. -func (c *cadenceImpl) GetFrontendService() service.Service { - return c.frontEndService + return c.frontendClient } func (c *cadenceImpl) GetHistoryClient() historyserviceclient.Interface { @@ -418,59 +406,23 @@ func (c *cadenceImpl) startFrontend(hosts map[string][]string, startWG *sync.Wai params.DynamicConfig = newIntegrationConfigClient(dynamicconfig.NewNopClient()) params.ArchivalMetadata = c.archiverMetadata params.ArchiverProvider = c.archiverProvider - - var replicationMessageSink messaging.Producer - var err error - if c.workerConfig.EnableReplicator { - replicationMessageSink = c.domainReplicationQueue - } else { - replicationMessageSink = &mocks.KafkaProducer{} - replicationMessageSink.(*mocks.KafkaProducer).On("Publish", mock.Anything).Return(nil) + params.ESConfig = c.esConfig + params.ESClient = c.esClient + if c.esConfig != nil { + esDataStoreName := "es-visibility" + params.PersistenceConfig.AdvancedVisibilityStore = esDataStoreName + params.PersistenceConfig.DataStores[esDataStoreName] = config.DataStore{ + ElasticSearch: c.esConfig, + } } - c.frontEndService = service.New(params) - - dc := dynamicconfig.NewCollection(params.DynamicConfig, c.logger) - frontendConfig := frontend.NewConfig(dc, c.historyConfig.NumHistoryShards, c.esConfig != nil) - domainCache := cache.NewDomainCache(c.metadataMgr, c.clusterMetadata, c.frontEndService.GetMetricsClient(), c.logger) - c.adminHandler = frontend.NewAdminHandler( - c.frontEndService, c.historyConfig.NumHistoryShards, domainCache, c.historyV2Mgr, params, frontendConfig) - c.adminHandler.RegisterHandler() - - historyArchiverBootstrapContainer := &carchiver.HistoryBootstrapContainer{ - HistoryV2Manager: c.historyV2Mgr, - Logger: c.logger, - MetricsClient: c.frontEndService.GetMetricsClient(), - ClusterMetadata: c.clusterMetadata, - DomainCache: domainCache, - } - visibilityArchiverBootstrapContainer := &carchiver.VisibilityBootstrapContainer{ - Logger: c.logger, - MetricsClient: c.frontEndService.GetMetricsClient(), - ClusterMetadata: c.clusterMetadata, - DomainCache: domainCache, - } - err = c.archiverProvider.RegisterBootstrapContainer(common.FrontendServiceName, historyArchiverBootstrapContainer, visibilityArchiverBootstrapContainer) + frontendService, err := frontend.NewService(params) if err != nil { - c.logger.Fatal("Failed to register archiver bootstrap container for frontend service", tag.Error(err)) - } - - c.frontendHandler = frontend.NewWorkflowHandler( - c.frontEndService, - frontendConfig, - c.metadataMgr, - c.historyV2Mgr, - c.visibilityMgr, - replicationMessageSink, - c.domainReplicationQueue, - domainCache) - dcRedirectionHandler := frontend.NewDCRedirectionHandler(c.frontendHandler, params.DCRedirectionPolicy) - dcRedirectionHandler.RegisterHandler() - - // must start base service first - c.frontEndService.Start() + params.Logger.Fatal("unable to start frontend service", tag.Error(err)) + } + if c.mockFrontendClient != nil { - clientBean := c.frontEndService.GetClientBean() + clientBean := frontendService.GetClientBean() if clientBean != nil { for serviceName, frontendClient := range c.mockFrontendClient { clientBean.SetRemoteFrontendClient(serviceName, frontendClient) @@ -478,14 +430,10 @@ func (c *cadenceImpl) startFrontend(hosts map[string][]string, startWG *sync.Wai } } - err = c.adminHandler.Start() - if err != nil { - c.logger.Fatal("Failed to start admin", tag.Error(err)) - } - err = dcRedirectionHandler.Start() - if err != nil { - c.logger.Fatal("Failed to start frontend", tag.Error(err)) - } + c.frontendService = frontendService + c.frontendClient = NewFrontendClient(frontendService.GetDispatcher()) + c.adminClient = NewAdminClient(frontendService.GetDispatcher()) + go frontendService.Start() startWG.Done() <-c.shutdownCh @@ -612,6 +560,8 @@ func (c *cadenceImpl) startMatching(hosts map[string][]string, startWG *sync.Wai params.PersistenceConfig = c.persistenceConfig params.MetricsClient = metrics.NewClient(params.MetricScope, service.GetMetricsServiceIdx(params.Name, c.logger)) params.DynamicConfig = newIntegrationConfigClient(dynamicconfig.NewNopClient()) + params.ArchivalMetadata = c.archiverMetadata + params.ArchiverProvider = c.archiverProvider matchingService, err := matching.NewService(params) if err != nil { diff --git a/host/testcluster.go b/host/testcluster.go index a452cc4efa6..a83bd5a8076 100644 --- a/host/testcluster.go +++ b/host/testcluster.go @@ -153,27 +153,26 @@ func NewCluster(options *TestClusterConfig, logger log.Logger) (*TestCluster, er pConfig := testBase.Config() pConfig.NumHistoryShards = options.HistoryConfig.NumHistoryShards cadenceParams := &CadenceParams{ - ClusterMetadata: clusterMetadata, - PersistenceConfig: pConfig, - DispatcherProvider: client.NewDNSYarpcDispatcherProvider(logger, 0), - MessagingClient: messagingClient, - MetadataMgr: testBase.MetadataManager, - ShardMgr: testBase.ShardMgr, - HistoryV2Mgr: testBase.HistoryV2Mgr, - ExecutionMgrFactory: testBase.ExecutionMgrFactory, - TaskMgr: testBase.TaskMgr, - VisibilityMgr: visibilityMgr, - Logger: logger, - ClusterNo: options.ClusterNo, - EnableNDC: options.EnableNDC, - ESConfig: options.ESConfig, - ESClient: esClient, - ArchiverMetadata: archiverBase.metadata, - ArchiverProvider: archiverBase.provider, - HistoryConfig: options.HistoryConfig, - WorkerConfig: options.WorkerConfig, - MockFrontendClient: options.MockFrontendClient, - DomainReplicationQueue: testBase.DomainReplicationQueue, + ClusterMetadata: clusterMetadata, + PersistenceConfig: pConfig, + DispatcherProvider: client.NewDNSYarpcDispatcherProvider(logger, 0), + MessagingClient: messagingClient, + MetadataMgr: testBase.MetadataManager, + ShardMgr: testBase.ShardMgr, + HistoryV2Mgr: testBase.HistoryV2Mgr, + ExecutionMgrFactory: testBase.ExecutionMgrFactory, + TaskMgr: testBase.TaskMgr, + VisibilityMgr: visibilityMgr, + Logger: logger, + ClusterNo: options.ClusterNo, + EnableNDC: options.EnableNDC, + ESConfig: options.ESConfig, + ESClient: esClient, + ArchiverMetadata: archiverBase.metadata, + ArchiverProvider: archiverBase.provider, + HistoryConfig: options.HistoryConfig, + WorkerConfig: options.WorkerConfig, + MockFrontendClient: options.MockFrontendClient, } cluster := NewCadence(cadenceParams) if err := cluster.Start(); err != nil { diff --git a/host/xdc/elasticsearch_test.go b/host/xdc/elasticsearch_test.go index c36f70455f9..80130bc7810 100644 --- a/host/xdc/elasticsearch_test.go +++ b/host/xdc/elasticsearch_test.go @@ -203,7 +203,7 @@ func (s *esCrossDCTestSuite) TestSearchAttributes() { query := fmt.Sprintf(`WorkflowID = "%s" and %s = "%s"`, id, s.testSearchAttributeKey, s.testSearchAttributeVal) listRequest := &workflow.ListWorkflowExecutionsRequest{ Domain: common.StringPtr(domainName), - PageSize: common.Int32Ptr(100), + PageSize: common.Int32Ptr(5), Query: common.StringPtr(query), } diff --git a/service/frontend/adminHandler.go b/service/frontend/adminHandler.go index 579661b25d0..5a43d4b17dd 100644 --- a/service/frontend/adminHandler.go +++ b/service/frontend/adminHandler.go @@ -26,8 +26,6 @@ import ( "errors" "fmt" "strconv" - "sync" - "sync/atomic" "time" "github.com/olivere/elastic" @@ -40,14 +38,13 @@ import ( h "github.com/uber/cadence/.gen/go/history" hist "github.com/uber/cadence/.gen/go/history" gen "github.com/uber/cadence/.gen/go/shared" - "github.com/uber/cadence/client/history" "github.com/uber/cadence/common" - "github.com/uber/cadence/common/cache" "github.com/uber/cadence/common/definition" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/resource" "github.com/uber/cadence/common/service" "github.com/uber/cadence/common/service/dynamicconfig" historyService "github.com/uber/cadence/service/history" @@ -58,16 +55,11 @@ var _ adminserviceserver.Interface = (*AdminHandler)(nil) type ( // AdminHandler - Thrift handler interface for admin service AdminHandler struct { - status int32 + resource.Resource + numberOfHistoryShards int - service.Service - history history.Client - domainCache cache.DomainCache - metricsClient metrics.Client - historyV2Mgr persistence.HistoryManager - startWG sync.WaitGroup - params *service.BootstrapParams - config *Config + params *service.BootstrapParams + config *Config } getWorkflowRawHistoryV2Token struct { @@ -85,53 +77,29 @@ type ( // NewAdminHandler creates a thrift handler for the cadence admin service func NewAdminHandler( - sVice service.Service, - numberOfHistoryShards int, - domainCache cache.DomainCache, - historyV2Mgr persistence.HistoryManager, + resource resource.Resource, params *service.BootstrapParams, config *Config, ) *AdminHandler { - handler := &AdminHandler{ - status: common.DaemonStatusInitialized, - numberOfHistoryShards: numberOfHistoryShards, - Service: sVice, - domainCache: domainCache, - historyV2Mgr: historyV2Mgr, + return &AdminHandler{ + Resource: resource, + numberOfHistoryShards: params.PersistenceConfig.NumHistoryShards, params: params, config: config, } - // prevent us from trying to serve requests before handler's Start() is complete - handler.startWG.Add(1) - return handler } // RegisterHandler register this handler, must be called before Start() func (adh *AdminHandler) RegisterHandler() { - adh.Service.GetDispatcher().Register(adminserviceserver.New(adh)) + adh.GetDispatcher().Register(adminserviceserver.New(adh)) } // Start starts the handler -func (adh *AdminHandler) Start() error { - if !atomic.CompareAndSwapInt32(&adh.status, common.DaemonStatusInitialized, common.DaemonStatusStarted) { - return nil - } - - adh.domainCache.Start() - - adh.history = adh.GetClientBean().GetHistoryClient() - adh.metricsClient = adh.Service.GetMetricsClient() - adh.startWG.Done() - return nil +func (adh *AdminHandler) Start() { } // Stop stops the handler func (adh *AdminHandler) Stop() { - if !atomic.CompareAndSwapInt32(&adh.status, common.DaemonStatusStarted, common.DaemonStatusStopped) { - return - } - adh.Service.Stop() - adh.domainCache.Stop() } // AddSearchAttribute add search attribute to whitelist @@ -221,10 +189,10 @@ func (adh *AdminHandler) DescribeWorkflowExecution(ctx context.Context, request return nil, adh.error(err, scope) } - domainID, err := adh.domainCache.GetDomainID(request.GetDomain()) + domainID, err := adh.GetDomainCache().GetDomainID(request.GetDomain()) historyAddr := historyHost.GetAddress() - resp2, err := adh.history.DescribeMutableState(ctx, &hist.DescribeMutableStateRequest{ + resp2, err := adh.GetHistoryClient().DescribeMutableState(ctx, &hist.DescribeMutableStateRequest{ DomainUUID: &domainID, Execution: request.Execution, }) @@ -246,7 +214,7 @@ func (adh *AdminHandler) RemoveTask(ctx context.Context, request *gen.RemoveTask if request == nil || request.ShardID == nil || request.Type == nil || request.TaskID == nil { return adh.error(errRequestNotSet, scope) } - err := adh.history.RemoveTask(ctx, request) + err := adh.GetHistoryClient().RemoveTask(ctx, request) return err } @@ -257,7 +225,7 @@ func (adh *AdminHandler) CloseShard(ctx context.Context, request *gen.CloseShard if request == nil || request.ShardID == nil { return adh.error(errRequestNotSet, scope) } - err := adh.history.CloseShard(ctx, request) + err := adh.GetHistoryClient().CloseShard(ctx, request) return err } @@ -275,7 +243,7 @@ func (adh *AdminHandler) DescribeHistoryHost(ctx context.Context, request *gen.D } } - resp, err := adh.history.DescribeHistoryHost(ctx, request) + resp, err := adh.GetHistoryClient().DescribeHistoryHost(ctx, request) return resp, err } @@ -293,11 +261,11 @@ func (adh *AdminHandler) GetWorkflowExecutionRawHistory( var err error var size int - domainID, err := adh.domainCache.GetDomainID(request.GetDomain()) + domainID, err := adh.GetDomainCache().GetDomainID(request.GetDomain()) if err != nil { return nil, adh.error(err, scope) } - domainScope := adh.metricsClient.Scope(scope, metrics.DomainTag(request.GetDomain())) + domainScope := adh.GetMetricsClient().Scope(scope, metrics.DomainTag(request.GetDomain())) execution := request.Execution if len(execution.GetWorkflowId()) == 0 { @@ -348,7 +316,7 @@ func (adh *AdminHandler) GetWorkflowExecutionRawHistory( return nil, &gen.BadRequestError{Message: "Invalid FirstEventID && NextEventID combination."} } - response, err := adh.history.GetMutableState(ctx, &h.GetMutableStateRequest{ + response, err := adh.GetHistoryClient().GetMutableState(ctx, &h.GetMutableStateRequest{ DomainUUID: common.StringPtr(domainID), Execution: execution, }) @@ -383,7 +351,7 @@ func (adh *AdminHandler) GetWorkflowExecutionRawHistory( var historyBatches []*gen.History shardID := common.WorkflowIDToHistoryShard(execution.GetWorkflowId(), adh.numberOfHistoryShards) _, historyBatches, token.PersistenceToken, size, err = historyService.PaginateHistory( - adh.historyV2Mgr, + adh.GetHistoryManager(), true, // this means that we are getting history by batch token.BranchToken, token.FirstEventID, @@ -407,7 +375,7 @@ func (adh *AdminHandler) GetWorkflowExecutionRawHistory( // N.B. - Dual emit is required here so that we can see aggregate timer stats across all // domains along with the individual domains stats - adh.metricsClient.RecordTimer(scope, metrics.HistorySize, time.Duration(size)) + adh.GetMetricsClient().RecordTimer(scope, metrics.HistorySize, time.Duration(size)) domainScope.RecordTimer(metrics.HistorySize, time.Duration(size)) blobs := []*gen.DataBlob{} @@ -454,7 +422,7 @@ func (adh *AdminHandler) GetWorkflowExecutionRawHistoryV2( ); err != nil { return nil, adh.error(err, scope) } - domainID, err := adh.domainCache.GetDomainID(request.GetDomain()) + domainID, err := adh.GetDomainCache().GetDomainID(request.GetDomain()) if err != nil { return nil, adh.error(err, scope) } @@ -463,7 +431,7 @@ func (adh *AdminHandler) GetWorkflowExecutionRawHistoryV2( var pageToken *getWorkflowRawHistoryV2Token var targetVersionHistory *persistence.VersionHistory if request.NextPageToken == nil { - response, err := adh.history.GetMutableState(ctx, &h.GetMutableStateRequest{ + response, err := adh.GetHistoryClient().GetMutableState(ctx, &h.GetMutableStateRequest{ DomainUUID: common.StringPtr(domainID), Execution: execution, }) @@ -521,7 +489,7 @@ func (adh *AdminHandler) GetWorkflowExecutionRawHistoryV2( execution.GetWorkflowId(), adh.numberOfHistoryShards, ) - rawHistoryResponse, err := adh.historyV2Mgr.ReadRawHistoryBranch(&persistence.ReadHistoryBranchRequest{ + rawHistoryResponse, err := adh.GetHistoryManager().ReadRawHistoryBranch(&persistence.ReadHistoryBranchRequest{ BranchToken: targetVersionHistory.GetBranchToken(), // GetWorkflowExecutionRawHistoryV2 is exclusive exclusive. // ReadRawHistoryBranch is inclusive exclusive. @@ -548,8 +516,8 @@ func (adh *AdminHandler) GetWorkflowExecutionRawHistoryV2( size := rawHistoryResponse.Size // N.B. - Dual emit is required here so that we can see aggregate timer stats across all // domains along with the individual domains stats - adh.metricsClient.RecordTimer(scope, metrics.HistorySize, time.Duration(size)) - domainScope := adh.metricsClient.Scope(scope, metrics.DomainTag(request.GetDomain())) + adh.GetMetricsClient().RecordTimer(scope, metrics.HistorySize, time.Duration(size)) + domainScope := adh.GetMetricsClient().Scope(scope, metrics.DomainTag(request.GetDomain())) domainScope.RecordTimer(metrics.HistorySize, time.Duration(size)) rawBlobs := rawHistoryResponse.HistoryEventBlobs @@ -732,16 +700,15 @@ func (adh *AdminHandler) validatePaginationToken( // startRequestProfile initiates recording of request metrics func (adh *AdminHandler) startRequestProfile(scope int) tally.Stopwatch { - adh.startWG.Wait() - sw := adh.metricsClient.StartTimer(scope, metrics.CadenceLatency) - adh.metricsClient.IncCounter(scope, metrics.CadenceRequests) + sw := adh.GetMetricsClient().StartTimer(scope, metrics.CadenceLatency) + adh.GetMetricsClient().IncCounter(scope, metrics.CadenceRequests) return sw } func (adh *AdminHandler) error(err error, scope int) error { switch err.(type) { case *gen.InternalServiceError: - adh.Service.GetLogger().Error("Internal service error", tag.Error(err)) + adh.GetLogger().Error("Internal service error", tag.Error(err)) return err case *gen.BadRequestError: return err @@ -750,7 +717,7 @@ func (adh *AdminHandler) error(err error, scope int) error { case *gen.EntityNotExistsError: return err default: - adh.Service.GetLogger().Error("Uncategorized error", tag.Error(err)) + adh.GetLogger().Error("Uncategorized error", tag.Error(err)) return &gen.InternalServiceError{Message: err.Error()} } } diff --git a/service/frontend/adminHandler_test.go b/service/frontend/adminHandler_test.go index f113eb44001..17416c81e25 100644 --- a/service/frontend/adminHandler_test.go +++ b/service/frontend/adminHandler_test.go @@ -26,32 +26,28 @@ import ( "fmt" "testing" - "github.com/uber/cadence/common/definition" - "github.com/uber/cadence/common/elasticsearch" - "github.com/uber/cadence/common/service/dynamicconfig" - "github.com/golang/mock/gomock" "github.com/pborman/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/uber-go/tally" "github.com/uber/cadence/.gen/go/admin" "github.com/uber/cadence/.gen/go/history" "github.com/uber/cadence/.gen/go/history/historyservicetest" "github.com/uber/cadence/.gen/go/shared" - "github.com/uber/cadence/client" "github.com/uber/cadence/common" "github.com/uber/cadence/common/cache" - "github.com/uber/cadence/common/cluster" + "github.com/uber/cadence/common/definition" + "github.com/uber/cadence/common/elasticsearch" esmock "github.com/uber/cadence/common/elasticsearch/mocks" - "github.com/uber/cadence/common/log" - "github.com/uber/cadence/common/log/loggerimpl" "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/mocks" "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/resource" "github.com/uber/cadence/common/service" + "github.com/uber/cadence/common/service/config" + "github.com/uber/cadence/common/service/dynamicconfig" ) type ( @@ -60,18 +56,14 @@ type ( *require.Assertions controller *gomock.Controller + mockResource *resource.Test mockHistoryClient *historyservicetest.MockClient mockDomainCache *cache.MockDomainCache - mockClientBean *client.MockBean - logger log.Logger - domainName string - domainID string - currentClusterName string - alternativeClusterName string - mockClusterMetadata *mocks.ClusterMetadata - mockHistoryV2Mgr *mocks.HistoryV2Manager - service service.Service + mockHistoryV2Mgr *mocks.HistoryV2Manager + + domainName string + domainID string handler *AdminHandler } @@ -85,37 +77,30 @@ func TestAdminHandlerSuite(t *testing.T) { func (s *adminHandlerSuite) SetupTest() { s.Assertions = require.New(s.T()) - s.controller = gomock.NewController(s.T()) - s.mockHistoryClient = historyservicetest.NewMockClient(s.controller) - s.mockDomainCache = cache.NewMockDomainCache(s.controller) - s.mockClientBean = client.NewMockBean(s.controller) - s.mockDomainCache.EXPECT().Start().AnyTimes() - s.mockDomainCache.EXPECT().Stop().AnyTimes() - s.mockClientBean.EXPECT().GetHistoryClient().Return(s.mockHistoryClient).AnyTimes() - - s.logger = loggerimpl.NewDevelopmentForTest(s.Suite) s.domainName = "some random domain name" s.domainID = "some random domain ID" - s.currentClusterName = cluster.TestCurrentClusterName - s.alternativeClusterName = cluster.TestAlternativeClusterName - - s.mockClusterMetadata = &mocks.ClusterMetadata{} - s.mockClusterMetadata.On("GetCurrentClusterName").Return(s.currentClusterName) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(true) - metricsClient := metrics.NewClient(tally.NoopScope, metrics.Frontend) - s.service = service.NewTestService(s.mockClusterMetadata, nil, metricsClient, s.mockClientBean, nil, nil, nil) - s.mockHistoryV2Mgr = &mocks.HistoryV2Manager{} - params := &service.BootstrapParams{} + s.controller = gomock.NewController(s.T()) + s.mockResource = resource.NewTest(s.controller, metrics.Frontend) + s.mockDomainCache = s.mockResource.DomainCache + s.mockHistoryClient = s.mockResource.HistoryClient + s.mockHistoryV2Mgr = s.mockResource.HistoryMgr + + params := &service.BootstrapParams{ + PersistenceConfig: config.Persistence{ + NumHistoryShards: 1, + }, + } config := &Config{ EnableAdminProtection: dynamicconfig.GetBoolPropertyFn(false), } - s.handler = NewAdminHandler(s.service, 1, s.mockDomainCache, s.mockHistoryV2Mgr, params, config) + s.handler = NewAdminHandler(s.mockResource, params, config) s.handler.Start() } func (s *adminHandlerSuite) TearDownTest() { s.controller.Finish() + s.mockResource.Finish(s.T()) s.handler.Stop() } diff --git a/service/frontend/dcRedirectionHandler.go b/service/frontend/dcRedirectionHandler.go index 71611af6773..c0b1767c9be 100644 --- a/service/frontend/dcRedirectionHandler.go +++ b/service/frontend/dcRedirectionHandler.go @@ -31,11 +31,9 @@ import ( "github.com/uber/cadence/.gen/go/shared" "github.com/uber/cadence/client" "github.com/uber/cadence/common" - "github.com/uber/cadence/common/cache" - "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/metrics" - "github.com/uber/cadence/common/service" + "github.com/uber/cadence/common/resource" "github.com/uber/cadence/common/service/config" ) @@ -46,18 +44,15 @@ type ( // DCRedirectionHandlerImpl is simple wrapper over frontend service, doing redirection based on policy DCRedirectionHandlerImpl struct { + resource.Resource + currentClusterName string - timeSource clock.TimeSource - domainCache cache.DomainCache - metricsClient metrics.Client config *Config redirectionPolicy DCRedirectionPolicy tokenSerializer common.TaskTokenSerializer - service service.Service frontendHandler workflowserviceserver.Interface - clientBeanProvider clientBeanProvider - startFn func() error + startFn func() stopFn func() } ) @@ -67,35 +62,31 @@ func NewDCRedirectionHandler(wfHandler *WorkflowHandler, policy config.DCRedirec dcRedirectionPolicy := RedirectionPolicyGenerator( wfHandler.GetClusterMetadata(), wfHandler.config, - wfHandler.domainCache, + wfHandler.GetDomainCache(), policy, ) return &DCRedirectionHandlerImpl{ + Resource: wfHandler.Resource, currentClusterName: wfHandler.GetClusterMetadata().GetCurrentClusterName(), - timeSource: clock.NewRealTimeSource(), - domainCache: wfHandler.domainCache, - metricsClient: wfHandler.metricsClient, config: wfHandler.config, redirectionPolicy: dcRedirectionPolicy, tokenSerializer: common.NewJSONTaskTokenSerializer(), - service: wfHandler.Service, frontendHandler: wfHandler, - clientBeanProvider: func() client.Bean { return wfHandler.Service.GetClientBean() }, - startFn: func() error { return wfHandler.Start() }, + startFn: func() { wfHandler.Start() }, stopFn: func() { wfHandler.Stop() }, } } // RegisterHandler register this handler, must be called before Start() func (handler *DCRedirectionHandlerImpl) RegisterHandler() { - handler.service.GetDispatcher().Register(workflowserviceserver.New(handler)) - handler.service.GetDispatcher().Register(metaserver.New(handler)) + handler.GetDispatcher().Register(workflowserviceserver.New(handler)) + handler.GetDispatcher().Register(metaserver.New(handler)) } // Start starts the handler -func (handler *DCRedirectionHandlerImpl) Start() error { - return handler.startFn() +func (handler *DCRedirectionHandlerImpl) Start() { + handler.startFn() } // Stop stops the handler @@ -214,7 +205,7 @@ func (handler *DCRedirectionHandlerImpl) DescribeTaskList( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.DescribeTaskList(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.DescribeTaskList(ctx, request) } return err @@ -244,7 +235,7 @@ func (handler *DCRedirectionHandlerImpl) DescribeWorkflowExecution( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.DescribeWorkflowExecution(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.DescribeWorkflowExecution(ctx, request) } return err @@ -274,7 +265,7 @@ func (handler *DCRedirectionHandlerImpl) GetWorkflowExecutionHistory( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.GetWorkflowExecutionHistory(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.GetWorkflowExecutionHistory(ctx, request) } return err @@ -304,7 +295,7 @@ func (handler *DCRedirectionHandlerImpl) ListArchivedWorkflowExecutions( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.ListArchivedWorkflowExecutions(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.ListArchivedWorkflowExecutions(ctx, request) } return err @@ -334,7 +325,7 @@ func (handler *DCRedirectionHandlerImpl) ListClosedWorkflowExecutions( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.ListClosedWorkflowExecutions(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.ListClosedWorkflowExecutions(ctx, request) } return err @@ -364,7 +355,7 @@ func (handler *DCRedirectionHandlerImpl) ListOpenWorkflowExecutions( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.ListOpenWorkflowExecutions(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.ListOpenWorkflowExecutions(ctx, request) } return err @@ -394,7 +385,7 @@ func (handler *DCRedirectionHandlerImpl) ListWorkflowExecutions( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.ListWorkflowExecutions(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.ListWorkflowExecutions(ctx, request) } return err @@ -423,7 +414,7 @@ func (handler *DCRedirectionHandlerImpl) ScanWorkflowExecutions( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.ScanWorkflowExecutions(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.ScanWorkflowExecutions(ctx, request) } return err @@ -453,7 +444,7 @@ func (handler *DCRedirectionHandlerImpl) CountWorkflowExecutions( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.CountWorkflowExecutions(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.CountWorkflowExecutions(ctx, request) } return err @@ -498,7 +489,7 @@ func (handler *DCRedirectionHandlerImpl) PollForActivityTask( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.PollForActivityTask(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.PollForActivityTask(ctx, request) } return err @@ -528,7 +519,7 @@ func (handler *DCRedirectionHandlerImpl) PollForDecisionTask( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.PollForDecisionTask(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.PollForDecisionTask(ctx, request) } return err @@ -558,7 +549,7 @@ func (handler *DCRedirectionHandlerImpl) QueryWorkflow( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.QueryWorkflow(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.QueryWorkflow(ctx, request) } return err @@ -593,7 +584,7 @@ func (handler *DCRedirectionHandlerImpl) RecordActivityTaskHeartbeat( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.RecordActivityTaskHeartbeat(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.RecordActivityTaskHeartbeat(ctx, request) } return err @@ -623,7 +614,7 @@ func (handler *DCRedirectionHandlerImpl) RecordActivityTaskHeartbeatByID( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.RecordActivityTaskHeartbeatByID(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.RecordActivityTaskHeartbeatByID(ctx, request) } return err @@ -653,7 +644,7 @@ func (handler *DCRedirectionHandlerImpl) RequestCancelWorkflowExecution( case targetDC == handler.currentClusterName: err = handler.frontendHandler.RequestCancelWorkflowExecution(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.RequestCancelWorkflowExecution(ctx, request) } return err @@ -683,7 +674,7 @@ func (handler *DCRedirectionHandlerImpl) ResetStickyTaskList( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.ResetStickyTaskList(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.ResetStickyTaskList(ctx, request) } return err @@ -713,7 +704,7 @@ func (handler *DCRedirectionHandlerImpl) ResetWorkflowExecution( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.ResetWorkflowExecution(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.ResetWorkflowExecution(ctx, request) } return err @@ -748,7 +739,7 @@ func (handler *DCRedirectionHandlerImpl) RespondActivityTaskCanceled( case targetDC == handler.currentClusterName: err = handler.frontendHandler.RespondActivityTaskCanceled(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.RespondActivityTaskCanceled(ctx, request) } return err @@ -778,7 +769,7 @@ func (handler *DCRedirectionHandlerImpl) RespondActivityTaskCanceledByID( case targetDC == handler.currentClusterName: err = handler.frontendHandler.RespondActivityTaskCanceledByID(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.RespondActivityTaskCanceledByID(ctx, request) } return err @@ -813,7 +804,7 @@ func (handler *DCRedirectionHandlerImpl) RespondActivityTaskCompleted( case targetDC == handler.currentClusterName: err = handler.frontendHandler.RespondActivityTaskCompleted(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.RespondActivityTaskCompleted(ctx, request) } return err @@ -843,7 +834,7 @@ func (handler *DCRedirectionHandlerImpl) RespondActivityTaskCompletedByID( case targetDC == handler.currentClusterName: err = handler.frontendHandler.RespondActivityTaskCompletedByID(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.RespondActivityTaskCompletedByID(ctx, request) } return err @@ -878,7 +869,7 @@ func (handler *DCRedirectionHandlerImpl) RespondActivityTaskFailed( case targetDC == handler.currentClusterName: err = handler.frontendHandler.RespondActivityTaskFailed(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.RespondActivityTaskFailed(ctx, request) } return err @@ -908,7 +899,7 @@ func (handler *DCRedirectionHandlerImpl) RespondActivityTaskFailedByID( case targetDC == handler.currentClusterName: err = handler.frontendHandler.RespondActivityTaskFailedByID(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.RespondActivityTaskFailedByID(ctx, request) } return err @@ -943,7 +934,7 @@ func (handler *DCRedirectionHandlerImpl) RespondDecisionTaskCompleted( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.RespondDecisionTaskCompleted(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.RespondDecisionTaskCompleted(ctx, request) } return err @@ -978,7 +969,7 @@ func (handler *DCRedirectionHandlerImpl) RespondDecisionTaskFailed( case targetDC == handler.currentClusterName: err = handler.frontendHandler.RespondDecisionTaskFailed(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.RespondDecisionTaskFailed(ctx, request) } return err @@ -1013,7 +1004,7 @@ func (handler *DCRedirectionHandlerImpl) RespondQueryTaskCompleted( case targetDC == handler.currentClusterName: err = handler.frontendHandler.RespondQueryTaskCompleted(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.RespondQueryTaskCompleted(ctx, request) } return err @@ -1043,7 +1034,7 @@ func (handler *DCRedirectionHandlerImpl) SignalWithStartWorkflowExecution( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.SignalWithStartWorkflowExecution(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.SignalWithStartWorkflowExecution(ctx, request) } return err @@ -1073,7 +1064,7 @@ func (handler *DCRedirectionHandlerImpl) SignalWorkflowExecution( case targetDC == handler.currentClusterName: err = handler.frontendHandler.SignalWorkflowExecution(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.SignalWorkflowExecution(ctx, request) } return err @@ -1102,7 +1093,7 @@ func (handler *DCRedirectionHandlerImpl) StartWorkflowExecution( case targetDC == handler.currentClusterName: resp, err = handler.frontendHandler.StartWorkflowExecution(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) resp, err = remoteClient.StartWorkflowExecution(ctx, request) } return err @@ -1132,7 +1123,7 @@ func (handler *DCRedirectionHandlerImpl) TerminateWorkflowExecution( case targetDC == handler.currentClusterName: err = handler.frontendHandler.TerminateWorkflowExecution(ctx, request) default: - remoteClient := handler.clientBeanProvider().GetRemoteFrontendClient(targetDC) + remoteClient := handler.GetRemoteFrontendClient(targetDC) err = remoteClient.TerminateWorkflowExecution(ctx, request) } return err @@ -1169,7 +1160,7 @@ func (handler *DCRedirectionHandlerImpl) beforeCall( scope int, ) (metrics.Scope, time.Time) { - return handler.metricsClient.Scope(scope), handler.timeSource.Now() + return handler.GetMetricsClient().Scope(scope), handler.GetTimeSource().Now() } func (handler *DCRedirectionHandlerImpl) afterCall( @@ -1179,11 +1170,11 @@ func (handler *DCRedirectionHandlerImpl) afterCall( retError *error, ) { - log.CapturePanic(handler.service.GetLogger(), retError) + log.CapturePanic(handler.GetLogger(), retError) scope = scope.Tagged(metrics.TargetClusterTag(cluster)) scope.IncCounter(metrics.CadenceDcRedirectionClientRequests) - scope.RecordTimer(metrics.CadenceDcRedirectionClientLatency, handler.timeSource.Now().Sub(startTime)) + scope.RecordTimer(metrics.CadenceDcRedirectionClientLatency, handler.GetTimeSource().Now().Sub(startTime)) if *retError != nil { scope.IncCounter(metrics.CadenceDcRedirectionClientFailures) } diff --git a/service/frontend/dcRedirectionHandler_test.go b/service/frontend/dcRedirectionHandler_test.go index fa68ab27b53..1c98c4ca485 100644 --- a/service/frontend/dcRedirectionHandler_test.go +++ b/service/frontend/dcRedirectionHandler_test.go @@ -28,21 +28,13 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/uber-go/tally" "github.com/uber/cadence/.gen/go/cadence/workflowservicetest" "github.com/uber/cadence/.gen/go/shared" - "github.com/uber/cadence/client" "github.com/uber/cadence/common" - "github.com/uber/cadence/common/archiver" - "github.com/uber/cadence/common/archiver/provider" - "github.com/uber/cadence/common/cache" "github.com/uber/cadence/common/cluster" - "github.com/uber/cadence/common/log" - "github.com/uber/cadence/common/log/loggerimpl" "github.com/uber/cadence/common/metrics" - "github.com/uber/cadence/common/mocks" - "github.com/uber/cadence/common/service" + "github.com/uber/cadence/common/resource" "github.com/uber/cadence/common/service/config" "github.com/uber/cadence/common/service/dynamicconfig" ) @@ -53,23 +45,18 @@ type ( *require.Assertions controller *gomock.Controller - mockDomainCache *cache.MockDomainCache - mockClientBean *client.MockBean - mockRemoteFrontendClient *workflowservicetest.MockClient + mockResource *resource.Test mockFrontendHandler *MockWorkflowHandler + mockRemoteFrontendClient *workflowservicetest.MockClient + mockClusterMetadata *cluster.MockMetadata + + mockDCRedirectionPolicy *MockDCRedirectionPolicy - logger log.Logger domainName string domainID string currentClusterName string alternativeClusterName string config *Config - service service.Service - - mockDCRedirectionPolicy *MockDCRedirectionPolicy - mockClusterMetadata *mocks.ClusterMetadata - mockArchivalMetadata *archiver.MockArchivalMetadata - mockArchiverProvider *provider.MockArchiverProvider frontendHandler *WorkflowHandler handler *DCRedirectionHandlerImpl @@ -90,45 +77,34 @@ func (s *dcRedirectionHandlerSuite) TearDownSuite() { func (s *dcRedirectionHandlerSuite) SetupTest() { s.Assertions = require.New(s.T()) + s.domainName = "some random domain name" + s.domainID = "some random domain ID" s.currentClusterName = cluster.TestCurrentClusterName s.alternativeClusterName = cluster.TestAlternativeClusterName - s.controller = gomock.NewController(s.T()) - s.mockDomainCache = cache.NewMockDomainCache(s.controller) - s.mockClientBean = client.NewMockBean(s.controller) - s.mockRemoteFrontendClient = workflowservicetest.NewMockClient(s.controller) - s.mockFrontendHandler = NewMockWorkflowHandler(s.controller) - s.mockClientBean.EXPECT().GetRemoteFrontendClient(s.alternativeClusterName).Return(s.mockRemoteFrontendClient).AnyTimes() - - s.logger = loggerimpl.NewDevelopmentForTest(s.Suite) - s.domainName = "some random domain name" - s.domainID = "some random domain ID" - s.config = NewConfig(dynamicconfig.NewCollection(dynamicconfig.NewNopClient(), s.logger), 0, false) + s.mockDCRedirectionPolicy = &MockDCRedirectionPolicy{} - s.mockClusterMetadata = &mocks.ClusterMetadata{} - s.mockClusterMetadata.On("GetCurrentClusterName").Return(s.currentClusterName) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(true) - metricsClient := metrics.NewClient(tally.NoopScope, metrics.Frontend) + s.controller = gomock.NewController(s.T()) + s.mockResource = resource.NewTest(s.controller, metrics.Frontend) + s.mockClusterMetadata = s.mockResource.ClusterMetadata + s.mockRemoteFrontendClient = s.mockResource.RemoteFrontendClient - s.mockArchivalMetadata = &archiver.MockArchivalMetadata{} - s.mockArchiverProvider = &provider.MockArchiverProvider{} - s.service = service.NewTestService(s.mockClusterMetadata, nil, metricsClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(s.currentClusterName).AnyTimes() + s.mockClusterMetadata.EXPECT().IsGlobalDomainEnabled().Return(true).AnyTimes() - frontendHandler := NewWorkflowHandler(s.service, s.config, nil, nil, nil, nil, nil, s.mockDomainCache) - frontendHandler.metricsClient = metricsClient - frontendHandler.startWG.Done() + s.config = NewConfig(dynamicconfig.NewCollection(dynamicconfig.NewNopClient(), s.mockResource.GetLogger()), 0, false) + frontendHandler := NewWorkflowHandler(s.mockResource, s.config, nil) + s.mockFrontendHandler = NewMockWorkflowHandler(s.controller) s.handler = NewDCRedirectionHandler(frontendHandler, config.DCRedirectionPolicy{}) - s.mockDCRedirectionPolicy = &MockDCRedirectionPolicy{} s.handler.frontendHandler = s.mockFrontendHandler s.handler.redirectionPolicy = s.mockDCRedirectionPolicy } func (s *dcRedirectionHandlerSuite) TearDownTest() { - s.mockDCRedirectionPolicy.AssertExpectations(s.T()) - s.mockArchivalMetadata.AssertExpectations(s.T()) - s.mockArchiverProvider.AssertExpectations(s.T()) s.controller.Finish() + s.mockResource.Finish(s.T()) + s.mockDCRedirectionPolicy.AssertExpectations(s.T()) } func (s *dcRedirectionHandlerSuite) TestDescribeTaskList() { diff --git a/service/frontend/service.go b/service/frontend/service.go index 3ff80ea9d9d..7b4f8bbeb0d 100644 --- a/service/frontend/service.go +++ b/service/frontend/service.go @@ -21,18 +21,19 @@ package frontend import ( + mock "github.com/stretchr/testify/mock" + "github.com/uber/cadence/common" - "github.com/uber/cadence/common/archiver" - "github.com/uber/cadence/common/cache" "github.com/uber/cadence/common/definition" "github.com/uber/cadence/common/domain" - "github.com/uber/cadence/common/log/loggerimpl" + "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" "github.com/uber/cadence/common/messaging" "github.com/uber/cadence/common/mocks" "github.com/uber/cadence/common/persistence" - "github.com/uber/cadence/common/persistence/client" + persistenceClient "github.com/uber/cadence/common/persistence/client" espersistence "github.com/uber/cadence/common/persistence/elasticsearch" + "github.com/uber/cadence/common/resource" "github.com/uber/cadence/common/service" "github.com/uber/cadence/common/service/config" "github.com/uber/cadence/common/service/dynamicconfig" @@ -124,153 +125,117 @@ func NewConfig(dc *dynamicconfig.Collection, numHistoryShards int, enableReadFro // Service represents the cadence-frontend service type Service struct { + resource.Resource + stopC chan struct{} config *Config params *service.BootstrapParams } // NewService builds a new cadence-frontend service -func NewService(params *service.BootstrapParams) common.Daemon { - isAdvancedVisExistInConfig := len(params.PersistenceConfig.AdvancedVisibilityStore) != 0 - config := NewConfig(dynamicconfig.NewCollection(params.DynamicConfig, params.Logger), params.PersistenceConfig.NumHistoryShards, isAdvancedVisExistInConfig) - params.ThrottledLogger = loggerimpl.NewThrottledLogger(params.Logger, config.ThrottledLogRPS) - params.UpdateLoggerWithServiceName(common.FrontendServiceName) - return &Service{ - params: params, - config: config, - stopC: make(chan struct{}), - } -} - -// Start starts the service -func (s *Service) Start() { - - var params = s.params - var log = params.Logger +func NewService( + params *service.BootstrapParams, +) (resource.Resource, error) { - log.Info("starting", tag.Service(common.FrontendServiceName)) - - base := service.New(params) - - pConfig := params.PersistenceConfig - pConfig.HistoryMaxConns = s.config.HistoryMgrNumConns() - pConfig.SetMaxQPS(pConfig.DefaultStore, s.config.PersistenceMaxQPS()) - pConfig.VisibilityConfig = &config.VisibilityConfig{ - VisibilityListMaxQPS: s.config.VisibilityListMaxQPS, - EnableSampling: s.config.EnableVisibilitySampling, - EnableReadFromClosedExecutionV2: s.config.EnableReadFromClosedExecutionV2, - } - pFactory := client.NewFactory(&pConfig, params.ClusterMetadata.GetCurrentClusterName(), base.GetMetricsClient(), log) - - metadata, err := pFactory.NewMetadataManager() - if err != nil { - log.Fatal("failed to create metadata manager", tag.Error(err)) - } - - visibilityFromDB, err := pFactory.NewVisibilityManager() - if err != nil { - log.Fatal("failed to create visibility manager", tag.Error(err)) + isAdvancedVisExistInConfig := len(params.PersistenceConfig.AdvancedVisibilityStore) != 0 + serviceConfig := NewConfig(dynamicconfig.NewCollection(params.DynamicConfig, params.Logger), params.PersistenceConfig.NumHistoryShards, isAdvancedVisExistInConfig) + + params.PersistenceConfig.HistoryMaxConns = serviceConfig.HistoryMgrNumConns() + params.PersistenceConfig.SetMaxQPS(params.PersistenceConfig.DefaultStore, serviceConfig.PersistenceMaxQPS()) + params.PersistenceConfig.VisibilityConfig = &config.VisibilityConfig{ + VisibilityListMaxQPS: serviceConfig.VisibilityListMaxQPS, + EnableSampling: serviceConfig.EnableVisibilitySampling, + EnableReadFromClosedExecutionV2: serviceConfig.EnableReadFromClosedExecutionV2, } - var visibilityFromES persistence.VisibilityManager - if params.ESConfig != nil { - visibilityIndexName := params.ESConfig.Indices[common.VisibilityAppName] - visibilityConfigForES := &config.VisibilityConfig{ - MaxQPS: s.config.PersistenceMaxQPS, - VisibilityListMaxQPS: s.config.ESVisibilityListMaxQPS, - ESIndexMaxResultWindow: s.config.ESIndexMaxResultWindow, - ValidSearchAttributes: s.config.ValidSearchAttributes, + visibilityManagerInitializer := func( + persistenceBean persistenceClient.Bean, + logger log.Logger, + ) (persistence.VisibilityManager, error) { + visibilityFromDB := persistenceBean.GetVisibilityManager() + + var visibilityFromES persistence.VisibilityManager + if params.ESConfig != nil { + visibilityIndexName := params.ESConfig.Indices[common.VisibilityAppName] + visibilityConfigForES := &config.VisibilityConfig{ + MaxQPS: serviceConfig.PersistenceMaxQPS, + VisibilityListMaxQPS: serviceConfig.ESVisibilityListMaxQPS, + ESIndexMaxResultWindow: serviceConfig.ESIndexMaxResultWindow, + ValidSearchAttributes: serviceConfig.ValidSearchAttributes, + } + visibilityFromES = espersistence.NewESVisibilityManager(visibilityIndexName, params.ESClient, visibilityConfigForES, + nil, params.MetricsClient, logger) } - visibilityFromES = espersistence.NewESVisibilityManager(visibilityIndexName, params.ESClient, visibilityConfigForES, - nil, base.GetMetricsClient(), log) + return persistence.NewVisibilityManagerWrapper( + visibilityFromDB, + visibilityFromES, + serviceConfig.EnableReadVisibilityFromES, + dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), // frontend visibility never write + ), nil } - visibility := persistence.NewVisibilityManagerWrapper( - visibilityFromDB, - visibilityFromES, - s.config.EnableReadVisibilityFromES, - dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), // frontend visibility never write - ) - historyV2, err := pFactory.NewHistoryManager() + serviceResource, err := resource.New( + params, + common.FrontendServiceName, + serviceConfig.ThrottledLogRPS, + visibilityManagerInitializer, + ) if err != nil { - log.Fatal("Creating historyV2 manager persistence failed", tag.Error(err)) + return nil, err } - domainCache := cache.NewDomainCache(metadata, base.GetClusterMetadata(), base.GetMetricsClient(), base.GetLogger()) + return &Service{ + Resource: serviceResource, + config: serviceConfig, + stopC: make(chan struct{}), + params: params, + }, nil +} - historyArchiverBootstrapContainer := &archiver.HistoryBootstrapContainer{ - HistoryV2Manager: historyV2, - Logger: base.GetLogger(), - MetricsClient: base.GetMetricsClient(), - ClusterMetadata: base.GetClusterMetadata(), - DomainCache: domainCache, - } - visibilityArchiverBootstrapContainer := &archiver.VisibilityBootstrapContainer{ - Logger: base.GetLogger(), - MetricsClient: base.GetMetricsClient(), - ClusterMetadata: base.GetClusterMetadata(), - DomainCache: domainCache, - } - err = params.ArchiverProvider.RegisterBootstrapContainer(common.FrontendServiceName, historyArchiverBootstrapContainer, visibilityArchiverBootstrapContainer) - if err != nil { - log.Fatal("Failed to register archiver bootstrap container", tag.Error(err)) - } +// Start starts the service +func (s *Service) Start() { + + logger := s.GetLogger() + logger.Info("frontend starting", tag.Service(common.FrontendServiceName)) var replicationMessageSink messaging.Producer - var domainReplicationQueue persistence.DomainReplicationQueue - clusterMetadata := base.GetClusterMetadata() + clusterMetadata := s.GetClusterMetadata() if clusterMetadata.IsGlobalDomainEnabled() { consumerConfig := clusterMetadata.GetReplicationConsumerConfig() if consumerConfig != nil && consumerConfig.Type == config.ReplicationConsumerTypeRPC { - domainReplicationQueue, err = pFactory.NewDomainReplicationQueue() - if err != nil { - log.Fatal("Failed to create domain replication queue", tag.Error(err)) - } - replicationMessageSink = domainReplicationQueue + replicationMessageSink = s.GetDomainReplicationQueue() } else { - replicationMessageSink, err = base.GetMessagingClient().NewProducerWithClusterName( - base.GetClusterMetadata().GetCurrentClusterName()) + var err error + replicationMessageSink, err = s.GetMessagingClient().NewProducerWithClusterName( + s.GetClusterMetadata().GetCurrentClusterName()) if err != nil { - log.Fatal("Creating replicationMessageSink producer failed", tag.Error(err)) + logger.Fatal("Creating replicationMessageSink producer failed", tag.Error(err)) } } } else { replicationMessageSink = &mocks.KafkaProducer{} + replicationMessageSink.(*mocks.KafkaProducer).On("Publish", mock.Anything).Return(nil) } - wfHandler := NewWorkflowHandler( - base, - s.config, - metadata, - historyV2, - visibility, - replicationMessageSink, - domainReplicationQueue, - domainCache) - dcRedirectionHandler := NewDCRedirectionHandler(wfHandler, params.DCRedirectionPolicy) + wfHandler := NewWorkflowHandler(s, s.config, replicationMessageSink) + dcRedirectionHandler := NewDCRedirectionHandler(wfHandler, s.params.DCRedirectionPolicy) dcRedirectionHandler.RegisterHandler() - adminHandler := NewAdminHandler(base, pConfig.NumHistoryShards, domainCache, historyV2, s.params, s.config) + adminHandler := NewAdminHandler(s, s.params, s.config) adminHandler.RegisterHandler() - // must start base service first - base.Start() - err = dcRedirectionHandler.Start() - if err != nil { - log.Fatal("DC redirection handler failed to start", tag.Error(err)) - } - err = adminHandler.Start() - if err != nil { - log.Fatal("Admin handler failed to start", tag.Error(err)) - } + // must start resource first + s.Resource.Start() + dcRedirectionHandler.Start() + adminHandler.Start() // base (service is not started in frontend or admin handler) in case of race condition in yarpc registration function - log.Info("started", tag.Service(common.FrontendServiceName)) + logger.Info("started", tag.Service(common.FrontendServiceName)) <-s.stopC - base.Stop() + s.Resource.Stop() } // Stop stops the service diff --git a/service/frontend/workflowHandler.go b/service/frontend/workflowHandler.go index a9ae3cbc683..f9d90811b23 100644 --- a/service/frontend/workflowHandler.go +++ b/service/frontend/workflowHandler.go @@ -27,7 +27,6 @@ import ( "encoding/json" "errors" "fmt" - "sync" "time" "github.com/pborman/uuid" @@ -41,8 +40,6 @@ import ( m "github.com/uber/cadence/.gen/go/matching" "github.com/uber/cadence/.gen/go/replicator" gen "github.com/uber/cadence/.gen/go/shared" - "github.com/uber/cadence/client/history" - "github.com/uber/cadence/client/matching" "github.com/uber/cadence/common" "github.com/uber/cadence/common/archiver" "github.com/uber/cadence/common/backoff" @@ -56,7 +53,7 @@ import ( "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/quotas" - "github.com/uber/cadence/common/service" + "github.com/uber/cadence/common/resource" ) const ( @@ -69,24 +66,15 @@ var _ workflowserviceserver.Interface = (*WorkflowHandler)(nil) type ( // WorkflowHandler - Thrift handler interface for workflow service WorkflowHandler struct { - domainCache cache.DomainCache - metadataMgr persistence.MetadataManager - historyV2Mgr persistence.HistoryManager - visibilityMgr persistence.VisibilityManager - history history.Client - matching matching.Client - matchingRawClient matching.Client + resource.Resource + tokenSerializer common.TaskTokenSerializer - metricsClient metrics.Client - startWG sync.WaitGroup rateLimiter quotas.Policy config *Config versionChecker client.VersionChecker domainHandler domain.Handler visibilityQueryValidator *validator.VisibilityQueryValidator searchAttributesValidator *validator.SearchAttributesValidator - domainReplicationQueue persistence.DomainReplicationQueue - service.Service } getHistoryContinuationToken struct { @@ -149,24 +137,14 @@ var ( // NewWorkflowHandler creates a thrift handler for the cadence service func NewWorkflowHandler( - sVice service.Service, + resource resource.Resource, config *Config, - metadataMgr persistence.MetadataManager, - historyV2Mgr persistence.HistoryManager, - visibilityMgr persistence.VisibilityManager, replicationMessageSink messaging.Producer, - domainReplicationQueue persistence.DomainReplicationQueue, - domainCache cache.DomainCache, ) *WorkflowHandler { - handler := &WorkflowHandler{ - Service: sVice, + return &WorkflowHandler{ + Resource: resource, config: config, - metadataMgr: metadataMgr, - historyV2Mgr: historyV2Mgr, - visibilityMgr: visibilityMgr, tokenSerializer: common.NewJSONTaskTokenSerializer(), - metricsClient: sVice.GetMetricsClient(), - domainCache: domainCache, rateLimiter: quotas.NewMultiStageRateLimiter( func() float64 { return float64(config.RPS()) @@ -179,63 +157,41 @@ func NewWorkflowHandler( domainHandler: domain.NewHandler( config.MinRetentionDays(), config.MaxBadBinaries, - sVice.GetLogger(), - metadataMgr, - sVice.GetClusterMetadata(), - domain.NewDomainReplicator(replicationMessageSink, sVice.GetLogger()), - sVice.GetArchivalMetadata(), - sVice.GetArchiverProvider(), + resource.GetLogger(), + resource.GetMetadataManager(), + resource.GetClusterMetadata(), + domain.NewDomainReplicator(replicationMessageSink, resource.GetLogger()), + resource.GetArchivalMetadata(), + resource.GetArchiverProvider(), ), visibilityQueryValidator: validator.NewQueryValidator(config.ValidSearchAttributes), searchAttributesValidator: validator.NewSearchAttributesValidator( - sVice.GetLogger(), + resource.GetLogger(), config.ValidSearchAttributes, config.SearchAttributesNumberOfKeysLimit, config.SearchAttributesSizeOfValueLimit, config.SearchAttributesTotalSizeLimit, ), - domainReplicationQueue: domainReplicationQueue, } - // prevent us from trying to serve requests before handler's Start() is complete - handler.startWG.Add(1) - return handler } // RegisterHandler register this handler, must be called before Start() // if DCRedirectionHandler is also used, use RegisterHandler in DCRedirectionHandler instead func (wh *WorkflowHandler) RegisterHandler() { - wh.Service.GetDispatcher().Register(workflowserviceserver.New(wh)) - wh.Service.GetDispatcher().Register(metaserver.New(wh)) + wh.GetDispatcher().Register(workflowserviceserver.New(wh)) + wh.GetDispatcher().Register(metaserver.New(wh)) } // Start starts the handler -func (wh *WorkflowHandler) Start() error { - wh.domainCache.Start() - - wh.history = wh.GetClientBean().GetHistoryClient() - matchingRawClient, err := wh.GetClientBean().GetMatchingClient(wh.domainCache.GetDomainName) - if err != nil { - return err - } - wh.matchingRawClient = matchingRawClient - wh.matching = matching.NewRetryableClient(wh.matchingRawClient, common.CreateMatchingServiceRetryPolicy(), - common.IsWhitelistServiceTransientError) - wh.startWG.Done() - return nil +func (wh *WorkflowHandler) Start() { } // Stop stops the handler func (wh *WorkflowHandler) Stop() { - wh.domainReplicationQueue.Close() - wh.domainCache.Stop() - wh.metadataMgr.Close() - wh.visibilityMgr.Close() - wh.Service.Stop() } // Health is for health check func (wh *WorkflowHandler) Health(ctx context.Context) (*health.HealthStatus, error) { - wh.startWG.Wait() wh.GetLogger().Debug("Frontend health check endpoint reached.") hs := &health.HealthStatus{Ok: true, Msg: common.StringPtr("frontend good")} return hs, nil @@ -420,11 +376,11 @@ func (wh *WorkflowHandler) PollForActivityTask( return nil, wh.error(errRequestNotSet, scope) } - wh.Service.GetLogger().Debug("Received PollForActivityTask") + wh.GetLogger().Debug("Received PollForActivityTask") if err := common.ValidateLongPollContextTimeout( ctx, "PollForActivityTask", - wh.Service.GetThrottledLogger(), + wh.GetThrottledLogger(), ); err != nil { return nil, wh.error(err, scope) } @@ -444,7 +400,7 @@ func (wh *WorkflowHandler) PollForActivityTask( return nil, wh.error(errIdentityTooLong, scope) } - domainID, err := wh.domainCache.GetDomainID(pollRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(pollRequest.GetDomain()) if err != nil { return nil, wh.error(err, scope) } @@ -452,7 +408,7 @@ func (wh *WorkflowHandler) PollForActivityTask( pollerID := uuid.New() op := func() error { var err error - resp, err = wh.matching.PollForActivityTask(ctx, &m.PollForActivityTaskRequest{ + resp, err = wh.GetMatchingClient().PollForActivityTask(ctx, &m.PollForActivityTaskRequest{ DomainUUID: common.StringPtr(domainID), PollerID: common.StringPtr(pollerID), PollRequest: pollRequest, @@ -470,7 +426,7 @@ func (wh *WorkflowHandler) PollForActivityTask( if ok { ctxTimeout = ctxDeadline.Sub(callTime).String() } - wh.Service.GetLogger().Error("PollForActivityTask failed.", + wh.GetLogger().Error("PollForActivityTask failed.", tag.WorkflowTaskListName(pollRequest.GetTaskList().GetName()), tag.Value(ctxTimeout), tag.Error(err)) @@ -500,11 +456,11 @@ func (wh *WorkflowHandler) PollForDecisionTask( return nil, wh.error(errRequestNotSet, scope) } - wh.Service.GetLogger().Debug("Received PollForDecisionTask") + wh.GetLogger().Debug("Received PollForDecisionTask") if err := common.ValidateLongPollContextTimeout( ctx, "PollForDecisionTask", - wh.Service.GetThrottledLogger(), + wh.GetThrottledLogger(), ); err != nil { return nil, wh.error(err, scope) } @@ -525,13 +481,13 @@ func (wh *WorkflowHandler) PollForDecisionTask( } domainName := pollRequest.GetDomain() - domainEntry, err := wh.domainCache.GetDomain(domainName) + domainEntry, err := wh.GetDomainCache().GetDomain(domainName) if err != nil { return nil, wh.error(err, scope) } domainID := domainEntry.GetInfo().ID - wh.Service.GetLogger().Debug("Poll for decision.", tag.WorkflowDomainName(domainName), tag.WorkflowDomainID(domainID)) + wh.GetLogger().Debug("Poll for decision.", tag.WorkflowDomainName(domainName), tag.WorkflowDomainID(domainID)) if err := wh.checkBadBinary(domainEntry, pollRequest.GetBinaryChecksum()); err != nil { return nil, wh.error(err, scope) } @@ -540,7 +496,7 @@ func (wh *WorkflowHandler) PollForDecisionTask( var matchingResp *m.PollForDecisionTaskResponse op := func() error { var err error - matchingResp, err = wh.matching.PollForDecisionTask(ctx, &m.PollForDecisionTaskRequest{ + matchingResp, err = wh.GetMatchingClient().PollForDecisionTask(ctx, &m.PollForDecisionTaskRequest{ DomainUUID: common.StringPtr(domainID), PollerID: common.StringPtr(pollerID), PollRequest: pollRequest, @@ -558,7 +514,7 @@ func (wh *WorkflowHandler) PollForDecisionTask( if ok { ctxTimeout = ctxDeadline.Sub(callTime).String() } - wh.Service.GetLogger().Error("PollForDecisionTask failed.", + wh.GetLogger().Error("PollForDecisionTask failed.", tag.WorkflowTaskListName(pollRequest.GetTaskList().GetName()), tag.Value(ctxTimeout), tag.Error(err)) @@ -581,7 +537,7 @@ func (wh *WorkflowHandler) checkBadBinary(domainEntry *cache.DomainCacheEntry, b badBinaries := domainEntry.GetConfig().BadBinaries.Binaries _, ok := badBinaries[binaryChecksum] if ok { - wh.metricsClient.IncCounter(metrics.FrontendPollForDecisionTaskScope, metrics.CadenceErrBadBinaryCounter) + wh.GetMetricsClient().IncCounter(metrics.FrontendPollForDecisionTaskScope, metrics.CadenceErrBadBinaryCounter) return &gen.BadRequestError{ Message: fmt.Sprintf("binary %v already marked as bad deployment", binaryChecksum), } @@ -596,7 +552,7 @@ func (wh *WorkflowHandler) cancelOutstandingPoll(ctx context.Context, err error, if ctx.Err() == context.Canceled { // Our rpc stack does not propagates context cancellation to the other service. Lets make an explicit // call to matching to notify this poller is gone to prevent any tasks being dispatched to zombie pollers. - err = wh.matching.CancelOutstandingPoll(context.Background(), &m.CancelOutstandingPollRequest{ + err = wh.GetMatchingClient().CancelOutstandingPoll(context.Background(), &m.CancelOutstandingPollRequest{ DomainUUID: common.StringPtr(domainID), TaskListType: common.Int32Ptr(taskListType), TaskList: taskList, @@ -604,7 +560,7 @@ func (wh *WorkflowHandler) cancelOutstandingPoll(ctx context.Context, err error, }) // We can not do much if this call fails. Just log the error and move on if err != nil { - wh.Service.GetLogger().Warn("Failed to cancel outstanding poller.", + wh.GetLogger().Warn("Failed to cancel outstanding poller.", tag.WorkflowTaskListName(taskList.GetName()), tag.Error(err)) } @@ -635,7 +591,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat( // Count the request in the RPS, but we still accept it even if RPS is exceeded wh.allow(nil) - wh.Service.GetLogger().Debug("Received RecordActivityTaskHeartbeat") + wh.GetLogger().Debug("Received RecordActivityTaskHeartbeat") if heartbeatRequest.TaskToken == nil { return nil, wh.error(errTaskTokenNotSet, scope) } @@ -647,7 +603,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat( return nil, wh.error(errDomainNotSet, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return nil, wh.error(err, scope) } @@ -680,7 +636,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat( Details: heartbeatRequest.Details[0:sizeLimitError], Identity: heartbeatRequest.Identity, } - err = wh.history.RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, }) @@ -689,7 +645,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat( } resp = &gen.RecordActivityTaskHeartbeatResponse{CancelRequested: common.BoolPtr(true)} } else { - resp, err = wh.history.RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{ + resp, err = wh.GetHistoryClient().RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), HeartbeatRequest: heartbeatRequest, }) @@ -722,8 +678,8 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID( // Count the request in the RPS, but we still accept it even if RPS is exceeded wh.allow(nil) - wh.Service.GetLogger().Debug("Received RecordActivityTaskHeartbeatByID") - domainID, err := wh.domainCache.GetDomainID(heartbeatRequest.GetDomain()) + wh.GetLogger().Debug("Received RecordActivityTaskHeartbeatByID") + domainID, err := wh.GetDomainCache().GetDomainID(heartbeatRequest.GetDomain()) if err != nil { return nil, wh.error(err, scope) } @@ -753,7 +709,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID( return nil, wh.error(err, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return nil, wh.error(err, scope) } @@ -781,7 +737,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID( Details: heartbeatRequest.Details[0:sizeLimitError], Identity: heartbeatRequest.Identity, } - err = wh.history.RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, }) @@ -796,7 +752,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID( Identity: heartbeatRequest.Identity, } - resp, err = wh.history.RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{ + resp, err = wh.GetHistoryClient().RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), HeartbeatRequest: req, }) @@ -838,7 +794,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted( return wh.error(errDomainNotSet, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return wh.error(err, scope) } @@ -874,7 +830,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted( Details: completeRequest.Result[0:sizeLimitError], Identity: completeRequest.Identity, } - err = wh.history.RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, }) @@ -882,7 +838,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted( return wh.error(err, scope) } } else { - err = wh.history.RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CompleteRequest: completeRequest, }) @@ -915,7 +871,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID( // Count the request in the RPS, but we still accept it even if RPS is exceeded wh.allow(nil) - domainID, err := wh.domainCache.GetDomainID(completeRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(completeRequest.GetDomain()) if err != nil { return wh.error(err, scope) } @@ -949,7 +905,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID( return wh.error(err, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return wh.error(err, scope) } @@ -977,7 +933,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID( Details: completeRequest.Result[0:sizeLimitError], Identity: completeRequest.Identity, } - err = wh.history.RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, }) @@ -991,7 +947,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID( Identity: completeRequest.Identity, } - err = wh.history.RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CompleteRequest: req, }) @@ -1033,7 +989,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailed( return wh.error(errDomainNotSet, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return wh.error(err, scope) } @@ -1068,7 +1024,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailed( failedRequest.Details = failedRequest.Details[0:sizeLimitError] } - err = wh.history.RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failedRequest, }) @@ -1099,7 +1055,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailedByID( // Count the request in the RPS, but we still accept it even if RPS is exceeded wh.allow(nil) - domainID, err := wh.domainCache.GetDomainID(failedRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(failedRequest.GetDomain()) if err != nil { return wh.error(err, scope) } @@ -1132,7 +1088,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailedByID( return wh.error(err, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return wh.error(err, scope) } @@ -1165,7 +1121,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailedByID( Identity: failedRequest.Identity, } - err = wh.history.RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: req, }) @@ -1205,7 +1161,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled( return wh.error(errDomainNotSet, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return wh.error(err, scope) } @@ -1242,7 +1198,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled( Details: cancelRequest.Details[0:sizeLimitError], Identity: cancelRequest.Identity, } - err = wh.history.RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, }) @@ -1250,7 +1206,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled( return wh.error(err, scope) } } else { - err = wh.history.RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{ + err = wh.GetHistoryClient().RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CancelRequest: cancelRequest, }) @@ -1283,7 +1239,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID( // Count the request in the RPS, but we still accept it even if RPS is exceeded wh.allow(nil) - domainID, err := wh.domainCache.GetDomainID(cancelRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(cancelRequest.GetDomain()) if err != nil { return wh.error(err, scope) } @@ -1316,7 +1272,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID( return wh.error(err, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return wh.error(err, scope) } @@ -1344,7 +1300,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID( Details: cancelRequest.Details[0:sizeLimitError], Identity: cancelRequest.Identity, } - err = wh.history.RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ + err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, }) @@ -1358,7 +1314,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID( Identity: cancelRequest.Identity, } - err = wh.history.RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{ + err = wh.GetHistoryClient().RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CancelRequest: req, }) @@ -1400,7 +1356,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskCompleted( return nil, wh.error(errDomainNotSet, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return nil, wh.error(err, scope) } @@ -1413,7 +1369,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskCompleted( ) defer sw.Stop() - histResp, err := wh.history.RespondDecisionTaskCompleted(ctx, &h.RespondDecisionTaskCompletedRequest{ + histResp, err := wh.GetHistoryClient().RespondDecisionTaskCompleted(ctx, &h.RespondDecisionTaskCompletedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CompleteRequest: completeRequest}, ) @@ -1481,7 +1437,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskFailed( return wh.error(errDomainNotSet, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(taskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(taskToken.DomainID) if err != nil { return wh.error(err, scope) } @@ -1515,7 +1471,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskFailed( failedRequest.Details = failedRequest.Details[0:sizeLimitError] } - err = wh.history.RespondDecisionTaskFailed(ctx, &h.RespondDecisionTaskFailedRequest{ + err = wh.GetHistoryClient().RespondDecisionTaskFailed(ctx, &h.RespondDecisionTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failedRequest, }) @@ -1555,7 +1511,7 @@ func (wh *WorkflowHandler) RespondQueryTaskCompleted( return wh.error(errInvalidTaskToken, scope) } - domainEntry, err := wh.domainCache.GetDomainByID(queryTaskToken.DomainID) + domainEntry, err := wh.GetDomainCache().GetDomainByID(queryTaskToken.DomainID) if err != nil { return wh.error(err, scope) } @@ -1602,7 +1558,7 @@ func (wh *WorkflowHandler) RespondQueryTaskCompleted( CompletedRequest: completeRequest, } - err = wh.matching.RespondQueryTaskCompleted(ctx, matchingRequest) + err = wh.GetMatchingClient().RespondQueryTaskCompleted(ctx, matchingRequest) if err != nil { return wh.error(err, scope) } @@ -1656,7 +1612,7 @@ func (wh *WorkflowHandler) StartWorkflowExecution( return nil, wh.error(err, scope) } - wh.Service.GetLogger().Debug( + wh.GetLogger().Debug( "Received StartWorkflowExecution. WorkflowID", tag.WorkflowID(startRequest.GetWorkflowId())) @@ -1692,8 +1648,8 @@ func (wh *WorkflowHandler) StartWorkflowExecution( return nil, wh.error(err, scope) } - wh.Service.GetLogger().Debug("Start workflow execution request domain", tag.WorkflowDomainName(domainName)) - domainID, err := wh.domainCache.GetDomainID(domainName) + wh.GetLogger().Debug("Start workflow execution request domain", tag.WorkflowDomainName(domainName)) + domainID, err := wh.GetDomainCache().GetDomainID(domainName) if err != nil { return nil, wh.error(err, scope) } @@ -1720,8 +1676,8 @@ func (wh *WorkflowHandler) StartWorkflowExecution( return nil, wh.error(err, scope) } - wh.Service.GetLogger().Debug("Start workflow execution request domainID", tag.WorkflowDomainID(domainID)) - resp, err = wh.history.StartWorkflowExecution(ctx, common.CreateHistoryStartWorkflowRequest(domainID, startRequest)) + wh.GetLogger().Debug("Start workflow execution request domainID", tag.WorkflowDomainID(domainID)) + resp, err = wh.GetHistoryClient().StartWorkflowExecution(ctx, common.CreateHistoryStartWorkflowRequest(domainID, startRequest)) if err != nil { return nil, wh.error(err, scope) @@ -1763,7 +1719,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistory( getRequest.MaximumPageSize = common.Int32Ptr(int32(wh.config.HistoryMaxPageSize(getRequest.GetDomain()))) } - domainID, err := wh.domainCache.GetDomainID(getRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(getRequest.GetDomain()) if err != nil { return nil, wh.error(err, scope) } @@ -1796,7 +1752,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistory( expectedNextEventID int64, currentBranchToken []byte, ) ([]byte, string, int64, int64, bool, error) { - response, err := wh.history.PollMutableState(ctx, &h.PollMutableStateRequest{ + response, err := wh.GetHistoryClient().PollMutableState(ctx, &h.PollMutableStateRequest{ DomainUUID: common.StringPtr(domainUUID), Execution: execution, ExpectedNextEventId: common.Int64Ptr(expectedNextEventID), @@ -1990,7 +1946,7 @@ func (wh *WorkflowHandler) SignalWorkflowExecution( return wh.error(errRequestIDTooLong, scope) } - domainID, err := wh.domainCache.GetDomainID(signalRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(signalRequest.GetDomain()) if err != nil { return wh.error(err, scope) } @@ -2010,7 +1966,7 @@ func (wh *WorkflowHandler) SignalWorkflowExecution( return wh.error(err, scope) } - err = wh.history.SignalWorkflowExecution(ctx, &h.SignalWorkflowExecutionRequest{ + err = wh.GetHistoryClient().SignalWorkflowExecution(ctx, &h.SignalWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), SignalRequest: signalRequest, }) @@ -2110,7 +2066,7 @@ func (wh *WorkflowHandler) SignalWithStartWorkflowExecution( return nil, wh.error(err, scope) } - domainID, err := wh.domainCache.GetDomainID(domainName) + domainID, err := wh.GetDomainCache().GetDomainID(domainName) if err != nil { return nil, wh.error(err, scope) } @@ -2145,7 +2101,7 @@ func (wh *WorkflowHandler) SignalWithStartWorkflowExecution( op := func() error { var err error - resp, err = wh.history.SignalWithStartWorkflowExecution(ctx, &h.SignalWithStartWorkflowExecutionRequest{ + resp, err = wh.GetHistoryClient().SignalWithStartWorkflowExecution(ctx, &h.SignalWithStartWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), SignalWithStartRequest: signalWithStartRequest, }) @@ -2191,12 +2147,12 @@ func (wh *WorkflowHandler) TerminateWorkflowExecution( return err } - domainID, err := wh.domainCache.GetDomainID(terminateRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(terminateRequest.GetDomain()) if err != nil { return wh.error(err, scope) } - err = wh.history.TerminateWorkflowExecution(ctx, &h.TerminateWorkflowExecutionRequest{ + err = wh.GetHistoryClient().TerminateWorkflowExecution(ctx, &h.TerminateWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), TerminateRequest: terminateRequest, }) @@ -2238,12 +2194,12 @@ func (wh *WorkflowHandler) ResetWorkflowExecution( return nil, err } - domainID, err := wh.domainCache.GetDomainID(resetRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(resetRequest.GetDomain()) if err != nil { return nil, wh.error(err, scope) } - resp, err = wh.history.ResetWorkflowExecution(ctx, &h.ResetWorkflowExecutionRequest{ + resp, err = wh.GetHistoryClient().ResetWorkflowExecution(ctx, &h.ResetWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), ResetRequest: resetRequest, }) @@ -2284,12 +2240,12 @@ func (wh *WorkflowHandler) RequestCancelWorkflowExecution( return err } - domainID, err := wh.domainCache.GetDomainID(cancelRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(cancelRequest.GetDomain()) if err != nil { return wh.error(err, scope) } - err = wh.history.RequestCancelWorkflowExecution(ctx, &h.RequestCancelWorkflowExecutionRequest{ + err = wh.GetHistoryClient().RequestCancelWorkflowExecution(ctx, &h.RequestCancelWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), CancelRequest: cancelRequest, }) @@ -2357,7 +2313,7 @@ func (wh *WorkflowHandler) ListOpenWorkflowExecutions( } domain := listRequest.GetDomain() - domainID, err := wh.domainCache.GetDomainID(domain) + domainID, err := wh.GetDomainCache().GetDomainID(domain) if err != nil { return nil, wh.error(err, scope) } @@ -2376,27 +2332,27 @@ func (wh *WorkflowHandler) ListOpenWorkflowExecutions( if wh.config.DisableListVisibilityByFilter(domain) { err = errNoPermission } else { - persistenceResp, err = wh.visibilityMgr.ListOpenWorkflowExecutionsByWorkflowID( + persistenceResp, err = wh.GetVisibilityManager().ListOpenWorkflowExecutionsByWorkflowID( &persistence.ListWorkflowExecutionsByWorkflowIDRequest{ ListWorkflowExecutionsRequest: baseReq, WorkflowID: listRequest.ExecutionFilter.GetWorkflowId(), }) } - wh.Service.GetLogger().Info("List open workflow with filter", + wh.GetLogger().Info("List open workflow with filter", tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByID) } else if listRequest.TypeFilter != nil { if wh.config.DisableListVisibilityByFilter(domain) { err = errNoPermission } else { - persistenceResp, err = wh.visibilityMgr.ListOpenWorkflowExecutionsByType(&persistence.ListWorkflowExecutionsByTypeRequest{ + persistenceResp, err = wh.GetVisibilityManager().ListOpenWorkflowExecutionsByType(&persistence.ListWorkflowExecutionsByTypeRequest{ ListWorkflowExecutionsRequest: baseReq, WorkflowTypeName: listRequest.TypeFilter.GetName(), }) } - wh.Service.GetLogger().Info("List open workflow with filter", + wh.GetLogger().Info("List open workflow with filter", tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByType) } else { - persistenceResp, err = wh.visibilityMgr.ListOpenWorkflowExecutions(&baseReq) + persistenceResp, err = wh.GetVisibilityManager().ListOpenWorkflowExecutions(&baseReq) } if err != nil { @@ -2453,7 +2409,7 @@ func (wh *WorkflowHandler) ListArchivedWorkflowExecutions( return nil, wh.error(&gen.BadRequestError{Message: "Cluster is not configured for reading archived visibility records"}, scope) } - entry, err := wh.domainCache.GetDomain(listRequest.GetDomain()) + entry, err := wh.GetDomainCache().GetDomain(listRequest.GetDomain()) if err != nil { return nil, wh.error(err, scope) } @@ -2562,7 +2518,7 @@ func (wh *WorkflowHandler) ListClosedWorkflowExecutions( } domain := listRequest.GetDomain() - domainID, err := wh.domainCache.GetDomainID(domain) + domainID, err := wh.GetDomainCache().GetDomainID(domain) if err != nil { return nil, wh.error(err, scope) } @@ -2581,38 +2537,38 @@ func (wh *WorkflowHandler) ListClosedWorkflowExecutions( if wh.config.DisableListVisibilityByFilter(domain) { err = errNoPermission } else { - persistenceResp, err = wh.visibilityMgr.ListClosedWorkflowExecutionsByWorkflowID( + persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutionsByWorkflowID( &persistence.ListWorkflowExecutionsByWorkflowIDRequest{ ListWorkflowExecutionsRequest: baseReq, WorkflowID: listRequest.ExecutionFilter.GetWorkflowId(), }) } - wh.Service.GetLogger().Info("List closed workflow with filter", + wh.GetLogger().Info("List closed workflow with filter", tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByID) } else if listRequest.TypeFilter != nil { if wh.config.DisableListVisibilityByFilter(domain) { err = errNoPermission } else { - persistenceResp, err = wh.visibilityMgr.ListClosedWorkflowExecutionsByType(&persistence.ListWorkflowExecutionsByTypeRequest{ + persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutionsByType(&persistence.ListWorkflowExecutionsByTypeRequest{ ListWorkflowExecutionsRequest: baseReq, WorkflowTypeName: listRequest.TypeFilter.GetName(), }) } - wh.Service.GetLogger().Info("List closed workflow with filter", + wh.GetLogger().Info("List closed workflow with filter", tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByType) } else if listRequest.StatusFilter != nil { if wh.config.DisableListVisibilityByFilter(domain) { err = errNoPermission } else { - persistenceResp, err = wh.visibilityMgr.ListClosedWorkflowExecutionsByStatus(&persistence.ListClosedWorkflowExecutionsByStatusRequest{ + persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutionsByStatus(&persistence.ListClosedWorkflowExecutionsByStatusRequest{ ListWorkflowExecutionsRequest: baseReq, Status: listRequest.GetStatusFilter(), }) } - wh.Service.GetLogger().Info("List closed workflow with filter", + wh.GetLogger().Info("List closed workflow with filter", tag.WorkflowDomainName(listRequest.GetDomain()), tag.WorkflowListWorkflowFilterByStatus) } else { - persistenceResp, err = wh.visibilityMgr.ListClosedWorkflowExecutions(&baseReq) + persistenceResp, err = wh.GetVisibilityManager().ListClosedWorkflowExecutions(&baseReq) } if err != nil { @@ -2665,7 +2621,7 @@ func (wh *WorkflowHandler) ListWorkflowExecutions( } domain := listRequest.GetDomain() - domainID, err := wh.domainCache.GetDomainID(domain) + domainID, err := wh.GetDomainCache().GetDomainID(domain) if err != nil { return nil, wh.error(err, scope) } @@ -2677,7 +2633,7 @@ func (wh *WorkflowHandler) ListWorkflowExecutions( NextPageToken: listRequest.NextPageToken, Query: listRequest.GetQuery(), } - persistenceResp, err := wh.visibilityMgr.ListWorkflowExecutions(req) + persistenceResp, err := wh.GetVisibilityManager().ListWorkflowExecutions(req) if err != nil { return nil, wh.error(err, scope) } @@ -2728,7 +2684,7 @@ func (wh *WorkflowHandler) ScanWorkflowExecutions( } domain := listRequest.GetDomain() - domainID, err := wh.domainCache.GetDomainID(domain) + domainID, err := wh.GetDomainCache().GetDomainID(domain) if err != nil { return nil, wh.error(err, scope) } @@ -2740,7 +2696,7 @@ func (wh *WorkflowHandler) ScanWorkflowExecutions( NextPageToken: listRequest.NextPageToken, Query: listRequest.GetQuery(), } - persistenceResp, err := wh.visibilityMgr.ScanWorkflowExecutions(req) + persistenceResp, err := wh.GetVisibilityManager().ScanWorkflowExecutions(req) if err != nil { return nil, wh.error(err, scope) } @@ -2782,7 +2738,7 @@ func (wh *WorkflowHandler) CountWorkflowExecutions( } domain := countRequest.GetDomain() - domainID, err := wh.domainCache.GetDomainID(domain) + domainID, err := wh.GetDomainCache().GetDomainID(domain) if err != nil { return nil, wh.error(err, scope) } @@ -2792,7 +2748,7 @@ func (wh *WorkflowHandler) CountWorkflowExecutions( Domain: domain, Query: countRequest.GetQuery(), } - persistenceResp, err := wh.visibilityMgr.CountWorkflowExecutions(req) + persistenceResp, err := wh.GetVisibilityManager().CountWorkflowExecutions(req) if err != nil { return nil, wh.error(err, scope) } @@ -2847,12 +2803,12 @@ func (wh *WorkflowHandler) ResetStickyTaskList( return nil, err } - domainID, err := wh.domainCache.GetDomainID(resetRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(resetRequest.GetDomain()) if err != nil { return nil, wh.error(err, scope) } - _, err = wh.history.ResetStickyTaskList(ctx, &h.ResetStickyTaskListRequest{ + _, err = wh.GetHistoryClient().ResetStickyTaskList(ctx, &h.ResetStickyTaskListRequest{ DomainUUID: common.StringPtr(domainID), Execution: resetRequest.Execution, }) @@ -2899,7 +2855,7 @@ func (wh *WorkflowHandler) QueryWorkflow( return nil, wh.error(errQueryTypeNotSet, scope) } - domainID, err := wh.domainCache.GetDomainID(queryRequest.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(queryRequest.GetDomain()) if err != nil { return nil, wh.error(err, scope) } @@ -2923,7 +2879,7 @@ func (wh *WorkflowHandler) QueryWorkflow( DomainUUID: common.StringPtr(domainID), Request: queryRequest, } - hResponse, err := wh.history.QueryWorkflow(ctx, req) + hResponse, err := wh.GetHistoryClient().QueryWorkflow(ctx, req) if err != nil { return nil, wh.error(err, scope) } @@ -2955,7 +2911,7 @@ func (wh *WorkflowHandler) DescribeWorkflowExecution( if request.GetDomain() == "" { return nil, wh.error(errDomainNotSet, scope) } - domainID, err := wh.domainCache.GetDomainID(request.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(request.GetDomain()) if err != nil { return nil, wh.error(err, scope) } @@ -2964,7 +2920,7 @@ func (wh *WorkflowHandler) DescribeWorkflowExecution( return nil, err } - response, err := wh.history.DescribeWorkflowExecution(ctx, &h.DescribeWorkflowExecutionRequest{ + response, err := wh.GetHistoryClient().DescribeWorkflowExecution(ctx, &h.DescribeWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), Request: request, }) @@ -3003,7 +2959,7 @@ func (wh *WorkflowHandler) DescribeTaskList( if request.GetDomain() == "" { return nil, wh.error(errDomainNotSet, scope) } - domainID, err := wh.domainCache.GetDomainID(request.GetDomain()) + domainID, err := wh.GetDomainCache().GetDomainID(request.GetDomain()) if err != nil { return nil, wh.error(err, scope) } @@ -3019,7 +2975,7 @@ func (wh *WorkflowHandler) DescribeTaskList( var response *gen.DescribeTaskListResponse op := func() error { var err error - response, err = wh.matching.DescribeTaskList(ctx, &m.DescribeTaskListRequest{ + response, err = wh.GetMatchingClient().DescribeTaskList(ctx, &m.DescribeTaskListRequest{ DomainUUID: common.StringPtr(domainID), DescRequest: request, }) @@ -3050,7 +3006,7 @@ func (wh *WorkflowHandler) getHistory( shardID := common.WorkflowIDToHistoryShard(*execution.WorkflowId, wh.config.NumHistoryShards) var err error - historyEvents, size, nextPageToken, err = persistence.ReadFullPageV2Events(wh.historyV2Mgr, &persistence.ReadHistoryBranchRequest{ + historyEvents, size, nextPageToken, err = persistence.ReadFullPageV2Events(wh.GetHistoryManager(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -3075,7 +3031,7 @@ func (wh *WorkflowHandler) getHistory( } func (wh *WorkflowHandler) getLoggerForTask(taskToken []byte) log.Logger { - logger := wh.Service.GetLogger() + logger := wh.GetLogger() task, err := wh.tokenSerializer.Deserialize(taskToken) if err == nil { logger = logger.WithTags(tag.WorkflowID(task.WorkflowID), @@ -3087,9 +3043,7 @@ func (wh *WorkflowHandler) getLoggerForTask(taskToken []byte) log.Logger { // startRequestProfile initiates recording of request metrics func (wh *WorkflowHandler) startRequestProfile(scope int) (metrics.Scope, metrics.Stopwatch) { - wh.startWG.Wait() - - metricsScope := wh.metricsClient.Scope(scope).Tagged(metrics.DomainUnknownTag()) + metricsScope := wh.GetMetricsClient().Scope(scope).Tagged(metrics.DomainUnknownTag()) // timer should be emitted with the all tag sw := metricsScope.StartTimer(metrics.CadenceLatency) metricsScope.IncCounter(metrics.CadenceRequests) @@ -3098,13 +3052,11 @@ func (wh *WorkflowHandler) startRequestProfile(scope int) (metrics.Scope, metric // startRequestProfileWithDomain initiates recording of request metrics and returns a domain tagged scope func (wh *WorkflowHandler) startRequestProfileWithDomain(scope int, d domainGetter) (metrics.Scope, metrics.Stopwatch) { - wh.startWG.Wait() - var metricsScope metrics.Scope if d != nil { - metricsScope = wh.metricsClient.Scope(scope).Tagged(metrics.DomainTag(d.GetDomain())) + metricsScope = wh.GetMetricsClient().Scope(scope).Tagged(metrics.DomainTag(d.GetDomain())) } else { - metricsScope = wh.metricsClient.Scope(scope).Tagged(metrics.DomainUnknownTag()) + metricsScope = wh.GetMetricsClient().Scope(scope).Tagged(metrics.DomainUnknownTag()) } sw := metricsScope.StartTimer(metrics.CadenceLatency) metricsScope.IncCounter(metrics.CadenceRequests) @@ -3113,13 +3065,13 @@ func (wh *WorkflowHandler) startRequestProfileWithDomain(scope int, d domainGett // getDefaultScope returns a default scope to use for request metrics func (wh *WorkflowHandler) getDefaultScope(scope int) metrics.Scope { - return wh.metricsClient.Scope(scope).Tagged(metrics.DomainUnknownTag()) + return wh.GetMetricsClient().Scope(scope).Tagged(metrics.DomainUnknownTag()) } func (wh *WorkflowHandler) error(err error, scope metrics.Scope) error { switch err := err.(type) { case *gen.InternalServiceError: - wh.Service.GetLogger().Error("Internal service error", tag.Error(err)) + wh.GetLogger().Error("Internal service error", tag.Error(err)) scope.IncCounter(metrics.CadenceFailures) // NOTE: For internal error, we won't return thrift error from cadence-frontend. // Because in uber internal metrics, thrift errors are counted as user errors @@ -3161,7 +3113,7 @@ func (wh *WorkflowHandler) error(err error, scope metrics.Scope) error { } } - wh.Service.GetLogger().Error("Uncategorized error", + wh.GetLogger().Error("Uncategorized error", tag.Error(err)) scope.IncCounter(metrics.CadenceFailures) return fmt.Errorf("cadence internal uncategorized error, msg: %v", err.Error()) @@ -3243,7 +3195,7 @@ func (wh *WorkflowHandler) createPollForDecisionTaskResponse( if matchingResp.GetStickyExecutionEnabled() { firstEventID = matchingResp.GetPreviousStartedEventId() + 1 } - domain, dErr := wh.domainCache.GetDomainByID(domainID) + domain, dErr := wh.GetDomainCache().GetDomainByID(domainID) if dErr != nil { return nil, dErr } @@ -3331,7 +3283,7 @@ func (wh *WorkflowHandler) historyArchived(ctx context.Context, request *gen.Get DomainUUID: common.StringPtr(domainID), Execution: request.Execution, } - _, err := wh.history.GetMutableState(ctx, getMutableStateRequest) + _, err := wh.GetHistoryClient().GetMutableState(ctx, getMutableStateRequest) if err == nil { return false } @@ -3349,7 +3301,7 @@ func (wh *WorkflowHandler) getArchivedHistory( domainID string, scope metrics.Scope, ) (*gen.GetWorkflowExecutionHistoryResponse, error) { - entry, err := wh.domainCache.GetDomainByID(domainID) + entry, err := wh.GetDomainCache().GetDomainByID(domainID) if err != nil { return nil, wh.error(err, scope) } @@ -3433,7 +3385,7 @@ func (wh *WorkflowHandler) GetReplicationMessages( return nil, wh.error(errRequestNotSet, scope) } - resp, err = wh.history.GetReplicationMessages(ctx, request) + resp, err = wh.GetHistoryClient().GetReplicationMessages(ctx, request) if err != nil { return nil, wh.error(err, scope) } @@ -3458,7 +3410,7 @@ func (wh *WorkflowHandler) GetDomainReplicationMessages( return nil, wh.error(errRequestNotSet, scope) } - if wh.domainReplicationQueue == nil { + if wh.GetDomainReplicationQueue() == nil { return nil, wh.error(errors.New("domain replication queue not enabled for cluster"), scope) } @@ -3468,7 +3420,7 @@ func (wh *WorkflowHandler) GetDomainReplicationMessages( } if lastMessageID == defaultLastMessageID { - clusterAckLevels, err := wh.domainReplicationQueue.GetAckLevels() + clusterAckLevels, err := wh.GetDomainReplicationQueue().GetAckLevels() if err == nil { if ackLevel, ok := clusterAckLevels[request.GetClusterName()]; ok { lastMessageID = ackLevel @@ -3476,7 +3428,7 @@ func (wh *WorkflowHandler) GetDomainReplicationMessages( } } - replicationTasks, lastMessageID, err := wh.domainReplicationQueue.GetReplicationMessages( + replicationTasks, lastMessageID, err := wh.GetDomainReplicationQueue().GetReplicationMessages( lastMessageID, getDomainReplicationMessageBatchSize) if err != nil { return nil, wh.error(err, scope) @@ -3488,7 +3440,7 @@ func (wh *WorkflowHandler) GetDomainReplicationMessages( } if lastProcessedMessageID != defaultLastMessageID { - err := wh.domainReplicationQueue.UpdateAckLevel(lastProcessedMessageID, request.GetClusterName()) + err := wh.GetDomainReplicationQueue().UpdateAckLevel(lastProcessedMessageID, request.GetClusterName()) if err != nil { wh.GetLogger().Warn("Failed to update domain replication queue ack level.", tag.TaskID(int64(lastProcessedMessageID)), @@ -3533,12 +3485,12 @@ func (wh *WorkflowHandler) ReapplyEvents( if request.GetEvents() == nil { return wh.error(errWorkflowIDNotSet, scope) } - domainEntry, err := wh.domainCache.GetDomain(request.GetDomainName()) + domainEntry, err := wh.GetDomainCache().GetDomain(request.GetDomainName()) if err != nil { return wh.error(err, scope) } - err = wh.history.ReapplyEvents(ctx, &h.ReapplyEventsRequest{ + err = wh.GetHistoryClient().ReapplyEvents(ctx, &h.ReapplyEventsRequest{ DomainUUID: common.StringPtr(domainEntry.GetInfo().ID), Request: request, }) diff --git a/service/frontend/workflowHandler_test.go b/service/frontend/workflowHandler_test.go index fe442eae537..9af619eb150 100644 --- a/service/frontend/workflowHandler_test.go +++ b/service/frontend/workflowHandler_test.go @@ -23,36 +23,28 @@ package frontend import ( "context" "errors" - "fmt" "testing" "time" "github.com/golang/mock/gomock" "github.com/pborman/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/uber-go/tally" "github.com/uber/cadence/.gen/go/history/historyservicetest" "github.com/uber/cadence/.gen/go/shared" - gen "github.com/uber/cadence/.gen/go/shared" - workflow "github.com/uber/cadence/.gen/go/shared" - "github.com/uber/cadence/client" "github.com/uber/cadence/common" "github.com/uber/cadence/common/archiver" "github.com/uber/cadence/common/archiver/provider" "github.com/uber/cadence/common/cache" "github.com/uber/cadence/common/cluster" "github.com/uber/cadence/common/domain" - "github.com/uber/cadence/common/log" - "github.com/uber/cadence/common/log/loggerimpl" "github.com/uber/cadence/common/messaging" "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/mocks" "github.com/uber/cadence/common/persistence" - cs "github.com/uber/cadence/common/service" + "github.com/uber/cadence/common/resource" dc "github.com/uber/cadence/common/service/dynamicconfig" ) @@ -70,25 +62,24 @@ type ( suite.Suite *require.Assertions - controller *gomock.Controller - mockDomainCache *cache.MockDomainCache - mockClientBean *client.MockBean + controller *gomock.Controller + mockResource *resource.Test + mockDomainCache *cache.MockDomainCache + mockHistoryClient *historyservicetest.MockClient + mockClusterMetadata *cluster.MockMetadata + + mockProducer *mocks.KafkaProducer + mockMessagingClient messaging.Client + mockMetadataMgr *mocks.MetadataManager + mockHistoryV2Mgr *mocks.HistoryV2Manager + mockVisibilityMgr *mocks.VisibilityManager + mockArchivalMetadata *archiver.MockArchivalMetadata + mockArchiverProvider *provider.MockArchiverProvider + mockHistoryArchiver *archiver.HistoryArchiverMock + mockVisibilityArchiver *archiver.VisibilityArchiverMock testDomain string testDomainID string - logger log.Logger - config *Config - - mockClusterMetadata *mocks.ClusterMetadata - mockProducer *mocks.KafkaProducer - mockMetricClient metrics.Client - mockMessagingClient messaging.Client - mockMetadataMgr *mocks.MetadataManager - mockHistoryV2Mgr *mocks.HistoryV2Manager - mockVisibilityMgr *mocks.VisibilityManager - mockService cs.Service - mockArchivalMetadata *archiver.MockArchivalMetadata - mockArchiverProvider *provider.MockArchiverProvider } ) @@ -101,68 +92,41 @@ func (s *workflowHandlerSuite) SetupSuite() { } func (s *workflowHandlerSuite) TearDownSuite() { - } func (s *workflowHandlerSuite) SetupTest() { s.Assertions = require.New(s.T()) - s.controller = gomock.NewController(s.T()) - s.mockDomainCache = cache.NewMockDomainCache(s.controller) - s.mockClientBean = client.NewMockBean(s.controller) - s.testDomain = "test-domain" s.testDomainID = "e4f90ec0-1313-45be-9877-8aa41f72a45a" - s.logger = loggerimpl.NewDevelopmentForTest(s.Suite) - s.mockClusterMetadata = &mocks.ClusterMetadata{} + + s.controller = gomock.NewController(s.T()) + s.mockResource = resource.NewTest(s.controller, metrics.Frontend) + s.mockDomainCache = s.mockResource.DomainCache + s.mockHistoryClient = s.mockResource.HistoryClient + s.mockClusterMetadata = s.mockResource.ClusterMetadata + s.mockMetadataMgr = s.mockResource.MetadataMgr + s.mockHistoryV2Mgr = s.mockResource.HistoryMgr + s.mockVisibilityMgr = s.mockResource.VisibilityMgr + s.mockArchivalMetadata = s.mockResource.ArchivalMetadata + s.mockArchiverProvider = s.mockResource.ArchiverProvider + s.mockProducer = &mocks.KafkaProducer{} - s.mockMetricClient = metrics.NewClient(tally.NoopScope, metrics.Frontend) s.mockMessagingClient = mocks.NewMockMessagingClient(s.mockProducer, nil) - s.mockMetadataMgr = &mocks.MetadataManager{} - s.mockHistoryV2Mgr = &mocks.HistoryV2Manager{} - s.mockVisibilityMgr = &mocks.VisibilityManager{} - s.mockArchivalMetadata = &archiver.MockArchivalMetadata{} - s.mockArchiverProvider = &provider.MockArchiverProvider{} - s.mockService = cs.NewTestService( - s.mockClusterMetadata, - s.mockMessagingClient, - s.mockMetricClient, - s.mockClientBean, - s.mockArchivalMetadata, - s.mockArchiverProvider, - nil, - ) + s.mockHistoryArchiver = &archiver.HistoryArchiverMock{} + s.mockVisibilityArchiver = &archiver.VisibilityArchiverMock{} } func (s *workflowHandlerSuite) TearDownTest() { - s.mockProducer.AssertExpectations(s.T()) - s.mockMetadataMgr.AssertExpectations(s.T()) - s.mockHistoryV2Mgr.AssertExpectations(s.T()) - s.mockVisibilityMgr.AssertExpectations(s.T()) - s.mockArchivalMetadata.AssertExpectations(s.T()) - s.mockArchiverProvider.AssertExpectations(s.T()) s.controller.Finish() + s.mockResource.Finish(s.T()) + s.mockProducer.AssertExpectations(s.T()) + s.mockHistoryArchiver.AssertExpectations(s.T()) + s.mockVisibilityArchiver.AssertExpectations(s.T()) } func (s *workflowHandlerSuite) getWorkflowHandler(config *Config) *WorkflowHandler { - domainCache := cache.NewDomainCache( - s.mockMetadataMgr, - s.mockService.GetClusterMetadata(), - s.mockService.GetMetricsClient(), - s.mockService.GetLogger(), - ) - return NewWorkflowHandler(s.mockService, config, s.mockMetadataMgr, - s.mockHistoryV2Mgr, s.mockVisibilityMgr, s.mockProducer, nil, domainCache) -} - -func (s *workflowHandlerSuite) getWorkflowHandlerHelper() *WorkflowHandler { - s.config = s.newConfig() - wh := s.getWorkflowHandler(s.config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.domainCache = s.mockDomainCache - wh.visibilityMgr = s.mockVisibilityMgr - wh.startWG.Done() - return wh + return NewWorkflowHandler(s.mockResource, config, s.mockProducer) } func (s *workflowHandlerSuite) TestDisableListVisibilityByFilter() { @@ -172,9 +136,6 @@ func (s *workflowHandlerSuite) TestDisableListVisibilityByFilter() { config.DisableListVisibilityByFilter = dc.GetBoolPropertyFnFilteredByDomain(true) wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.domainCache = s.mockDomainCache - wh.startWG.Done() s.mockDomainCache.EXPECT().GetDomainID(gomock.Any()).Return(domainID, nil).AnyTimes() @@ -190,8 +151,8 @@ func (s *workflowHandlerSuite) TestDisableListVisibilityByFilter() { }, } _, err := wh.ListOpenWorkflowExecutions(context.Background(), listRequest) - assert.Error(s.T(), err) - assert.Equal(s.T(), errNoPermission, err) + s.Error(err) + s.Equal(errNoPermission, err) // test list open by workflow type listRequest.ExecutionFilter = nil @@ -199,8 +160,8 @@ func (s *workflowHandlerSuite) TestDisableListVisibilityByFilter() { Name: common.StringPtr("workflow-type"), } _, err = wh.ListOpenWorkflowExecutions(context.Background(), listRequest) - assert.Error(s.T(), err) - assert.Equal(s.T(), errNoPermission, err) + s.Error(err) + s.Equal(errNoPermission, err) // test list close by wid listRequest2 := &shared.ListClosedWorkflowExecutionsRequest{ @@ -214,8 +175,8 @@ func (s *workflowHandlerSuite) TestDisableListVisibilityByFilter() { }, } _, err = wh.ListClosedWorkflowExecutions(context.Background(), listRequest2) - assert.Error(s.T(), err) - assert.Equal(s.T(), errNoPermission, err) + s.Error(err) + s.Equal(errNoPermission, err) // test list close by workflow type listRequest2.ExecutionFilter = nil @@ -223,51 +184,47 @@ func (s *workflowHandlerSuite) TestDisableListVisibilityByFilter() { Name: common.StringPtr("workflow-type"), } _, err = wh.ListClosedWorkflowExecutions(context.Background(), listRequest2) - assert.Error(s.T(), err) - assert.Equal(s.T(), errNoPermission, err) + s.Error(err) + s.Equal(errNoPermission, err) // test list close by workflow status listRequest2.TypeFilter = nil failedStatus := shared.WorkflowExecutionCloseStatusFailed listRequest2.StatusFilter = &failedStatus _, err = wh.ListClosedWorkflowExecutions(context.Background(), listRequest2) - assert.Error(s.T(), err) - assert.Equal(s.T(), errNoPermission, err) + s.Error(err) + s.Equal(errNoPermission, err) } func (s *workflowHandlerSuite) TestPollForTask_Failed_ContextTimeoutTooShort() { config := s.newConfig() wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() bgCtx := context.Background() _, err := wh.PollForDecisionTask(bgCtx, &shared.PollForDecisionTaskRequest{}) - assert.Error(s.T(), err) - assert.Equal(s.T(), common.ErrContextTimeoutNotSet, err) + s.Error(err) + s.Equal(common.ErrContextTimeoutNotSet, err) _, err = wh.PollForActivityTask(bgCtx, &shared.PollForActivityTaskRequest{}) - assert.Error(s.T(), err) - assert.Equal(s.T(), common.ErrContextTimeoutNotSet, err) + s.Error(err) + s.Equal(common.ErrContextTimeoutNotSet, err) shortCtx, cancel := context.WithTimeout(bgCtx, common.MinLongPollTimeout-time.Millisecond) defer cancel() _, err = wh.PollForDecisionTask(shortCtx, &shared.PollForDecisionTaskRequest{}) - assert.Error(s.T(), err) - assert.Equal(s.T(), common.ErrContextTimeoutTooShort, err) + s.Error(err) + s.Equal(common.ErrContextTimeoutTooShort, err) _, err = wh.PollForActivityTask(shortCtx, &shared.PollForActivityTaskRequest{}) - assert.Error(s.T(), err) - assert.Equal(s.T(), common.ErrContextTimeoutTooShort, err) + s.Error(err) + s.Equal(common.ErrContextTimeoutTooShort, err) } func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_RequestIdNotSet() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() startWorkflowExecutionRequest := &shared.StartWorkflowExecutionRequest{ Domain: common.StringPtr("test-domain"), @@ -289,28 +246,24 @@ func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_RequestIdNotSet }, } _, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) - assert.Error(s.T(), err) - assert.Equal(s.T(), errRequestIDNotSet, err) + s.Error(err) + s.Equal(errRequestIDNotSet, err) } func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_StartRequestNotSet() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() _, err := wh.StartWorkflowExecution(context.Background(), nil) - assert.Error(s.T(), err) - assert.Equal(s.T(), errRequestNotSet, err) + s.Error(err) + s.Equal(errRequestNotSet, err) } func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_DomainNotSet() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() startWorkflowExecutionRequest := &shared.StartWorkflowExecutionRequest{ WorkflowId: common.StringPtr("workflow-id"), @@ -332,16 +285,14 @@ func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_DomainNotSet() RequestId: common.StringPtr(uuid.New()), } _, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) - assert.Error(s.T(), err) - assert.Equal(s.T(), errDomainNotSet, err) + s.Error(err) + s.Equal(errDomainNotSet, err) } func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_WorkflowIdNotSet() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() startWorkflowExecutionRequest := &shared.StartWorkflowExecutionRequest{ Domain: common.StringPtr("test-domain"), @@ -363,16 +314,14 @@ func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_WorkflowIdNotSe RequestId: common.StringPtr(uuid.New()), } _, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) - assert.Error(s.T(), err) - assert.Equal(s.T(), errWorkflowIDNotSet, err) + s.Error(err) + s.Equal(errWorkflowIDNotSet, err) } func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_WorkflowTypeNotSet() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() startWorkflowExecutionRequest := &shared.StartWorkflowExecutionRequest{ Domain: common.StringPtr("test-domain"), @@ -395,16 +344,14 @@ func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_WorkflowTypeNot RequestId: common.StringPtr(uuid.New()), } _, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) - assert.Error(s.T(), err) - assert.Equal(s.T(), errWorkflowTypeNotSet, err) + s.Error(err) + s.Equal(errWorkflowTypeNotSet, err) } func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_TaskListNotSet() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() startWorkflowExecutionRequest := &shared.StartWorkflowExecutionRequest{ Domain: common.StringPtr("test-domain"), @@ -427,16 +374,14 @@ func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_TaskListNotSet( RequestId: common.StringPtr(uuid.New()), } _, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) - assert.Error(s.T(), err) - assert.Equal(s.T(), errTaskListNotSet, err) + s.Error(err) + s.Equal(errTaskListNotSet, err) } func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_InvalidExecutionStartToCloseTimeout() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() startWorkflowExecutionRequest := &shared.StartWorkflowExecutionRequest{ Domain: common.StringPtr("test-domain"), @@ -459,16 +404,14 @@ func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_InvalidExecutio RequestId: common.StringPtr(uuid.New()), } _, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) - assert.Error(s.T(), err) - assert.Equal(s.T(), errInvalidExecutionStartToCloseTimeoutSeconds, err) + s.Error(err) + s.Equal(errInvalidExecutionStartToCloseTimeoutSeconds, err) } func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_InvalidTaskStartToCloseTimeout() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) wh := s.getWorkflowHandler(config) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() startWorkflowExecutionRequest := &shared.StartWorkflowExecutionRequest{ Domain: common.StringPtr("test-domain"), @@ -491,40 +434,22 @@ func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_InvalidTaskStar RequestId: common.StringPtr(uuid.New()), } _, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) - assert.Error(s.T(), err) - assert.Equal(s.T(), errInvalidTaskStartToCloseTimeoutSeconds, err) -} - -func (s *workflowHandlerSuite) getWorkflowHandlerWithParams(mService cs.Service, config *Config, - mMetadataManager persistence.MetadataManager, mockDomainCache *cache.MockDomainCache) *WorkflowHandler { - return NewWorkflowHandler(mService, config, mMetadataManager, s.mockHistoryV2Mgr, - s.mockVisibilityMgr, s.mockProducer, nil, mockDomainCache) + s.Error(err) + s.Equal(errInvalidTaskStartToCloseTimeoutSeconds, err) } func (s *workflowHandlerSuite) TestRegisterDomain_Failure_InvalidArchivalURI() { - config := s.newConfig() - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockClusterMetadata.EXPECT().IsGlobalDomainEnabled().Return(false) + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName) s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "random URI")) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "random URI")) - s.mockClusterMetadata.On("GetNextFailoverVersion", mock.Anything, mock.Anything).Return(int64(0)) - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) - mMetadataManager.On("CreateDomain", mock.Anything).Return(&persistence.CreateDomainResponse{ - ID: "test-id", - }, nil) - mHistoryArchiver := &archiver.HistoryArchiverMock{} - mHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) - mVisibilityArchiver := &archiver.VisibilityArchiverMock{} - mVisibilityArchiver.On("ValidateURI", mock.Anything).Return(errors.New("invalid URI")) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) - s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(mVisibilityArchiver, nil) - - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) + s.mockHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockVisibilityArchiver.On("ValidateURI", mock.Anything).Return(errors.New("invalid URI")) + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) + s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(s.mockVisibilityArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) req := registerDomainRequest( shared.ArchivalStatusEnabled.Ptr(), @@ -533,63 +458,47 @@ func (s *workflowHandlerSuite) TestRegisterDomain_Failure_InvalidArchivalURI() { common.StringPtr(testVisibilityArchivalURI), ) err := wh.RegisterDomain(context.Background(), req) - assert.Error(s.T(), err) + s.Error(err) } func (s *workflowHandlerSuite) TestRegisterDomain_Success_EnabledWithNoArchivalURI() { - config := s.newConfig() - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockClusterMetadata.EXPECT().IsGlobalDomainEnabled().Return(false) + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", testHistoryArchivalURI)) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", testVisibilityArchivalURI)) - s.mockClusterMetadata.On("GetNextFailoverVersion", mock.Anything, mock.Anything).Return(int64(0)) - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) - mMetadataManager.On("CreateDomain", mock.Anything).Return(&persistence.CreateDomainResponse{ + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) + s.mockMetadataMgr.On("CreateDomain", mock.Anything).Return(&persistence.CreateDomainResponse{ ID: "test-id", }, nil) - mHistoryArchiver := &archiver.HistoryArchiverMock{} - mHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) - mVisibilityArchiver := &archiver.VisibilityArchiverMock{} - mVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) - s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(mVisibilityArchiver, nil) - - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) + s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(s.mockVisibilityArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) req := registerDomainRequest(shared.ArchivalStatusEnabled.Ptr(), nil, shared.ArchivalStatusEnabled.Ptr(), nil) err := wh.RegisterDomain(context.Background(), req) - assert.NoError(s.T(), err) + s.NoError(err) } func (s *workflowHandlerSuite) TestRegisterDomain_Success_EnabledWithArchivalURI() { - config := s.newConfig() - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockClusterMetadata.EXPECT().IsGlobalDomainEnabled().Return(false) + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "invalidURI")) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "invalidURI")) - s.mockClusterMetadata.On("GetNextFailoverVersion", mock.Anything, mock.Anything).Return(int64(0)) - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) - mMetadataManager.On("CreateDomain", mock.Anything).Return(&persistence.CreateDomainResponse{ + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) + s.mockMetadataMgr.On("CreateDomain", mock.Anything).Return(&persistence.CreateDomainResponse{ ID: "test-id", }, nil) - mHistoryArchiver := &archiver.HistoryArchiverMock{} - mHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) - mVisibilityArchiver := &archiver.VisibilityArchiverMock{} - mVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) - s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(mVisibilityArchiver, nil) - - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) + s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(s.mockVisibilityArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) req := registerDomainRequest( shared.ArchivalStatusEnabled.Ptr(), @@ -598,27 +507,21 @@ func (s *workflowHandlerSuite) TestRegisterDomain_Success_EnabledWithArchivalURI common.StringPtr(testVisibilityArchivalURI), ) err := wh.RegisterDomain(context.Background(), req) - assert.NoError(s.T(), err) + s.NoError(err) } func (s *workflowHandlerSuite) TestRegisterDomain_Success_ClusterNotConfiguredForArchival() { - config := s.newConfig() - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockClusterMetadata.EXPECT().IsGlobalDomainEnabled().Return(false) + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewDisabledArchvialConfig()) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewDisabledArchvialConfig()) - s.mockClusterMetadata.On("GetNextFailoverVersion", mock.Anything, mock.Anything).Return(int64(0)) - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) - mMetadataManager.On("CreateDomain", mock.Anything).Return(&persistence.CreateDomainResponse{ + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) + s.mockMetadataMgr.On("CreateDomain", mock.Anything).Return(&persistence.CreateDomainResponse{ ID: "test-id", }, nil) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + wh := s.getWorkflowHandler(s.newConfig()) req := registerDomainRequest( shared.ArchivalStatusEnabled.Ptr(), @@ -627,111 +530,88 @@ func (s *workflowHandlerSuite) TestRegisterDomain_Success_ClusterNotConfiguredFo common.StringPtr("invalidURI"), ) err := wh.RegisterDomain(context.Background(), req) - assert.NoError(s.T(), err) + s.NoError(err) } func (s *workflowHandlerSuite) TestRegisterDomain_Success_NotEnabled() { - config := s.newConfig() - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockClusterMetadata.EXPECT().IsGlobalDomainEnabled().Return(false) + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) - s.mockClusterMetadata.On("GetNextFailoverVersion", mock.Anything, mock.Anything).Return(int64(0)) - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) - mMetadataManager.On("CreateDomain", mock.Anything).Return(&persistence.CreateDomainResponse{ + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(nil, &shared.EntityNotExistsError{}) + s.mockMetadataMgr.On("CreateDomain", mock.Anything).Return(&persistence.CreateDomainResponse{ ID: "test-id", }, nil) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + wh := s.getWorkflowHandler(s.newConfig()) req := registerDomainRequest(nil, nil, nil, nil) err := wh.RegisterDomain(context.Background(), req) - assert.NoError(s.T(), err) + s.NoError(err) } func (s *workflowHandlerSuite) TestDescribeDomain_Success_ArchivalDisabled() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} getDomainResp := persistenceGetDomainResponse( &domain.ArchivalState{Status: shared.ArchivalStatusDisabled, URI: ""}, &domain.ArchivalState{Status: shared.ArchivalStatusDisabled, URI: ""}, ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(getDomainResp, nil) + + wh := s.getWorkflowHandler(s.newConfig()) req := &shared.DescribeDomainRequest{ Name: common.StringPtr("test-domain"), } result, err := wh.DescribeDomain(context.Background(), req) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), result) - assert.NotNil(s.T(), result.Configuration) - assert.Equal(s.T(), shared.ArchivalStatusDisabled, result.Configuration.GetHistoryArchivalStatus()) - assert.Equal(s.T(), "", result.Configuration.GetHistoryArchivalURI()) - assert.Equal(s.T(), shared.ArchivalStatusDisabled, result.Configuration.GetVisibilityArchivalStatus()) - assert.Equal(s.T(), "", result.Configuration.GetVisibilityArchivalURI()) + s.NoError(err) + s.NotNil(result) + s.NotNil(result.Configuration) + s.Equal(shared.ArchivalStatusDisabled, result.Configuration.GetHistoryArchivalStatus()) + s.Equal("", result.Configuration.GetHistoryArchivalURI()) + s.Equal(shared.ArchivalStatusDisabled, result.Configuration.GetVisibilityArchivalStatus()) + s.Equal("", result.Configuration.GetVisibilityArchivalURI()) } func (s *workflowHandlerSuite) TestDescribeDomain_Success_ArchivalEnabled() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} getDomainResp := persistenceGetDomainResponse( &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testHistoryArchivalURI}, &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testVisibilityArchivalURI}, ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(getDomainResp, nil) + + wh := s.getWorkflowHandler(s.newConfig()) req := &shared.DescribeDomainRequest{ Name: common.StringPtr("test-domain"), } result, err := wh.DescribeDomain(context.Background(), req) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), result) - assert.NotNil(s.T(), result.Configuration) - assert.Equal(s.T(), shared.ArchivalStatusEnabled, result.Configuration.GetHistoryArchivalStatus()) - assert.Equal(s.T(), testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) - assert.Equal(s.T(), shared.ArchivalStatusEnabled, result.Configuration.GetVisibilityArchivalStatus()) - assert.Equal(s.T(), testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) + s.NoError(err) + s.NotNil(result) + s.NotNil(result.Configuration) + s.Equal(shared.ArchivalStatusEnabled, result.Configuration.GetHistoryArchivalStatus()) + s.Equal(testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) + s.Equal(shared.ArchivalStatusEnabled, result.Configuration.GetVisibilityArchivalStatus()) + s.Equal(testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) } func (s *workflowHandlerSuite) TestUpdateDomain_Failure_UpdateExistingArchivalURI() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetMetadata").Return(&persistence.GetMetadataResponse{ + s.mockMetadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{ NotificationVersion: int64(0), }, nil) getDomainResp := persistenceGetDomainResponse( &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testHistoryArchivalURI}, &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testVisibilityArchivalURI}, ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(getDomainResp, nil) s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) - mHistoryArchiver := &archiver.HistoryArchiverMock{} - mHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) + s.mockHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + wh := s.getWorkflowHandler(s.newConfig()) updateReq := updateRequest( nil, @@ -740,31 +620,23 @@ func (s *workflowHandlerSuite) TestUpdateDomain_Failure_UpdateExistingArchivalUR nil, ) _, err := wh.UpdateDomain(context.Background(), updateReq) - assert.Error(s.T(), err) + s.Error(err) } func (s *workflowHandlerSuite) TestUpdateDomain_Failure_InvalidArchivalURI() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetMetadata").Return(&persistence.GetMetadataResponse{ + s.mockMetadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{ NotificationVersion: int64(0), }, nil) getDomainResp := persistenceGetDomainResponse( &domain.ArchivalState{Status: shared.ArchivalStatusDisabled, URI: ""}, &domain.ArchivalState{Status: shared.ArchivalStatusDisabled, URI: ""}, ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(getDomainResp, nil) s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) - mHistoryArchiver := &archiver.HistoryArchiverMock{} - mHistoryArchiver.On("ValidateURI", mock.Anything).Return(errors.New("invalid URI")) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockHistoryArchiver.On("ValidateURI", mock.Anything).Return(errors.New("invalid URI")) + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) updateReq := updateRequest( common.StringPtr("testScheme://invalid/updated/history/URI"), @@ -773,36 +645,29 @@ func (s *workflowHandlerSuite) TestUpdateDomain_Failure_InvalidArchivalURI() { nil, ) _, err := wh.UpdateDomain(context.Background(), updateReq) - assert.Error(s.T(), err) + s.Error(err) } func (s *workflowHandlerSuite) TestUpdateDomain_Success_ArchivalEnabledToArchivalDisabledWithoutSettingURI() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetMetadata").Return(&persistence.GetMetadataResponse{ + s.mockMetadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{ NotificationVersion: int64(0), }, nil) getDomainResp := persistenceGetDomainResponse( &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testHistoryArchivalURI}, &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testVisibilityArchivalURI}, ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - mMetadataManager.On("UpdateDomain", mock.Anything).Return(nil) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(getDomainResp, nil) + s.mockMetadataMgr.On("UpdateDomain", mock.Anything).Return(nil) + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - mHistoryArchiver := &archiver.HistoryArchiverMock{} - mHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) - mVisibilityArchiver := &archiver.VisibilityArchiverMock{} - mVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) - s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(mVisibilityArchiver, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) + s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(s.mockVisibilityArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) updateReq := updateRequest( nil, @@ -811,75 +676,62 @@ func (s *workflowHandlerSuite) TestUpdateDomain_Success_ArchivalEnabledToArchiva common.ArchivalStatusPtr(shared.ArchivalStatusDisabled), ) result, err := wh.UpdateDomain(context.Background(), updateReq) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), result) - assert.NotNil(s.T(), result.Configuration) - assert.Equal(s.T(), shared.ArchivalStatusDisabled, result.Configuration.GetHistoryArchivalStatus()) - assert.Equal(s.T(), testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) - assert.Equal(s.T(), shared.ArchivalStatusDisabled, result.Configuration.GetVisibilityArchivalStatus()) - assert.Equal(s.T(), testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) + s.NoError(err) + s.NotNil(result) + s.NotNil(result.Configuration) + s.Equal(shared.ArchivalStatusDisabled, result.Configuration.GetHistoryArchivalStatus()) + s.Equal(testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) + s.Equal(shared.ArchivalStatusDisabled, result.Configuration.GetVisibilityArchivalStatus()) + s.Equal(testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) } func (s *workflowHandlerSuite) TestUpdateDomain_Success_ClusterNotConfiguredForArchival() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetMetadata").Return(&persistence.GetMetadataResponse{ + s.mockMetadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{ NotificationVersion: int64(0), }, nil) getDomainResp := persistenceGetDomainResponse( &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: "some random history URI"}, &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: "some random visibility URI"}, ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - mMetadataManager.On("UpdateDomain", mock.Anything).Return(nil) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(getDomainResp, nil) + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewDisabledArchvialConfig()) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewDisabledArchvialConfig()) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + + wh := s.getWorkflowHandler(s.newConfig()) updateReq := updateRequest(nil, common.ArchivalStatusPtr(shared.ArchivalStatusDisabled), nil, nil) result, err := wh.UpdateDomain(context.Background(), updateReq) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), result) - assert.NotNil(s.T(), result.Configuration) - assert.Equal(s.T(), shared.ArchivalStatusEnabled, result.Configuration.GetHistoryArchivalStatus()) - assert.Equal(s.T(), "some random history URI", result.Configuration.GetHistoryArchivalURI()) - assert.Equal(s.T(), shared.ArchivalStatusEnabled, result.Configuration.GetVisibilityArchivalStatus()) - assert.Equal(s.T(), "some random visibility URI", result.Configuration.GetVisibilityArchivalURI()) + s.NoError(err) + s.NotNil(result) + s.NotNil(result.Configuration) + s.Equal(shared.ArchivalStatusEnabled, result.Configuration.GetHistoryArchivalStatus()) + s.Equal("some random history URI", result.Configuration.GetHistoryArchivalURI()) + s.Equal(shared.ArchivalStatusEnabled, result.Configuration.GetVisibilityArchivalStatus()) + s.Equal("some random visibility URI", result.Configuration.GetVisibilityArchivalURI()) } func (s *workflowHandlerSuite) TestUpdateDomain_Success_ArchivalEnabledToArchivalDisabledWithSettingBucket() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetMetadata").Return(&persistence.GetMetadataResponse{ + s.mockMetadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{ NotificationVersion: int64(0), }, nil) getDomainResp := persistenceGetDomainResponse( &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testHistoryArchivalURI}, &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testVisibilityArchivalURI}, ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - mMetadataManager.On("UpdateDomain", mock.Anything).Return(nil) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(getDomainResp, nil) + s.mockMetadataMgr.On("UpdateDomain", mock.Anything).Return(nil) + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - mHistoryArchiver := &archiver.HistoryArchiverMock{} - mHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) - mVisibilityArchiver := &archiver.VisibilityArchiverMock{} - mVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) - s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(mVisibilityArchiver, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) + s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(s.mockVisibilityArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) updateReq := updateRequest( common.StringPtr(testHistoryArchivalURI), @@ -888,42 +740,34 @@ func (s *workflowHandlerSuite) TestUpdateDomain_Success_ArchivalEnabledToArchiva common.ArchivalStatusPtr(shared.ArchivalStatusDisabled), ) result, err := wh.UpdateDomain(context.Background(), updateReq) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), result) - assert.NotNil(s.T(), result.Configuration) - assert.Equal(s.T(), shared.ArchivalStatusDisabled, result.Configuration.GetHistoryArchivalStatus()) - assert.Equal(s.T(), testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) - assert.Equal(s.T(), shared.ArchivalStatusDisabled, result.Configuration.GetVisibilityArchivalStatus()) - assert.Equal(s.T(), testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) + s.NoError(err) + s.NotNil(result) + s.NotNil(result.Configuration) + s.Equal(shared.ArchivalStatusDisabled, result.Configuration.GetHistoryArchivalStatus()) + s.Equal(testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) + s.Equal(shared.ArchivalStatusDisabled, result.Configuration.GetVisibilityArchivalStatus()) + s.Equal(testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) } func (s *workflowHandlerSuite) TestUpdateDomain_Success_ArchivalEnabledToEnabled() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetMetadata").Return(&persistence.GetMetadataResponse{ + s.mockMetadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{ NotificationVersion: int64(0), }, nil) getDomainResp := persistenceGetDomainResponse( &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testHistoryArchivalURI}, &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: testVisibilityArchivalURI}, ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - mMetadataManager.On("UpdateDomain", mock.Anything).Return(nil) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(getDomainResp, nil) + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - mHistoryArchiver := &archiver.HistoryArchiverMock{} - mHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) - mVisibilityArchiver := &archiver.VisibilityArchiverMock{} - mVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) - s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(mVisibilityArchiver, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) + s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(s.mockVisibilityArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) updateReq := updateRequest( common.StringPtr(testHistoryArchivalURI), @@ -932,42 +776,35 @@ func (s *workflowHandlerSuite) TestUpdateDomain_Success_ArchivalEnabledToEnabled common.ArchivalStatusPtr(shared.ArchivalStatusEnabled), ) result, err := wh.UpdateDomain(context.Background(), updateReq) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), result) - assert.NotNil(s.T(), result.Configuration) - assert.Equal(s.T(), shared.ArchivalStatusEnabled, result.Configuration.GetHistoryArchivalStatus()) - assert.Equal(s.T(), testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) - assert.Equal(s.T(), shared.ArchivalStatusEnabled, result.Configuration.GetVisibilityArchivalStatus()) - assert.Equal(s.T(), testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) + s.NoError(err) + s.NotNil(result) + s.NotNil(result.Configuration) + s.Equal(shared.ArchivalStatusEnabled, result.Configuration.GetHistoryArchivalStatus()) + s.Equal(testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) + s.Equal(shared.ArchivalStatusEnabled, result.Configuration.GetVisibilityArchivalStatus()) + s.Equal(testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) } func (s *workflowHandlerSuite) TestUpdateDomain_Success_ArchivalNeverEnabledToEnabled() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetMetadata").Return(&persistence.GetMetadataResponse{ + s.mockMetadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{ NotificationVersion: int64(0), }, nil) getDomainResp := persistenceGetDomainResponse( &domain.ArchivalState{Status: shared.ArchivalStatusDisabled, URI: ""}, &domain.ArchivalState{Status: shared.ArchivalStatusDisabled, URI: ""}, ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - mMetadataManager.On("UpdateDomain", mock.Anything).Return(nil) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockMetadataMgr.On("GetDomain", mock.Anything).Return(getDomainResp, nil) + s.mockMetadataMgr.On("UpdateDomain", mock.Anything).Return(nil) + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockArchivalMetadata.On("GetHistoryConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "some random URI")) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - mHistoryArchiver := &archiver.HistoryArchiverMock{} - mHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) - mVisibilityArchiver := &archiver.VisibilityArchiverMock{} - mVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) - s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(mVisibilityArchiver, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockHistoryArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockVisibilityArchiver.On("ValidateURI", mock.Anything).Return(nil) + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) + s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(s.mockVisibilityArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) updateReq := updateRequest( common.StringPtr(testHistoryArchivalURI), @@ -976,17 +813,18 @@ func (s *workflowHandlerSuite) TestUpdateDomain_Success_ArchivalNeverEnabledToEn common.ArchivalStatusPtr(shared.ArchivalStatusEnabled), ) result, err := wh.UpdateDomain(context.Background(), updateReq) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), result) - assert.NotNil(s.T(), result.Configuration) - assert.Equal(s.T(), shared.ArchivalStatusEnabled, result.Configuration.GetHistoryArchivalStatus()) - assert.Equal(s.T(), testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) - assert.Equal(s.T(), shared.ArchivalStatusEnabled, result.Configuration.GetVisibilityArchivalStatus()) - assert.Equal(s.T(), testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) + s.NoError(err) + s.NotNil(result) + s.NotNil(result.Configuration) + s.Equal(shared.ArchivalStatusEnabled, result.Configuration.GetHistoryArchivalStatus()) + s.Equal(testHistoryArchivalURI, result.Configuration.GetHistoryArchivalURI()) + s.Equal(shared.ArchivalStatusEnabled, result.Configuration.GetVisibilityArchivalStatus()) + s.Equal(testVisibilityArchivalURI, result.Configuration.GetVisibilityArchivalURI()) } func (s *workflowHandlerSuite) TestHistoryArchived() { - wh := &WorkflowHandler{} + wh := s.getWorkflowHandler(s.newConfig()) + getHistoryRequest := &shared.GetWorkflowExecutionHistoryRequest{} s.False(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain")) @@ -995,11 +833,7 @@ func (s *workflowHandlerSuite) TestHistoryArchived() { } s.False(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain")) - mockHistoryClient := historyservicetest.NewMockClient(s.controller) - mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) - wh = &WorkflowHandler{ - history: mockHistoryClient, - } + s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{ Execution: &shared.WorkflowExecution{ WorkflowId: common.StringPtr(testWorkflowID), @@ -1008,10 +842,7 @@ func (s *workflowHandlerSuite) TestHistoryArchived() { } s.False(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain")) - mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, &shared.EntityNotExistsError{Message: "got archival indication error"}).Times(1) - wh = &WorkflowHandler{ - history: mockHistoryClient, - } + s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, &shared.EntityNotExistsError{Message: "got archival indication error"}).Times(1) getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{ Execution: &shared.WorkflowExecution{ WorkflowId: common.StringPtr(testWorkflowID), @@ -1020,10 +851,7 @@ func (s *workflowHandlerSuite) TestHistoryArchived() { } s.True(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain")) - mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, errors.New("got non-archival indication error")).Times(1) - wh = &WorkflowHandler{ - history: mockHistoryClient, - } + s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, errors.New("got non-archival indication error")).Times(1) getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{ Execution: &shared.WorkflowExecution{ WorkflowId: common.StringPtr(testWorkflowID), @@ -1034,18 +862,16 @@ func (s *workflowHandlerSuite) TestHistoryArchived() { } func (s *workflowHandlerSuite) TestGetArchivedHistory_Failure_DomainCacheEntryError() { - config := s.newConfig() s.mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(nil, errors.New("error getting domain")).Times(1) - wh := s.getWorkflowHandlerWithParams(s.mockService, config, nil, s.mockDomainCache) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + + wh := s.getWorkflowHandler(s.newConfig()) + resp, err := wh.getArchivedHistory(context.Background(), getHistoryRequest(nil), s.testDomainID, metrics.NoopScope(metrics.Frontend)) s.Nil(resp) s.Error(err) } func (s *workflowHandlerSuite) TestGetArchivedHistory_Failure_ArchivalURIEmpty() { - config := s.newConfig() domainEntry := cache.NewLocalDomainCacheEntryForTest( &persistence.DomainInfo{Name: "test-domain"}, &persistence.DomainConfig{ @@ -1057,20 +883,15 @@ func (s *workflowHandlerSuite) TestGetArchivedHistory_Failure_ArchivalURIEmpty() "", nil) s.mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(domainEntry, nil).AnyTimes() - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, nil, s.mockDomainCache) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + + wh := s.getWorkflowHandler(s.newConfig()) + resp, err := wh.getArchivedHistory(context.Background(), getHistoryRequest(nil), s.testDomainID, metrics.NoopScope(metrics.Frontend)) s.Nil(resp) s.Error(err) } func (s *workflowHandlerSuite) TestGetArchivedHistory_Failure_InvalidURI() { - config := s.newConfig() domainEntry := cache.NewLocalDomainCacheEntryForTest( &persistence.DomainInfo{Name: "test-domain"}, &persistence.DomainConfig{ @@ -1082,20 +903,15 @@ func (s *workflowHandlerSuite) TestGetArchivedHistory_Failure_InvalidURI() { "", nil) s.mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(domainEntry, nil).AnyTimes() - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, nil, s.mockDomainCache) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + + wh := s.getWorkflowHandler(s.newConfig()) + resp, err := wh.getArchivedHistory(context.Background(), getHistoryRequest(nil), s.testDomainID, metrics.NoopScope(metrics.Frontend)) s.Nil(resp) s.Error(err) } func (s *workflowHandlerSuite) TestGetArchivedHistory_Success_GetFirstPage() { - config := s.newConfig() domainEntry := cache.NewLocalDomainCacheEntryForTest( &persistence.DomainInfo{Name: "test-domain"}, &persistence.DomainConfig{ @@ -1107,36 +923,32 @@ func (s *workflowHandlerSuite) TestGetArchivedHistory_Success_GetFirstPage() { "", nil) s.mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(domainEntry, nil).AnyTimes() - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - mHistoryArchiver := &archiver.HistoryArchiverMock{} + nextPageToken := []byte{'1', '2', '3'} - historyBatch1 := &gen.History{ - Events: []*gen.HistoryEvent{ - &gen.HistoryEvent{EventId: common.Int64Ptr(1)}, - &gen.HistoryEvent{EventId: common.Int64Ptr(2)}, + historyBatch1 := &shared.History{ + Events: []*shared.HistoryEvent{ + &shared.HistoryEvent{EventId: common.Int64Ptr(1)}, + &shared.HistoryEvent{EventId: common.Int64Ptr(2)}, }, } - historyBatch2 := &gen.History{ - Events: []*gen.HistoryEvent{ - &gen.HistoryEvent{EventId: common.Int64Ptr(3)}, - &gen.HistoryEvent{EventId: common.Int64Ptr(4)}, - &gen.HistoryEvent{EventId: common.Int64Ptr(5)}, + historyBatch2 := &shared.History{ + Events: []*shared.HistoryEvent{ + &shared.HistoryEvent{EventId: common.Int64Ptr(3)}, + &shared.HistoryEvent{EventId: common.Int64Ptr(4)}, + &shared.HistoryEvent{EventId: common.Int64Ptr(5)}, }, } - history := &gen.History{} + history := &shared.History{} history.Events = append(history.Events, historyBatch1.Events...) history.Events = append(history.Events, historyBatch2.Events...) - mHistoryArchiver.On("Get", mock.Anything, mock.Anything, mock.Anything).Return(&archiver.GetHistoryResponse{ + s.mockHistoryArchiver.On("Get", mock.Anything, mock.Anything, mock.Anything).Return(&archiver.GetHistoryResponse{ NextPageToken: nextPageToken, - HistoryBatches: []*gen.History{historyBatch1, historyBatch2}, + HistoryBatches: []*shared.History{historyBatch1, historyBatch2}, }, nil) - s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(mHistoryArchiver, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, nil, s.mockDomainCache) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockArchiverProvider.On("GetHistoryArchiver", mock.Anything, mock.Anything).Return(s.mockHistoryArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) + resp, err := wh.getArchivedHistory(context.Background(), getHistoryRequest(nil), s.testDomainID, metrics.NoopScope(metrics.Frontend)) s.NoError(err) s.NotNil(resp) @@ -1147,12 +959,11 @@ func (s *workflowHandlerSuite) TestGetArchivedHistory_Success_GetFirstPage() { } func (s *workflowHandlerSuite) TestGetHistory() { - config := s.newConfig() domainID := uuid.New() firstEventID := int64(100) nextEventID := int64(101) branchToken := []byte{1} - we := gen.WorkflowExecution{ + we := shared.WorkflowExecution{ WorkflowId: common.StringPtr("wid"), RunId: common.StringPtr("rid"), } @@ -1166,7 +977,7 @@ func (s *workflowHandlerSuite) TestGetHistory() { ShardID: common.IntPtr(shardID), } s.mockHistoryV2Mgr.On("ReadHistoryBranch", req).Return(&persistence.ReadHistoryBranchResponse{ - HistoryEvents: []*workflow.HistoryEvent{ + HistoryEvents: []*shared.HistoryEvent{ { EventId: common.Int64Ptr(int64(1)), }, @@ -1175,11 +986,10 @@ func (s *workflowHandlerSuite) TestGetHistory() { Size: 1, LastFirstEventID: nextEventID, }, nil).Once() - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - mMetadataManager := &mocks.MetadataManager{} - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - scope := wh.metricsClient.Scope(0) + + wh := s.getWorkflowHandler(s.newConfig()) + + scope := metrics.NoopScope(metrics.Frontend) history, token, err := wh.getHistory(scope, domainID, we, firstEventID, nextEventID, 0, []byte{}, nil, branchToken) s.NotNil(history) s.Equal([]byte{}, token) @@ -1187,115 +997,94 @@ func (s *workflowHandlerSuite) TestGetHistory() { } func (s *workflowHandlerSuite) TestListArchivedVisibility_Failure_InvalidRequest() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - wh := s.getWorkflowHandlerWithParams(s.mockService, config, mMetadataManager, nil) - wh.startWG.Done() - resp, err := wh.ListArchivedWorkflowExecutions(context.Background(), &workflow.ListArchivedWorkflowExecutionsRequest{}) + wh := s.getWorkflowHandler(s.newConfig()) + + resp, err := wh.ListArchivedWorkflowExecutions(context.Background(), &shared.ListArchivedWorkflowExecutionsRequest{}) s.Nil(resp) s.Error(err) } func (s *workflowHandlerSuite) TestListArchivedVisibility_Failure_ClusterNotConfiguredForArchival() { - config := s.newConfig() s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewDisabledArchvialConfig()) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, &mocks.MetadataManager{}, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + + wh := s.getWorkflowHandler(s.newConfig()) + resp, err := wh.ListArchivedWorkflowExecutions(context.Background(), listArchivedWorkflowExecutionsTestRequest()) s.Nil(resp) s.Error(err) } func (s *workflowHandlerSuite) TestListArchivedVisibility_Failure_DomainCacheEntryError() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - mMetadataManager.On("GetDomain", mock.Anything).Return(nil, errors.New("error getting domain")).Once() + s.mockDomainCache.EXPECT().GetDomain(gomock.Any()).Return(nil, errors.New("error getting domain")) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "random URI")) - wh := s.getWorkflowHandlerWithParams(s.mockService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + + wh := s.getWorkflowHandler(s.newConfig()) + resp, err := wh.ListArchivedWorkflowExecutions(context.Background(), listArchivedWorkflowExecutionsTestRequest()) s.Nil(resp) s.Error(err) } func (s *workflowHandlerSuite) TestListArchivedVisibility_Failure_DomainNotConfiguredForArchival() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - getDomainResp := persistenceGetDomainResponse( - &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: "uri without scheme"}, - &domain.ArchivalState{Status: shared.ArchivalStatusDisabled, URI: "uri without scheme"}, - ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockDomainCache.EXPECT().GetDomain(gomock.Any()).Return(cache.NewLocalDomainCacheEntryForTest( + nil, + &persistence.DomainConfig{ + VisibilityArchivalStatus: shared.ArchivalStatusDisabled, + }, + "", + nil, + ), nil) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "random URI")) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + + wh := s.getWorkflowHandler(s.newConfig()) + resp, err := wh.ListArchivedWorkflowExecutions(context.Background(), listArchivedWorkflowExecutionsTestRequest()) s.Nil(resp) - fmt.Println(err) s.Error(err) } func (s *workflowHandlerSuite) TestListArchivedVisibility_Failure_InvalidURI() { - config := s.newConfig() - mMetadataManager := &mocks.MetadataManager{} - getDomainResp := persistenceGetDomainResponse( - &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: "uri without scheme"}, - &domain.ArchivalState{Status: shared.ArchivalStatusEnabled, URI: "uri without scheme"}, - ) - mMetadataManager.On("GetDomain", mock.Anything).Return(getDomainResp, nil) - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + s.mockDomainCache.EXPECT().GetDomain(gomock.Any()).Return(cache.NewLocalDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "test-domain"}, + &persistence.DomainConfig{ + VisibilityArchivalStatus: shared.ArchivalStatusDisabled, + VisibilityArchivalURI: "uri without scheme", + }, + "", + nil, + ), nil) s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "random URI")) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, mMetadataManager, nil) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + + wh := s.getWorkflowHandler(s.newConfig()) + resp, err := wh.ListArchivedWorkflowExecutions(context.Background(), listArchivedWorkflowExecutionsTestRequest()) s.Nil(resp) s.Error(err) } func (s *workflowHandlerSuite) TestListArchivedVisibility_Success() { - config := s.newConfig() - domainEntry := cache.NewLocalDomainCacheEntryForTest( + s.mockDomainCache.EXPECT().GetDomain(gomock.Any()).Return(cache.NewLocalDomainCacheEntryForTest( &persistence.DomainInfo{Name: "test-domain"}, &persistence.DomainConfig{ - HistoryArchivalStatus: shared.ArchivalStatusEnabled, - HistoryArchivalURI: testHistoryArchivalURI, VisibilityArchivalStatus: shared.ArchivalStatusEnabled, VisibilityArchivalURI: testVisibilityArchivalURI, }, "", - nil) - s.mockDomainCache.EXPECT().GetDomain(gomock.Any()).Return(domainEntry, nil).AnyTimes() - s.mockClusterMetadata.On("IsGlobalDomainEnabled").Return(false) - s.mockClusterMetadata.On("GetAllClusterInfo").Return(cluster.TestAllClusterInfo) - s.mockClusterMetadata.On("GetCurrentClusterName").Return(cluster.TestCurrentClusterName) + nil, + ), nil).AnyTimes() s.mockArchivalMetadata.On("GetVisibilityConfig").Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "random URI")) - mVisibilityArchiver := &archiver.VisibilityArchiverMock{} - mVisibilityArchiver.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(&archiver.QueryVisibilityResponse{}, nil) - s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(mVisibilityArchiver, nil) - mService := cs.NewTestService(s.mockClusterMetadata, s.mockMessagingClient, s.mockMetricClient, s.mockClientBean, s.mockArchivalMetadata, s.mockArchiverProvider, nil) - wh := s.getWorkflowHandlerWithParams(mService, config, nil, s.mockDomainCache) - wh.metricsClient = wh.Service.GetMetricsClient() - wh.startWG.Done() + s.mockVisibilityArchiver.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(&archiver.QueryVisibilityResponse{}, nil) + s.mockArchiverProvider.On("GetVisibilityArchiver", mock.Anything, mock.Anything).Return(s.mockVisibilityArchiver, nil) + + wh := s.getWorkflowHandler(s.newConfig()) + resp, err := wh.ListArchivedWorkflowExecutions(context.Background(), listArchivedWorkflowExecutionsTestRequest()) s.NotNil(resp) - fmt.Println(err) s.NoError(err) } func (s *workflowHandlerSuite) TestGetSearchAttributes() { - wh := s.getWorkflowHandlerHelper() + wh := s.getWorkflowHandler(s.newConfig()) ctx := context.Background() resp, err := wh.GetSearchAttributes(ctx) @@ -1304,14 +1093,15 @@ func (s *workflowHandlerSuite) TestGetSearchAttributes() { } func (s *workflowHandlerSuite) TestListWorkflowExecutions() { - wh := s.getWorkflowHandlerHelper() + config := s.newConfig() + wh := s.getWorkflowHandler(config) s.mockDomainCache.EXPECT().GetDomainID(gomock.Any()).Return(s.testDomainID, nil).AnyTimes() s.mockVisibilityMgr.On("ListWorkflowExecutions", mock.Anything).Return(&persistence.ListWorkflowExecutionsResponse{}, nil).Once() listRequest := &shared.ListWorkflowExecutionsRequest{ Domain: common.StringPtr(s.testDomain), - PageSize: common.Int32Ptr(int32(s.config.ESIndexMaxResultWindow())), + PageSize: common.Int32Ptr(int32(config.ESIndexMaxResultWindow())), } ctx := context.Background() @@ -1326,20 +1116,21 @@ func (s *workflowHandlerSuite) TestListWorkflowExecutions() { _, err = wh.ListWorkflowExecutions(ctx, listRequest) s.NotNil(err) - listRequest.PageSize = common.Int32Ptr(int32(s.config.ESIndexMaxResultWindow() + 1)) + listRequest.PageSize = common.Int32Ptr(int32(config.ESIndexMaxResultWindow() + 1)) _, err = wh.ListWorkflowExecutions(ctx, listRequest) s.NotNil(err) } func (s *workflowHandlerSuite) TestScantWorkflowExecutions() { - wh := s.getWorkflowHandlerHelper() + config := s.newConfig() + wh := s.getWorkflowHandler(config) s.mockDomainCache.EXPECT().GetDomainID(gomock.Any()).Return(s.testDomainID, nil).AnyTimes() s.mockVisibilityMgr.On("ScanWorkflowExecutions", mock.Anything).Return(&persistence.ListWorkflowExecutionsResponse{}, nil).Once() listRequest := &shared.ListWorkflowExecutionsRequest{ Domain: common.StringPtr(s.testDomain), - PageSize: common.Int32Ptr(int32(s.config.ESIndexMaxResultWindow())), + PageSize: common.Int32Ptr(int32(config.ESIndexMaxResultWindow())), } ctx := context.Background() @@ -1354,13 +1145,13 @@ func (s *workflowHandlerSuite) TestScantWorkflowExecutions() { _, err = wh.ScanWorkflowExecutions(ctx, listRequest) s.NotNil(err) - listRequest.PageSize = common.Int32Ptr(int32(s.config.ESIndexMaxResultWindow() + 1)) + listRequest.PageSize = common.Int32Ptr(int32(config.ESIndexMaxResultWindow() + 1)) _, err = wh.ListWorkflowExecutions(ctx, listRequest) s.NotNil(err) } func (s *workflowHandlerSuite) TestCountWorkflowExecutions() { - wh := s.getWorkflowHandlerHelper() + wh := s.getWorkflowHandler(s.newConfig()) s.mockDomainCache.EXPECT().GetDomainID(gomock.Any()).Return(s.testDomainID, nil).AnyTimes() s.mockVisibilityMgr.On("CountWorkflowExecutions", mock.Anything).Return(&persistence.CountWorkflowExecutionsResponse{}, nil).Once() @@ -1383,7 +1174,7 @@ func (s *workflowHandlerSuite) TestCountWorkflowExecutions() { } func (s *workflowHandlerSuite) TestConvertIndexedKeyToThrift() { - wh := s.getWorkflowHandlerHelper() + wh := s.getWorkflowHandler(s.newConfig()) m := map[string]interface{}{ "key1": float64(0), "key2": float64(1), @@ -1397,32 +1188,32 @@ func (s *workflowHandlerSuite) TestConvertIndexedKeyToThrift() { "key4i": 3, "key5i": 4, "key6i": 5, - "key1t": gen.IndexedValueTypeString, - "key2t": gen.IndexedValueTypeKeyword, - "key3t": gen.IndexedValueTypeInt, - "key4t": gen.IndexedValueTypeDouble, - "key5t": gen.IndexedValueTypeBool, - "key6t": gen.IndexedValueTypeDatetime, + "key1t": shared.IndexedValueTypeString, + "key2t": shared.IndexedValueTypeKeyword, + "key3t": shared.IndexedValueTypeInt, + "key4t": shared.IndexedValueTypeDouble, + "key5t": shared.IndexedValueTypeBool, + "key6t": shared.IndexedValueTypeDatetime, } result := wh.convertIndexedKeyToThrift(m) - s.Equal(gen.IndexedValueTypeString, result["key1"]) - s.Equal(gen.IndexedValueTypeKeyword, result["key2"]) - s.Equal(gen.IndexedValueTypeInt, result["key3"]) - s.Equal(gen.IndexedValueTypeDouble, result["key4"]) - s.Equal(gen.IndexedValueTypeBool, result["key5"]) - s.Equal(gen.IndexedValueTypeDatetime, result["key6"]) - s.Equal(gen.IndexedValueTypeString, result["key1i"]) - s.Equal(gen.IndexedValueTypeKeyword, result["key2i"]) - s.Equal(gen.IndexedValueTypeInt, result["key3i"]) - s.Equal(gen.IndexedValueTypeDouble, result["key4i"]) - s.Equal(gen.IndexedValueTypeBool, result["key5i"]) - s.Equal(gen.IndexedValueTypeDatetime, result["key6i"]) - s.Equal(gen.IndexedValueTypeString, result["key1t"]) - s.Equal(gen.IndexedValueTypeKeyword, result["key2t"]) - s.Equal(gen.IndexedValueTypeInt, result["key3t"]) - s.Equal(gen.IndexedValueTypeDouble, result["key4t"]) - s.Equal(gen.IndexedValueTypeBool, result["key5t"]) - s.Equal(gen.IndexedValueTypeDatetime, result["key6t"]) + s.Equal(shared.IndexedValueTypeString, result["key1"]) + s.Equal(shared.IndexedValueTypeKeyword, result["key2"]) + s.Equal(shared.IndexedValueTypeInt, result["key3"]) + s.Equal(shared.IndexedValueTypeDouble, result["key4"]) + s.Equal(shared.IndexedValueTypeBool, result["key5"]) + s.Equal(shared.IndexedValueTypeDatetime, result["key6"]) + s.Equal(shared.IndexedValueTypeString, result["key1i"]) + s.Equal(shared.IndexedValueTypeKeyword, result["key2i"]) + s.Equal(shared.IndexedValueTypeInt, result["key3i"]) + s.Equal(shared.IndexedValueTypeDouble, result["key4i"]) + s.Equal(shared.IndexedValueTypeBool, result["key5i"]) + s.Equal(shared.IndexedValueTypeDatetime, result["key6i"]) + s.Equal(shared.IndexedValueTypeString, result["key1t"]) + s.Equal(shared.IndexedValueTypeKeyword, result["key2t"]) + s.Equal(shared.IndexedValueTypeInt, result["key3t"]) + s.Equal(shared.IndexedValueTypeDouble, result["key4t"]) + s.Equal(shared.IndexedValueTypeBool, result["key5t"]) + s.Equal(shared.IndexedValueTypeDatetime, result["key6t"]) s.Panics(func() { wh.convertIndexedKeyToThrift(map[string]interface{}{ "invalidType": "unknown", @@ -1431,7 +1222,7 @@ func (s *workflowHandlerSuite) TestConvertIndexedKeyToThrift() { } func (s *workflowHandlerSuite) newConfig() *Config { - return NewConfig(dc.NewCollection(dc.NewNopClient(), s.logger), numHistoryShards, false) + return NewConfig(dc.NewCollection(dc.NewNopClient(), s.mockResource.GetLogger()), numHistoryShards, false) } func updateRequest( diff --git a/service/matching/service.go b/service/matching/service.go index 17c2d945128..4ad86437d75 100644 --- a/service/matching/service.go +++ b/service/matching/service.go @@ -45,6 +45,10 @@ func NewService( ) (resource.Resource, error) { serviceConfig := NewConfig(dynamicconfig.NewCollection(params.DynamicConfig, params.Logger)) + params.PersistenceConfig.SetMaxQPS( + params.PersistenceConfig.DefaultStore, + serviceConfig.PersistenceMaxQPS(), + ) serviceResource, err := resource.New( params, common.MatchingServiceName, diff --git a/service/worker/service.go b/service/worker/service.go index 6be1e4bfd95..5b268e0e9e7 100644 --- a/service/worker/service.go +++ b/service/worker/service.go @@ -25,7 +25,6 @@ import ( "github.com/uber/cadence/.gen/go/shared" "github.com/uber/cadence/common" - carchiver "github.com/uber/cadence/common/archiver" "github.com/uber/cadence/common/definition" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" @@ -78,6 +77,11 @@ func NewService( serviceConfig := NewConfig(params) + params.PersistenceConfig.SetMaxQPS( + params.PersistenceConfig.DefaultStore, + serviceConfig.ReplicationCfg.PersistenceMaxQPS(), + ) + serviceResource, err := resource.New( params, common.WorkerServiceName, @@ -266,29 +270,6 @@ func (s *Service) startIndexer() { } func (s *Service) startArchiver() { - historyArchiverBootstrapContainer := &carchiver.HistoryBootstrapContainer{ - HistoryV2Manager: s.GetHistoryManager(), - Logger: s.GetLogger(), - MetricsClient: s.GetMetricsClient(), - ClusterMetadata: s.GetClusterMetadata(), - DomainCache: s.GetDomainCache(), - } - visibilityArchiverBootstrapContainer := &carchiver.VisibilityBootstrapContainer{ - Logger: s.GetLogger(), - MetricsClient: s.GetMetricsClient(), - ClusterMetadata: s.GetClusterMetadata(), - DomainCache: s.GetDomainCache(), - } - archiverProvider := s.GetArchiverProvider() - err := archiverProvider.RegisterBootstrapContainer( - common.WorkerServiceName, - historyArchiverBootstrapContainer, - visibilityArchiverBootstrapContainer, - ) - if err != nil { - s.GetLogger().Fatal("failed to register archiver bootstrap container", tag.Error(err)) - } - bc := &archiver.BootstrapContainer{ PublicClient: s.GetSDKClient(), MetricsClient: s.GetMetricsClient(), @@ -296,7 +277,7 @@ func (s *Service) startArchiver() { HistoryV2Manager: s.GetHistoryManager(), DomainCache: s.GetDomainCache(), Config: s.config.ArchiverConfig, - ArchiverProvider: archiverProvider, + ArchiverProvider: s.GetArchiverProvider(), } clientWorker := archiver.NewClientWorker(bc) if err := clientWorker.Start(); err != nil {