Skip to content

Commit

Permalink
[cuda] Support numpy and torch tensors with zeros in shapes (e.g., (5…
Browse files Browse the repository at this point in the history
…, 0, 5)) (taichi-dev#1305)
  • Loading branch information
yuanming-hu authored Jun 24, 2020
1 parent 2ac833b commit f4e7db2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
18 changes: 12 additions & 6 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "taichi/program/program.h"
#include "taichi/lang_util.h"
#include "taichi/backends/cuda/cuda_driver.h"
#include "taichi/backends/cuda/cuda_context.h"
#include "taichi/codegen/codegen_llvm.h"

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -57,18 +58,23 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
kernel = this->kernel](Context &context) {
// copy data to GRAM
auto args = kernel->args;
std::vector<void *> host_buffers(args.size());
std::vector<void *> device_buffers(args.size());
std::vector<void *> host_buffers(args.size(), nullptr);
std::vector<void *> device_buffers(args.size(), nullptr);
bool has_buffer = false;
for (int i = 0; i < (int)args.size(); i++) {
if (args[i].is_nparray) {
has_buffer = true;
CUDADriver::get_instance().malloc(&device_buffers[i], args[i].size);
// replace host buffer with device buffer
host_buffers[i] = get_current_program().context.get_arg<void *>(i);
if (args[i].size > 0) {
// Note: both numpy and PyTorch support arrays/tensors with zeros
// in shapes, e.g., shape=(0) or shape=(100, 0, 200). This makes
// args[i].size = 0.
CUDADriver::get_instance().malloc(&device_buffers[i], args[i].size);
CUDADriver::get_instance().memcpy_host_to_device(
(void *)device_buffers[i], host_buffers[i], args[i].size);
}
kernel->set_arg_nparray(i, (uint64)device_buffers[i], args[i].size);
CUDADriver::get_instance().memcpy_host_to_device(
(void *)device_buffers[i], host_buffers[i], args[i].size);
}
}
if (has_buffer) {
Expand All @@ -87,7 +93,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
CUDADriver::get_instance().stream_synchronize(nullptr);
}
for (int i = 0; i < (int)args.size(); i++) {
if (args[i].is_nparray) {
if (args[i].is_nparray && args[i].size > 0) {
CUDADriver::get_instance().memcpy_device_to_host(
host_buffers[i], (void *)device_buffers[i], args[i].size);
CUDADriver::get_instance().mem_free((void *)device_buffers[i]);
Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,14 @@ def test_numpy(a: ti.ext_arr(), b: ti.ext_arr()):
def test_index_mismatch():
val = ti.var(ti.i32, shape=(1, 2, 3))
val[0, 0] = 1


@ti.all_archs
def test_numpy_zero():
@ti.kernel
def test_numpy(arr: ti.ext_arr()):
pass

test_numpy(np.empty(shape=(0), dtype=np.int32))
test_numpy(np.empty(shape=(0, 5), dtype=np.int32))
test_numpy(np.empty(shape=(5, 0), dtype=np.int32))
11 changes: 11 additions & 0 deletions tests/python/test_torch_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,14 @@ def test_shape_vector():
X1 = x.to_torch()

assert (X == X1).all()


@ti.torch_test
def test_torch_zero():
@ti.kernel
def test_torch(arr: ti.ext_arr()):
pass

test_torch(torch.zeros((0), dtype=torch.int32))
test_torch(torch.zeros((0, 5), dtype=torch.int32))
test_torch(torch.zeros((5, 0, 5), dtype=torch.int32))

0 comments on commit f4e7db2

Please sign in to comment.