From ea8838f2c3896c80036247afd989b334fc5e1c8d Mon Sep 17 00:00:00 2001 From: Jiafeng Liu Date: Sun, 29 Nov 2020 09:18:40 +0800 Subject: [PATCH] [type] Use zext instruction to cast unsigned int (#2066) --- taichi/codegen/codegen_llvm.cpp | 9 ++++- tests/python/test_cast.py | 65 +++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index a3882125765e3..c514098a64432 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -333,8 +333,13 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { from_size = data_type_size(from); } if (from_size < data_type_size(to)) { - llvm_val[stmt] = builder->CreateSExt( - llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); + if (is_signed(from)) { + llvm_val[stmt] = builder->CreateSExt( + llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); + } else { + llvm_val[stmt] = builder->CreateZExt( + llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); + } } else { llvm_val[stmt] = builder->CreateTrunc( llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); diff --git a/tests/python/test_cast.py b/tests/python/test_cast.py index 86f94a3fdf10f..cc5dc39dc407f 100644 --- a/tests/python/test_cast.py +++ b/tests/python/test_cast.py @@ -87,3 +87,68 @@ def func2(): func1() func2() assert z[None] == 2333 + + +@ti.test(arch=ti.cpu) +def test_int_extension(): + x = ti.field(dtype=ti.i32, shape=2) + y = ti.field(dtype=ti.u32, shape=2) + + a = ti.field(dtype=ti.i8, shape=1) + b = ti.field(dtype=ti.u8, shape=1) + + @ti.kernel + def run_cast_i32(): + x[0] = ti.cast(a[0], ti.i32) + x[1] = ti.cast(b[0], ti.i32) + + @ti.kernel + def run_cast_u32(): + y[0] = ti.cast(a[0], ti.u32) + y[1] = ti.cast(b[0], ti.u32) + + a[0] = -128 + b[0] = -128 + + run_cast_i32() + assert x[0] == -128 + assert x[1] == 128 + + run_cast_u32() + assert y[0] == 0xFFFFFF80 + assert y[1] == 128 + + +@ti.test(arch=ti.cpu) +def test_custom_int_extension(): + x = ti.field(dtype=ti.i32, shape=2) + y = ti.field(dtype=ti.u32, shape=2) + + ci5 = ti.type_factory_.get_custom_int_type(5, True, 16) + cu7 = ti.type_factory_.get_custom_int_type(7, False, 16) + + a = ti.field(dtype=ci5) + b = ti.field(dtype=cu7) + + ti.root._bit_struct(num_bits=32).place(a, b) + + @ti.kernel + def run_cast_int(): + x[0] = ti.cast(a[None], ti.i32) + x[1] = ti.cast(b[None], ti.i32) + + @ti.kernel + def run_cast_uint(): + y[0] = ti.cast(a[None], ti.u32) + y[1] = ti.cast(b[None], ti.u32) + + a[None] = -16 + b[None] = -64 + + run_cast_int() + assert x[0] == -16 + assert x[1] == 64 + + run_cast_uint() + assert y[0] == 0xFFFFFFF0 + assert y[1] == 64