Skip to content

Commit

Permalink
[lang] Support vector and matrix dtypes in ti.field (taichi-dev#7761)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#6572 

For backward compatibility reason we cannot change any positional arg
and kwargs in `ti.field` but we this PR simply added support for vector
and matrix dtypes.

Whether to deprecate `ti.Vector.field` and `ti.Matrix.field` can be a
separate decision.

Testing: to avoid massive duplicated tests I just replaced all
`ti.Vector.field/ti.Matrix.field` in `test_fields.py` to use this
unified API. There're plenty of other files that use the old APIs so
this should be fine for now.

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 326d67e</samp>

Unified the creation of vector and matrix fields with a single `field`
function in `taichi.lang.impl`. Updated the tests in `test_field.py` to
use the new function and removed the deprecated ones. This improves the
consistency and readability of the field API.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 326d67e</samp>

* Rename and redefine `field` function to handle vector and matrix types
([link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-99744c5ae5f6a754d6f68408fdc64fb0d6097216518a7f3d1ef43ffe12599577L19-R19),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-99744c5ae5f6a754d6f68408fdc64fb0d6097216518a7f3d1ef43ffe12599577L698-R704),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-99744c5ae5f6a754d6f68408fdc64fb0d6097216518a7f3d1ef43ffe12599577R760-R805))
* Update `field` function docstring and examples in
`python/taichi/lang/impl.py`
([link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-99744c5ae5f6a754d6f68408fdc64fb0d6097216518a7f3d1ef43ffe12599577R760-R805))
* Replace `Vector.field` and `Matrix.field` with `field` function in
`tests/python/test_field.py`
([link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L38-R39),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L55-R57),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L135-R137),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L141-R144),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L171-R174),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L177-R181),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L197-R203),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L207-R212),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L215-R221),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L295-R300),[link](https://github.com/taichi-dev/taichi/pull/7761/files?diff=unified&w=0#diff-c08dd53cc282976d42e5643ea69e8e30e390e2ebd2f4e73f2789f84ac56f2494L302-R308))
  • Loading branch information
ailzhang authored Apr 10, 2023
1 parent ea8002e commit a533779
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 55 deletions.
97 changes: 54 additions & 43 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType,
VectorNdarray, make_matrix)
Vector, VectorNdarray, make_matrix)
from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance,
MeshRelationAccessProxy,
MeshReorderedMatrixFieldProxy,
Expand Down Expand Up @@ -695,48 +695,13 @@ def create_field_member(dtype, name, needs_grad, needs_dual):


@python_scope
def field(dtype,
shape=None,
order=None,
name="",
offset=None,
needs_grad=False,
needs_dual=False):
"""Defines a Taichi field.
A Taichi field can be viewed as an abstract N-dimensional array, hiding away
the complexity of how its underlying :class:`~taichi.lang.snode.SNode` are
actually defined. The data in a Taichi field can be directly accessed by
a Taichi :func:`~taichi.lang.kernel_impl.kernel`.
See also https://docs.taichi-lang.org/docs/field
Args:
dtype (DataType): data type of the field.
shape (Union[int, tuple[int]], optional): shape of the field.
order (str, optional): order of the shape laid out in memory.
name (str, optional): name of the field.
offset (Union[int, tuple[int]], optional): offset of the field domain.
needs_grad (bool, optional): whether this field participates in autodiff (reverse mode)
and thus needs an adjoint field to store the gradients.
needs_dual (bool, optional): whether this field participates in autodiff (forward mode)
and thus needs an dual field to store the gradients.
Example::
The code below shows how a Taichi field can be declared and defined::
>>> x1 = ti.field(ti.f32, shape=(16, 8))
>>> # Equivalently
>>> x2 = ti.field(ti.f32)
>>> ti.root.dense(ti.ij, shape=(16, 8)).place(x2)
>>>
>>> x3 = ti.field(ti.f32, shape=(16, 8), order='ji')
>>> # Equivalently
>>> x4 = ti.field(ti.f32)
>>> ti.root.dense(ti.j, shape=8).dense(ti.i, shape=16).place(x4)
"""
def _field(dtype,
shape=None,
order=None,
name="",
offset=None,
needs_grad=False,
needs_dual=False):
x, x_grad, x_dual = create_field_member(dtype, name, needs_grad,
needs_dual)
x = ScalarField(x)
Expand Down Expand Up @@ -791,6 +756,52 @@ def field(dtype,
return x


@python_scope
def field(dtype, *args, **kwargs):
"""Defines a Taichi field.
A Taichi field can be viewed as an abstract N-dimensional array, hiding away
the complexity of how its underlying :class:`~taichi.lang.snode.SNode` are
actually defined. The data in a Taichi field can be directly accessed by
a Taichi :func:`~taichi.lang.kernel_impl.kernel`.
See also https://docs.taichi-lang.org/docs/field
Args:
dtype (DataType): data type of the field. Note it can be vector or matrix types as well.
shape (Union[int, tuple[int]], optional): shape of the field.
order (str, optional): order of the shape laid out in memory.
name (str, optional): name of the field.
offset (Union[int, tuple[int]], optional): offset of the field domain.
needs_grad (bool, optional): whether this field participates in autodiff (reverse mode)
and thus needs an adjoint field to store the gradients.
needs_dual (bool, optional): whether this field participates in autodiff (forward mode)
and thus needs an dual field to store the gradients.
Example::
The code below shows how a Taichi field can be declared and defined::
>>> x1 = ti.field(ti.f32, shape=(16, 8))
>>> # Equivalently
>>> x2 = ti.field(ti.f32)
>>> ti.root.dense(ti.ij, shape=(16, 8)).place(x2)
>>>
>>> x3 = ti.field(ti.f32, shape=(16, 8), order='ji')
>>> # Equivalently
>>> x4 = ti.field(ti.f32)
>>> ti.root.dense(ti.j, shape=8).dense(ti.i, shape=16).place(x4)
>>>
>>> x5 = ti.field(ti.math.vec3, shape=(16, 8))
"""
if isinstance(dtype, MatrixType):
if dtype.ndim == 1:
return Vector.field(dtype.n, dtype.dtype, *args, **kwargs)
return Matrix.field(dtype.n, dtype.m, dtype.dtype, *args, **kwargs)
return _field(dtype, *args, **kwargs)


@python_scope
def ndarray(dtype, shape, needs_grad=False):
"""Defines a Taichi ndarray with scalar elements.
Expand Down
30 changes: 18 additions & 12 deletions tests/python/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def test_scalar_field(dtype, shape):
@pytest.mark.parametrize('shape', field_shapes)
@test_utils.test(arch=get_host_arch_list())
def test_vector_field(n, dtype, shape):
x = ti.Vector.field(n, dtype, shape)
vec_type = ti.types.vector(n, dtype)
x = ti.field(vec_type, shape)

if isinstance(shape, tuple):
assert x.shape == shape
Expand All @@ -52,7 +53,8 @@ def test_vector_field(n, dtype, shape):
@pytest.mark.parametrize('shape', field_shapes)
@test_utils.test(arch=get_host_arch_list())
def test_matrix_field(n, m, dtype, shape):
x = ti.Matrix.field(n, m, dtype=dtype, shape=shape)
mat_type = ti.types.matrix(n, m, dtype)
x = ti.field(dtype=mat_type, shape=shape)

if isinstance(shape, tuple):
assert x.shape == shape
Expand Down Expand Up @@ -132,13 +134,14 @@ def test_field_needs_grad_dtype():
match=
r".* is not supported for field with `needs_grad=True` or `needs_dual=True`."
):
b = ti.Vector.field(3, int, shape=1, needs_grad=True)
b = ti.field(ti.math.ivec3, shape=1, needs_grad=True)
with pytest.raises(
RuntimeError,
match=
r".* is not supported for field with `needs_grad=True` or `needs_dual=True`."
):
c = ti.Matrix.field(2, 3, int, shape=1, needs_grad=True)
mat_type = ti.types.matrix(2, 3, int)
c = ti.field(dtype=mat_type, shape=1, needs_grad=True)
with pytest.raises(
RuntimeError,
match=
Expand Down Expand Up @@ -168,13 +171,14 @@ def test_field_needs_dual_dtype():
match=
r".* is not supported for field with `needs_grad=True` or `needs_dual=True`."
):
b = ti.Vector.field(3, int, shape=1, needs_dual=True)
b = ti.field(ti.math.ivec3, shape=1, needs_dual=True)
with pytest.raises(
RuntimeError,
match=
r".* is not supported for field with `needs_grad=True` or `needs_dual=True`."
):
c = ti.Matrix.field(2, 3, int, shape=1, needs_dual=True)
mat_type = ti.types.matrix(2, 3, int)
c = ti.field(mat_type, shape=1, needs_dual=True)
with pytest.raises(
RuntimeError,
match=
Expand All @@ -194,8 +198,9 @@ def test_field_needs_dual_dtype():
@pytest.mark.parametrize('dtype', [ti.f32, ti.f64])
def test_default_fp(dtype):
ti.init(default_fp=dtype)
vec_type = ti.types.vector(3, dtype)

x = ti.Vector.field(2, float, ())
x = ti.field(vec_type, ())

assert x.dtype == impl.get_runtime().default_fp

Expand All @@ -204,16 +209,16 @@ def test_default_fp(dtype):
def test_default_ip(dtype):
ti.init(default_ip=dtype)

x = ti.Vector.field(2, int, ())
x = ti.field(ti.math.ivec2, ())

assert x.dtype == impl.get_runtime().default_ip


@test_utils.test()
def test_field_name():
a = ti.field(dtype=ti.f32, shape=(2, 3), name='a')
b = ti.Vector.field(3, dtype=ti.f32, shape=(2, 3), name='b')
c = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(5, 4), name='c')
b = ti.field(ti.math.vec3, shape=(2, 3), name='b')
c = ti.field(ti.math.mat3, shape=(5, 4), name='c')
assert a._name == 'a'
assert b._name == 'b'
assert c._name == 'c'
Expand Down Expand Up @@ -292,14 +297,15 @@ def test_indexing_with_np_int():

@test_utils.test()
def test_indexing_vec_field_with_np_int():
val = ti.Vector.field(2, ti.i32, shape=(2))
val = ti.field(ti.math.ivec2, shape=(2))
idx = np.int32(0)
val[idx][idx]


@test_utils.test()
def test_indexing_mat_field_with_np_int():
val = ti.Matrix.field(2, 2, ti.i32, shape=(2))
mat_type = ti.types.matrix(2, 2, int)
val = ti.field(mat_type, shape=(2))
idx = np.int32(0)
val[idx][idx, idx]

Expand Down

0 comments on commit a533779

Please sign in to comment.