Skip to content

Commit

Permalink
[metal] Add mem_offset_in_parent to AOT module (taichi-dev#3245)
Browse files Browse the repository at this point in the history
* [metal] Add mem_offset_in_parent to AOT module

* Auto Format

* tweak

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
k-ye and taichi-gardener authored Nov 5, 2021
1 parent 29fd273 commit 1a19f97
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 30 deletions.
5 changes: 3 additions & 2 deletions python/taichi/aot/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def add_field(self, name, field):
column_num = field.n
else:
assert isinstance(field, ScalarField)
self._aot_builder.add_field(name, is_scalar, field.dtype,
field.snode.shape, row_num, column_num)
self._aot_builder.add_field(name, field.snode.ptr, is_scalar,
field.dtype, field.snode.shape, row_num,
column_num)

def add_kernel(self, kernel_fn, name=None):
"""Add a taichi kernel to the AOT module.
Expand Down
26 changes: 18 additions & 8 deletions taichi/backends/metal/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ namespace metal {
AotModuleBuilderImpl::AotModuleBuilderImpl(
const CompiledRuntimeModule *compiled_runtime_module,
const std::vector<CompiledStructs> &compiled_snode_trees,
const BufferMetaData &buffer_meta_data)
const std::unordered_set<const SNode *> &fields,
BufferMetaData buffer_meta_data)
: compiled_runtime_module_(compiled_runtime_module),
compiled_snode_trees_(compiled_snode_trees),
buffer_meta_data_(buffer_meta_data) {
fields_(fields) {
buffer_meta_data.root_buffer_size = compiled_snode_trees_[0].root_size;
ti_aot_data_.metadata = buffer_meta_data;
}

void AotModuleBuilderImpl::metalgen(const std::string &dir,
const std::string &filename,
const CompiledKernelData &k) const {
void AotModuleBuilderImpl::write_metal_file(const std::string &dir,
const std::string &filename,
const CompiledKernelData &k) const {
const std::string mtl_path =
fmt::format("{}/{}_{}.metal", dir, filename, k.kernel_name);
std::ofstream fs{mtl_path};
Expand All @@ -41,12 +43,12 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,
ts.write_to_file(txt_path);

for (const auto &k : ti_aot_data_.kernels) {
metalgen(output_dir, filename, k);
write_metal_file(output_dir, filename, k);
}

for (const auto &k : ti_aot_data_.tmpl_kernels) {
for (auto &ki : k.kernel_tmpl_map) {
metalgen(output_dir, filename, ki.second);
write_metal_file(output_dir, filename, ki.second);
}
}
}
Expand All @@ -59,18 +61,26 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
ti_aot_data_.kernels.push_back(std::move(compiled));
}

void AotModuleBuilderImpl::add_per_backend_field(const std::string &identifier,
void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
int row_num,
int column_num) {
const auto *dense_snode = rep_snode->parent;
TI_ASSERT_INFO(fields_.find(dense_snode) != fields_.end(),
"dense_snode: id={} type={}", dense_snode->id,
dense_snode->get_node_type_name_hinted());
const auto &dense_desc =
compiled_snode_trees_[0].snode_descriptors.at(dense_snode->id);
CompiledFieldData field_data;
field_data.field_name = identifier;
field_data.is_scalar = is_scalar;
field_data.dtype = to_metal_type(dt);
field_data.dtype_name = metal_data_type_name(dt);
field_data.shape = shape;
field_data.mem_offset_in_parent = dense_desc.mem_offset_in_parent;
field_data.row_num = row_num;
field_data.column_num = column_num;
ti_aot_data_.fields.push_back(field_data);
Expand Down
18 changes: 11 additions & 7 deletions taichi/backends/metal/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <string>
#include <vector>
#include <unordered_set>

#include "taichi/backends/metal/aot_utils.h"
#include "taichi/backends/metal/struct_metal.h"
Expand All @@ -16,14 +17,17 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
explicit AotModuleBuilderImpl(
const CompiledRuntimeModule *compiled_runtime_module,
const std::vector<CompiledStructs> &compiled_snode_trees,
const BufferMetaData &buffer_meta_data);
const std::unordered_set<const SNode *> &fields,
BufferMetaData buffer_meta_data);

void dump(const std::string &output_dir,
const std::string &filename) const override;

protected:
void add_per_backend(const std::string &identifier, Kernel *kernel) override;
void add_per_backend_field(const std::string &identifier,

void add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
Expand All @@ -34,15 +38,15 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
Kernel *kernel) override;

private:
void write_metal_file(const std::string &dir,
const std::string &filename,
const CompiledKernelData &k) const;

const CompiledRuntimeModule *compiled_runtime_module_;
const std::vector<CompiledStructs> &compiled_snode_trees_;
BufferMetaData buffer_meta_data_;
const std::unordered_set<const SNode *> fields_;
PrintStringTable strtab_;
TaichiAotData ti_aot_data_;

void metalgen(const std::string &dir,
const std::string &filename,
const CompiledKernelData &k) const;
};

} // namespace metal
Expand Down
2 changes: 2 additions & 0 deletions taichi/backends/metal/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ struct CompiledFieldData {
MetalDataType dtype;
std::string dtype_name;
std::vector<int> shape;
int mem_offset_in_parent{0};
bool is_scalar{false};
int row_num{0};
int column_num{0};
Expand All @@ -288,6 +289,7 @@ struct CompiledFieldData {
dtype,
dtype_name,
shape,
mem_offset_in_parent,
is_scalar,
row_num,
column_num);
Expand Down
67 changes: 60 additions & 7 deletions taichi/backends/metal/metal_program.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,46 @@
#include <unordered_set>

#include "metal_program.h"
#include "taichi/backends/metal/codegen_metal.h"
#include "taichi/backends/metal/struct_metal.h"

namespace taichi {
namespace lang {
namespace {

std::unordered_set<const SNode *> find_all_dense_snodes(
const metal::SNodeDescriptorsMap &snodes_map) {
std::unordered_set<const SNode *> res;
for (const auto [_, desc] : snodes_map) {
const auto *sn = desc.snode;
if (sn->type == SNodeType::dense) {
res.insert(sn);
}
}
return res;
}

bool all_fields_are_dense(
const std::unordered_set<const SNode *> &placed_snodes) {
for (const auto *sn : placed_snodes) {
for (const auto &ch : sn->ch) {
if (ch->type != SNodeType::place) {
return false;
}
}
const auto *parent = sn->parent;
if (!parent) {
return false;
}
if (parent->type != SNodeType::root) {
return false;
}
}
return true;
}

} // namespace

MetalProgramImpl::MetalProgramImpl(CompileConfig &config_)
: ProgramImpl(config_) {
}
Expand Down Expand Up @@ -43,24 +80,40 @@ void MetalProgramImpl::materialize_runtime(MemoryPool *memory_pool,
metal_kernel_mgr_ = std::make_unique<metal::KernelManager>(std::move(params));
}

void MetalProgramImpl::compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) {
(void)compile_snode_tree_types_impl(tree);
}

void MetalProgramImpl::materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &,
uint64 *result_buffer) {
// TODO: support materializing multiple snode trees
TI_ASSERT_INFO(config->use_llvm,
"Metal arch requires that LLVM being enabled");
auto *const root = tree->root();
auto csnode_tree = metal::compile_structs(*root);
const auto &csnode_tree = compile_snode_tree_types_impl(tree);
metal_kernel_mgr_->add_compiled_snode_tree(csnode_tree);
compiled_snode_trees_.push_back(std::move(csnode_tree));
}

std::unique_ptr<AotModuleBuilder> MetalProgramImpl::make_aot_module_builder() {
TI_ERROR_IF(compiled_snode_trees_.size() > 1,
"AOT: only supports one SNodeTree");
const auto fields =
find_all_dense_snodes(compiled_snode_trees_[0].snode_descriptors);
TI_ERROR_IF(!all_fields_are_dense(fields), "AOT: only supports dense field");
return std::make_unique<metal::AotModuleBuilderImpl>(
&(compiled_runtime_module_.value()), compiled_snode_trees_,
&(compiled_runtime_module_.value()), compiled_snode_trees_, fields,
metal_kernel_mgr_->get_buffer_meta_data());
}

const metal::CompiledStructs &MetalProgramImpl::compile_snode_tree_types_impl(
SNodeTree *tree) {
TI_ASSERT_INFO(config->use_llvm,
"Metal arch requires that LLVM being enabled");
auto *const root = tree->root();
auto csnode_tree = metal::compile_structs(*root);
compiled_snode_trees_.push_back(std::move(csnode_tree));
return compiled_snode_trees_.back();
}

} // namespace lang
} // namespace taichi
6 changes: 6 additions & 0 deletions taichi/backends/metal/metal_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class MetalProgramImpl : public ProgramImpl {
KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) override;

void compile_snode_tree_types(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees) override;

void materialize_snode_tree(
SNodeTree *tree,
std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
Expand All @@ -46,6 +50,8 @@ class MetalProgramImpl : public ProgramImpl {
std::unique_ptr<AotModuleBuilder> make_aot_module_builder() override;

private:
const metal::CompiledStructs &compile_snode_tree_types_impl(SNodeTree *tree);

std::optional<metal::CompiledRuntimeModule> compiled_runtime_module_{
std::nullopt};
std::vector<metal::CompiledStructs> compiled_snode_trees_;
Expand Down
3 changes: 2 additions & 1 deletion taichi/backends/opengl/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
aot_data_.kernels.push_back({compiled, identifier});
}

void AotModuleBuilderImpl::add_per_backend_field(const std::string &identifier,
void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
Expand Down
4 changes: 3 additions & 1 deletion taichi/backends/opengl/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class AotModuleBuilderImpl : public AotModuleBuilder {

protected:
void add_per_backend(const std::string &identifier, Kernel *kernel) override;
void add_per_backend_field(const std::string &identifier,

void add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
Expand Down
3 changes: 2 additions & 1 deletion taichi/backends/wasm/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
name_list_.push_back(name);
}

void AotModuleBuilderImpl::add_per_backend_field(const std::string &identifier,
void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
Expand Down
3 changes: 2 additions & 1 deletion taichi/backends/wasm/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
void add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) override;
void add_per_backend_field(const std::string &Identifier,
void add_field_per_backend(const std::string &Identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
Expand Down
4 changes: 3 additions & 1 deletion taichi/program/aot_module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ void AotModuleBuilder::add(const std::string &identifier, Kernel *kernel) {
}

void AotModuleBuilder::add_field(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
int row_num,
int column_num) {
add_per_backend_field(identifier, is_scalar, dt, shape, row_num, column_num);
add_field_per_backend(identifier, rep_snode, is_scalar, dt, shape, row_num,
column_num);
}

void AotModuleBuilder::add_kernel_template(const std::string &identifier,
Expand Down
6 changes: 5 additions & 1 deletion taichi/program/aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <string>
#include <vector>

#include "taichi/ir/snode.h"

namespace taichi {
namespace lang {

Expand All @@ -16,6 +18,7 @@ class AotModuleBuilder {
void add(const std::string &identifier, Kernel *kernel);

void add_field(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
Expand All @@ -35,7 +38,8 @@ class AotModuleBuilder {
*/
virtual void add_per_backend(const std::string &identifier,
Kernel *kernel) = 0;
virtual void add_per_backend_field(const std::string &identifier,
virtual void add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
Expand Down

0 comments on commit 1a19f97

Please sign in to comment.