Skip to content

Commit

Permalink
ENH: use 128-bit integers to avoid overflows in solve_diophantine
Browse files Browse the repository at this point in the history
  • Loading branch information
pv committed Aug 29, 2015
1 parent 74c4454 commit 710beb2
Show file tree
Hide file tree
Showing 4 changed files with 381 additions and 113 deletions.
5 changes: 4 additions & 1 deletion numpy/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,7 @@ def generate_multiarray_templated_sources(ext, build_dir):
join('src', 'private', 'templ_common.h.src'),
join('src', 'private', 'lowlevel_strided_loops.h'),
join('src', 'private', 'mem_overlap.h'),
join('src', 'private', 'npy_extint128.h'),
join('include', 'numpy', 'arrayobject.h'),
join('include', 'numpy', '_neighborhood_iterator_imp.h'),
join('include', 'numpy', 'npy_endian.h'),
Expand Down Expand Up @@ -962,7 +963,9 @@ def generate_umath_c(ext, build_dir):

config.add_extension('multiarray_tests',
sources=[join('src', 'multiarray', 'multiarray_tests.c.src'),
join('src', 'private', 'mem_overlap.c')])
join('src', 'private', 'mem_overlap.c')],
depends=[join('src', 'private', 'mem_overlap.h'),
join('src', 'private', 'npy_extint128.h')])

#######################################################################
# operand_flag_tests module #
Expand Down
160 changes: 54 additions & 106 deletions numpy/core/src/private/mem_overlap.c
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@
- Minimal solutions to a_i x_i + a_j x_j == b are too large,
in some of the intermediate equations.
Note that array memory bound overlap check is done before integer overflows
can occur, so these are not of so much practical relevance, since we are
working in int64.
We do this part of the computation in 128-bit integers.
In general, overflows are expected only if array size is close to
NPY_INT64_MAX, requiring ~exabyte size arrays, which is usually not possible.
References
----------
Expand Down Expand Up @@ -188,103 +189,13 @@
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#include "numpy/ndarraytypes.h"
#include "mem_overlap.h"
#include "npy_extint128.h"


#define MAX(a, b) (((a) >= (b)) ? (a) : (b))
#define MIN(a, b) (((a) <= (b)) ? (a) : (b))


/* Integer addition with overflow checking */
static npy_int64
safe_add(npy_int64 a, npy_int64 b, char *overflow_flag)
{
if (a > 0 && b > NPY_MAX_INT64 - a) {
*overflow_flag = 1;
}
else if (a < 0 && b < NPY_MIN_INT64 - a) {
*overflow_flag = 1;
}
return a + b;
}


/* Integer subtraction with overflow checking */
static npy_int64
safe_sub(npy_int64 a, npy_int64 b, char *overflow_flag)
{
if (a > 0 && b < a - NPY_MAX_INT64) {
*overflow_flag = 1;
}
else if (a < 0 && b > a - NPY_MIN_INT64) {
*overflow_flag = 1;
}
return a - b;
}


/* Integer multiplication with overflow checking */
static npy_int64
safe_mul(npy_int64 a, npy_int64 b, char *overflow_flag)
{
if (a > 0) {
if (b > NPY_MAX_INT64 / a || b < NPY_MIN_INT64 / a) {
*overflow_flag = 1;
}
}
else if (a < 0) {
if (b > 0 && a < NPY_MIN_INT64 / b) {
*overflow_flag = 1;
}
else if (b < 0 && a < NPY_MAX_INT64 / b) {
*overflow_flag = 1;
}
}
return a * b;
}


/* Divide and round down (positive divisor; no overflows) */
static npy_int64
floordiv(npy_int64 a, npy_int64 b)
{
assert(b > 0);

/* C division truncates */
if (a > 0) {
return a / b;
}
else {
npy_int64 v, r;
v = a / b;
r = a % b;
if (r != 0) {
--v; /* cannot overflow */
}
return v;
}
}

/* Divide and round up (positive divisor; no overflows) */
static npy_int64
ceildiv(npy_int64 a, npy_int64 b)
{
assert(b > 0);

if (a < 0) {
return a / b;
}
else {
npy_int64 v, r;
v = a / b;
r = a % b;
if (r != 0) {
++v; /* cannot overflow */
}
return v;
}
}


/**
* Euclid's algorithm for GCD.
*
Expand Down Expand Up @@ -405,7 +316,8 @@ diophantine_dfs(unsigned int v,
npy_int64 *x,
Py_ssize_t *count)
{
npy_int64 a_gcd, gamma, epsilon, a1, l1, u1, a2, l2, u2, c, r, x1, x2, c1, c2, t_l, t_u, t, b2;
npy_int64 a_gcd, gamma, epsilon, a1, u1, a2, u2, c, r, c1, c2, t, t_l, t_u, b2, x1, x2;
npy_extint128_t x10, x20, t_l1, t_l2, t_u1, t_u2;
mem_overlap_t res;
char overflow = 0;

Expand All @@ -416,17 +328,14 @@ diophantine_dfs(unsigned int v,
/* Fetch precomputed values for the reduced problem */
if (v == 1) {
a1 = E[0].a;
l1 = 0;
u1 = E[0].ub;
}
else {
a1 = Ep[v-2].a;
l1 = 0;
u1 = Ep[v-2].ub;
}

a2 = E[v].a;
l2 = 0;
u2 = E[v].ub;

a_gcd = Ep[v-1].a;
Expand All @@ -441,16 +350,55 @@ diophantine_dfs(unsigned int v,
return MEM_OVERLAP_NO;
}

x1 = safe_mul(gamma, c, &overflow);
x2 = safe_mul(epsilon, c, &overflow);

c1 = a2 / a_gcd;
c2 = a1 / a_gcd;

t_l = MAX(ceildiv(safe_sub(l1, x1, &overflow), c1),
ceildiv(safe_sub(x2, u2, &overflow), c2));
t_u = MIN(floordiv(safe_sub(u1, x1, &overflow), c1),
floordiv(safe_sub(x2, l2, &overflow), c2));
/*
The set to enumerate is:
x1 = gamma*c + c1*t
x2 = epsilon*c - c2*t
t integer
0 <= x1 <= u1
0 <= x2 <= u2
and we have c, c1, c2 >= 0
*/

x10 = mul_64_64(gamma, c);
x20 = mul_64_64(epsilon, c);

t_l1 = ceildiv_128_64(neg_128(x10), c1);
t_l2 = ceildiv_128_64(sub_128(x20, to_128(u2), &overflow), c2);

t_u1 = floordiv_128_64(sub_128(to_128(u1), x10, &overflow), c1);
t_u2 = floordiv_128_64(x20, c2);

if (overflow) {
return MEM_OVERLAP_OVERFLOW;
}

if (gt_128(t_l2, t_l1)) {
t_l1 = t_l2;
}

if (gt_128(t_u1, t_u2)) {
t_u1 = t_u2;
}

if (gt_128(t_l1, t_u1)) {
++*count;
return MEM_OVERLAP_NO;
}

t_l = to_64(t_l1, &overflow);
t_u = to_64(t_u1, &overflow);

x10 = add_128(x10, mul_64_64(c1, t_l), &overflow);
x20 = sub_128(x20, mul_64_64(c2, t_l), &overflow);

t_u = safe_sub(t_u, t_l, &overflow);
t_l = 0;
x1 = to_64(x10, &overflow);
x2 = to_64(x20, &overflow);

if (overflow) {
return MEM_OVERLAP_OVERFLOW;
Expand Down Expand Up @@ -624,7 +572,7 @@ diophantine_simplify(unsigned int *n, diophantine_term_t *E, npy_int64 b)
m = *n;
i = 0;
for (j = 0; j < m; ++j) {
E[j].ub = MIN(E[j].ub, floordiv(b, E[j].a));
E[j].ub = MIN(E[j].ub, b / E[j].a);
if (E[j].ub == 0) {
/* If the problem is feasible at all, x[i]=0 */
--*n;
Expand Down
Loading

0 comments on commit 710beb2

Please sign in to comment.