diff --git a/taichi/aot/graph_data.cpp b/taichi/aot/graph_data.cpp index 1ebefde10cbde..2cb9eafcd2acc 100644 --- a/taichi/aot/graph_data.cpp +++ b/taichi/aot/graph_data.cpp @@ -23,7 +23,8 @@ void CompiledGraph::run( TI_ERROR_IF(found == args.end(), "Missing runtime value for {}", symbolic_arg.name); const aot::IValue &ival = found->second; - if (ival.tag == aot::ArgKind::kNdarray) { + if (symbolic_arg.tag == aot::ArgKind::kNdarray) { + TI_ASSERT(ival.tag == aot::ArgKind::kNdarray); Ndarray *arr = reinterpret_cast(ival.val); TI_ERROR_IF(arr->get_element_shape() != symbolic_arg.element_shape, @@ -65,12 +66,17 @@ void CompiledGraph::run( arr_primitive_dtype.to_string()); ctx.set_arg_ndarray(i, arr->get_device_allocation_ptr_as_int(), arr->shape); - } else if (ival.tag == aot::ArgKind::kScalar) { + } else if (symbolic_arg.tag == aot::ArgKind::kScalar || + symbolic_arg.tag == aot::ArgKind::kMatrix) { + TI_ASSERT(ival.tag == aot::ArgKind::kScalar); + // Matrix args are flattened so they're same as scalars. ctx.set_arg(i, ival.val); - } else if (ival.tag == aot::ArgKind::kTexture) { + } else if (symbolic_arg.tag == aot::ArgKind::kTexture) { + TI_ASSERT(ival.tag == aot::ArgKind::kTexture); Texture *tex = reinterpret_cast(ival.val); ctx.set_arg_texture(i, tex->get_device_allocation_ptr_as_int()); - } else if (ival.tag == aot::ArgKind::kRWTexture) { + } else if (symbolic_arg.tag == aot::ArgKind::kRWTexture) { + TI_ASSERT(ival.tag == aot::ArgKind::kTexture); Texture *tex = reinterpret_cast(ival.val); ctx.set_arg_rw_texture(i, tex->get_device_allocation_ptr_as_int(), tex->get_size()); diff --git a/tests/python/test_graph.py b/tests/python/test_graph.py index dad64359702a1..1cd653e4a7133 100644 --- a/tests/python/test_graph.py +++ b/tests/python/test_graph.py @@ -448,3 +448,41 @@ def test_kernel(arr: ti.types.ndarray()): g = g_builder.compile() g.run({'arr': arr}) assert arr.to_numpy() == 1 + + +@test_utils.test(arch=[ti.vulkan]) +def test_texture_struct_for(): + res = (128, 128) + tex = ti.Texture(ti.Format.r32f, res) + arr = ti.ndarray(ti.f32, res) + + @ti.kernel + def write(tex: ti.types.rw_texture(num_dimensions=2, + num_channels=1, + channel_format=ti.f32, + lod=0)): + for i, j in tex: + tex.store(ti.Vector([i, j]), ti.Vector([1.0, 0.0, 0.0, 0.0])) + + @ti.kernel + def read(tex: ti.types.texture(num_dimensions=2), arr: ti.types.ndarray()): + for i, j in arr: + arr[i, j] = tex.fetch(ti.Vector([i, j]), 0).x + + sym_tex = ti.graph.Arg(ti.graph.ArgKind.RWTEXTURE, + 'tex', + channel_format=ti.f32, + shape=res, + num_channels=1) + sym_arr = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, + 'arr', + ti.f32, + field_dim=2) + + gb = ti.graph.GraphBuilder() + gb.dispatch(write, sym_tex) + gb.dispatch(read, sym_tex, sym_arr) + graph = gb.compile() + + graph.run({'tex': tex, 'arr': arr}) + assert arr.to_numpy().sum() == 128 * 128