forked from taichi-dev/taichi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_custom_float.py
87 lines (63 loc) · 2.05 KB
/
test_custom_float.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import math
from pytest import approx
import taichi as ti
@ti.test(require=ti.extension.quant_basic)
def test_custom_float():
cft = ti.quant.fixed(frac=32, range=2)
x = ti.field(dtype=cft)
ti.root.bit_struct(num_bits=32).place(x)
@ti.kernel
def foo():
x[None] = 0.7
print(x[None])
x[None] = x[None] + 0.4
foo()
assert x[None] == approx(1.1)
x[None] = 0.64
assert x[None] == approx(0.64)
x[None] = 0.66
assert x[None] == approx(0.66)
@ti.test(require=ti.extension.quant_basic)
def test_custom_matrix_rotation():
cft = ti.quant.fixed(frac=16, range=1.2)
x = ti.Matrix.field(2, 2, dtype=cft)
ti.root.bit_struct(num_bits=32).place(x(0, 0), x(0, 1))
ti.root.bit_struct(num_bits=32).place(x(1, 0), x(1, 1))
x[None] = [[1.0, 0.0], [0.0, 1.0]]
@ti.kernel
def rotate_18_degrees():
angle = math.pi / 10
x[None] = x[None] @ ti.Matrix(
[[ti.cos(angle), ti.sin(angle)], [-ti.sin(angle),
ti.cos(angle)]])
for i in range(5):
rotate_18_degrees()
assert x[None][0, 0] == approx(0, abs=1e-4)
assert x[None][0, 1] == approx(1, abs=1e-4)
assert x[None][1, 0] == approx(-1, abs=1e-4)
assert x[None][1, 1] == approx(0, abs=1e-4)
@ti.test(require=ti.extension.quant_basic)
def test_custom_float_implicit_cast():
ci13 = ti.quant.int(bits=13)
cft = ti.type_factory.custom_float(significand_type=ci13, scale=0.1)
x = ti.field(dtype=cft)
ti.root.bit_struct(num_bits=32).place(x)
@ti.kernel
def foo():
x[None] = 10
foo()
assert x[None] == approx(10.0)
@ti.test(require=ti.extension.quant_basic)
def test_cache_read_only():
ci15 = ti.quant.int(bits=15)
cft = ti.type_factory.custom_float(significand_type=ci15, scale=0.1)
x = ti.field(dtype=cft)
ti.root.bit_struct(num_bits=32).place(x)
@ti.kernel
def test(data: ti.f32):
ti.cache_read_only(x)
assert x[None] == data
x[None] = 0.7
test(0.7)
x[None] = 1.2
test(1.2)