Skip to content

Commit

Permalink
Support Pipeline in Training Runner (microsoft#3770)
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin authored May 7, 2020
1 parent c222ed6 commit 0aeb383
Show file tree
Hide file tree
Showing 15 changed files with 1,123 additions and 169 deletions.
71 changes: 62 additions & 9 deletions orttraining/orttraining/core/graph/pipeline_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) {
Status AddRecordBackward(Graph& graph,
Node* send_bw,
std::vector<std::string>& new_input_names,
std::vector<std::string>& new_output_names) {
std::vector<std::string>& new_output_names,
std::string &event_id_tensor_name,
std::string &output_tensor_name) {
std::vector<NodeArg*> input_args;
AddInputEvent(graph, "RecordEvent", false /* is_forward */, input_args, new_input_names);
std::vector<NodeArg*> output_args{};
Expand Down Expand Up @@ -99,10 +101,20 @@ Status AddRecordBackward(Graph& graph,
output_args,
nullptr,
kMSDomain);

// First input argument is the recorded event ID tensor.
event_id_tensor_name = input_args.front()->Name();
// Use first output as output singnal. It will be fetched outside to make sure
// event operator is computed.
output_tensor_name = output_args.front()->Name();
return Status::OK();
}

Status AddWaitForward(Graph& graph, Node* /* recv_fw */, std::vector<std::string>& new_input_names) {
Status AddWaitForward(Graph& graph,
Node* /* recv_fw */,
std::vector<std::string>& new_input_names,
std::string& forward_waited_event_name,
std::string& output_tensor_name) {
// Append old_input to input_args and return its pass-through value. Note that
// input_args and output_args are Wait's inputs and outputs, respectively.
auto update_wait_input_output = [&](NodeArg* old_input,
Expand Down Expand Up @@ -148,11 +160,19 @@ Status AddWaitForward(Graph& graph, Node* /* recv_fw */, std::vector<std::string
output_args,
nullptr,
kMSDomain);

forward_waited_event_name = input_args.front()->Name();
output_tensor_name = output_args.front()->Name();
return Status::OK();
}

Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* recv_bw, std::vector<std::string>& new_input_names) {
Status AddOrSkipRecordForwardWaitBackward(Graph& graph,
Node* send_fw,
Node* recv_bw,
std::vector<std::string>& new_input_names,
std::string& forward_recorded_event_name,
std::string& backward_waited_event_name,
std::string& forward_output_name,
std::string& backward_output_name) {
if (!send_fw != !recv_bw){
ORT_THROW("Graph requires either having both send forward node "
"and recv backword node, or none of them. Currently the graph "
Expand Down Expand Up @@ -189,6 +209,9 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* rec
{}, /* attribute */
kMSDomain);
record_node = &new_node;

forward_recorded_event_name = record_node->InputDefs()[0]->Name();
forward_output_name = record_node->OutputDefs()[0]->Name();
}
// Insert WaitEvent
{
Expand All @@ -213,13 +236,24 @@ Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* rec
{}, /* attribute */
kMSDomain);
wait_node = &new_node;
ORT_UNUSED_PARAMETER(wait_node);

backward_waited_event_name = wait_node->InputDefs()[0]->Name();
backward_output_name = wait_node->OutputDefs()[0]->Name();
}

return Status::OK();
}

Status TransformGraphForPipeline(Graph& graph) {
Status TransformGraphForPipeline(
Graph& graph,
std::string& forward_waited_event_name,
std::string& forward_recorded_event_name,
std::string& backward_waited_event_name,
std::string& backward_recorded_event_name,
std::string& forward_waited_output_name,
std::string& forward_recorded_output_name,
std::string& backward_waited_output_name,
std::string& backward_recorded_output_name) {
// insert WaitEvent and RecordEvent to the partition
Node* send_fw{nullptr};
Node* send_bw{nullptr};
Expand All @@ -244,9 +278,28 @@ Status TransformGraphForPipeline(Graph& graph) {
std::vector<std::string> new_input_names;
std::vector<std::string> new_output_names;

ORT_RETURN_IF_ERROR(AddRecordBackward(graph, send_bw, new_input_names, new_output_names));
ORT_RETURN_IF_ERROR(AddWaitForward(graph, recv_fw, new_input_names));
ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward(graph, send_fw, recv_bw, new_input_names));
ORT_RETURN_IF_ERROR(AddRecordBackward(
graph,
send_bw,
new_input_names,
new_output_names,
backward_recorded_event_name,
backward_recorded_output_name));
ORT_RETURN_IF_ERROR(AddWaitForward(
graph,
recv_fw,
new_input_names,
forward_waited_event_name,
forward_waited_output_name));
ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward(
graph,
send_fw,
recv_bw,
new_input_names,
forward_recorded_event_name,
backward_waited_event_name,
forward_recorded_output_name,
backward_waited_output_name));

auto fill_node_args = [&](const Graph& graph,
const std::vector<const NodeArg*>& existed_node_args,
Expand Down
11 changes: 10 additions & 1 deletion orttraining/orttraining/core/graph/pipeline_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@ namespace onnxruntime {
namespace training {

void GetPipelineSendOutput(const Graph& graph, std::string& loss_name);
common::Status TransformGraphForPipeline(Graph& graph);
common::Status TransformGraphForPipeline(
Graph& graph,
std::string& forward_waited_event_name,
std::string& forward_recorded_event_name,
std::string& backward_waited_event_name,
std::string& backward_recorded_event_name,
std::string& forward_waited_output_name,
std::string& forward_recorded_output_name,
std::string& backward_waited_output_name,
std::string& backward_recorded_output_name);

} // namespace training
} // namespace onnxruntime
71 changes: 64 additions & 7 deletions orttraining/orttraining/core/session/training_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,16 @@ Status TrainingSession::ConfigureForTraining(

TrainingConfigurationResult config_result{};

ORT_ENFORCE(config.distributed_config.pipeline_parallel_size > 0,
"This parameter should be 1 if there is no pipelie parallelism. Otherwise, it's the number of pipeline stages.");

DistributedRunContext::CreateInstance({config.distributed_config.world_rank,
config.distributed_config.world_size,
config.distributed_config.local_rank,
config.distributed_config.local_size,
config.distributed_config.data_parallel_size,
config.distributed_config.horizontal_parallel_size,
config.distributed_config.pipeline_stage_size});
config.distributed_config.pipeline_parallel_size});

ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph());

Expand All @@ -134,11 +137,12 @@ Status TrainingSession::ConfigureForTraining(
std::string loss_name{};
optional<std::string> loss_scale_input_name =
is_mixed_precision_enabled_ ? optional<std::string>{""} : optional<std::string>{};
if (config.use_pipeline) {
if (config.pipeline_config.has_value()) {
// if use pipeline, first check if model contains send op. If it does, set the
// send node's output as the start tensor to build gradient graph
GetPipelineSendOutput(model_->MainGraph(), loss_name);
}

if (loss_name.empty()) {
const optional<LossFunctionInfo> loss_function_info =
config.loss_function_config.has_value()
Expand Down Expand Up @@ -193,8 +197,30 @@ Status TrainingSession::ConfigureForTraining(
weight_names_to_train, mixed_precision_config.use_fp16_initializers, fp32_weight_name_to_fp16_node_arg));
}

if (config.use_pipeline) {
ORT_RETURN_IF_ERROR(InsertPipelineOps());
if (config.pipeline_config.has_value()) {
TrainingConfigurationResult::PipelineConfigurationResult pipeline_result{};
ORT_RETURN_IF_ERROR(InsertPipelineOps(pipeline_result.forward_waited_event_name,
pipeline_result.forward_recorded_event_name,
pipeline_result.backward_waited_event_name,
pipeline_result.backward_recorded_event_name,
pipeline_result.forward_waited_output_name,
pipeline_result.forward_recorded_output_name,
pipeline_result.backward_waited_output_name,
pipeline_result.backward_recorded_output_name));
// The following loop is for not to fetch tensors not in this pipeline stage.
for (size_t i = 0; i < config.pipeline_config.value().fetch_names.size(); ++i) {
auto name = config.pipeline_config.value().fetch_names[i];
const auto* node_arg = model_->MainGraph().GetNodeArg(name);
if (!node_arg) {
// This pipelie stage doesn't contain this name.
// Let's not to fetch it.
continue;
}
pipeline_result.fetch_names.push_back(name);
}
pipeline_result.pipeline_stage_id = config.distributed_config.world_rank /
(config.distributed_config.data_parallel_size * config.distributed_config.horizontal_parallel_size);
config_result.pipeline_config_result = pipeline_result;
}

// All non-float tensors are not trainable. Remove those weights.
Expand Down Expand Up @@ -266,7 +292,7 @@ Status TrainingSession::ConfigureForTraining(
tensorboard_config.histogram_node_names, tensorboard_config.norm_node_names,
tensorboard_config.dump_convergence_metrics));
}

// add GIST encoding
if (config.gist_config.has_value()) {
ORT_RETURN_IF_ERROR(AddGistEncoding());
Expand All @@ -277,6 +303,20 @@ Status TrainingSession::ConfigureForTraining(
config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD));
}

// After pipeline partition, we need to return the inputs allowed in this partition.
if (config.pipeline_config.has_value()) {
const auto& allowed_inputs = model_->MainGraph().GetInputsIncludingInitializers();
const auto& allowed_outputs = model_->MainGraph().GetInputsIncludingInitializers();
for (size_t i = 0; i < allowed_inputs.size(); ++i) {
const auto name = allowed_inputs[i]->Name();
config_result.pipeline_config_result.value().feed_names.push_back(name);
}
for (size_t i = 0; i < allowed_outputs.size(); ++i) {
const auto name = allowed_outputs[i]->Name();
config_result.pipeline_config_result.value().fetch_names.push_back(name);
}
}

config_result_out = std::move(config_result);
is_configured_ = true;

Expand Down Expand Up @@ -471,8 +511,25 @@ Status TrainingSession::AddTensorboard(const std::string& summary_name,
return DoPostLoadProcessing(*model_);
}

Status TrainingSession::InsertPipelineOps() {
ORT_RETURN_IF_ERROR(TransformGraphForPipeline(model_->MainGraph()));
Status TrainingSession::InsertPipelineOps(
std::string& forward_waited_event_name,
std::string& forward_recorded_event_name,
std::string& backward_waited_event_name,
std::string& backward_recorded_event_name,
std::string& forward_waited_output_name,
std::string& forward_recorded_output_name,
std::string& backward_waited_output_name,
std::string& backward_recorded_output_name) {
ORT_RETURN_IF_ERROR(TransformGraphForPipeline(
model_->MainGraph(),
forward_waited_event_name,
forward_recorded_event_name,
backward_waited_event_name,
backward_recorded_event_name,
forward_waited_output_name,
forward_recorded_output_name,
backward_waited_output_name,
backward_recorded_output_name));
return DoPostLoadProcessing(*model_);
}

Expand Down
65 changes: 58 additions & 7 deletions orttraining/orttraining/core/session/training_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ class TrainingSession : public InferenceSession {
int world_size{1};
// The number of local ranks on a node.
int local_size{1};
// The number of ranks for data parallel group
// The number of ranks for data parallel group.
int data_parallel_size{1};
// The number of ranks for horizontal model parallel group
// The number of ranks for horizontal model parallel group.
int horizontal_parallel_size{1};
// The number of stages for pipeline model parallel group
int pipeline_stage_size{1};
// The number of pipeline stages.
int pipeline_parallel_size{1};
};
// The distributed training configuration.
DistributedConfiguration distributed_config{};
Expand Down Expand Up @@ -135,8 +135,19 @@ class TrainingSession : public InferenceSession {
// If not provided, no optimizer is added.
optional<OptimizerConfiguration> optimizer_config{};

// Whether to use pipeline in training.
bool use_pipeline{false};
struct PipelineConfiguration {
// If model partition happens outside ORT, this flag should be false.
// Otherwise, use true to trigger ORT's pipeline partition.
bool do_partition;
// Tensors to fetch as specified by the user.
// Each pipeline stage should pick up some strings from this field..
std::vector<std::string> fetch_names;
// [TODO] Add cut information.
};

// If pipeline is enabled, this field's has_value() returns true.
// Otherwise, it returns false.
optional<PipelineConfiguration> pipeline_config{};
};

/**
Expand All @@ -158,6 +169,33 @@ class TrainingSession : public InferenceSession {
// The optimizer configuration output.
// This is only set if an optimizer is added.
optional<OptimizerConfigurationResult> opt_config_result;

// The names of pipeline events in model's input list.
// If an event is not used, its name should be empty.
struct PipelineConfigurationResult {
// Index of obtained pipeline stage. The first stage is indexed by 0.
int pipeline_stage_id;
// The names of pipeline events in model's input list.
std::string forward_waited_event_name;
std::string forward_recorded_event_name;
std::string backward_waited_event_name;
std::string backward_recorded_event_name;

std::string forward_waited_output_name;
std::string forward_recorded_output_name;
std::string backward_waited_output_name;
std::string backward_recorded_output_name;

// Tensors to feed at this pipeline stage.
std::vector<std::string> feed_names;
// Tensors to fetch at this pipeline stage.
// It's a subset of PipelineConfiguration.fetch_names.
std::vector<std::string> fetch_names;
};

// The pipeline configuration output.
// This is only set if an pipeline is enabled.
optional<PipelineConfigurationResult> pipeline_config_result;
};

/**
Expand Down Expand Up @@ -283,7 +321,20 @@ class TrainingSession : public InferenceSession {
const std::vector<std::string>& norm_nodes,
const bool dump_convergence_metrics);

common::Status InsertPipelineOps();
// Insert operators for running pipeline and return event tensor names.
// For an intermediate pipeline stage, two WaitEvent and two RecordEvent would
// be inserted. The dependent event tensor names are returned.
// The related computation order is
// WaitEvent --> Forward --> RecordEvent --> WaitEvent --> Backward --> RecordEvent
common::Status InsertPipelineOps(std::string& forward_waited_event_name,
std::string& forward_recorded_event_name,
std::string& backward_waited_event_name,
std::string& backward_recorded_event_name,
std::string& forward_waited_output_name,
std::string& forward_recorded_output_name,
std::string& backward_waited_output_name,
std::string& backward_recorded_output_name);

common::Status ApplyTransformationsToMainGraph();

/** configure initial transformers for training */
Expand Down
Loading

0 comments on commit 0aeb383

Please sign in to comment.