Skip to content

Commit

Permalink
Better support for np.convolve() for prime fields
Browse files Browse the repository at this point in the history
This allows the optimized version to be used when calling the JIT-ed function, not just `__call__()`.
  • Loading branch information
mhostetter committed Nov 8, 2022
1 parent 8a692ff commit 2a27178
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions src/galois/_domains/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,7 @@ def __call__(self, a: Array, b: Array, mode="full") -> Array:
raise ValueError(f"Operation 'convolve' currently only supports mode of 'full', not {mode!r}.")
dtype = a.dtype

if self.field._is_prime_field:
# Determine the minimum dtype to hold the entire product and summation without overflowing
if self.field.dtypes == [np.object_]:
dtype = np.object_
else:
n_sum = min(a.size, b.size)
max_value = n_sum * (self.field.characteristic - 1)**2
dtypes = [dtype for dtype in self.field.dtypes if np.iinfo(dtype).max >= max_value]
dtype = np.object_ if len(dtypes) == 0 else dtypes[0]
return_dtype = a.dtype
c = np.convolve(a.view(np.ndarray).astype(dtype), b.view(np.ndarray).astype(dtype)) # Compute result using native numpy LAPACK/BLAS implementation
c = c % self.field.characteristic # Reduce the result mod p
if np.isscalar(c):
c = self.field(c, dtype=return_dtype)
else:
c = c.astype(return_dtype)
elif self.field.ufunc_mode != "python-calculate":
if self.field.ufunc_mode != "python-calculate":
c = self.jit(a.astype(np.int64), b.astype(np.int64))
c = c.astype(dtype)
else:
Expand All @@ -132,15 +116,36 @@ def __call__(self, a: Array, b: Array, mode="full") -> Array:
return c

def set_globals(self):
global ADD, MULTIPLY
global IS_PRIME_FIELD, CHARACTERISTIC, ADD, MULTIPLY
IS_PRIME_FIELD = self.field._is_prime_field
CHARACTERISTIC = self.field.characteristic
ADD = self.field._add.ufunc
MULTIPLY = self.field._multiply.ufunc

_SIGNATURE = numba.types.FunctionType(int64[:](int64[:], int64[:]))

@staticmethod
def implementation(a, b):
c = np.zeros(a.size + b.size - 1, dtype=a.dtype)
dtype = a.dtype

if IS_PRIME_FIELD:
try:
max_sum = np.iinfo(dtype).max // (CHARACTERISTIC - 1)**2
n_sum = min(a.size, b.size)
overflow = n_sum > max_sum
except: # pylint: disable=bare-except
# This happens when the dtype is np.object_
overflow = False

if not overflow:
# Compute the result using native NumPy LAPACK/BLAS implementation since it is guaranteed to not
# overflow. Then reduce the result mod p.
c = np.convolve(a, b)
c = c % CHARACTERISTIC
return c

# Fall-back brute force method
c = np.zeros(a.size + b.size - 1, dtype=dtype)
for i in range(a.size):
for j in range(b.size - 1, -1, -1):
c[i + j] = ADD(c[i + j], MULTIPLY(a[i], b[j]))
Expand Down

0 comments on commit 2a27178

Please sign in to comment.