Skip to content

Commit

Permalink
Implement history handler for fetching and responding cross-cluster t…
Browse files Browse the repository at this point in the history
…asks (cadence-workflow#4329)

- Implement history GetCrossClusterTasks handler
- Implement history RespondCrossClusterTasksCompleted handler
  • Loading branch information
yycptt authored Jul 28, 2021
1 parent adbffa4 commit 40c5f18
Show file tree
Hide file tree
Showing 11 changed files with 393 additions and 54 deletions.
51 changes: 11 additions & 40 deletions client/history/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ package history

import (
"context"
"math"
"sync"
"time"

Expand Down Expand Up @@ -824,7 +823,7 @@ func (c *clientImpl) GetReplicationMessages(
}

// preserve 5% timeout to return partial of the result if context is timing out
requestContext, cancel := c.createChildContext(ctx, 0.05)
requestContext, cancel := common.CreateChildContext(ctx, 0.05)
defer cancel()

var wg sync.WaitGroup
Expand Down Expand Up @@ -1036,7 +1035,7 @@ func (c *clientImpl) GetCrossClusterTasks(
}

// preserve 5% timeout to return partial of the result if context is timing out
ctx, cancel := c.createChildContext(ctx, 0.05)
ctx, cancel := common.CreateChildContext(ctx, 0.05)
defer cancel()

futureByClient := make(map[Client]future.Future, len(requestByClient))
Expand All @@ -1050,36 +1049,29 @@ func (c *clientImpl) GetCrossClusterTasks(
}

response := &types.GetCrossClusterTasksResponse{
TasksByShard: make(map[int32][]*types.CrossClusterTaskRequest),
TasksByShard: make(map[int32][]*types.CrossClusterTaskRequest),
FailedCauseByShard: make(map[int32]types.GetTaskFailedCause),
}
var err error
for _, future := range futureByClient {
for client, future := range futureByClient {
var resp *types.GetCrossClusterTasksResponse
if futureErr := future.Get(ctx, &resp); futureErr != nil {
c.logger.Error("Failed to get cross cluster tasks", tag.Error(futureErr))
// TODO: return error for each shard and perform backoff at shard level.
// and ensure every shardID in request has a response (either tasks or failed cause).
//
// for _, failedShardID := range requestByClient[client].ShardIDs {
// response.FailedCauseByShard[failedShardID] = ...
// }
//
// for now following the pattern for getting replication tasks:
// ignore errors other than service busy, so that task fetcher in target
// cluster can slow down.
if err == nil && common.IsServiceBusyError(futureErr) {
err = futureErr
for _, failedShardID := range requestByClient[client].ShardIDs {
response.FailedCauseByShard[failedShardID] = common.ConvertErrToGetTaskFailedCause(futureErr)
}
} else {
for shardID, tasks := range resp.TasksByShard {
response.TasksByShard[shardID] = tasks
}
for shardID, failedCause := range resp.FailedCauseByShard {
response.FailedCauseByShard[shardID] = failedCause
}
}
}
// not using a waitGroup for created goroutines as once all futures are unblocked,
// those goroutines will eventually be completed

return response, err
return response, nil
}

func (c *clientImpl) RespondCrossClusterTasksCompleted(
Expand Down Expand Up @@ -1109,27 +1101,6 @@ func (c *clientImpl) RespondCrossClusterTasksCompleted(
return response, nil
}

func (c *clientImpl) createChildContext(
parent context.Context,
tailroom float64,
) (context.Context, context.CancelFunc) {
if parent == nil {
return nil, func() {}
}
if parent.Err() != nil {
return parent, func() {}
}

now := time.Now()
deadline, ok := parent.Deadline()
if !ok || deadline.Before(now) {
return parent, func() {}
}

newDeadline := now.Add(time.Duration(math.Ceil(float64(deadline.Sub(now)) * (1.0 - tailroom))))
return context.WithDeadline(parent, newDeadline)
}

func (c *clientImpl) createContext(parent context.Context) (context.Context, context.CancelFunc) {
if parent == nil {
return context.WithTimeout(context.Background(), c.timeout)
Expand Down
41 changes: 41 additions & 0 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"encoding/pem"
"fmt"
"io/ioutil"
"math"
"math/rand"
"sort"
"strconv"
Expand Down Expand Up @@ -330,6 +331,31 @@ func IsValidContext(ctx context.Context) error {
return nil
}

// CreateChildContext creates a child context which shorted context timeout
// from the given parent context
// tailroom must be in range [0, 1] and
// (1-tailroom) * parent timeout will be the new child context timeout
func CreateChildContext(
parent context.Context,
tailroom float64,
) (context.Context, context.CancelFunc) {
if parent == nil {
return nil, func() {}
}
if parent.Err() != nil {
return parent, func() {}
}

now := time.Now()
deadline, ok := parent.Deadline()
if !ok || deadline.Before(now) {
return parent, func() {}
}

newDeadline := now.Add(time.Duration(math.Ceil(float64(deadline.Sub(now)) * (1.0 - tailroom))))
return context.WithDeadline(parent, newDeadline)
}

// GenerateRandomString is used for generate test string
func GenerateRandomString(n int) string {
rand.Seed(time.Now().UnixNano())
Expand Down Expand Up @@ -932,6 +958,7 @@ func SleepWithMinDuration(desired time.Duration, available time.Duration) time.D
return available - d
}

// LoadRSAPublicKey loads a rsa.PublicKey from the given filepath
func LoadRSAPublicKey(path string) (*rsa.PublicKey, error) {
key, err := ioutil.ReadFile(path)
if err != nil {
Expand All @@ -949,3 +976,17 @@ func LoadRSAPublicKey(path string) (*rsa.PublicKey, error) {
publicKey := pub.(*rsa.PublicKey)
return publicKey, nil
}

// ConvertErrToGetTaskFailedCause converts error to GetTaskFailedCause
func ConvertErrToGetTaskFailedCause(err error) types.GetTaskFailedCause {
if IsContextTimeoutError(err) {
return types.GetTaskFailedCauseTimeout
}
if IsServiceBusyError(err) {
return types.GetTaskFailedCauseServiceBusy
}
if _, ok := err.(*types.ShardOwnershipLostError); ok {
return types.GetTaskFailedCauseShardOwnershipLost
}
return types.GetTaskFailedCauseUncategorized
}
28 changes: 28 additions & 0 deletions common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,31 @@ func TestValidateDomainUUID(t *testing.T) {
})
}
}

func TestConvertErrToGetTaskFailedCause(t *testing.T) {
testCases := []struct {
err error
expectedFailedCause types.GetTaskFailedCause
}{
{
err: errors.New("some random error"),
expectedFailedCause: types.GetTaskFailedCauseUncategorized,
},
{
err: context.DeadlineExceeded,
expectedFailedCause: types.GetTaskFailedCauseTimeout,
},
{
err: &types.ServiceBusyError{},
expectedFailedCause: types.GetTaskFailedCauseServiceBusy,
},
{
err: &types.ShardOwnershipLostError{},
expectedFailedCause: types.GetTaskFailedCauseShardOwnershipLost,
},
}

for _, tc := range testCases {
require.Equal(t, tc.expectedFailedCause, ConvertErrToGetTaskFailedCause(tc.err))
}
}
2 changes: 2 additions & 0 deletions service/history/engine/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ type (
SyncActivity(ctx context.Context, request *types.SyncActivityRequest) error
GetReplicationMessages(ctx context.Context, pollingCluster string, lastReadMessageID int64) (*types.ReplicationMessages, error)
GetDLQReplicationMessages(ctx context.Context, taskInfos []*types.ReplicationTaskInfo) ([]*types.ReplicationTask, error)
GetCrossClusterTasks(ctx context.Context, targetCluster string) ([]*types.CrossClusterTaskRequest, error)
RespondCrossClusterTasksCompleted(ctx context.Context, targetCluster string, responses []*types.CrossClusterTaskResponse) error
QueryWorkflow(ctx context.Context, request *types.HistoryQueryWorkflowRequest) (*types.HistoryQueryWorkflowResponse, error)
ReapplyEvents(ctx context.Context, domainUUID string, workflowID string, runID string, events []*types.HistoryEvent) error
ReadDLQMessages(ctx context.Context, messagesRequest *types.ReadDLQMessagesRequest) (*types.ReadDLQMessagesResponse, error)
Expand Down
29 changes: 29 additions & 0 deletions service/history/engine/interface_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 74 additions & 6 deletions service/history/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ package history

import (
"context"
"errors"
"fmt"
"strconv"
"sync"
Expand All @@ -36,6 +35,7 @@ import (

"github.com/uber/cadence/common"
"github.com/uber/cadence/common/definition"
"github.com/uber/cadence/common/future"
"github.com/uber/cadence/common/log"
"github.com/uber/cadence/common/log/tag"
"github.com/uber/cadence/common/metrics"
Expand Down Expand Up @@ -1873,8 +1873,56 @@ func (h *handlerImpl) GetCrossClusterTasks(
_, sw := h.startRequestProfile(ctx, metrics.HistoryGetCrossClusterTasksScope)
defer sw.Stop()

// TODO: implement this API when cross cluster queue is implemented
return nil, errors.New("not implemented")
if h.isShuttingDown() {
return nil, errShuttingDown
}

ctx, cancel := common.CreateChildContext(ctx, 0.05)
defer cancel()

futureByShardID := make(map[int32]future.Future, len(request.ShardIDs))
for _, shardID := range request.ShardIDs {
future, settable := future.NewFuture()
futureByShardID[shardID] = future
go func(shardID int32) {
logger := h.GetLogger().WithTags(tag.ShardID(int(shardID)))
engine, err := h.controller.GetEngineForShard(int(shardID))
if err != nil {
logger.Error("History engine not found for shard", tag.Error(err))
var owner string
if info, err := h.GetHistoryServiceResolver().Lookup(strconv.Itoa(int(shardID))); err == nil {
owner = info.GetAddress()
}
settable.Set(nil, shard.CreateShardOwnershipLostError(h.GetHostInfo().GetAddress(), owner))
return
}

if tasks, err := engine.GetCrossClusterTasks(ctx, request.TargetCluster); err != nil {
logger.Error("Failed to get cross cluster tasks", tag.Error(err))
settable.Set(nil, h.convertError(err))
} else {
settable.Set(tasks, nil)
}
}(shardID)
}

response := &types.GetCrossClusterTasksResponse{
TasksByShard: make(map[int32][]*types.CrossClusterTaskRequest),
FailedCauseByShard: make(map[int32]types.GetTaskFailedCause),
}
for shardID, future := range futureByShardID {
var taskRequests []*types.CrossClusterTaskRequest
if futureErr := future.Get(ctx, &taskRequests); futureErr != nil {
response.FailedCauseByShard[shardID] = common.ConvertErrToGetTaskFailedCause(futureErr)
} else {
response.TasksByShard[shardID] = taskRequests
}
}
// not using a waitGroup for created goroutines here
// as once all futures are unblocked,
// those goroutines will eventually be completed

return response, nil
}

func (h *handlerImpl) RespondCrossClusterTasksCompleted(
Expand All @@ -1884,11 +1932,31 @@ func (h *handlerImpl) RespondCrossClusterTasksCompleted(
defer log.CapturePanic(h.GetLogger(), &retError)
h.startWG.Wait()

_, sw := h.startRequestProfile(ctx, metrics.HistoryRespondCrossClusterTasksCompletedScope)
scope, sw := h.startRequestProfile(ctx, metrics.HistoryRespondCrossClusterTasksCompletedScope)
defer sw.Stop()

// TODO: implement this API when cross cluster queue is implemented
return nil, errors.New("not implemented")
if h.isShuttingDown() {
return nil, errShuttingDown
}

engine, err := h.controller.GetEngineForShard(int(request.GetShardID()))
if err != nil {
return nil, h.error(err, scope, "", "")
}

err = engine.RespondCrossClusterTasksCompleted(ctx, request.TargetCluster, request.TaskResponses)
if err != nil {
return nil, h.error(err, scope, "", "")
}

response := &types.RespondCrossClusterTasksCompletedResponse{}
if request.FetchNewTasks {
response.Tasks, err = engine.GetCrossClusterTasks(ctx, request.TargetCluster)
if err != nil {
return nil, h.error(err, scope, "", "")
}
}
return response, nil
}

// convertError is a helper method to convert ShardOwnershipLostError from persistence layer returned by various
Expand Down
Loading

0 comments on commit 40c5f18

Please sign in to comment.