Skip to content

Commit

Permalink
[PIR] Refactor Inplace strategy (PaddlePaddle#65491)
Browse files Browse the repository at this point in the history
* Refactor Inplace

* update

* handle for tensorarray

* update

* fix assign_value_

* update

* update

* update

* fix custom meta bug

* fix optional value bug
  • Loading branch information
chen2016013 authored Jul 15, 2024
1 parent 709d182 commit 6716d7f
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,18 @@ CustomKernelInstruction::CustomKernelInstruction(
GetStreamPriority()));
VLOG(6) << "finish process device context";

auto& op_inplace_map = OpMetaInfoHelper::GetInplaceMap(*custom_op_meta_);
for (auto const& pair : op_inplace_map) {
pir::Value input_value =
op->operand_source(yaml_info_parser.InputName2Id().at(pair.first));
pir::Value output_value =
op->result(yaml_info_parser.OutputName2Id().at(pair.second));
if (IsInvalid(output_value) && IsInvalid(input_value)) {
this->AddInplace(value_exec_info_.GetVarByValue(input_value),
value_exec_info_.GetVarByValue(output_value));
}
}

InitInputsOutputsIds(op, value_exec_info_);
VLOG(6) << "finish process inputs outputs index";

Expand Down Expand Up @@ -453,6 +465,7 @@ void CustomKernelInstruction::UpdateOutputMeta(
auto out_meta = phi::DenseTensorUtils::GetMutableMeta(out_in_scope);
out_meta->dims = phi::make_ddim(output_shapes[i]);
out_meta->dtype = output_dtypes[i];
out_meta->strides = out_meta->calc_strides(out_meta->dims);
}
}

Expand Down Expand Up @@ -504,7 +517,9 @@ void CustomKernelInstruction::Run() {
vec_input_name2id_map_,
custom_attrs_);
UpdateOutputMeta(output_shapes, output_dtypes);

for (auto& pair : this->InplaceInfo()) {
ShareVarBuffer(pair.first, pair.second);
}
VLOG(6) << "Run custom op " << custom_op_name_ << " kernel.";
kernel_func_(&custom_kernel_ctx_);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ const std::vector<Variable*>& InstructionBase::EagerGCVars() const {

void InstructionBase::ClearEagerGCVars() { eager_gc_vars_.clear(); }

const std::vector<std::pair<Variable*, Variable*>>&
const std::vector<std::pair<const Variable*, Variable*>>&
InstructionBase::InplaceInfo() const {
return vec_inplace_in_to_out_;
}

void InstructionBase::AddInplace(Variable* in, Variable* out) {
void InstructionBase::AddInplace(const Variable* in, Variable* out) {
vec_inplace_in_to_out_.emplace_back(in, out);
}

Expand Down Expand Up @@ -334,6 +334,17 @@ void InstructionBase::InitInputsOutputsIds(
outputs.emplace(value, outputs_id);
}
}

const auto value_2_var_name_map = value_exec_info.GetValue2VarName();
for (auto inplace_var_pair : this->InplaceInfo()) {
for (auto item : value_2_var_name_map) {
if (item.second == value_exec_info.GetVarName(inplace_var_pair.first)) {
std::vector<int> outputs_id = GetValueIds(item.first, value_exec_info);
outputs.emplace(item.first, outputs_id);
break;
}
}
}
SetOutputs(outputs);
VLOG(8) << "finish process outputs_index";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ class InstructionBase {
void AddEagerGCVar(Variable* var);
void ClearEagerGCVars();

const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
void AddInplace(Variable* in, Variable* out);
const std::vector<std::pair<const Variable*, Variable*>>& InplaceInfo() const;
void AddInplace(const Variable* in, Variable* out);
void ClearInplace();

std::map<int, int>& GetMutableInplaceBackMap() { return inplace_back_map_; }
Expand Down Expand Up @@ -207,7 +207,7 @@ class InstructionBase {

std::vector<Variable*> eager_gc_vars_;

std::vector<std::pair<Variable*, Variable*>>
std::vector<std::pair<const Variable*, Variable*>>
vec_inplace_in_to_out_; // If not use share data, need this ?

std::map<int, int> inplace_back_map_;
Expand Down
102 changes: 102 additions & 0 deletions paddle/fluid/framework/new_executor/instruction/instruction_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,106 @@ bool GetCondData(const phi::DenseTensor& cond) {
return cpu_cond->data<bool>()[0];
}

// NOTE(chenxi67): Here, we only perform inplace processing for variables whose
// type is NOT TensorArray. It has already been processed in the previous
// step(HandleForInplaceVarOp).
void HandleForInplaceOp(pir::Operation* op,
const ValueExecutionInfo* value_exe_info,
InstructionBase* instr) {
if (op->num_results() < 1) return;
pir::IrContext* ctx = pir::IrContext::Instance();
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<pir::StrAttribute>().AsString();
}

pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
->get_op_info_(op_name),
paddle::dialect::IsLegacyOp(op_name));

for (size_t i = 0; i < op->num_results(); ++i) {
pir::Value value = op->result(i);
if (!IsInvalid(value)) {
VLOG(8) << "Number " << i << " result of " << op_name
<< " is not invalid, so skip build a variable.";
continue;
}
if (IsNeedVarInplace(op, value, op_name)) {
continue;
}
std::string value_name = yaml_parser.OutputNames()[i];
if (yaml_parser.HasInplace(value_name)) {
const std::string& inplace_name = yaml_parser.InplaceName(value_name);
pir::Value inplace_value =
op->operand_source(yaml_parser.InputName2Id().at(inplace_name));
std::string input_var_name = value_exe_info->GetVarName(inplace_value);
std::string output_var_name = value_exe_info->GetVarName(value);
PADDLE_ENFORCE_NE(input_var_name,
"",
phi::errors::InvalidArgument(
"The input var name of inplace op is empty."));
PADDLE_ENFORCE_NE(output_var_name,
"",
phi::errors::InvalidArgument(
"The output var name of inplace op is empty."));
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << input_var_name << ")";
instr->AddInplace(value_exe_info->GetVarByValue(inplace_value),
value_exe_info->GetVarByValue(value));
} else if (yaml_parser.HasView(value_name)) {
const std::string& view_name = yaml_parser.ViewName(value_name);
pir::Value view_value =
op->operand_source(yaml_parser.InputName2Id().at(view_name));
// const std::string& var_name = value_2_var_name->at(view_value);
std::string input_var_name = value_exe_info->GetVarName(view_value);
std::string output_var_name = value_exe_info->GetVarName(value);

PADDLE_ENFORCE_NE(input_var_name,
"",
platform::errors::InvalidArgument(
"The input var name of view op is empty."));
PADDLE_ENFORCE_NE(output_var_name,
"",
platform::errors::InvalidArgument(
"The output var name of view op is empty."));
VLOG(4) << "view: " << value_name << " -> " << view_name
<< " (var: " << input_var_name << ")";
instr->AddInplace(value_exe_info->GetVarByValue(view_value),
value_exe_info->GetVarByValue(value));
}
}
}

void ShareVarBuffer(const Variable* src_var, Variable* dst_var) {
if (src_var->IsType<phi::DenseTensor>()) {
auto& src_tensor = src_var->Get<phi::DenseTensor>();
auto* tmp_dst_tensor = dst_var->GetMutable<phi::DenseTensor>();
tmp_dst_tensor->ShareBufferWith(src_tensor);
return;
} else if (src_var->IsType<phi::SelectedRows>()) {
auto* tmp_dst_slr = dst_var->GetMutable<phi::SelectedRows>();
auto* dst_t = tmp_dst_slr->mutable_value();
auto& src_slr = src_var->Get<phi::SelectedRows>();
auto& src_t = src_slr.value();
dst_t->ShareBufferWith(src_t);
return;
} else if (src_var->IsType<VariableRefArray>()) {
auto src_var_array = src_var->Get<VariableRefArray>();
auto* dst_var_array = dst_var->GetMutable<VariableRefArray>();
for (size_t i = 0; i < src_var_array.size(); ++i) {
Variable* copy_var = const_cast<Variable*>(dst_var_array->at(i));
ShareVarBuffer(src_var_array.at(i), copy_var);
}
return;
} else {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Output only support DenseTensorType "
"or SelectedRowsType or VariableRefArray"));
}
return;
}

} // namespace paddle::framework
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,10 @@ void InsertInplacedExternalInputsToOuts(

bool GetCondData(const phi::DenseTensor& cond);

void HandleForInplaceOp(pir::Operation* op,
const ValueExecutionInfo* value_exe_info,
InstructionBase* instr);

void ShareVarBuffer(const Variable* src_var, Variable* dst_var);
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ LegacyKernelInstruction::LegacyKernelInstruction(

VLOG(6) << "finish process kernel context";

if (op->attributes().count("is_inplace") != 0 &&
op->attributes().at("is_inplace").dyn_cast<pir::BoolAttribute>().data()) {
HandleForInplaceOp(op, value_exec_info_, this);
}

InitInputsOutputsIds(op, *value_exec_info);
VLOG(6) << "finish process inputs outputs index";

Expand All @@ -185,6 +190,9 @@ void LegacyKernelInstruction::Run() {
if (infer_meta_interface_) {
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
}
for (auto& pair : this->InplaceInfo()) {
ShareVarBuffer(pair.first, pair.second);
}
VLOG(6) << "Run op " << legacy_op_name_ << " kernel.";
(*(phi_kernel_))((kernel_context_));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ PhiKernelInstruction::PhiKernelInstruction(

kernel_context_.SetDeviceContext(dev_ctx);
VLOG(6) << "finish process kernel context";

if (op->attributes().count("is_inplace") != 0 &&
op->attributes().at("is_inplace").dyn_cast<pir::BoolAttribute>().data()) {
HandleForInplaceOp(op, value_exec_info_, this);
}
InitInputsOutputsIds(op, *value_exec_info);
VLOG(6) << "finish process inputs outputs index";

Expand All @@ -181,6 +184,9 @@ void PhiKernelInstruction::Run() {
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
}
VLOG(6) << "End run op " << phi_op_name_ << " infer meta.";
for (auto& pair : this->InplaceInfo()) {
ShareVarBuffer(pair.first, pair.second);
}
VLOG(6) << "Begin run op " << phi_op_name_ << " kernel.";
(*(phi_kernel_))(&(kernel_context_));
VLOG(6) << "End run op " << phi_op_name_ << " kernel.";
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/new_executor/new_executor_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,12 @@ const platform::DeviceContext& Instruction::DeviceContext() const {
return dev_ctx_;
}

const std::vector<std::pair<Variable*, Variable*>>& Instruction::InplaceInfo()
const {
const std::vector<std::pair<const Variable*, Variable*>>&
Instruction::InplaceInfo() const {
return vec_inplace_in_to_out_;
}

void Instruction::AddInplace(Variable* in, Variable* out) {
void Instruction::AddInplace(const Variable* in, Variable* out) {
vec_inplace_in_to_out_.emplace_back(in, out);
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,9 @@ class Instruction {

const platform::DeviceContext& DeviceContext() const;

const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
const std::vector<std::pair<const Variable*, Variable*>>& InplaceInfo() const;

void AddInplace(Variable* in, Variable* out);
void AddInplace(const Variable* in, Variable* out);

void ClearInplace();

Expand Down Expand Up @@ -340,7 +340,7 @@ class Instruction {

std::vector<size_t> gc_check_vars_;

std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
std::vector<std::pair<const Variable*, Variable*>> vec_inplace_in_to_out_;

bool pre_define_context_{false};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,21 @@ void HandleForSpecialOp(pir::Operation* op,
}
}

void HandleForInplaceOp(pir::Operation* op,
const std::string& var_name_prefix,
ValueExecutionInfo* value_exe_info) {
bool IsNeedVarInplace(pir::Operation* op,
pir::Value value,
std::string op_name) {
return (value.type().isa<paddle::dialect::DenseTensorArrayType>() ||
op_name == "pd_op.assign_value_");
}

// NOTE(chenxi67): Here, we only perform inplace processing for variables that
// need to be inplaced by var (mostly, whose type is TensorArray or re-Allocated
// Densetensor). For other types of variables, we only share the holder of
// DenseTensor but not the var*. The reason is that vector<DenseTensor> in
// TensorArray (or re-Allocated Densetensor) cannot be shared totally.
void HandleForInplaceVarOp(pir::Operation* op,
const std::string& var_name_prefix,
ValueExecutionInfo* value_exe_info) {
if (op->num_results() < 1) return;
pir::IrContext* ctx = pir::IrContext::Instance();
std::string op_name = op->name();
Expand All @@ -706,6 +718,10 @@ void HandleForInplaceOp(pir::Operation* op,
<< " is not invalid, so skip build a variable.";
continue;
}
if (!IsNeedVarInplace(op, value, op_name)) {
BuildValue(value, var_name_prefix, value_exe_info);
continue;
}
std::string value_name = yaml_parser.OutputNames()[i];
if (yaml_parser.HasInplace(value_name)) {
const std::string& inplace_name = yaml_parser.InplaceName(value_name);
Expand Down Expand Up @@ -785,7 +801,7 @@ void BuildScope(const pir::Block& block,
.at("is_inplace")
.dyn_cast<pir::BoolAttribute>()
.data()) {
HandleForInplaceOp(&op, var_name_prefix, value_exe_info);
HandleForInplaceVarOp(&op, var_name_prefix, value_exe_info);
continue;
} else {
for (size_t i = 0; i < op.num_results(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ std::shared_ptr<OperatorBase> BuildOperatorBase(
const ValueExecutionInfo& value_exec_info,
const paddle::dialect::OpYamlInfoParser& op_yaml_info);

bool IsNeedVarInplace(pir::Operation* op,
pir::Value value,
std::string op_name);

template <typename Context,
typename InType,
typename OutType,
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/new_executor/standalone_executor_pir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ TEST(StandaloneExecutor, run_inplace_sqrt) {
bool res3 = simple_cmp(out_tensor.data<float>()[3], 2.0);

EXPECT_EQ(scope.kids().size(), 1u);
EXPECT_EQ(scope.kids().front()->Size(), 1u);
EXPECT_EQ(scope.kids().front()->Size(), 2u);
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
Expand Down

0 comments on commit 6716d7f

Please sign in to comment.