Skip to content

Commit

Permalink
[type] [bug] Correct mask used in setting bits partially (taichi-dev#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanke98 authored Dec 15, 2020
1 parent 31683c9 commit dba8d6b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
9 changes: 7 additions & 2 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,9 +1554,14 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {
#include "internal_functions.h"

// TODO: make here less repetitious.
// Original implementation is
// u##N mask = ((((u##N)1 << bits) - 1) << offset);
// When N equals bits equals 32, 32 times of left shifting will be carried on
// which is an undefined behavior.
// see #2096 for more details
#define DEFINE_SET_PARTIAL_BITS(N) \
void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \
u##N mask = ((((u##N)1 << bits) - 1) << offset); \
u##N mask = ((~(u##N)0) << (N - bits)) >> (N - offset - bits); \
u##N new_value = 0; \
u##N old_value = *ptr; \
do { \
Expand All @@ -1570,7 +1575,7 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {
\
u##N atomic_add_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, \
u##N value) { \
u##N mask = ((((u##N)1 << bits) - 1) << offset); \
u##N mask = ((~(u##N)0) << (N - bits)) >> (N - offset - bits); \
u##N new_value = 0; \
u##N old_value = *ptr; \
do { \
Expand Down
28 changes: 28 additions & 0 deletions tests/python/test_bit_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,34 @@ def verify_val(idx: ti.i32):
verify_val.__wrapped__(idx)


@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False)
def test_custom_int_full_struct():
cit = ti.type_factory_.get_custom_int_type(32, True)
x = ti.field(dtype=cit)
ti.root.dense(ti.i, 1)._bit_struct(num_bits=32).place(x)

@ti.kernel
def set_val():
x[0] = 15

@ti.kernel
def varify_val1():
assert x[0] == 15

@ti.kernel
def set_val2():
x[0] = 12

@ti.kernel
def varify_val2():
assert x[0] == 12

set_val()
varify_val1()
set_val2()
varify_val2()


def test_bit_struct():
def test_single_bit_struct(physical_type, compute_type, custom_bits,
test_case):
Expand Down

0 comments on commit dba8d6b

Please sign in to comment.