Skip to content

Commit

Permalink
[bug] [aot] Fix texture struct for with cgraph (taichi-dev#6536)
Browse files Browse the repository at this point in the history
Fixes taichi-dev#6518

We should handle graph args based on their symbolic_arg type instead of
runtime type.

Issue: #

### Brief Summary
  • Loading branch information
ailzhang authored Nov 9, 2022
1 parent f957092 commit 4b51af6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
14 changes: 10 additions & 4 deletions taichi/aot/graph_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ void CompiledGraph::run(
TI_ERROR_IF(found == args.end(), "Missing runtime value for {}",
symbolic_arg.name);
const aot::IValue &ival = found->second;
if (ival.tag == aot::ArgKind::kNdarray) {
if (symbolic_arg.tag == aot::ArgKind::kNdarray) {
TI_ASSERT(ival.tag == aot::ArgKind::kNdarray);
Ndarray *arr = reinterpret_cast<Ndarray *>(ival.val);

TI_ERROR_IF(arr->get_element_shape() != symbolic_arg.element_shape,
Expand Down Expand Up @@ -65,12 +66,17 @@ void CompiledGraph::run(
arr_primitive_dtype.to_string());
ctx.set_arg_ndarray(i, arr->get_device_allocation_ptr_as_int(),
arr->shape);
} else if (ival.tag == aot::ArgKind::kScalar) {
} else if (symbolic_arg.tag == aot::ArgKind::kScalar ||
symbolic_arg.tag == aot::ArgKind::kMatrix) {
TI_ASSERT(ival.tag == aot::ArgKind::kScalar);
// Matrix args are flattened so they're same as scalars.
ctx.set_arg(i, ival.val);
} else if (ival.tag == aot::ArgKind::kTexture) {
} else if (symbolic_arg.tag == aot::ArgKind::kTexture) {
TI_ASSERT(ival.tag == aot::ArgKind::kTexture);
Texture *tex = reinterpret_cast<Texture *>(ival.val);
ctx.set_arg_texture(i, tex->get_device_allocation_ptr_as_int());
} else if (ival.tag == aot::ArgKind::kRWTexture) {
} else if (symbolic_arg.tag == aot::ArgKind::kRWTexture) {
TI_ASSERT(ival.tag == aot::ArgKind::kTexture);
Texture *tex = reinterpret_cast<Texture *>(ival.val);
ctx.set_arg_rw_texture(i, tex->get_device_allocation_ptr_as_int(),
tex->get_size());
Expand Down
38 changes: 38 additions & 0 deletions tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,41 @@ def test_kernel(arr: ti.types.ndarray()):
g = g_builder.compile()
g.run({'arr': arr})
assert arr.to_numpy() == 1


@test_utils.test(arch=[ti.vulkan])
def test_texture_struct_for():
res = (128, 128)
tex = ti.Texture(ti.Format.r32f, res)
arr = ti.ndarray(ti.f32, res)

@ti.kernel
def write(tex: ti.types.rw_texture(num_dimensions=2,
num_channels=1,
channel_format=ti.f32,
lod=0)):
for i, j in tex:
tex.store(ti.Vector([i, j]), ti.Vector([1.0, 0.0, 0.0, 0.0]))

@ti.kernel
def read(tex: ti.types.texture(num_dimensions=2), arr: ti.types.ndarray()):
for i, j in arr:
arr[i, j] = tex.fetch(ti.Vector([i, j]), 0).x

sym_tex = ti.graph.Arg(ti.graph.ArgKind.RWTEXTURE,
'tex',
channel_format=ti.f32,
shape=res,
num_channels=1)
sym_arr = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'arr',
ti.f32,
field_dim=2)

gb = ti.graph.GraphBuilder()
gb.dispatch(write, sym_tex)
gb.dispatch(read, sym_tex, sym_arr)
graph = gb.compile()

graph.run({'tex': tex, 'arr': arr})
assert arr.to_numpy().sum() == 128 * 128

0 comments on commit 4b51af6

Please sign in to comment.