From 111ad8ccd3a16b7cbce57b9cf28a2bf6fdcbc8d1 Mon Sep 17 00:00:00 2001 From: Jiafeng Liu Date: Wed, 23 Dec 2020 10:29:53 +0800 Subject: [PATCH] [type] [bug] Fix global load of CustomFloatType on CUDA (#2115) Co-authored-by: Yuanming Hu --- taichi/backends/cuda/codegen_cuda.cpp | 12 ++++++++---- tests/python/test_custom_float.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index fd270adaa2881..b7f14584e57ec 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -425,19 +425,23 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { auto val_type = ptr_type->get_pointee_type(); llvm::Value *data_ptr = nullptr; llvm::Value *bit_offset = nullptr; + Type *int_in_mem = nullptr; + // For CustomIntType "int_in_mem" refers to the type itself; + // for CustomFloatType "int_in_mem" refers to the CustomIntType of the + // digits. if (auto cit = val_type->cast()) { + int_in_mem = val_type; dtype = cit->get_physical_type(); } else if (auto cft = val_type->cast()) { - dtype = cft->get_compute_type() - ->as() - ->get_physical_type(); + int_in_mem = cft->get_digits_type(); + dtype = int_in_mem->as()->get_physical_type(); } else { TI_NOT_IMPLEMENTED; } read_bit_pointer(llvm_val[stmt->ptr], data_ptr, bit_offset); data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype)); auto data = create_intrinsic_load(dtype, data_ptr); - llvm_val[stmt] = extract_custom_int(data, bit_offset, val_type); + llvm_val[stmt] = extract_custom_int(data, bit_offset, int_in_mem); if (val_type->is()) { llvm_val[stmt] = reconstruct_custom_float(llvm_val[stmt], val_type); } diff --git a/tests/python/test_custom_float.py b/tests/python/test_custom_float.py index f644c968b53bb..b661fc709af74 100644 --- a/tests/python/test_custom_float.py +++ b/tests/python/test_custom_float.py @@ -67,3 +67,22 @@ def foo(): foo() assert x[None] == approx(10.0) + + +@ti.test(require=ti.extension.quant) +def test_cache_read_only(): + ci15 = ti.type_factory_.get_custom_int_type(15, True) + cft = ti.type_factory_.get_custom_float_type(ci15, ti.f32.get_ptr(), 0.1) + x = ti.field(dtype=cft) + + ti.root._bit_struct(num_bits=32).place(x) + + @ti.kernel + def test(data: ti.f32): + ti.cache_read_only(x) + assert x[None] == data + + x[None] = 0.7 + test(0.7) + x[None] = 1.2 + test(1.2)