Skip to content

Commit

Permalink
Kernel.foo->func
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Oct 21, 2019
1 parent 3c69dc8 commit 59787b3
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def message(self):


class Kernel:
def __init__(self, foo, is_grad):
self.foo = foo
def __init__(self, func, is_grad):
self.func = func
self.is_grad = is_grad
self.materialized = False
self.arguments = []
Expand All @@ -288,7 +288,7 @@ def __init__(self, foo, is_grad):
self.compiled_functions = pytaichi.compiled_grad_functions

def extract_arguments(self):
sig = inspect.signature(self.foo)
sig = inspect.signature(self.func)
params = sig.parameters
arg_names = params.keys()
for arg_name in arg_names:
Expand Down Expand Up @@ -321,9 +321,9 @@ def materialize(self, extra_frame_backtrace=-1):
grad_suffix = ""
if self.is_grad:
grad_suffix = "_grad"
print("Compiling kernel {}{}...".format(self.foo.__name__, grad_suffix))
print("Compiling kernel {}{}...".format(self.func.__name__, grad_suffix))

src = remove_indent(inspect.getsource(self.foo))
src = remove_indent(inspect.getsource(self.func))
tree = ast.parse(src)
# print(astor.to_source(tree.body[0]))

Expand All @@ -338,22 +338,22 @@ def materialize(self, extra_frame_backtrace=-1):
if pytaichi.print_preprocessed:
print(astor.to_source(tree.body[0], indent_with=' '))

ast.increment_lineno(tree, inspect.getsourcelines(self.foo)[1] - 1)
ast.increment_lineno(tree, inspect.getsourcelines(self.func)[1] - 1)

pytaichi.inside_kernel = True
frame = inspect.currentframe()
for t in range(extra_frame_backtrace + 2):
frame = frame.f_back
exec(compile(tree, filename=inspect.getsourcefile(self.foo), mode='exec'),
exec(compile(tree, filename=inspect.getsourcefile(self.func), mode='exec'),
dict(frame.f_globals, **frame.f_locals), locals())
pytaichi.inside_kernel = False
compiled = locals()[self.foo.__name__]
compiled = locals()[self.func.__name__]

taichi_kernel = taichi_lang_core.create_kernel(self.foo.__name__ + grad_suffix,
self.is_grad)
taichi_kernel = taichi_lang_core.create_kernel(self.func.__name__ + grad_suffix,
self.is_grad)
taichi_kernel = taichi_kernel.define(lambda: compiled())

self.compiled_functions[self.foo] = self.get_function_body(taichi_kernel)
self.compiled_functions[self.func] = self.get_function_body(taichi_kernel)


def get_function_body(self, t_kernel):
Expand Down Expand Up @@ -405,7 +405,7 @@ def func__(*args):

def __call__(self, *args, extra_frame_backtrace=0):
self.materialize(extra_frame_backtrace=extra_frame_backtrace)
self.compiled_functions[self.foo](*args)
self.compiled_functions[self.func](*args)


def kernel(foo):
Expand Down

0 comments on commit 59787b3

Please sign in to comment.