Skip to content

Commit

Permalink
[spirv] Fix duplicated interface id for global_tmp buffers (taichi-de…
Browse files Browse the repository at this point in the history
…v#6392)

Issue: #

### Brief Summary

Before this PR the generated spirv shader fails spirv validation.
  • Loading branch information
ailzhang authored Oct 20, 2022
1 parent 00ca1cf commit d6322f0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
15 changes: 12 additions & 3 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1733,11 +1733,20 @@ class TaskCodegen : public IRVisitor {
std::vector<spirv::Value> buffers;
if (device_->get_cap(DeviceCapability::spirv_version) > 0x10300) {
buffers = shared_array_binds_;
std::unordered_set<BufferInfo, BufferInfoHasher> unique_bufs;
// One buffer can be bound to different bind points but has to be unique
// in OpEntryPoint interface declarations.
// From Spec: before SPIR-V version 1.4, duplication of these interface id
// is tolerated. Starting with version 1.4, an interface id must not
// appear more than once.
for (const auto &bb : task_attribs_.buffer_binds) {
for (auto &it : buffer_value_map_) {
if (it.first.first == bb.buffer) {
buffers.push_back(it.second);
if (unique_bufs.count(bb.buffer) == 0) {
for (auto &it : buffer_value_map_) {
if (it.first.first == bb.buffer) {
buffers.push_back(it.second);
}
}
unique_bufs.insert(bb.buffer);
}
}
}
Expand Down
30 changes: 30 additions & 0 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,3 +645,33 @@ def func(arr_src: ti.template(), arr_dst: ti.template()):
with pytest.raises(ti.TaichiRuntimeTypeError,
match=r"Ndarray shouldn't be passed in via"):
func(arr_0, arr_1)


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_gaussian_kernel():
M_PI = 3.14159265358979323846

@ti.func
def gaussian(x, sigma):
return ti.exp(
-0.5 * ti.pow(x / sigma, 2)) / (sigma * ti.sqrt(2.0 * M_PI))

@ti.kernel
def fill_gaussian_kernel(
ker: ti.types.ndarray(ti.f32, field_dim=1), N: ti.i32):
sum = 0.0
for i in range(2 * N + 1):
ker[i] = gaussian(i - N, ti.sqrt(N))
sum += ker[i]
for i in range(2 * N + 1):
ker[i] = ker[i] / sum

N = 4
arr = ti.ndarray(dtype=ti.f32, shape=(20))
fill_gaussian_kernel(arr, N)
res = arr.to_numpy()

np_arr = np.zeros(20, dtype=np.float32)
fill_gaussian_kernel(np_arr, N)

assert test_utils.allclose(res, np_arr)

0 comments on commit d6322f0

Please sign in to comment.