Skip to content

Commit

Permalink
simplified torch io in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Dec 4, 2019
1 parent d8d6cab commit b116fe1
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 100 deletions.
48 changes: 1 addition & 47 deletions examples/pytorch_tensor_ad.py
Original file line number Diff line number Diff line change
@@ -1,47 +1 @@
import taichi as ti
import numpy as np
import torch

@ti.host_arch
def test_torch_ad():
if not ti.has_pytorch():
return
n = 32

x = ti.var(ti.f32, shape=n, needs_grad=True)
y = ti.var(ti.f32, shape=n, needs_grad=True)

@ti.kernel
def torch_kernel():
for i in range(n):
# Do whatever complex operations here a little bit fancier
y[n - i - 1] = x[i] * x[i]

# https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html

class Sqr(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
outp = torch.zeros_like(inp)
ti.from_torch(x, inp)
torch_kernel()
ti.to_torch(y, outp)
return outp

@staticmethod
def backward(ctx, outp_grad):
inp_grad = torch.zeros_like(outp_grad)

ti.clear_all_gradients()
ti.from_torch(y.grad, outp_grad)
torch_kernel.grad()
ti.to_torch(x.grad, inp_grad)

return inp_grad

sqr = Sqr.apply
for i in range(10):
X = torch.tensor(2 * np.ones((n, ), dtype=np.float32), requires_grad=True)
sqr(X).sum().backward()
print(X.grad.cpu().numpy())

# Moved to test/python/test_torch_ad
45 changes: 1 addition & 44 deletions examples/pytorch_tensor_io.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,2 @@
import taichi as ti
import numpy as np
import torch

ti.cfg.arch = ti.cuda

n = 32

# https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html

@ti.kernel
def torch_kernel(t: ti.ext_arr(), o: ti.ext_arr()):
for i in range(n):
o[i] = t[i] * t[i]

@ti.kernel
def torch_kernel_2(t_grad: ti.ext_arr(), t: ti.ext_arr(), o_grad: ti.ext_arr()):
for i in range(n):
print(o_grad[i])
t_grad[i] = 2 * t[i] * o_grad[i]


class Sqr(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
outp = torch.zeros_like(inp)
ctx.save_for_backward(inp)
torch_kernel(inp, outp)
return outp

@staticmethod
def backward(ctx, outp_grad):
outp_grad = outp_grad.contiguous()
inp_grad = torch.zeros_like(outp_grad)
inp, = ctx.saved_tensors
torch_kernel_2(inp_grad, inp, outp_grad)
return inp_grad

#, device=torch.device('cuda:0')

sqr = Sqr.apply
X = torch.tensor(2 * np.ones((n, ), dtype=np.float32), requires_grad=True)
sqr(X).sum().backward()
print(X.grad.cpu())
# Moved to test/python/test_torch_io

2 changes: 1 addition & 1 deletion python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def from_numpy(self, arr):
numpy_to_tensor(arr, self)

def from_torch(self, arr):
self.from_numpy(arr)
self.from_numpy(arr.contiguous())

def make_var_vector(size):
import taichi as ti
Expand Down
12 changes: 4 additions & 8 deletions tests/python/test_torch_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,17 @@ def torch_kernel():
class Sqr(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
outp = torch.zeros_like(inp)
ti.from_torch(x, inp)
x.from_torch(inp)
torch_kernel()
ti.to_torch(y, outp)
outp = y.to_torch()
return outp

@staticmethod
def backward(ctx, outp_grad):
inp_grad = torch.zeros_like(outp_grad)

ti.clear_all_gradients()
ti.from_torch(y.grad, outp_grad)
y.grad.from_torch(outp_grad)
torch_kernel.grad()
ti.to_torch(x.grad, inp_grad)

inp_grad = x.grad.to_torch()
return inp_grad

sqr = Sqr.apply
Expand Down

0 comments on commit b116fe1

Please sign in to comment.