diff --git a/agent/agent_server_cmd.go b/agent/agent_server_cmd.go index 89d7219..184c2c5 100644 --- a/agent/agent_server_cmd.go +++ b/agent/agent_server_cmd.go @@ -27,5 +27,9 @@ func (as *AgentServer) handleCommandConnection(conn net.Conn, reply.Type = cmd.ControlMessage_GetStatusResponse.Enum() reply.GetStatusResponse = as.handleStatus(command.GetStatusRequest) } + if command.GetType() == cmd.ControlMessage_StopRequest { + reply.Type = cmd.ControlMessage_StopResponse.Enum() + reply.StopResponse = as.handleStopRequest(command.StopRequest) + } return reply } diff --git a/agent/agent_server_executor_start.go b/agent/agent_server_executor_start.go index 877d964..978a23a 100644 --- a/agent/agent_server_executor_start.go +++ b/agent/agent_server_executor_start.go @@ -55,6 +55,7 @@ func (as *AgentServer) handleStart(conn net.Conn, } else { reply.Pid = proto.Int32(int32(cmd.Process.Pid)) } + stat.Process = cmd.Process cmd.Wait() stat.StopTime = time.Now() diff --git a/agent/agent_server_executor_status.go b/agent/agent_server_executor_status.go index 10e70e0..38a28dc 100644 --- a/agent/agent_server_executor_status.go +++ b/agent/agent_server_executor_status.go @@ -1,8 +1,6 @@ package agent import ( - "fmt" - "github.com/chrislusf/glow/driver/cmd" "github.com/golang/protobuf/proto" ) @@ -18,7 +16,18 @@ func (as *AgentServer) handleStatus(getStatusRequest *cmd.GetStatusRequest) *cmd StopTime: proto.Int64(stat.StopTime.Unix()), } - fmt.Printf("stat: %v\n", stat) + return reply +} + +func (as *AgentServer) handleStopRequest(stopRequest *cmd.StopRequest) *cmd.StopResponse { + requestId := stopRequest.GetStartRequestHash() + stat := as.localExecutorManager.getExecutorStatus(requestId) + + stat.Process.Kill() + + reply := &cmd.StopResponse{ + StartRequestHash: proto.Int32(requestId), + } return reply } diff --git a/agent/local_executor_status.go b/agent/local_executor_status.go index af0355a..18e477a 100644 --- a/agent/local_executor_status.go +++ b/agent/local_executor_status.go @@ -1,6 +1,7 @@ package agent import ( + "os" "time" ) @@ -11,6 +12,7 @@ type ExecutorStatus struct { OutputLength int StartTime time.Time StopTime time.Time + Process *os.Process LastAccessTime time.Time // used for expiring entries } diff --git a/driver/context_driver.go b/driver/context_driver.go index afc799c..9a3e670 100644 --- a/driver/context_driver.go +++ b/driver/context_driver.go @@ -106,6 +106,8 @@ func (fcd *FlowContextDriver) Run(fc *flow.FlowContext) { flow.OnInterrupt(func() { fcd.OnInterrupt(fc, sched) + }, func() { + fcd.OnExit(fc, sched) }) // schedule to run the steps diff --git a/driver/context_driver_on_interrupt.go b/driver/context_driver_on_interrupt.go index 15ceebc..ee49590 100644 --- a/driver/context_driver_on_interrupt.go +++ b/driver/context_driver_on_interrupt.go @@ -19,6 +19,32 @@ func (fcd *FlowContextDriver) OnInterrupt( fcd.printDistributedStatus(sched, status) } +func (fcd *FlowContextDriver) OnExit( + fc *flow.FlowContext, + sched *scheduler.Scheduler) { + var wg sync.WaitGroup + for _, tg := range fcd.taskGroups { + wg.Add(1) + go func(tg *plan.TaskGroup) { + defer wg.Done() + + requestId := tg.RequestId + request, ok := sched.RemoteExecutorStatuses[requestId] + if !ok { + fmt.Printf("No executors for %v\n", tg) + return + } + // println("checking", request.Allocation.Location.URL(), requestId) + if err := askExecutorToStopRequest(request.Allocation.Location.URL(), requestId); err != nil { + fmt.Printf("Error to stop request %d on %s: %v\n", request.Allocation.Location.URL(), requestId, err) + return + } + }(tg) + } + wg.Wait() + +} + func (fcd *FlowContextDriver) printDistributedStatus(sched *scheduler.Scheduler, stats []*RemoteExecutorStatus) { fmt.Print("\n") for _, stepGroup := range fcd.stepGroups { @@ -112,3 +138,8 @@ func askExecutorStatusForRequest(server string, requestId int32) (*RemoteExecuto }, }, nil } + +func askExecutorToStopRequest(server string, requestId int32) (err error) { + _, err = scheduler.RemoteDirectCommand(server, scheduler.NewStopRequest(requestId)) + return +} diff --git a/driver/scheduler/command_execution.go b/driver/scheduler/command_execution.go index a9e4a11..1009df7 100644 --- a/driver/scheduler/command_execution.go +++ b/driver/scheduler/command_execution.go @@ -53,6 +53,15 @@ func NewGetStatusRequest(requestId int32) *cmd.ControlMessage { } } +func NewStopRequest(requestId int32) *cmd.ControlMessage { + return &cmd.ControlMessage{ + Type: cmd.ControlMessage_StopRequest.Enum(), + StopRequest: &cmd.StopRequest{ + StartRequestHash: proto.Int32(requestId), + }, + } +} + func NewDeleteDatasetShardRequest(name string) *cmd.ControlMessage { return &cmd.ControlMessage{ Type: cmd.ControlMessage_DeleteDatasetShardRequest.Enum(), diff --git a/flow/context_run.go b/flow/context_run.go index 9945515..1dbad0d 100644 --- a/flow/context_run.go +++ b/flow/context_run.go @@ -71,7 +71,7 @@ func (fc *FlowContext) runFlowContextInStandAloneMode() { isDatasetStarted := make(map[int]bool) - OnInterrupt(fc.OnInterrupt) + OnInterrupt(fc.OnInterrupt, nil) // start all task edges for _, step := range fc.Steps { diff --git a/flow/signal_handling.go b/flow/signal_handling.go index c008a5d..79476c8 100644 --- a/flow/signal_handling.go +++ b/flow/signal_handling.go @@ -9,7 +9,7 @@ import ( "syscall" ) -func OnInterrupt(fn func()) { +func OnInterrupt(fn func(), onExitFunc func()) { // deal with control+c,etc signalChan := make(chan os.Signal, 1) // controlling terminal close, daemon not exit @@ -28,6 +28,9 @@ func OnInterrupt(fn func()) { for sig := range signalChan { fn() if sig != syscall.SIGINFO { + if onExitFunc != nil { + onExitFunc() + } os.Exit(0) } }