Skip to content

Commit

Permalink
[type] [bug] Fix global load of CustomFloatType on CUDA (taichi-dev#2115
Browse files Browse the repository at this point in the history
)

Co-authored-by: Yuanming Hu <[email protected]>
  • Loading branch information
Hanke98 and yuanming-hu authored Dec 23, 2020
1 parent c43a0de commit 111ad8c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
12 changes: 8 additions & 4 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,23 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
auto val_type = ptr_type->get_pointee_type();
llvm::Value *data_ptr = nullptr;
llvm::Value *bit_offset = nullptr;
Type *int_in_mem = nullptr;
// For CustomIntType "int_in_mem" refers to the type itself;
// for CustomFloatType "int_in_mem" refers to the CustomIntType of the
// digits.
if (auto cit = val_type->cast<CustomIntType>()) {
int_in_mem = val_type;
dtype = cit->get_physical_type();
} else if (auto cft = val_type->cast<CustomFloatType>()) {
dtype = cft->get_compute_type()
->as<CustomIntType>()
->get_physical_type();
int_in_mem = cft->get_digits_type();
dtype = int_in_mem->as<CustomIntType>()->get_physical_type();
} else {
TI_NOT_IMPLEMENTED;
}
read_bit_pointer(llvm_val[stmt->ptr], data_ptr, bit_offset);
data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype));
auto data = create_intrinsic_load(dtype, data_ptr);
llvm_val[stmt] = extract_custom_int(data, bit_offset, val_type);
llvm_val[stmt] = extract_custom_int(data, bit_offset, int_in_mem);
if (val_type->is<CustomFloatType>()) {
llvm_val[stmt] = reconstruct_custom_float(llvm_val[stmt], val_type);
}
Expand Down
19 changes: 19 additions & 0 deletions tests/python/test_custom_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,22 @@ def foo():

foo()
assert x[None] == approx(10.0)


@ti.test(require=ti.extension.quant)
def test_cache_read_only():
ci15 = ti.type_factory_.get_custom_int_type(15, True)
cft = ti.type_factory_.get_custom_float_type(ci15, ti.f32.get_ptr(), 0.1)
x = ti.field(dtype=cft)

ti.root._bit_struct(num_bits=32).place(x)

@ti.kernel
def test(data: ti.f32):
ti.cache_read_only(x)
assert x[None] == data

x[None] = 0.7
test(0.7)
x[None] = 1.2
test(1.2)

0 comments on commit 111ad8c

Please sign in to comment.