From 4117f968e975dd22572615e41876a0c8f08cfd33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mantas=20=C5=A0idlauskas?= Date: Fri, 4 Aug 2023 03:16:16 +0300 Subject: [PATCH] StartWorkflowExecution: validate RequestID before calling history (#5359) What changed? Validating RequestID (uuid) at handler level. Why? When calling StartWorkflowExecution and passing wrongly formatted UUID, a generic error will be returned from persistence layer (cassandra for example). This error is not treated as non-retryable, so Cadence will try to insert wrong data multiple times. On the client side, only request-timeout will be returned which reveals no details about the nature for this failure. This change will validate UUID on handler side and no calls to history/persistence will be made. Additionally, user will get information on what data is missing or malformed. How did you test it? Unit test updated to include malformed UUID check --- service/frontend/workflowHandler.go | 65 ++++++++++++------------ service/frontend/workflowHandler_test.go | 9 +++- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/service/frontend/workflowHandler.go b/service/frontend/workflowHandler.go index 025441ef9a7..afe35f77309 100644 --- a/service/frontend/workflowHandler.go +++ b/service/frontend/workflowHandler.go @@ -28,7 +28,7 @@ import ( "sync/atomic" "time" - "github.com/pborman/uuid" + "github.com/google/uuid" "go.uber.org/yarpc" "go.uber.org/yarpc/yarpcerrors" "golang.org/x/sync/errgroup" @@ -131,7 +131,6 @@ var ( errQueryTypeNotSet = &types.BadRequestError{Message: "QueryType is not set on request."} errRequestNotSet = &types.BadRequestError{Message: "Request is nil."} errNoPermission = &types.BadRequestError{Message: "No permission to do this operation."} - errRequestIDNotSet = &types.BadRequestError{Message: "RequestId is not set on request."} errWorkflowTypeNotSet = &types.BadRequestError{Message: "WorkflowType is not set on request."} errInvalidRetention = &types.BadRequestError{Message: "RetentionDays is invalid."} errInvalidExecutionStartToCloseTimeoutSeconds = &types.BadRequestError{Message: "A valid ExecutionStartToCloseTimeoutSeconds is not set on request."} @@ -615,7 +614,7 @@ func (wh *WorkflowHandler) PollForActivityTask( ); err != nil { return &types.PollForActivityTaskResponse{}, nil } - pollerID := uuid.New() + pollerID := uuid.New().String() op := func() error { resp, err = wh.GetMatchingClient().PollForActivityTask(ctx, &types.MatchingPollForActivityTaskRequest{ DomainUUID: domainID, @@ -745,7 +744,7 @@ func (wh *WorkflowHandler) PollForDecisionTask( return &types.PollForDecisionTaskResponse{}, nil } - pollerID := uuid.New() + pollerID := uuid.New().String() var matchingResp *types.MatchingPollForDecisionTaskResponse op := func() error { matchingResp, err = wh.GetMatchingClient().PollForDecisionTask(ctx, &types.MatchingPollForDecisionTaskRequest{ @@ -2082,18 +2081,14 @@ func (wh *WorkflowHandler) StartWorkflowExecution( scope, sw := wh.startRequestProfileWithDomain(ctx, metrics.FrontendStartWorkflowExecutionScope, startRequest) defer sw.Stop() - if wh.isShuttingDown() { - return nil, errShuttingDown - } - - if err := wh.versionChecker.ClientSupported(ctx, wh.config.EnableClientVersionCheck()); err != nil { - return nil, wh.error(err, scope) - } - if startRequest == nil { return nil, wh.error(errRequestNotSet, scope) } + if wh.isShuttingDown() { + return nil, errShuttingDown + } + domainName := startRequest.GetDomain() wfExecution := &types.WorkflowExecution{ WorkflowID: startRequest.GetWorkflowID(), @@ -2104,6 +2099,21 @@ func (wh *WorkflowHandler) StartWorkflowExecution( return nil, wh.error(errDomainNotSet, scope, tags...) } + if startRequest.GetWorkflowID() == "" { + return nil, wh.error(errWorkflowIDNotSet, scope, tags...) + } + + if _, err := uuid.Parse(startRequest.RequestID); err != nil { + return nil, wh.error(&types.BadRequestError{Message: fmt.Sprintf("requestId %q is not a valid UUID", startRequest.RequestID)}, scope, tags...) + } + if startRequest.WorkflowType == nil || startRequest.WorkflowType.GetName() == "" { + return nil, wh.error(errWorkflowTypeNotSet, scope, tags...) + } + + if err := wh.versionChecker.ClientSupported(ctx, wh.config.EnableClientVersionCheck()); err != nil { + return nil, wh.error(err, scope) + } + if ok := wh.allow(ratelimitTypeUser, startRequest); !ok { return nil, wh.error(createServiceBusyError(), scope, tags...) } @@ -2121,10 +2131,6 @@ func (wh *WorkflowHandler) StartWorkflowExecution( return nil, wh.error(errDomainTooLong, scope, tags...) } - if startRequest.GetWorkflowID() == "" { - return nil, wh.error(errWorkflowIDNotSet, scope, tags...) - } - if !common.ValidIDLength( startRequest.GetWorkflowID(), scope, @@ -2141,20 +2147,10 @@ func (wh *WorkflowHandler) StartWorkflowExecution( return nil, wh.error(err, scope, tags...) } - if startRequest.GetCronSchedule() != "" { - if _, err := backoff.ValidateSchedule(startRequest.GetCronSchedule()); err != nil { - return nil, wh.error(err, scope, tags...) - } - } - wh.GetLogger().Debug( "Received StartWorkflowExecution. WorkflowID", tag.WorkflowID(startRequest.GetWorkflowID())) - if startRequest.WorkflowType == nil || startRequest.WorkflowType.GetName() == "" { - return nil, wh.error(errWorkflowTypeNotSet, scope, tags...) - } - if !common.ValidIDLength( startRequest.WorkflowType.GetName(), scope, @@ -2189,6 +2185,11 @@ func (wh *WorkflowHandler) StartWorkflowExecution( jitter := startRequest.GetJitterStartSeconds() cron := startRequest.GetCronSchedule() + if cron != "" { + if _, err := backoff.ValidateSchedule(startRequest.GetCronSchedule()); err != nil { + return nil, wh.error(err, scope, tags...) + } + } if jitter > 0 && cron != "" { // Calculate the cron duration and ensure that jitter is not greater than the cron duration, // because that would be confusing to users. @@ -2205,10 +2206,6 @@ func (wh *WorkflowHandler) StartWorkflowExecution( } } - if startRequest.GetRequestID() == "" { - return nil, wh.error(errRequestIDNotSet, scope, tags...) - } - if !common.ValidIDLength( startRequest.GetRequestID(), scope, @@ -4264,8 +4261,10 @@ func validateExecution(w *types.WorkflowExecution) error { if w.GetWorkflowID() == "" { return errWorkflowIDNotSet } - if w.GetRunID() != "" && uuid.Parse(w.GetRunID()) == nil { - return errInvalidRunID + if w.GetRunID() != "" { + if _, err := uuid.Parse(w.GetRunID()); err != nil { + return errInvalidRunID + } } return nil } @@ -4715,7 +4714,7 @@ func (wh *WorkflowHandler) normalizeVersionedErrors(ctx context.Context, err err func constructRestartWorkflowRequest(w *types.WorkflowExecutionStartedEventAttributes, domain string, identity string, workflowID string) *types.StartWorkflowExecutionRequest { startRequest := &types.StartWorkflowExecutionRequest{ - RequestID: uuid.New(), + RequestID: uuid.New().String(), Domain: domain, WorkflowID: workflowID, WorkflowType: &types.WorkflowType{ diff --git a/service/frontend/workflowHandler_test.go b/service/frontend/workflowHandler_test.go index 06dc00fc158..7214ffc73f9 100644 --- a/service/frontend/workflowHandler_test.go +++ b/service/frontend/workflowHandler_test.go @@ -323,7 +323,12 @@ func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_RequestIdNotSet } _, err := wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) s.Error(err) - s.Equal(errRequestIDNotSet, err) + s.Equal(&types.BadRequestError{Message: "requestId \"\" is not a valid UUID"}, err) + startWorkflowExecutionRequest.RequestID = "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + _, err = wh.StartWorkflowExecution(context.Background(), startWorkflowExecutionRequest) + s.Error(err) + s.Equal(&types.BadRequestError{Message: "requestId \"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx\" is not a valid UUID"}, err) + } func (s *workflowHandlerSuite) TestStartWorkflowExecution_Failed_BadDelayStartSeconds() { @@ -1389,7 +1394,7 @@ func (s *workflowHandlerSuite) TestRestartWorkflowExecution__Success() { }, Identity: "", }) - s.Equal(resp.GetRunID(), testRunID) + s.Equal(testRunID, resp.GetRunID()) s.NoError(err) }