diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index 452c283eb5d90..efd0559927869 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -9,6 +9,7 @@ #include "taichi/program/program.h" #include "taichi/lang_util.h" #include "taichi/backends/cuda/cuda_driver.h" +#include "taichi/backends/cuda/cuda_context.h" #include "taichi/codegen/codegen_llvm.h" TLANG_NAMESPACE_BEGIN @@ -57,18 +58,23 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { kernel = this->kernel](Context &context) { // copy data to GRAM auto args = kernel->args; - std::vector host_buffers(args.size()); - std::vector device_buffers(args.size()); + std::vector host_buffers(args.size(), nullptr); + std::vector device_buffers(args.size(), nullptr); bool has_buffer = false; for (int i = 0; i < (int)args.size(); i++) { if (args[i].is_nparray) { has_buffer = true; - CUDADriver::get_instance().malloc(&device_buffers[i], args[i].size); // replace host buffer with device buffer host_buffers[i] = get_current_program().context.get_arg(i); + if (args[i].size > 0) { + // Note: both numpy and PyTorch support arrays/tensors with zeros + // in shapes, e.g., shape=(0) or shape=(100, 0, 200). This makes + // args[i].size = 0. + CUDADriver::get_instance().malloc(&device_buffers[i], args[i].size); + CUDADriver::get_instance().memcpy_host_to_device( + (void *)device_buffers[i], host_buffers[i], args[i].size); + } kernel->set_arg_nparray(i, (uint64)device_buffers[i], args[i].size); - CUDADriver::get_instance().memcpy_host_to_device( - (void *)device_buffers[i], host_buffers[i], args[i].size); } } if (has_buffer) { @@ -87,7 +93,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { CUDADriver::get_instance().stream_synchronize(nullptr); } for (int i = 0; i < (int)args.size(); i++) { - if (args[i].is_nparray) { + if (args[i].is_nparray && args[i].size > 0) { CUDADriver::get_instance().memcpy_device_to_host( host_buffers[i], (void *)device_buffers[i], args[i].size); CUDADriver::get_instance().mem_free((void *)device_buffers[i]); diff --git a/tests/python/test_numpy.py b/tests/python/test_numpy.py index 20eccfaf1919e..f8157b8144db0 100644 --- a/tests/python/test_numpy.py +++ b/tests/python/test_numpy.py @@ -182,3 +182,14 @@ def test_numpy(a: ti.ext_arr(), b: ti.ext_arr()): def test_index_mismatch(): val = ti.var(ti.i32, shape=(1, 2, 3)) val[0, 0] = 1 + + +@ti.all_archs +def test_numpy_zero(): + @ti.kernel + def test_numpy(arr: ti.ext_arr()): + pass + + test_numpy(np.empty(shape=(0), dtype=np.int32)) + test_numpy(np.empty(shape=(0, 5), dtype=np.int32)) + test_numpy(np.empty(shape=(5, 0), dtype=np.int32)) diff --git a/tests/python/test_torch_io.py b/tests/python/test_torch_io.py index a0716f827b2f6..83e17b00aa100 100644 --- a/tests/python/test_torch_io.py +++ b/tests/python/test_torch_io.py @@ -223,3 +223,14 @@ def test_shape_vector(): X1 = x.to_torch() assert (X == X1).all() + + +@ti.torch_test +def test_torch_zero(): + @ti.kernel + def test_torch(arr: ti.ext_arr()): + pass + + test_torch(torch.zeros((0), dtype=torch.int32)) + test_torch(torch.zeros((0, 5), dtype=torch.int32)) + test_torch(torch.zeros((5, 0, 5), dtype=torch.int32))