Skip to content

Commit

Permalink
[error] Instruct user to use static range when matrix accessed with n…
Browse files Browse the repository at this point in the history
…on constant index (taichi-dev#1420)
  • Loading branch information
archibate authored Jul 6, 2020
1 parent fe74125 commit 5c43dcf
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
14 changes: 11 additions & 3 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from .util import taichi_scope, python_scope, deprecated, to_numpy_type, to_pytorch_type, in_python_scope
from .common_ops import TaichiOperations
from .exception import TaichiSyntaxError
from collections.abc import Iterable
import warnings

Expand Down Expand Up @@ -181,9 +182,16 @@ def linearize_entry_id(self, *args):
assert 0 <= args[1] < self.m
# TODO(#1004): See if it's possible to support indexing at runtime
for i, a in enumerate(args):
assert isinstance(
a, int
), f'The {i}-th index of a Matrix/Vector must be a compile-time constant integer, got {a}'
if not isinstance(a, int):
raise TaichiSyntaxError(
f'The {i}-th index of a Matrix/Vector must be a compile-time constant '
'integer, got {a}. This is because matrix operations will be **unrolled**'
' at compile-time for performance reason.\n'
'If you want to *iterate through matrix elements*, use a static range:\n'
' for i in ti.static(range(3)):\n'
' print(i, "-th component is", vec[i])\n'
'See https://taichi.readthedocs.io/en/stable/meta.html#when-to-use-for-loops-with-ti-static for more details.'
)
return args[0] * self.m + args[1]

def __call__(self, *args, **kwargs):
Expand Down
29 changes: 29 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,32 @@ def run():

assert np.allclose(r1[None].value.to_numpy(), ops(a, b))
assert np.allclose(r2[None].value.to_numpy(), ops(a, c))


@ti.host_arch_only
@ti.must_throw(ti.TaichiSyntaxError)
def test_matrix_non_constant_index():
m = ti.Matrix(2, 2, ti.i32, 5)

@ti.kernel
def func():
for i in range(5):
for j, k in ti.ndrange(2, 2):
m[i][j, k] = 12

func()


@ti.host_arch_only
def test_matrix_constant_index():
m = ti.Matrix(2, 2, ti.i32, 5)

@ti.kernel
def func():
for i in range(5):
for j, k in ti.static(ti.ndrange(2, 2)):
m[i][j, k] = 12

func()

assert np.allclose(m.to_numpy(), np.ones((5, 2, 2), np.int32) * 12)

0 comments on commit 5c43dcf

Please sign in to comment.