Skip to content

Commit

Permalink
[Mesh] [opt] Optimize reordered index mapping case
Browse files Browse the repository at this point in the history
* wip

* add config

* wip

* work

* auto format
  • Loading branch information
g1n0st authored and yuanming-hu committed Nov 21, 2021
1 parent b0ccf69 commit bc61265
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 55 deletions.
5 changes: 5 additions & 0 deletions python/taichi/lang/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def set_l2g(self, element_type: MeshElementType,
_ti_core.set_l2g(self.mesh_ptr, element_type,
total_offset.vars[0].ptr.snode())

def set_l2r(self, element_type: MeshElementType,
total_offset: ScalarField):
_ti_core.set_l2r(self.mesh_ptr, element_type,
total_offset.vars[0].ptr.snode())

def set_num_patches(self, num_patches: int):
_ti_core.set_num_patches(self.mesh_ptr, num_patches)

Expand Down
25 changes: 17 additions & 8 deletions taichi/analysis/gather_mesh_thread_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@ class GatherMeshThreadLocal : public BasicStmtVisitor {

GatherMeshThreadLocal(OffloadedStmt *offload_,
MeshElementTypeSet *owned_ptr_,
MeshElementTypeSet *total_ptr_) {
MeshElementTypeSet *total_ptr_,
bool optimize_mesh_reordered_mapping_) {
allow_undefined_visitor = true;
invoke_default_visitor = true;

this->offload = offload_;
this->owned_ptr = owned_ptr_;
this->total_ptr = total_ptr_;
this->optimize_mesh_reordered_mapping = optimize_mesh_reordered_mapping_;
}

static void run(OffloadedStmt *offload,
MeshElementTypeSet *owned_ptr,
MeshElementTypeSet *total_ptr) {
MeshElementTypeSet *total_ptr,
const CompileConfig &config) {
TI_ASSERT(offload->task_type == OffloadedStmt::TaskType::mesh_for);
GatherMeshThreadLocal analyser(offload, owned_ptr, total_ptr);
GatherMeshThreadLocal analyser(offload, owned_ptr, total_ptr,
config.optimize_mesh_reordered_mapping);
offload->accept(&analyser);
}

Expand All @@ -48,22 +52,27 @@ class GatherMeshThreadLocal : public BasicStmtVisitor {

void visit(MeshIndexConversionStmt *stmt) override {
this->total_ptr->insert(stmt->from_type());
if (optimize_mesh_reordered_mapping &&
stmt->conv_type == mesh::ConvType::l2r) {
this->owned_ptr->insert(stmt->from_type());
}
}

OffloadedStmt *offload;
MeshElementTypeSet *owned_ptr;
MeshElementTypeSet *total_ptr;
OffloadedStmt *offload{nullptr};
MeshElementTypeSet *owned_ptr{nullptr};
MeshElementTypeSet *total_ptr{nullptr};
bool optimize_mesh_reordered_mapping{false};
};

namespace irpass::analysis {

std::pair</* owned= */ MeshElementTypeSet,
/* total= */ MeshElementTypeSet>
gather_mesh_thread_local(OffloadedStmt *offload) {
gather_mesh_thread_local(OffloadedStmt *offload, const CompileConfig &config) {
MeshElementTypeSet local_owned{};
MeshElementTypeSet local_total{};

GatherMeshThreadLocal::run(offload, &local_owned, &local_total);
GatherMeshThreadLocal::run(offload, &local_owned, &local_total, config);
return std::make_pair(local_owned, local_total);
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void verify(IRNode *root);
void gather_meshfor_relation_types(IRNode *node);
std::pair</* owned= */ std::unordered_set<mesh::MeshElementType>,
/* total= */ std::unordered_set<mesh::MeshElementType>>
gather_mesh_thread_local(OffloadedStmt *offload);
gather_mesh_thread_local(OffloadedStmt *offload, const CompileConfig &config);

} // namespace analysis
} // namespace irpass
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct CompileConfig {
// Mesh related.
// MeshTaichi options
bool make_mesh_index_mapping_local{true};
bool optimize_mesh_reordered_mapping{true};
bool mesh_localize_from_end_mapping{false};

// helpers
Expand Down
9 changes: 8 additions & 1 deletion taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ void export_lang(py::module &m) {
.def_readwrite("make_mesh_index_mapping_local",
&CompileConfig::make_mesh_index_mapping_local)
.def_readwrite("mesh_localize_from_end_mapping",
&CompileConfig::mesh_localize_from_end_mapping);
&CompileConfig::mesh_localize_from_end_mapping)
.def_readwrite("optimize_mesh_reordered_mapping",
&CompileConfig::optimize_mesh_reordered_mapping);

m.def("reset_default_compile_config",
[&]() { default_compile_config = CompileConfig(); });
Expand Down Expand Up @@ -1216,6 +1218,11 @@ void export_lang(py::module &m) {
mesh_ptr.ptr->l2g_map.insert(std::pair(type, snode));
});

m.def("set_l2r",
[](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type, SNode *snode) {
mesh_ptr.ptr->l2r_map.insert(std::pair(type, snode));
});

m.def("set_relation_fixed",
[](mesh::MeshPtr &mesh_ptr, mesh::MeshRelationType type, SNode *value) {
mesh_ptr.ptr->relations.insert(
Expand Down
132 changes: 88 additions & 44 deletions taichi/transforms/make_mesh_index_mapping_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void make_mesh_index_mapping_local_offload(OffloadedStmt *offload,

// TODO(changyu): A analyzer to determinte which mapping should be localized
mappings.insert(std::make_pair(mesh::MeshElementType::Vertex,
mesh::ConvType::l2g)); // FIXME: A hack
mesh::ConvType::l2r)); // FIXME: A hack

std::size_t bls_offset_in_bytes = 0;
auto &block = offload->bls_prologue;
Expand Down Expand Up @@ -51,60 +51,104 @@ void make_mesh_index_mapping_local_offload(OffloadedStmt *offload,
block->parent_stmt = offload;
}

// int i = threadIdx.x;
// while (x < total_{}_num) {
// mapping_shared[i] = mapping[i + total_{}_offset];
// x += blockDim.x;
// }

// Step 1:
// Fetch mapping to BLS

// TODO(changyu): if the target index space is reordered, we can do more
// optimization
{
auto bls_mapping_loop = [&](Stmt *start_val, Stmt *end_val,
std::function<Stmt *(Block *, Stmt *)>
global_val) {
Stmt *idx = block->push_back<AllocaStmt>(data_type);
[[maybe_unused]] Stmt *init_val =
block->push_back<LocalStoreStmt>(idx, start_val);
Stmt *bls_element_offset_bytes = block->push_back<ConstStmt>(
LaneAttribute<TypedConstant>{(int32)bls_offset_in_bytes});
Stmt *block_dim_val = block->push_back<ConstStmt>(
LaneAttribute<TypedConstant>{offload->block_dim});

std::unique_ptr<Block> body = std::make_unique<Block>();
{
Stmt *idx_val = body->push_back<LocalLoadStmt>(LocalAddress{idx, 0});
Stmt *cond = body->push_back<BinaryOpStmt>(BinaryOpType::cmp_lt,
idx_val, end_val);
{ body->push_back<WhileControlStmt>(nullptr, cond); }
Stmt *idx_val_byte = body->push_back<BinaryOpStmt>(
BinaryOpType::mul, idx_val,
body->push_back<ConstStmt>(TypedConstant(dtype_size)));
Stmt *offset = body->push_back<BinaryOpStmt>(
BinaryOpType::add, bls_element_offset_bytes, idx_val_byte);
Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>(
offset,
TypeFactory::create_vector_or_scalar_type(1, data_type, true));
[[maybe_unused]] Stmt *bls_store = body->push_back<GlobalStoreStmt>(
bls_ptr, global_val(body.get(), idx_val));

Stmt *idx_val_ = body->push_back<BinaryOpStmt>(
BinaryOpType::add, idx_val, block_dim_val);
[[maybe_unused]] Stmt *idx_store =
body->push_back<LocalStoreStmt>(idx, idx_val_);
}
block->push_back<WhileStmt>(std::move(body));
Stmt *idx_val = block->push_back<LocalLoadStmt>(LocalAddress{idx, 0});
return idx_val;
};

Stmt *thread_idx_stmt = block->push_back<LoopLinearIndexStmt>(
offload); // Equivalent to CUDA threadIdx
Stmt *idx = block->push_back<AllocaStmt>(data_type);
[[maybe_unused]] Stmt *init_val =
block->push_back<LocalStoreStmt>(idx, thread_idx_stmt);
Stmt *bls_element_offset_bytes = block->push_back<ConstStmt>(
LaneAttribute<TypedConstant>{(int32)bls_offset_in_bytes});
Stmt *block_dim_val = block->push_back<ConstStmt>(
LaneAttribute<TypedConstant>{offload->block_dim});
Stmt *total_element_num =
offload->total_num_local.find(element_type)->second;
Stmt *total_element_offset =
offload->total_offset_local.find(element_type)->second;

std::unique_ptr<Block> body = std::make_unique<Block>();
{
Stmt *idx_val = body->push_back<LocalLoadStmt>(LocalAddress{idx, 0});
Stmt *cond = body->push_back<BinaryOpStmt>(BinaryOpType::cmp_lt,
idx_val, total_element_num);
{ body->push_back<WhileControlStmt>(nullptr, cond); }
Stmt *idx_val_byte = body->push_back<BinaryOpStmt>(
BinaryOpType::mul, idx_val,
body->push_back<ConstStmt>(TypedConstant(dtype_size)));
Stmt *offset = body->push_back<BinaryOpStmt>(
BinaryOpType::add, bls_element_offset_bytes, idx_val_byte);
Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>(
offset,
TypeFactory::create_vector_or_scalar_type(1, data_type, true));
Stmt *global_offset = body->push_back<BinaryOpStmt>(
BinaryOpType::add, total_element_offset, idx_val);
Stmt *global_ptr = body->push_back<GlobalPtrStmt>(
LaneAttribute<SNode *>{snode}, std::vector<Stmt *>{global_offset});
Stmt *global_load = body->push_back<GlobalLoadStmt>(global_ptr);
[[maybe_unused]] Stmt *bls_store =
body->push_back<GlobalStoreStmt>(bls_ptr, global_load);

Stmt *idx_val_ = body->push_back<BinaryOpStmt>(BinaryOpType::add,
idx_val, block_dim_val);
[[maybe_unused]] Stmt *idx_store =
body->push_back<LocalStoreStmt>(idx, idx_val_);
if (config.optimize_mesh_reordered_mapping &&
conv_type == mesh::ConvType::l2r) {
// int i = threadIdx.x;
// while (i < owned_{}_num) {
// mapping_shared[i] = i + owned_{}_offset;
// i += blockDim.x;
// }
// while (i < total_{}_num) {
// mapping_shared[i] = mapping[i + total_{}_offset];
// i += blockDim.x;
// }
Stmt *owned_element_num =
offload->owned_num_local.find(element_type)->second;
Stmt *owned_element_offset =
offload->owned_offset_local.find(element_type)->second;
Stmt *pre_idx_val = bls_mapping_loop(
thread_idx_stmt, owned_element_num,
[&](Block *body, Stmt *idx_val) {
Stmt *global_index = body->push_back<BinaryOpStmt>(
BinaryOpType::add, idx_val, owned_element_offset);
return global_index;
});
bls_mapping_loop(
pre_idx_val, total_element_num, [&](Block *body, Stmt *idx_val) {
Stmt *global_offset = body->push_back<BinaryOpStmt>(
BinaryOpType::add, total_element_offset, idx_val);
Stmt *global_ptr = body->push_back<GlobalPtrStmt>(
LaneAttribute<SNode *>{snode},
std::vector<Stmt *>{global_offset});
Stmt *global_load = body->push_back<GlobalLoadStmt>(global_ptr);
return global_load;
});
} else {
// int i = threadIdx.x;
// while (i < total_{}_num) {
// mapping_shared[i] = mapping[i + total_{}_offset];
// i += blockDim.x;
// }
bls_mapping_loop(
thread_idx_stmt, total_element_num,
[&](Block *body, Stmt *idx_val) {
Stmt *global_offset = body->push_back<BinaryOpStmt>(
BinaryOpType::add, total_element_offset, idx_val);
Stmt *global_ptr = body->push_back<GlobalPtrStmt>(
LaneAttribute<SNode *>{snode},
std::vector<Stmt *>{global_offset});
Stmt *global_load = body->push_back<GlobalLoadStmt>(global_ptr);
return global_load;
});
}
block->push_back<WhileStmt>(std::move(body));
}

// Step 2:
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/make_mesh_thread_local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void make_mesh_thread_local_offload(OffloadedStmt *offload,

std::pair</* owned= */ std::unordered_set<mesh::MeshElementType>,
/* total= */ std::unordered_set<mesh::MeshElementType>>
accessed = analysis::gather_mesh_thread_local(offload);
accessed = analysis::gather_mesh_thread_local(offload, config);

std::size_t tls_offset = offload->tls_size;

Expand Down

0 comments on commit bc61265

Please sign in to comment.