Skip to content

Commit

Permalink
[aot] Fixed header generator (taichi-dev#7455)
Browse files Browse the repository at this point in the history
  • Loading branch information
PENGUINLIONG authored Feb 28, 2023
1 parent 93d2549 commit d136a71
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 75 deletions.
6 changes: 6 additions & 0 deletions python/taichi/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,12 @@ def lint(arguments: list = sys.argv[2:]):
import pylint # pylint: disable=C0415
pylint.lint.Run(options)

@register
def module(self, arguments: list = sys.argv[2:]):
"""Taichi module tools"""
from taichi import _ti_module # pylint: disable=C0415
_ti_module.module(arguments)

@staticmethod
@register
def cache(arguments: list = sys.argv[2:]):
Expand Down
72 changes: 72 additions & 0 deletions python/taichi/_ti_module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
import json
from pathlib import Path
from typing import List

from taichi._ti_module.cppgen import generate_header
from taichi.aot.conventions.gfxruntime140 import GfxRuntime140

import taichi as ti


def module_cppgen(parser: argparse.ArgumentParser):
"""Generate C++ headers for Taichi modules."""
parser.add_argument("MODOLE", help="Path to the module directory.")
parser.add_argument("-n",
"--namespace",
type=str,
help="C++ namespace if wanted.")
parser.add_argument(
"-m",
"--module-name",
type=str,
help=
"Module name to be a part of the module class. By default, it's the directory name.",
default=None)
parser.add_argument("-o",
"--output",
type=str,
help="Output C++ header path.",
default="module.h")
parser.set_defaults(func=module_cppgen_impl)


def module_cppgen_impl(a):
module_path = a.MODOLE

print(
f"Generating C++ header for Taichi module: {Path(module_path).absolute()}"
)

with open(f"{module_path}/metadata.json") as f:
metadata_json = json.load(f)

with open(f"{module_path}/graphs.json") as f:
graphs_json = json.load(f)

if a.module_name:
module_name = a.module_name
else:
module_name = Path(module_path).name
if module_name.endswith(".tcm"):
module_name = module_name[:-4]

out = generate_header(metadata_json, graphs_json, module_name, a.namespace)

with open(a.output, "w") as f:
f.write('\n'.join(out))

print(f"Module header is saved to: {Path(a.output).absolute()}")


def module(arguments: List[str]):
"""Taichi module tools."""
parser = argparse.ArgumentParser(prog='ti module',
description=module.__doc__)
subparsers = parser.add_subparsers(title="Taichi module manager commands",
required=True)

cppgen_parser = subparsers.add_parser('cppgen', help=module_cppgen.__doc__)
module_cppgen(cppgen_parser)
args = parser.parse_args(arguments)
args.func(args)
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
"""
C++ header generator for Taichi AOT modules.
"""
import argparse
import json
from pathlib import Path
from typing import Any, List

import taichi.aot.conventions.gfxruntime140.sr as sr
from taichi.aot.conventions.gfxruntime140 import GfxRuntime140
from taichi.aot.conventions.gfxruntime140 import GfxRuntime140, sr

dtype2ctype = {
sr.DataType.f16: "half_t",
Expand All @@ -31,8 +24,8 @@ def check_arg(actual: str, expect: Any) -> List[str]:
out += [
f" if (value.{actual} != {expect}) {{",
f" ti_set_last_error(TI_ERROR_INVALID_ARGUMENT, \"value.{actual} != {expect}\");",
f" return *this;",
f" }}",
" return *this;",
" }",
]

return out
Expand All @@ -41,8 +34,7 @@ def check_arg(actual: str, expect: Any) -> List[str]:
def get_arg_dst(i: int, is_named: bool) -> str:
if is_named:
return f"args[{i}].argument"
else:
return f"args[{i}]"
return f"args[{i}]"


def generate_scalar_assign(cls_name: str, i: int, arg_name: str,
Expand Down Expand Up @@ -91,7 +83,7 @@ def generate_ndarray_assign(cls_name: str, i: int, arg_name: str,
out = []

out += [
f" {cls_name} &{arg_name}(const TiNdArray &value) {{",
f" {cls_name} &set_{arg_name}(const TiNdArray &value) {{",
]

out += check_arg("elem_type", arg.dtype)
Expand Down Expand Up @@ -120,7 +112,7 @@ def generate_texture_assign(cls_name: str, i: int, arg_name: str,
out = []

out += [
f" {cls_name} &{arg_name}(const TiTexture &value) {{",
f" {cls_name} &set_{arg_name}(const TiTexture &value) {{",
]

assert arg.ndim in [1, 2, 3]
Expand Down Expand Up @@ -151,7 +143,7 @@ def generate_kernel_args_builder(kernel: sr.Kernel) -> List[str]:
f" TiArgument args[{len(kernel.context.args)}];",
"",
f" explicit Kernel_{kernel.name}(TiRuntime runtime, TiKernel kernel) :",
f" runtime(runtime), kernel(kernel) {{}}",
" runtime(runtime), kernel(kernel) {}",
f" explicit Kernel_{kernel.name}(TiRuntime runtime, TiAotModule aot_module) :",
f" runtime(runtime), kernel(ti_get_aot_module_kernel(aot_module, \"{kernel.name}\")) {{}}",
"",
Expand Down Expand Up @@ -190,7 +182,7 @@ def generate_graph_args_builder(graph: sr.Graph) -> List[str]:
f" TiNamedArgument args[{len(graph.args)}];",
"",
f" explicit Graph_{graph.name}(TiRuntime runtime, TiComputeGraph graph) :",
f" runtime(runtime), graph(graph) {{}}",
" runtime(runtime), graph(graph) {}",
f" explicit Graph_{graph.name}(TiRuntime runtime, TiAotModule aot_module) :",
f" runtime(runtime), graph(ti_get_aot_module_compute_graph(aot_module, \"{graph.name}\")) {{}}",
"",
Expand Down Expand Up @@ -231,16 +223,16 @@ def generate_module_content_repr(m: GfxRuntime140,
]
else:
out += [
f"struct Module {{",
"struct Module {",
]

out += [
" TiRuntime runtime;",
" TiAotModule aot_module;",
" bool should_destroy;",
"",
f" explicit Module(TiRuntime runtime, TiAotModule aot_module, bool should_destroy = true) :",
f" runtime(runtime), aot_module(aot_module), should_destroy(should_destroy) {{}}",
" explicit Module(TiRuntime runtime, TiAotModule aot_module, bool should_destroy = true) :",
" runtime(runtime), aot_module(aot_module), should_destroy(should_destroy) {}",
" ~Module() {",
" if (should_destroy) {",
" ti_destroy_aot_module(aot_module);",
Expand Down Expand Up @@ -269,17 +261,14 @@ def generate_module_content_repr(m: GfxRuntime140,

def generate_module_content(m: GfxRuntime140, module_name: str) -> List[str]:
# This has all kernels including all the ones launched by compute graphs.
kernel_names = [x for x in m.metadata.kernels]
cgraph_kernel_names = [
dispatch.kernel.name in kernel_names for graph in m.graphs
for dispatch in graph.dispatches
]
cgraph_kernel_names = set(dispatch.kernel.name for graph in m.graphs
for dispatch in graph.dispatches)

out = []
for kernel in m.metadata.kernels.values():
out += generate_kernel_args_builder(kernel)
if kernel.name in cgraph_kernel_names:
continue
out += generate_kernel_args_builder(kernel)

for graph in m.graphs:
out += generate_graph_args_builder(graph)
Expand All @@ -289,10 +278,12 @@ def generate_module_content(m: GfxRuntime140, module_name: str) -> List[str]:
return out


def generate_header(m: GfxRuntime140, module_name: str,
def generate_header(metadata_json: str, graphs_json: str, module_name: str,
namespace: str) -> List[str]:
out = []

m = GfxRuntime140(metadata_json, graphs_json)

out += [
"// THIS IS A GENERATED HEADER; PLEASE DO NOT MODIFY.",
"#pragma once",
Expand All @@ -312,52 +303,3 @@ def generate_header(m: GfxRuntime140, module_name: str,
out += [f"}} // namespace {namespace}", ""]

return out


def main():
parser = argparse.ArgumentParser()
parser.add_argument("MODOLE",
help="Path to the module directory, or TCM archive.")
parser.add_argument("-n",
"--namespace",
type=str,
help="C++ namespace if wanted.")
parser.add_argument(
"-m",
"--module-name",
type=str,
help=
"Module name to be a part of the module class. By default, it's the directory name.",
default=None)
parser.add_argument("-o",
"--output",
type=str,
help="Output C++ header path.",
default="module.h")
a = parser.parse_args()

module_path = a.MODOLE

with open(f"{module_path}/metadata.json") as f:
metadata_json = json.load(f)

with open(f"{module_path}/graphs.json") as f:
graphs_json = json.load(f)

m = GfxRuntime140(metadata_json, graphs_json)

if a.module_name:
module_name = a.module_name
else:
module_name = Path(module_path).name
if module_name.endswith(".tcm"):
module_name = module_name[:-4]

out = generate_header(m, module_name, a.namespace)

with open(a.output, "w") as f:
f.write('\n'.join(out))


if __name__ == "__main__":
main()

0 comments on commit d136a71

Please sign in to comment.