Skip to content

Commit

Permalink
[type] [bug] Support atomic add negative numbers for custom types (ta…
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanke98 authored Dec 8, 2020
1 parent 6012b79 commit e614519
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
3 changes: 3 additions & 0 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,7 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {

#include "internal_functions.h"

// TODO: make here less repetitious.
#define DEFINE_SET_PARTIAL_BITS(N) \
void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \
u##N mask = ((((u##N)1 << bits) - 1) << offset); \
Expand All @@ -1569,11 +1570,13 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {
\
u##N atomic_add_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, \
u##N value) { \
u##N mask = ((((u##N)1 << bits) - 1) << offset); \
u##N new_value = 0; \
u##N old_value = *ptr; \
do { \
old_value = *ptr; \
new_value = old_value + (value << offset); \
new_value = (old_value & (~mask)) | (new_value & mask); \
} while ( \
!__atomic_compare_exchange(ptr, &old_value, &new_value, true, \
std::memory_order::memory_order_seq_cst, \
Expand Down
23 changes: 15 additions & 8 deletions tests/python/test_custom_type_atomics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,35 @@
@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False)
def test_custom_int_atomics():
ci13 = ti.type_factory_.get_custom_int_type(13, True)
ci5 = ti.type_factory_.get_custom_int_type(5, True)
cu2 = ti.type_factory_.get_custom_int_type(2, False)

x = ti.field(dtype=ci13)
y = ti.field(dtype=cu2)
y = ti.field(dtype=ci5)
z = ti.field(dtype=cu2)

ti.root._bit_struct(num_bits=32).place(x, y)
ti.root._bit_struct(num_bits=32).place(x, y, z)

x[None] = 3
y[None] = 0
y[None] = 2
z[None] = 0

@ti.kernel
def foo():
for i in range(10):
x[None] += 4

for j in range(3):
y[None] += 1
for j in range(5):
y[None] -= 1

for k in range(3):
z[None] += 1

foo()

assert x[None] == 43
assert y[None] == 3
assert y[None] == -3
assert z[None] == 3


@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False)
Expand Down Expand Up @@ -70,9 +77,9 @@ def foo():
x[None] = 0.7
y[None] = 123.4
for _ in range(10):
x[None] += 0.4
x[None] -= 0.4
y[None] += 100.1

foo()
assert x[None] == approx(4.7)
assert x[None] == approx(-3.3)
assert y[None] == approx(1124.4)

0 comments on commit e614519

Please sign in to comment.