Skip to content

Commit

Permalink
Add context parameter to data interfaces (cadence-workflow#3547)
Browse files Browse the repository at this point in the history
* Add context parameter to all data interfaces and persistence retryer.
* context.TODO() is used as a placeholder in the caller of data interfaces.
  • Loading branch information
yycptt authored Oct 3, 2020
1 parent 4b35fcd commit a9a41c6
Show file tree
Hide file tree
Showing 130 changed files with 3,917 additions and 2,747 deletions.
3 changes: 2 additions & 1 deletion common/archiver/historyIterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package archiver

import (
"context"
"encoding/json"
"errors"

Expand Down Expand Up @@ -228,7 +229,7 @@ func (i *historyIterator) readHistory(firstEventID int64) ([]*shared.History, er
PageSize: i.historyPageSize,
ShardID: common.IntPtr(i.request.ShardID),
}
historyBatches, _, _, err := persistence.ReadFullPageV2EventsByBatch(i.historyV2Manager, req)
historyBatches, _, _, err := persistence.ReadFullPageV2EventsByBatch(context.TODO(), i.historyV2Manager, req)
return historyBatches, err

}
Expand Down
10 changes: 5 additions & 5 deletions common/archiver/historyIterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (s *HistoryIteratorSuite) SetupTest() {

func (s *HistoryIteratorSuite) TestReadHistory_Failed_EventsV2() {
mockHistoryV2Manager := &mocks.HistoryV2Manager{}
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", mock.Anything).Return(nil, errors.New("got error reading history branch"))
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", mock.Anything, mock.Anything).Return(nil, errors.New("got error reading history branch"))
itr := s.constructTestHistoryIterator(mockHistoryV2Manager, testDefaultTargetHistoryBlobSize, nil)
history, err := itr.readHistory(common.FirstEventID)
s.Error(err)
Expand All @@ -103,7 +103,7 @@ func (s *HistoryIteratorSuite) TestReadHistory_Success_EventsV2() {
History: []*shared.History{},
NextPageToken: []byte{},
}
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", mock.Anything).Return(&resp, nil)
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", mock.Anything, mock.Anything).Return(&resp, nil)
itr := s.constructTestHistoryIterator(mockHistoryV2Manager, testDefaultTargetHistoryBlobSize, nil)
history, err := itr.readHistory(common.FirstEventID)
s.NoError(err)
Expand Down Expand Up @@ -624,14 +624,14 @@ func (s *HistoryIteratorSuite) constructMockHistoryV2Manager(batchInfo []int, re
ShardID: common.IntPtr(testShardID),
}
if returnErrorOnPage == i {
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", req).Return(nil, errors.New("got error getting workflow execution history"))
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", mock.Anything, req).Return(nil, errors.New("got error getting workflow execution history"))
return mockHistoryV2Manager
}

resp := &persistence.ReadHistoryBranchByBatchResponse{
History: s.constructHistoryBatches(batchInfo, p, firstEventIDs[p.firstbatchIdx]),
}
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", req).Return(resp, nil)
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", mock.Anything, req).Return(resp, nil)
}

if addNotExistCall {
Expand All @@ -642,7 +642,7 @@ func (s *HistoryIteratorSuite) constructMockHistoryV2Manager(batchInfo []int, re
PageSize: testDefaultPersistencePageSize,
ShardID: common.IntPtr(testShardID),
}
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", req).Return(nil, &shared.EntityNotExistsError{Message: "Reach the end"})
mockHistoryV2Manager.On("ReadHistoryBranchByBatch", mock.Anything, req).Return(nil, &shared.EntityNotExistsError{Message: "Reach the end"})
}

return mockHistoryV2Manager
Expand Down
7 changes: 4 additions & 3 deletions common/cache/domainCache.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package cache

import (
"context"
"hash/fnv"
"sort"
"strconv"
Expand Down Expand Up @@ -424,7 +425,7 @@ func (c *domainCache) refreshDomainsLocked() error {

// first load the metadata record, then load domains
// this can guarantee that domains in the cache are not updated more than metadata record
metadata, err := c.metadataMgr.GetMetadata()
metadata, err := c.metadataMgr.GetMetadata(context.TODO())
if err != nil {
return err
}
Expand All @@ -436,7 +437,7 @@ func (c *domainCache) refreshDomainsLocked() error {

for continuePage {
request.NextPageToken = token
response, err := c.metadataMgr.ListDomains(request)
response, err := c.metadataMgr.ListDomains(context.TODO(), request)
if err != nil {
return err
}
Expand Down Expand Up @@ -511,7 +512,7 @@ func (c *domainCache) checkDomainExists(
id string,
) error {

_, err := c.metadataMgr.GetDomain(&persistence.GetDomainRequest{Name: name, ID: id})
_, err := c.metadataMgr.GetDomain(context.TODO(), &persistence.GetDomainRequest{Name: name, ID: id})
return err
}

Expand Down
37 changes: 19 additions & 18 deletions common/cache/domainCache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"time"

"github.com/pborman/uuid"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/uber-go/tally"
Expand Down Expand Up @@ -152,17 +153,17 @@ func (s *domainCacheSuite) TestListDomain() {

pageToken := []byte("some random page token")

s.metadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil)
s.metadataMgr.On("GetMetadata", mock.Anything).Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil)
s.clusterMetadata.On("IsGlobalDomainEnabled").Return(true)
s.metadataMgr.On("ListDomains", &persistence.ListDomainsRequest{
s.metadataMgr.On("ListDomains", mock.Anything, &persistence.ListDomainsRequest{
PageSize: domainCacheRefreshPageSize,
NextPageToken: nil,
}).Return(&persistence.ListDomainsResponse{
Domains: []*persistence.GetDomainResponse{domainRecord1},
NextPageToken: pageToken,
}, nil).Once()

s.metadataMgr.On("ListDomains", &persistence.ListDomainsRequest{
s.metadataMgr.On("ListDomains", mock.Anything, &persistence.ListDomainsRequest{
PageSize: domainCacheRefreshPageSize,
NextPageToken: pageToken,
}).Return(&persistence.ListDomainsResponse{
Expand Down Expand Up @@ -198,7 +199,7 @@ func (s *domainCacheSuite) TestListDomain() {
func (s *domainCacheSuite) TestGetDomain_NonLoaded_GetByName() {
s.clusterMetadata.On("IsGlobalDomainEnabled").Return(true)
domainNotificationVersion := int64(999999) // make this notification version really large for test
s.metadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil)
s.metadataMgr.On("GetMetadata", mock.Anything).Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil)
domainRecord := &persistence.GetDomainResponse{
Info: &persistence.DomainInfo{ID: uuid.New(), Name: "some random domain name", Data: make(map[string]string)},
Config: &persistence.DomainConfig{
Expand All @@ -222,8 +223,8 @@ func (s *domainCacheSuite) TestGetDomain_NonLoaded_GetByName() {
}
entry := s.buildEntryFromRecord(domainRecord)

s.metadataMgr.On("GetDomain", &persistence.GetDomainRequest{Name: entry.info.Name}).Return(domainRecord, nil).Once()
s.metadataMgr.On("ListDomains", &persistence.ListDomainsRequest{
s.metadataMgr.On("GetDomain", mock.Anything, &persistence.GetDomainRequest{Name: entry.info.Name}).Return(domainRecord, nil).Once()
s.metadataMgr.On("ListDomains", mock.Anything, &persistence.ListDomainsRequest{
PageSize: domainCacheRefreshPageSize,
NextPageToken: nil,
}).Return(&persistence.ListDomainsResponse{
Expand All @@ -242,7 +243,7 @@ func (s *domainCacheSuite) TestGetDomain_NonLoaded_GetByName() {
func (s *domainCacheSuite) TestGetDomain_NonLoaded_GetByID() {
s.clusterMetadata.On("IsGlobalDomainEnabled").Return(true)
domainNotificationVersion := int64(999999) // make this notification version really large for test
s.metadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil)
s.metadataMgr.On("GetMetadata", mock.Anything).Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil)
domainRecord := &persistence.GetDomainResponse{
Info: &persistence.DomainInfo{ID: uuid.New(), Name: "some random domain name", Data: make(map[string]string)},
Config: &persistence.DomainConfig{
Expand All @@ -261,8 +262,8 @@ func (s *domainCacheSuite) TestGetDomain_NonLoaded_GetByID() {
}
entry := s.buildEntryFromRecord(domainRecord)

s.metadataMgr.On("GetDomain", &persistence.GetDomainRequest{ID: entry.info.ID}).Return(domainRecord, nil).Once()
s.metadataMgr.On("ListDomains", &persistence.ListDomainsRequest{
s.metadataMgr.On("GetDomain", mock.Anything, &persistence.GetDomainRequest{ID: entry.info.ID}).Return(domainRecord, nil).Once()
s.metadataMgr.On("ListDomains", mock.Anything, &persistence.ListDomainsRequest{
PageSize: domainCacheRefreshPageSize,
NextPageToken: nil,
}).Return(&persistence.ListDomainsResponse{
Expand Down Expand Up @@ -324,9 +325,9 @@ func (s *domainCacheSuite) TestRegisterCallback_CatchUp() {
entry2 := s.buildEntryFromRecord(domainRecord2)
domainNotificationVersion++

s.metadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil).Once()
s.metadataMgr.On("GetMetadata", mock.Anything).Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil).Once()
s.clusterMetadata.On("IsGlobalDomainEnabled").Return(true)
s.metadataMgr.On("ListDomains", &persistence.ListDomainsRequest{
s.metadataMgr.On("ListDomains", mock.Anything, &persistence.ListDomainsRequest{
PageSize: domainCacheRefreshPageSize,
NextPageToken: nil,
}).Return(&persistence.ListDomainsResponse{
Expand Down Expand Up @@ -407,9 +408,9 @@ func (s *domainCacheSuite) TestUpdateCache_TriggerCallBack() {
entry2Old := s.buildEntryFromRecord(domainRecord2Old)
domainNotificationVersion++

s.metadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil).Once()
s.metadataMgr.On("GetMetadata", mock.Anything).Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil).Once()
s.clusterMetadata.On("IsGlobalDomainEnabled").Return(true)
s.metadataMgr.On("ListDomains", &persistence.ListDomainsRequest{
s.metadataMgr.On("ListDomains", mock.Anything, &persistence.ListDomainsRequest{
PageSize: domainCacheRefreshPageSize,
NextPageToken: nil,
}).Return(&persistence.ListDomainsResponse{
Expand Down Expand Up @@ -476,8 +477,8 @@ func (s *domainCacheSuite) TestUpdateCache_TriggerCallBack() {
s.Empty(entriesOld)
s.Empty(entriesNew)

s.metadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil).Once()
s.metadataMgr.On("ListDomains", &persistence.ListDomainsRequest{
s.metadataMgr.On("GetMetadata", mock.Anything).Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil).Once()
s.metadataMgr.On("ListDomains", mock.Anything, &persistence.ListDomainsRequest{
PageSize: domainCacheRefreshPageSize,
NextPageToken: nil,
}).Return(&persistence.ListDomainsResponse{
Expand All @@ -500,7 +501,7 @@ func (s *domainCacheSuite) TestUpdateCache_TriggerCallBack() {
func (s *domainCacheSuite) TestGetTriggerListAndUpdateCache_ConcurrentAccess() {
s.clusterMetadata.On("IsGlobalDomainEnabled").Return(true)
domainNotificationVersion := int64(999999) // make this notification version really large for test
s.metadataMgr.On("GetMetadata").Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil)
s.metadataMgr.On("GetMetadata", mock.Anything).Return(&persistence.GetMetadataResponse{NotificationVersion: domainNotificationVersion}, nil)
id := uuid.New()
domainRecordOld := &persistence.GetDomainResponse{
Info: &persistence.DomainInfo{ID: id, Name: "some random domain name", Data: make(map[string]string)},
Expand All @@ -521,8 +522,8 @@ func (s *domainCacheSuite) TestGetTriggerListAndUpdateCache_ConcurrentAccess() {
}
entryOld := s.buildEntryFromRecord(domainRecordOld)

s.metadataMgr.On("GetDomain", &persistence.GetDomainRequest{ID: id}).Return(domainRecordOld, nil).Maybe()
s.metadataMgr.On("ListDomains", &persistence.ListDomainsRequest{
s.metadataMgr.On("GetDomain", mock.Anything, &persistence.GetDomainRequest{ID: id}).Return(domainRecordOld, nil).Maybe()
s.metadataMgr.On("ListDomains", mock.Anything, &persistence.ListDomainsRequest{
PageSize: domainCacheRefreshPageSize,
NextPageToken: nil,
}).Return(&persistence.ListDomainsResponse{
Expand Down
15 changes: 11 additions & 4 deletions common/domain/dlqMessageHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
package domain

import (
"context"

"github.com/uber/cadence/.gen/go/replicator"
"github.com/uber/cadence/.gen/go/shared"
"github.com/uber/cadence/common/log"
Expand Down Expand Up @@ -65,12 +67,13 @@ func (d *dlqMessageHandlerImpl) Read(
pageToken []byte,
) ([]*replicator.ReplicationTask, []byte, error) {

ackLevel, err := d.domainReplicationQueue.GetDLQAckLevel()
ackLevel, err := d.domainReplicationQueue.GetDLQAckLevel(context.TODO())
if err != nil {
return nil, nil, err
}

return d.domainReplicationQueue.GetMessagesFromDLQ(
context.TODO(),
ackLevel,
lastMessageID,
pageSize,
Expand All @@ -83,19 +86,21 @@ func (d *dlqMessageHandlerImpl) Purge(
lastMessageID int64,
) error {

ackLevel, err := d.domainReplicationQueue.GetDLQAckLevel()
ackLevel, err := d.domainReplicationQueue.GetDLQAckLevel(context.TODO())
if err != nil {
return err
}

if err := d.domainReplicationQueue.RangeDeleteMessagesFromDLQ(
context.TODO(),
ackLevel,
lastMessageID,
); err != nil {
return err
}

if err := d.domainReplicationQueue.UpdateDLQAckLevel(
context.TODO(),
lastMessageID,
); err != nil {
d.logger.Error("Failed to update DLQ ack level after purging messages", tag.Error(err))
Expand All @@ -111,12 +116,13 @@ func (d *dlqMessageHandlerImpl) Merge(
pageToken []byte,
) ([]byte, error) {

ackLevel, err := d.domainReplicationQueue.GetDLQAckLevel()
ackLevel, err := d.domainReplicationQueue.GetDLQAckLevel(context.TODO())
if err != nil {
return nil, err
}

messages, token, err := d.domainReplicationQueue.GetMessagesFromDLQ(
context.TODO(),
ackLevel,
lastMessageID,
pageSize,
Expand All @@ -142,13 +148,14 @@ func (d *dlqMessageHandlerImpl) Merge(
}

if err := d.domainReplicationQueue.RangeDeleteMessagesFromDLQ(
context.TODO(),
ackLevel,
ackedMessageID,
); err != nil {
d.logger.Error("failed to delete merged tasks on merging domain DLQ message", tag.Error(err))
return nil, err
}
if err := d.domainReplicationQueue.UpdateDLQAckLevel(ackedMessageID); err != nil {
if err := d.domainReplicationQueue.UpdateDLQAckLevel(context.TODO(), ackedMessageID); err != nil {
d.logger.Error("failed to update ack level on merging domain DLQ message", tag.Error(err))
}

Expand Down
Loading

0 comments on commit a9a41c6

Please sign in to comment.