Skip to content

Commit

Permalink
matching: refactor internalTask type to support remote forwarded tasks (
Browse files Browse the repository at this point in the history
  • Loading branch information
venkat1109 authored Jul 15, 2019
1 parent 314cdda commit 4045618
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 91 deletions.
8 changes: 4 additions & 4 deletions service/matching/matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (tm *TaskMatcher) Offer(ctx context.Context, task *internalTask) (bool, err
if task.isQuery() {
select {
case tm.queryTaskC <- task:
<-task.syncResponseCh
<-task.responseC
return true, nil
case <-ctx.Done():
return false, ctx.Err()
Expand All @@ -97,10 +97,10 @@ func (tm *TaskMatcher) Offer(ctx context.Context, task *internalTask) (bool, err

select {
case tm.taskC <- task: // poller picked up the task
if task.syncResponseCh != nil {
if task.responseC != nil {
// if there is a response channel, block until resp is received
// and return error if the response contains error
err = <-task.syncResponseCh
err = <-task.responseC
return true, err
}
return false, nil
Expand Down Expand Up @@ -135,7 +135,7 @@ func (tm *TaskMatcher) MustOffer(ctx context.Context, task *internalTask) error
func (tm *TaskMatcher) Poll(ctx context.Context) (*internalTask, error) {
select {
case task := <-tm.taskC:
if task.syncResponseCh != nil {
if task.responseC != nil {
tm.scope().IncCounter(metrics.PollSuccessWithSyncCounter)
}
tm.scope().IncCounter(metrics.PollSuccessCounter)
Expand Down
88 changes: 49 additions & 39 deletions service/matching/matchingEngine.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,24 +303,30 @@ pollLoop:
return nil, err
}

if task.queryInfo != nil {
if task.isForwarded() {
// forwarded tasks are already started and are matched remotely on a
// different matching host. So, simply forward the response
return task.forwarded.decisionTaskInfo, nil
}

if task.isQuery() {
task.finish(nil) // this only means query task sync match succeed.

// for query task, we don't need to update history to record decision task started. but we need to know
// the NextEventID so front end knows what are the history events to load for this decision task.
mutableStateResp, err := e.historyService.GetMutableState(ctx, &h.GetMutableStateRequest{
DomainUUID: req.DomainUUID,
Execution: &task.workflowExecution,
Execution: task.workflowExecution(),
})
if err != nil {
// will notify query client that the query task failed
e.deliverQueryResult(task.queryInfo.taskID, &queryResult{err: err})
e.deliverQueryResult(task.query.taskID, &queryResult{err: err})
return emptyPollForDecisionTaskResponse, nil
}

if mutableStateResp.GetPreviousStartedEventId() <= 0 {
// first decision task is not processed by worker yet.
e.deliverQueryResult(task.queryInfo.taskID,
e.deliverQueryResult(task.query.taskID,
&queryResult{err: errQueryBeforeFirstDecisionCompleted, waitNextEventID: mutableStateResp.GetNextEventId()})
return emptyPollForDecisionTaskResponse, nil
}
Expand Down Expand Up @@ -352,7 +358,7 @@ pollLoop:
switch err.(type) {
case *workflow.EntityNotExistsError, *h.EventAlreadyStartedError:
e.logger.Debug(fmt.Sprintf("Duplicated decision task taskList=%v, taskID=%v",
taskListName, task.info.TaskID))
taskListName, task.generic.TaskID))
task.finish(nil)
default:
task.finish(err)
Expand Down Expand Up @@ -404,12 +410,19 @@ pollLoop:
}
return nil, err
}

if task.isForwarded() {
// forwarded tasks are already started and are matched remotely on a
// different matching host. So, simply forward the response
return task.forwarded.activityTaskInfo, nil
}

resp, err := e.recordActivityTaskStarted(ctx, request, task)
if err != nil {
switch err.(type) {
case *workflow.EntityNotExistsError, *h.EventAlreadyStartedError:
e.logger.Debug(fmt.Sprintf("Duplicated activity task taskList=%v, taskID=%v",
taskListName, task.info.TaskID))
taskListName, task.generic.TaskID))
task.finish(nil)
default:
task.finish(err)
Expand Down Expand Up @@ -440,22 +453,19 @@ query_loop:
if err != nil {
return nil, err
}
queryTask := &queryTaskInfo{
queryRequest: queryRequest,
taskID: uuid.New(),
}
err = tlMgr.DispatchQueryTask(ctx, queryTask)
taskID := uuid.New()
err = tlMgr.DispatchQueryTask(ctx, taskID, queryRequest)
if err != nil {
return nil, err
}

queryResultCh := make(chan *queryResult, 1)
e.queryMapLock.Lock()
e.queryTaskMap[queryTask.taskID] = queryResultCh
e.queryTaskMap[taskID] = queryResultCh
e.queryMapLock.Unlock()
defer func() {
e.queryMapLock.Lock()
delete(e.queryTaskMap, queryTask.taskID)
delete(e.queryTaskMap, taskID)
e.queryMapLock.Unlock()
}()

Expand Down Expand Up @@ -608,33 +618,33 @@ func (e *matchingEngineImpl) createPollForDecisionTaskResponse(
) *m.PollForDecisionTaskResponse {

var token []byte
if task.queryInfo != nil {
if task.isQuery() {
// for a query task
queryRequest := task.queryInfo.queryRequest
queryRequest := task.query.queryRequest
taskToken := &common.QueryTaskToken{
DomainID: *queryRequest.DomainUUID,
TaskList: *queryRequest.TaskList.Name,
TaskID: task.queryInfo.taskID,
TaskID: task.query.taskID,
}
token, _ = e.tokenSerializer.SerializeQueryTaskToken(taskToken)
} else {
taskoken := &common.TaskToken{
DomainID: task.info.DomainID,
WorkflowID: task.info.WorkflowID,
RunID: task.info.RunID,
DomainID: task.generic.DomainID,
WorkflowID: task.generic.WorkflowID,
RunID: task.generic.RunID,
ScheduleID: historyResponse.GetScheduledEventId(),
ScheduleAttempt: historyResponse.GetAttempt(),
}
token, _ = e.tokenSerializer.Serialize(taskoken)
if task.syncResponseCh == nil {
if task.responseC == nil {
scope := e.metricsClient.Scope(metrics.MatchingPollForDecisionTaskScope)
scope.Tagged(metrics.DomainTag(task.domainName)).RecordTimer(metrics.AsyncMatchLatency, time.Since(task.info.CreatedTime))
scope.Tagged(metrics.DomainTag(task.domainName)).RecordTimer(metrics.AsyncMatchLatency, time.Since(task.generic.CreatedTime))
}
}

response := common.CreateMatchingPollForDecisionTaskResponse(historyResponse, workflowExecutionPtr(task.workflowExecution), token)
if task.queryInfo != nil {
response.Query = task.queryInfo.queryRequest.QueryRequest.Query
response := common.CreateMatchingPollForDecisionTaskResponse(historyResponse, task.workflowExecution(), token)
if task.query != nil {
response.Query = task.query.queryRequest.QueryRequest.Query
}
response.BacklogCountHint = common.Int64Ptr(task.backlogCountHint)
return response
Expand All @@ -654,17 +664,17 @@ func (e *matchingEngineImpl) createPollForActivityTaskResponse(
if attributes.ActivityId == nil {
panic("ActivityTaskScheduledEventAttributes.ActivityID is not set")
}
if task.syncResponseCh == nil {
if task.responseC == nil {
scope := e.metricsClient.Scope(metrics.MatchingPollForActivityTaskScope)
scope.Tagged(metrics.DomainTag(task.domainName)).RecordTimer(metrics.AsyncMatchLatency, time.Since(task.info.CreatedTime))
scope.Tagged(metrics.DomainTag(task.domainName)).RecordTimer(metrics.AsyncMatchLatency, time.Since(task.generic.CreatedTime))
}

response := &workflow.PollForActivityTaskResponse{}
response.ActivityId = attributes.ActivityId
response.ActivityType = attributes.ActivityType
response.Header = attributes.Header
response.Input = attributes.Input
response.WorkflowExecution = workflowExecutionPtr(task.workflowExecution)
response.WorkflowExecution = task.workflowExecution()
response.ScheduledTimestampOfThisAttempt = historyResponse.ScheduledTimestampOfThisAttempt
response.ScheduledTimestamp = common.Int64Ptr(*scheduledEvent.Timestamp)
response.ScheduleToCloseTimeoutSeconds = common.Int32Ptr(*attributes.ScheduleToCloseTimeoutSeconds)
Expand All @@ -673,10 +683,10 @@ func (e *matchingEngineImpl) createPollForActivityTaskResponse(
response.HeartbeatTimeoutSeconds = common.Int32Ptr(*attributes.HeartbeatTimeoutSeconds)

token := &common.TaskToken{
DomainID: task.info.DomainID,
WorkflowID: task.info.WorkflowID,
RunID: task.info.RunID,
ScheduleID: task.info.ScheduleID,
DomainID: task.generic.DomainID,
WorkflowID: task.generic.WorkflowID,
RunID: task.generic.RunID,
ScheduleID: task.generic.ScheduleID,
ScheduleAttempt: historyResponse.GetAttempt(),
}

Expand All @@ -694,10 +704,10 @@ func (e *matchingEngineImpl) recordDecisionTaskStarted(
task *internalTask,
) (*h.RecordDecisionTaskStartedResponse, error) {
request := &h.RecordDecisionTaskStartedRequest{
DomainUUID: &task.info.DomainID,
WorkflowExecution: &task.workflowExecution,
ScheduleId: &task.info.ScheduleID,
TaskId: &task.info.TaskID,
DomainUUID: &task.generic.DomainID,
WorkflowExecution: task.workflowExecution(),
ScheduleId: &task.generic.ScheduleID,
TaskId: &task.generic.TaskID,
RequestId: common.StringPtr(uuid.New()),
PollRequest: pollReq,
}
Expand All @@ -723,10 +733,10 @@ func (e *matchingEngineImpl) recordActivityTaskStarted(
task *internalTask,
) (*h.RecordActivityTaskStartedResponse, error) {
request := &h.RecordActivityTaskStartedRequest{
DomainUUID: &task.info.DomainID,
WorkflowExecution: &task.workflowExecution,
ScheduleId: &task.info.ScheduleID,
TaskId: &task.info.TaskID,
DomainUUID: &task.generic.DomainID,
WorkflowExecution: task.workflowExecution(),
ScheduleId: &task.generic.ScheduleID,
TaskId: &task.generic.TaskID,
RequestId: common.StringPtr(uuid.New()),
PollRequest: pollReq,
}
Expand Down
8 changes: 4 additions & 4 deletions service/matching/matchingEngine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1396,10 +1396,10 @@ func (s *matchingEngineSuite) TestAddTaskAfterStartFailure() {
ctx2, err := s.matchingEngine.getTask(context.Background(), tlID, nil, tlKind)
s.NoError(err)

s.NotEqual(ctx.info.TaskID, ctx2.info.TaskID)
s.Equal(ctx.info.WorkflowID, ctx2.info.WorkflowID)
s.Equal(ctx.info.RunID, ctx2.info.RunID)
s.Equal(ctx.info.ScheduleID, ctx2.info.ScheduleID)
s.NotEqual(ctx.generic.TaskID, ctx2.generic.TaskID)
s.Equal(ctx.generic.WorkflowID, ctx2.generic.WorkflowID)
s.Equal(ctx.generic.RunID, ctx2.generic.RunID)
s.Equal(ctx.generic.ScheduleID, ctx2.generic.ScheduleID)

ctx2.finish(nil)
s.EqualValues(0, s.taskManager.getTaskCount(tlID))
Expand Down
91 changes: 63 additions & 28 deletions service/matching/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,67 +27,102 @@ import (
)

type (
// genericTaskInfo contains the info for an activity or decision task
genericTaskInfo struct {
*persistence.TaskInfo
completionFunc func(*persistence.TaskInfo, error)
}
// queryTaskInfo contains the info for a query task
queryTaskInfo struct {
taskID string
queryRequest *m.QueryWorkflowRequest
}
// internalTask represents an activity, decision or query task
// holds task specific info and additional metadata
// forwardedTaskInfo contains the forwarded info for any task forwarded from
// another matching host. Forwarded tasks are already started
forwardedTaskInfo struct {
decisionTaskInfo *m.PollForDecisionTaskResponse
activityTaskInfo *s.PollForActivityTaskResponse
}
// internalTask represents an activity, decision, query or remote forwarded task.
// this struct is more like a union and only one of [ query, generic, forwarded ] is
// non-nil for any given task
internalTask struct {
info *persistence.TaskInfo
syncResponseCh chan error
workflowExecution s.WorkflowExecution
queryInfo *queryTaskInfo
backlogCountHint int64
domainName string
completionFunc func(*internalTask, error)
query *queryTaskInfo // non-nil for locally matched matched query task
generic *genericTaskInfo // non-nil for locally matched activity or decision task
forwarded *forwardedTaskInfo // non-nil for a remote forwarded task
domainName string
responseC chan error // non-nil only where there is a caller waiting for response (sync-match)
backlogCountHint int64
}
)

func newInternalTask(
info *persistence.TaskInfo,
completionFunc func(*internalTask, error),
completionFunc func(*persistence.TaskInfo, error),
forSyncMatch bool,
) *internalTask {
task := &internalTask{
info: info,
completionFunc: completionFunc,
workflowExecution: s.WorkflowExecution{
WorkflowId: &info.WorkflowID,
RunId: &info.RunID,
generic: &genericTaskInfo{
TaskInfo: info,
completionFunc: completionFunc,
},
}
if forSyncMatch {
task.syncResponseCh = make(chan error, 1)
task.responseC = make(chan error, 1)
}
return task
}

func newInternalQueryTask(
queryInfo *queryTaskInfo,
completionFunc func(*internalTask, error),
taskID string,
request *m.QueryWorkflowRequest,
) *internalTask {
return &internalTask{
info: &persistence.TaskInfo{
DomainID: queryInfo.queryRequest.GetDomainUUID(),
WorkflowID: queryInfo.queryRequest.QueryRequest.Execution.GetWorkflowId(),
RunID: queryInfo.queryRequest.QueryRequest.Execution.GetRunId(),
query: &queryTaskInfo{
taskID: taskID,
queryRequest: request,
},
completionFunc: completionFunc,
queryInfo: queryInfo,
workflowExecution: *queryInfo.queryRequest.QueryRequest.GetExecution(),
syncResponseCh: make(chan error, 1),
responseC: make(chan error, 1),
}
}

func newInternalForwardedTask(info *forwardedTaskInfo) *internalTask {
return &internalTask{forwarded: info}
}

// isQuery returns true if the underlying task is a query task
func (task *internalTask) isQuery() bool {
return task.queryInfo != nil
return task.query != nil
}

// isForwarded returns true if the underlying task is forwarded by a remote matching host
// forwarded tasks are already marked as started in history
func (task *internalTask) isForwarded() bool {
return task.forwarded != nil
}

func (task *internalTask) workflowExecution() *s.WorkflowExecution {
switch {
case task.generic != nil:
return &s.WorkflowExecution{WorkflowId: &task.generic.WorkflowID, RunId: &task.generic.RunID}
case task.query != nil:
return task.query.queryRequest.GetQueryRequest().GetExecution()
case task.forwarded != nil && task.forwarded.decisionTaskInfo != nil:
return task.forwarded.decisionTaskInfo.WorkflowExecution
case task.forwarded != nil && task.forwarded.activityTaskInfo != nil:
return task.forwarded.activityTaskInfo.WorkflowExecution
}
return &s.WorkflowExecution{}
}

// finish marks a task as finished. Should be called after a poller picks up a task
// and marks it as started. If the task is unable to marked as started, then this
// method should be called with a non-nil error argument.
func (task *internalTask) finish(err error) {
task.completionFunc(task, err)
switch {
case task.responseC != nil:
task.responseC <- err
case task.generic.completionFunc != nil:
task.generic.completionFunc(task.generic.TaskInfo, err)
}
}
Loading

0 comments on commit 4045618

Please sign in to comment.