Skip to content

Commit

Permalink
[Dist Dialect] Add MoE-related api in PIR dist dialect (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…66462)

* add two MoE api in distributed dialect

* polish the dist_op and add unit test

* remove simple_net_ep unit test

* remove redundant print

* bug fix, replace platform::errors with phi::errors
  • Loading branch information
pkuzyc authored Jul 30, 2024
1 parent dbfc48b commit 8718d78
Show file tree
Hide file tree
Showing 12 changed files with 833 additions and 2 deletions.
49 changes: 49 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,53 @@ pir::Value reshard(const pir::Value& x,
return reshard_op.result(0);
}

std::vector<pir::Value> local_tensors_from_dist(
const pir::Value& input,
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
const std::vector<int64_t>& local_dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& local_partial_status,
const phi::distributed::ProcessMesh& global_mesh,
const std::vector<int64_t>& global_dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status) {
pir::IrContext* ctx = pir::IrContext::Instance();
std::vector<TensorDistAttribute> local_dist_attrs;
for (const phi::distributed::ProcessMesh& mesh : local_mesh_list) {
local_dist_attrs.emplace_back(TensorDistAttribute::get(
ctx, mesh, local_dims_mapping, local_partial_status));
}
TensorDistAttribute global_dist_attr = TensorDistAttribute::get(
ctx, global_mesh, global_dims_mapping, global_partial_status);

auto op = ApiBuilder::Instance().GetBuilder()->Build<LocalTensorsFromDistOp>(
input, local_dist_attrs, global_dist_attr);
return op.results();
}

pir::Value dist_tensor_from_locals(
const std::vector<pir::Value>& inputs,
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
const std::vector<int64_t>& local_dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& local_partial_status,
const phi::distributed::ProcessMesh& global_mesh,
const std::vector<int64_t>& global_dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status,
const std::vector<int64_t>& global_shape) {
pir::IrContext* ctx = pir::IrContext::Instance();

std::vector<TensorDistAttribute> local_dist_attrs;
for (const phi::distributed::ProcessMesh& mesh : local_mesh_list) {
local_dist_attrs.emplace_back(TensorDistAttribute::get(
ctx, mesh, local_dims_mapping, local_partial_status));
}

TensorDistAttribute global_dist_attr = TensorDistAttribute::get(
ctx, global_mesh, global_dims_mapping, global_partial_status);

phi::DDim global_ddim = phi::make_ddim(global_shape);

auto op = ApiBuilder::Instance().GetBuilder()->Build<DistTensorFromLocalsOp>(
inputs, local_dist_attrs, global_dist_attr, global_ddim);
return op.result(0);
}

} // namespace paddle::dialect
19 changes: 19 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,24 @@ pir::Value reshard(
pir::Value reshard(const pir::Value& x,
const TensorDistAttribute& tensor_dist_attr);

std::vector<pir::Value> local_tensors_from_dist(
const pir::Value& input,
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
const std::vector<int64_t>& local_dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& local_partial_status,
const phi::distributed::ProcessMesh& global_mesh,
const std::vector<int64_t>& global_dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status);

pir::Value dist_tensor_from_locals(
const std::vector<pir::Value>& inputs,
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
const std::vector<int64_t>& local_dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& local_partial_status,
const phi::distributed::ProcessMesh& global_mesh,
const std::vector<int64_t>& global_dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status,
const std::vector<int64_t>& global_shape);

} // namespace dialect
} // namespace paddle
5 changes: 4 additions & 1 deletion paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ void DistDialect::initialize() {
TensorDistAttribute,
OperationDistAttribute>();
RegisterTypes<DistDenseTensorType>();
RegisterOps<ShardTensorOp, ReshardOp>();
RegisterOps<ShardTensorOp,
ReshardOp,
LocalTensorsFromDistOp,
DistTensorFromLocalsOp>();
}

void DistDialect::PrintType(pir::Type type, std::ostream &os) const {
Expand Down
Loading

0 comments on commit 8718d78

Please sign in to comment.