Skip to content

Commit

Permalink
[type] Use zext instruction to cast unsigned int (taichi-dev#2066)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanke98 authored Nov 29, 2020
1 parent 3b4adaa commit ea8838f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
9 changes: 7 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,13 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) {
from_size = data_type_size(from);
}
if (from_size < data_type_size(to)) {
llvm_val[stmt] = builder->CreateSExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
if (is_signed(from)) {
llvm_val[stmt] = builder->CreateSExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else {
llvm_val[stmt] = builder->CreateZExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
}
} else {
llvm_val[stmt] = builder->CreateTrunc(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
Expand Down
65 changes: 65 additions & 0 deletions tests/python/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,68 @@ def func2():
func1()
func2()
assert z[None] == 2333


@ti.test(arch=ti.cpu)
def test_int_extension():
x = ti.field(dtype=ti.i32, shape=2)
y = ti.field(dtype=ti.u32, shape=2)

a = ti.field(dtype=ti.i8, shape=1)
b = ti.field(dtype=ti.u8, shape=1)

@ti.kernel
def run_cast_i32():
x[0] = ti.cast(a[0], ti.i32)
x[1] = ti.cast(b[0], ti.i32)

@ti.kernel
def run_cast_u32():
y[0] = ti.cast(a[0], ti.u32)
y[1] = ti.cast(b[0], ti.u32)

a[0] = -128
b[0] = -128

run_cast_i32()
assert x[0] == -128
assert x[1] == 128

run_cast_u32()
assert y[0] == 0xFFFFFF80
assert y[1] == 128


@ti.test(arch=ti.cpu)
def test_custom_int_extension():
x = ti.field(dtype=ti.i32, shape=2)
y = ti.field(dtype=ti.u32, shape=2)

ci5 = ti.type_factory_.get_custom_int_type(5, True, 16)
cu7 = ti.type_factory_.get_custom_int_type(7, False, 16)

a = ti.field(dtype=ci5)
b = ti.field(dtype=cu7)

ti.root._bit_struct(num_bits=32).place(a, b)

@ti.kernel
def run_cast_int():
x[0] = ti.cast(a[None], ti.i32)
x[1] = ti.cast(b[None], ti.i32)

@ti.kernel
def run_cast_uint():
y[0] = ti.cast(a[None], ti.u32)
y[1] = ti.cast(b[None], ti.u32)

a[None] = -16
b[None] = -64

run_cast_int()
assert x[0] == -16
assert x[1] == 64

run_cast_uint()
assert y[0] == 0xFFFFFFF0
assert y[1] == 64

0 comments on commit ea8838f

Please sign in to comment.