Skip to content

Commit

Permalink
Add reflect_type argument for symmetric and reflect modes in padding
Browse files Browse the repository at this point in the history
  • Loading branch information
minoring committed Dec 12, 2020
1 parent ca326d6 commit f0a248c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 28 deletions.
70 changes: 44 additions & 26 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2315,37 +2315,55 @@ def _pad_wrap(array, pad_width):
return array


def _pad_symmetric_or_reflect(array, pad_width, mode):
def _pad_symmetric_or_reflect(array, pad_width, mode, reflect_type):
assert mode in ("symmetric", "reflect")
assert reflect_type in ("even", "odd")

for i in range(ndim(array)):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], mode)
continue

n = array.shape[i]
rarray = lax.rev(array, dimensions=(i,))
offset = 1 if (mode == "reflect" and n > 1) else 0

def build_padding(padding, forward):
xs = []
delta = n - offset
while padding > delta:
padding -= delta
p = array if forward else rarray
xs.append(lax.slice_in_dim(p, offset, n, axis=i))
forward = not forward
if padding > 0:
x = lax.slice_in_dim(array if forward else rarray, offset,
padding + offset, axis=i)
xs.append(x)
return xs

parts = reversed(build_padding(pad_width[i, 0], forward=True))
parts = [lax.rev(x, dimensions=(i,)) for x in parts]
parts += [array]
parts += build_padding(pad_width[i, 1], forward=False)
array = lax.concatenate(parts, dimension=i)
def build_padding(array, padding, before):
if before:
edge = lax.slice_in_dim(array, 0, 1, axis=i)
else:
edge = lax.slice_in_dim(array, -1, None, axis=i)

while padding > 0:
curr_pad = _min(padding, n - offset)
padding -= curr_pad

if before:
start = offset
stop = offset + curr_pad
else:
start = -(curr_pad + offset)
stop = None if (mode == "symmetric" or n == 1) else -1

x = lax.slice_in_dim(array, start, stop, axis=i)
x = flip(x, axis=i)

if reflect_type == 'odd':
x = 2 * edge - x
x = x.astype(array.dtype) # Unexpected type conversion might happened
if n > 1:
if before:
edge = lax.slice_in_dim(x, 0, 1, axis=i)
else:
edge = lax.slice_in_dim(x, -1, None, axis=i)

if before:
array = lax.concatenate([x, array], dimension=i)
else:
array = lax.concatenate([array, x], dimension=i)
return array

array = build_padding(array, pad_width[i, 0], before=True)
array = build_padding(array, pad_width[i, 1], before=False)
return array


Expand Down Expand Up @@ -2458,8 +2476,8 @@ def _broadcast_to_pairs(nvals, nd, name):
return nvals


@partial(jit, static_argnums=(1, 2, 4, 5))
def _pad(array, pad_width, mode, constant_values, stat_length, end_values):
@partial(jit, static_argnums=(1, 2, 4, 5, 6))
def _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type):
array = asarray(array)
nd = ndim(array)

Expand All @@ -2483,7 +2501,7 @@ def _pad(array, pad_width, mode, constant_values, stat_length, end_values):
return _pad_wrap(array, pad_width)

elif mode in ("symmetric", "reflect"):
return _pad_symmetric_or_reflect(array, pad_width, mode)
return _pad_symmetric_or_reflect(array, pad_width, mode, reflect_type)

elif mode == "edge":
return _pad_edge(array, pad_width)
Expand All @@ -2504,12 +2522,12 @@ def _pad(array, pad_width, mode, constant_values, stat_length, end_values):

@_wraps(np.pad)
def pad(array, pad_width, mode="constant", constant_values=0, stat_length=None,
end_values=0):
end_values=0, reflect_type="even"):
if isinstance(pad_width, Iterable):
pad_width = tuple(
tuple(int(i) for i in x) if isinstance(x, Iterable) else x
for x in pad_width)
return _pad(array, pad_width, mode, constant_values, stat_length, end_values)
return _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type)


@_wraps(np.stack)
Expand Down
37 changes: 35 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,8 +1291,6 @@ def testOperatorRound(self):
"pad_width": pad_width, "constant_values": constant_values}
for mode, shapes in [
('constant', all_shapes),
('symmetric', nonempty_shapes),
('reflect', nonempty_shapes),
('wrap', nonempty_shapes),
('edge', nonempty_shapes),
]
Expand Down Expand Up @@ -1385,6 +1383,41 @@ def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length):
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_pad_width={}_reflect_type={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, reflect_type),
"shape": shape, "dtype": dtype, "mode": mode, "pad_width": pad_width,
"reflect_type": reflect_type}
for mode in ['symmetric', 'reflect']
for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes)
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 3),),
# (before, after) (not in the docstring but works in numpy)
(2, 1), (1, 2),
# (pad,)
(1,), (2,), (3,),
# pad
0, 5, 7, 10
]
for reflect_type in ['even', 'odd']
if (pad_width != () and
# following types lack precision when calculating odd values
(reflect_type != 'odd' or dtype not in [np.float16, jnp.bfloat16]))))
def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]

np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type)

self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE,
tol={np.float32: 1e-3, np.complex64: 1e-3})
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 16, 6), "numpy <= 1.16.5 has a bug in linear_rmap")
# https://github.com/numpy/numpy/commit/1c45e0df150b1f49982aaa3fc1a328407b5eff7e
@parameterized.named_parameters(jtu.cases_from_list(
Expand Down

0 comments on commit f0a248c

Please sign in to comment.