Skip to content

Commit

Permalink
v0.3.8; fixed @data_oriented with nn.Module
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Dec 23, 2019
1 parent 8362d6d commit 6e045d4
Showing 3 changed files with 10 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ project(taichi)

SET(TI_VERSION_MAJOR 0)
SET(TI_VERSION_MINOR 3)
SET(TI_VERSION_PATCH 7)
SET(TI_VERSION_PATCH 8)

execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
2 changes: 1 addition & 1 deletion docs/version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.7
0.3.8
25 changes: 8 additions & 17 deletions python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
@@ -25,8 +25,6 @@ def remove_indent(lines):


# The ti.func decorator


def func(foo):
from .impl import get_runtime
src = remove_indent(inspect.getsource(foo))
@@ -330,29 +328,22 @@ def __call__(self, *args, **kwargs):
self.func(*args, **kwargs)

def data_oriented(cls):
class new_class:
def __init__(self, *args, **kwargs):
self.instance = cls(*args, **kwargs)
def getattr(self, item):
x = super(cls, self).__getattribute__(item)
if hasattr(x, '_classkernel'):
return DifferentiableMethod(x)
else:
return x

def __getattribute__(self, item):
if item == 'instance':
return super().__getattribute__(item)

x = self.instance.__getattribute__(item)
if hasattr(x, '_classkernel'):
print('Calling classkernel')
return DifferentiableMethod(x)
else:
return x
cls.__getattribute__ = getattr

return new_class
return cls

def classkernel(foo):
primal = Kernel(foo, False, classkernel=True)
adjoint = Kernel(foo, True, classkernel=True)

def decorated(*args, _gradient=False, **kwargs):
print(kwargs)
if _gradient:
adjoint(*args, **kwargs)
else:

0 comments on commit 6e045d4

Please sign in to comment.