forked from taichi-dev/taichi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_poly_timed.py
54 lines (40 loc) · 1.21 KB
/
test_poly_timed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from autograd import grad
import taichi as ti
from taichi import approx
# Note: test happens at v = 0.2
def grad_test(tifunc, npfunc=None, default_fp=ti.f32):
if npfunc is None:
npfunc = tifunc
@ti.all_archs_with(default_fp=default_fp)
def impl():
print(f'arch={ti.cfg.arch} default_fp={ti.cfg.default_fp}')
x = ti.field(default_fp)
y = ti.field(default_fp)
ti.root.dense(ti.i, 1).place(x, x.grad, y, y.grad)
@ti.kernel
def func():
for i in x:
y[i] = tifunc(x[i])
v = 0.234
y.grad[0] = 1
x[0] = v
func()
func.grad()
assert y[0] == approx(npfunc(v))
assert x.grad[0] == approx(grad(npfunc)(v))
impl()
def test_poly():
import time
t = time.time()
grad_test(lambda x: x)
grad_test(lambda x: -x)
grad_test(lambda x: x * x)
grad_test(lambda x: x**2)
grad_test(lambda x: x * x * x)
grad_test(lambda x: x * x * x * x)
grad_test(lambda x: 0.4 * x * x - 3)
grad_test(lambda x: (x - 3) * (x - 1))
grad_test(lambda x: (x - 3) * (x - 1) + x * x)
ti.core.print_profile_info()
print('total_time', time.time() - t)
test_poly()