From 710beb20281eb65a291edfd041942301296ff125 Mon Sep 17 00:00:00 2001 From: Pauli Virtanen Date: Thu, 13 Aug 2015 23:02:00 +0300 Subject: [PATCH] ENH: use 128-bit integers to avoid overflows in solve_diophantine --- numpy/core/setup.py | 5 +- numpy/core/src/private/mem_overlap.c | 160 +++++-------- numpy/core/src/private/npy_extint128.h | 317 +++++++++++++++++++++++++ numpy/core/tests/test_mem_overlap.py | 12 +- 4 files changed, 381 insertions(+), 113 deletions(-) create mode 100644 numpy/core/src/private/npy_extint128.h diff --git a/numpy/core/setup.py b/numpy/core/setup.py index 0a9d5852a86e..6d9926d89e97 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -763,6 +763,7 @@ def generate_multiarray_templated_sources(ext, build_dir): join('src', 'private', 'templ_common.h.src'), join('src', 'private', 'lowlevel_strided_loops.h'), join('src', 'private', 'mem_overlap.h'), + join('src', 'private', 'npy_extint128.h'), join('include', 'numpy', 'arrayobject.h'), join('include', 'numpy', '_neighborhood_iterator_imp.h'), join('include', 'numpy', 'npy_endian.h'), @@ -962,7 +963,9 @@ def generate_umath_c(ext, build_dir): config.add_extension('multiarray_tests', sources=[join('src', 'multiarray', 'multiarray_tests.c.src'), - join('src', 'private', 'mem_overlap.c')]) + join('src', 'private', 'mem_overlap.c')], + depends=[join('src', 'private', 'mem_overlap.h'), + join('src', 'private', 'npy_extint128.h')]) ####################################################################### # operand_flag_tests module # diff --git a/numpy/core/src/private/mem_overlap.c b/numpy/core/src/private/mem_overlap.c index 68def67a7d85..e87d85b121f8 100644 --- a/numpy/core/src/private/mem_overlap.c +++ b/numpy/core/src/private/mem_overlap.c @@ -155,9 +155,10 @@ - Minimal solutions to a_i x_i + a_j x_j == b are too large, in some of the intermediate equations. - Note that array memory bound overlap check is done before integer overflows - can occur, so these are not of so much practical relevance, since we are - working in int64. + We do this part of the computation in 128-bit integers. + + In general, overflows are expected only if array size is close to + NPY_INT64_MAX, requiring ~exabyte size arrays, which is usually not possible. References ---------- @@ -188,103 +189,13 @@ #define NPY_NO_DEPRECATED_API NPY_API_VERSION #include "numpy/ndarraytypes.h" #include "mem_overlap.h" +#include "npy_extint128.h" #define MAX(a, b) (((a) >= (b)) ? (a) : (b)) #define MIN(a, b) (((a) <= (b)) ? (a) : (b)) -/* Integer addition with overflow checking */ -static npy_int64 -safe_add(npy_int64 a, npy_int64 b, char *overflow_flag) -{ - if (a > 0 && b > NPY_MAX_INT64 - a) { - *overflow_flag = 1; - } - else if (a < 0 && b < NPY_MIN_INT64 - a) { - *overflow_flag = 1; - } - return a + b; -} - - -/* Integer subtraction with overflow checking */ -static npy_int64 -safe_sub(npy_int64 a, npy_int64 b, char *overflow_flag) -{ - if (a > 0 && b < a - NPY_MAX_INT64) { - *overflow_flag = 1; - } - else if (a < 0 && b > a - NPY_MIN_INT64) { - *overflow_flag = 1; - } - return a - b; -} - - -/* Integer multiplication with overflow checking */ -static npy_int64 -safe_mul(npy_int64 a, npy_int64 b, char *overflow_flag) -{ - if (a > 0) { - if (b > NPY_MAX_INT64 / a || b < NPY_MIN_INT64 / a) { - *overflow_flag = 1; - } - } - else if (a < 0) { - if (b > 0 && a < NPY_MIN_INT64 / b) { - *overflow_flag = 1; - } - else if (b < 0 && a < NPY_MAX_INT64 / b) { - *overflow_flag = 1; - } - } - return a * b; -} - - -/* Divide and round down (positive divisor; no overflows) */ -static npy_int64 -floordiv(npy_int64 a, npy_int64 b) -{ - assert(b > 0); - - /* C division truncates */ - if (a > 0) { - return a / b; - } - else { - npy_int64 v, r; - v = a / b; - r = a % b; - if (r != 0) { - --v; /* cannot overflow */ - } - return v; - } -} - -/* Divide and round up (positive divisor; no overflows) */ -static npy_int64 -ceildiv(npy_int64 a, npy_int64 b) -{ - assert(b > 0); - - if (a < 0) { - return a / b; - } - else { - npy_int64 v, r; - v = a / b; - r = a % b; - if (r != 0) { - ++v; /* cannot overflow */ - } - return v; - } -} - - /** * Euclid's algorithm for GCD. * @@ -405,7 +316,8 @@ diophantine_dfs(unsigned int v, npy_int64 *x, Py_ssize_t *count) { - npy_int64 a_gcd, gamma, epsilon, a1, l1, u1, a2, l2, u2, c, r, x1, x2, c1, c2, t_l, t_u, t, b2; + npy_int64 a_gcd, gamma, epsilon, a1, u1, a2, u2, c, r, c1, c2, t, t_l, t_u, b2, x1, x2; + npy_extint128_t x10, x20, t_l1, t_l2, t_u1, t_u2; mem_overlap_t res; char overflow = 0; @@ -416,17 +328,14 @@ diophantine_dfs(unsigned int v, /* Fetch precomputed values for the reduced problem */ if (v == 1) { a1 = E[0].a; - l1 = 0; u1 = E[0].ub; } else { a1 = Ep[v-2].a; - l1 = 0; u1 = Ep[v-2].ub; } a2 = E[v].a; - l2 = 0; u2 = E[v].ub; a_gcd = Ep[v-1].a; @@ -441,16 +350,55 @@ diophantine_dfs(unsigned int v, return MEM_OVERLAP_NO; } - x1 = safe_mul(gamma, c, &overflow); - x2 = safe_mul(epsilon, c, &overflow); - c1 = a2 / a_gcd; c2 = a1 / a_gcd; - t_l = MAX(ceildiv(safe_sub(l1, x1, &overflow), c1), - ceildiv(safe_sub(x2, u2, &overflow), c2)); - t_u = MIN(floordiv(safe_sub(u1, x1, &overflow), c1), - floordiv(safe_sub(x2, l2, &overflow), c2)); + /* + The set to enumerate is: + x1 = gamma*c + c1*t + x2 = epsilon*c - c2*t + t integer + 0 <= x1 <= u1 + 0 <= x2 <= u2 + and we have c, c1, c2 >= 0 + */ + + x10 = mul_64_64(gamma, c); + x20 = mul_64_64(epsilon, c); + + t_l1 = ceildiv_128_64(neg_128(x10), c1); + t_l2 = ceildiv_128_64(sub_128(x20, to_128(u2), &overflow), c2); + + t_u1 = floordiv_128_64(sub_128(to_128(u1), x10, &overflow), c1); + t_u2 = floordiv_128_64(x20, c2); + + if (overflow) { + return MEM_OVERLAP_OVERFLOW; + } + + if (gt_128(t_l2, t_l1)) { + t_l1 = t_l2; + } + + if (gt_128(t_u1, t_u2)) { + t_u1 = t_u2; + } + + if (gt_128(t_l1, t_u1)) { + ++*count; + return MEM_OVERLAP_NO; + } + + t_l = to_64(t_l1, &overflow); + t_u = to_64(t_u1, &overflow); + + x10 = add_128(x10, mul_64_64(c1, t_l), &overflow); + x20 = sub_128(x20, mul_64_64(c2, t_l), &overflow); + + t_u = safe_sub(t_u, t_l, &overflow); + t_l = 0; + x1 = to_64(x10, &overflow); + x2 = to_64(x20, &overflow); if (overflow) { return MEM_OVERLAP_OVERFLOW; @@ -624,7 +572,7 @@ diophantine_simplify(unsigned int *n, diophantine_term_t *E, npy_int64 b) m = *n; i = 0; for (j = 0; j < m; ++j) { - E[j].ub = MIN(E[j].ub, floordiv(b, E[j].a)); + E[j].ub = MIN(E[j].ub, b / E[j].a); if (E[j].ub == 0) { /* If the problem is feasible at all, x[i]=0 */ --*n; diff --git a/numpy/core/src/private/npy_extint128.h b/numpy/core/src/private/npy_extint128.h new file mode 100644 index 000000000000..6a35e736fd1a --- /dev/null +++ b/numpy/core/src/private/npy_extint128.h @@ -0,0 +1,317 @@ +#ifndef NPY_EXTINT128_H_ +#define NPY_EXTINT128_H_ + + +typedef struct { + char sign; + npy_uint64 lo, hi; +} npy_extint128_t; + + +/* Integer addition with overflow checking */ +static NPY_INLINE npy_int64 +safe_add(npy_int64 a, npy_int64 b, char *overflow_flag) +{ + if (a > 0 && b > NPY_MAX_INT64 - a) { + *overflow_flag = 1; + } + else if (a < 0 && b < NPY_MIN_INT64 - a) { + *overflow_flag = 1; + } + return a + b; +} + + +/* Integer subtraction with overflow checking */ +static NPY_INLINE npy_int64 +safe_sub(npy_int64 a, npy_int64 b, char *overflow_flag) +{ + if (a >= 0 && b < a - NPY_MAX_INT64) { + *overflow_flag = 1; + } + else if (a < 0 && b > a - NPY_MIN_INT64) { + *overflow_flag = 1; + } + return a - b; +} + + +/* Integer multiplication with overflow checking */ +static NPY_INLINE npy_int64 +safe_mul(npy_int64 a, npy_int64 b, char *overflow_flag) +{ + if (a > 0) { + if (b > NPY_MAX_INT64 / a || b < NPY_MIN_INT64 / a) { + *overflow_flag = 1; + } + } + else if (a < 0) { + if (b > 0 && a < NPY_MIN_INT64 / b) { + *overflow_flag = 1; + } + else if (b < 0 && a < NPY_MAX_INT64 / b) { + *overflow_flag = 1; + } + } + return a * b; +} + + +/* Long integer init */ +static NPY_INLINE npy_extint128_t +to_128(npy_int64 x) +{ + npy_extint128_t result; + result.sign = (x >= 0 ? 1 : -1); + if (x >= 0) { + result.lo = x; + } + else { + result.lo = (npy_uint64)(-(x + 1)) + 1; + } + result.hi = 0; + return result; +} + + +static NPY_INLINE npy_int64 +to_64(npy_extint128_t x, char *overflow) +{ + if (x.hi != 0 || + (x.sign > 0 && x.lo > NPY_MAX_INT64) || + (x.sign < 0 && x.lo != 0 && x.lo - 1 > -(NPY_MIN_INT64 + 1))) { + *overflow = 1; + } + return x.lo * x.sign; +} + + +/* Long integer multiply */ +static NPY_INLINE npy_extint128_t +mul_64_64(npy_int64 a, npy_int64 b) +{ + npy_extint128_t x, y, z; + npy_uint64 x1, x2, y1, y2, r1, r2, prev; + + x = to_128(a); + y = to_128(b); + + x1 = x.lo & 0xffffffff; + x2 = x.lo >> 32; + + y1 = y.lo & 0xffffffff; + y2 = y.lo >> 32; + + r1 = x1*y2; + r2 = x2*y1; + + z.sign = x.sign * y.sign; + z.hi = x2*y2 + (r1 >> 32) + (r2 >> 32); + z.lo = x1*y1; + + /* Add with carry */ + prev = z.lo; + z.lo += (r1 << 32); + if (z.lo < prev) { + ++z.hi; + } + + prev = z.lo; + z.lo += (r2 << 32); + if (z.lo < prev) { + ++z.hi; + } + + return z; +} + + +/* Long integer add */ +static NPY_INLINE npy_extint128_t +add_128(npy_extint128_t x, npy_extint128_t y, char *overflow) +{ + npy_extint128_t z; + + if (x.sign == y.sign) { + z.sign = x.sign; + z.hi = x.hi + y.hi; + if (z.hi < x.hi) { + *overflow = 1; + } + z.lo = x.lo + y.lo; + if (z.lo < x.lo) { + if (z.hi == NPY_MAX_UINT64) { + *overflow = 1; + } + ++z.hi; + } + } + else if (x.hi > y.hi || (x.hi == y.hi && x.lo >= y.lo)) { + z.sign = x.sign; + z.hi = x.hi - y.hi; + z.lo = x.lo; + z.lo -= y.lo; + if (z.lo > x.lo) { + --z.hi; + } + } + else { + z.sign = y.sign; + z.hi = y.hi - x.hi; + z.lo = y.lo; + z.lo -= x.lo; + if (z.lo > y.lo) { + --z.hi; + } + } + + return z; +} + + +/* Long integer negation */ +static NPY_INLINE npy_extint128_t +neg_128(npy_extint128_t x) +{ + npy_extint128_t z = x; + z.sign *= -1; + return z; +} + + +static NPY_INLINE npy_extint128_t +sub_128(npy_extint128_t x, npy_extint128_t y, char *overflow) +{ + return add_128(x, neg_128(y), overflow); +} + + +static NPY_INLINE npy_extint128_t +shl_128(npy_extint128_t v) +{ + npy_extint128_t z; + z = v; + z.hi <<= 1; + z.hi |= (z.lo & (((npy_uint64)1) << 63)) >> 63; + z.lo <<= 1; + return z; +} + + +static NPY_INLINE npy_extint128_t +shr_128(npy_extint128_t v) +{ + npy_extint128_t z; + z = v; + z.lo >>= 1; + z.lo |= (z.hi & 0x1) << 63; + z.hi >>= 1; + return z; +} + +static NPY_INLINE int +gt_128(npy_extint128_t a, npy_extint128_t b) +{ + if (a.sign > 0 && b.sign > 0) { + return (a.hi > b.hi) || (a.hi == b.hi && a.lo > b.lo); + } + else if (a.sign < 0 && b.sign < 0) { + return (a.hi < b.hi) || (a.hi == b.hi && a.lo < b.lo); + } + else if (a.sign > 0 && b.sign < 0) { + return a.hi != 0 || a.lo != 0 || b.hi != 0 || b.lo != 0; + } + else { + return 0; + } +} + + +/* Long integer divide */ +static NPY_INLINE npy_extint128_t +divmod_128_64(npy_extint128_t x, npy_int64 b, npy_int64 *mod) +{ + npy_extint128_t remainder, pointer, result, divisor; + char overflow = 0; + + assert(b > 0); + + if (b <= 1 || x.hi == 0) { + result.sign = x.sign; + result.lo = x.lo / b; + result.hi = x.hi / b; + *mod = x.sign * (x.lo % b); + return result; + } + + /* Long division, not the most efficient choice */ + remainder = x; + remainder.sign = 1; + + divisor.sign = 1; + divisor.hi = 0; + divisor.lo = b; + + result.sign = 1; + result.lo = 0; + result.hi = 0; + + pointer.sign = 1; + pointer.lo = 1; + pointer.hi = 0; + + while ((divisor.hi & (((npy_uint64)1) << 63)) == 0 && + gt_128(remainder, divisor)) { + divisor = shl_128(divisor); + pointer = shl_128(pointer); + } + + while (pointer.lo || pointer.hi) { + if (!gt_128(divisor, remainder)) { + remainder = sub_128(remainder, divisor, &overflow); + result = add_128(result, pointer, &overflow); + } + divisor = shr_128(divisor); + pointer = shr_128(pointer); + } + + /* Fix signs and return; cannot overflow */ + result.sign = x.sign; + *mod = x.sign * remainder.lo; + + return result; +} + + +/* Divide and round down (positive divisor; no overflows) */ +static NPY_INLINE npy_extint128_t +floordiv_128_64(npy_extint128_t a, npy_int64 b) +{ + npy_extint128_t result; + npy_int64 remainder; + char overflow = 0; + assert(b > 0); + result = divmod_128_64(a, b, &remainder); + if (a.sign < 0 && remainder != 0) { + result = sub_128(result, to_128(1), &overflow); + } + return result; +} + + +/* Divide and round up (positive divisor; no overflows) */ +static NPY_INLINE npy_extint128_t +ceildiv_128_64(npy_extint128_t a, npy_int64 b) +{ + npy_extint128_t result; + npy_int64 remainder; + char overflow = 0; + assert(b > 0); + result = divmod_128_64(a, b, &remainder); + if (a.sign > 0 && remainder != 0) { + result = add_128(result, to_128(1), &overflow); + } + return result; +} + +#endif diff --git a/numpy/core/tests/test_mem_overlap.py b/numpy/core/tests/test_mem_overlap.py index f353317dd35c..e48d4891daea 100644 --- a/numpy/core/tests/test_mem_overlap.py +++ b/numpy/core/tests/test_mem_overlap.py @@ -162,13 +162,13 @@ def test_diophantine_overflow(): max_int64 = np.iinfo(np.int64).max if max_int64 <= max_intp: - # The Python wrapper only takes intp inputs, but the solver - # works internally in 64-bit - A = (2, 3) - U = (max_int64//2, max_int64//6) - b = max_int64 - 1 + # Check that the algorithm works internally in 128-bit; + # solving this problem requires large intermediate numbers + A = (max_int64//2, max_int64//2 - 10) + U = (max_int64//2, max_int64//2 - 10) + b = 2*(max_int64//2) - 10 - assert_raises(OverflowError, solve_diophantine, A, U, b) + assert_equal(solve_diophantine(A, U, b), (1, 1)) def check_may_share_memory_exact(a, b):