Skip to content

Commit

Permalink
Theano conv fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Feb 10, 2017
1 parent 6710396 commit 46649e5
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 105 deletions.
2 changes: 1 addition & 1 deletion keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from . import optimizers
from . import regularizers

__version__ = '1.2.1'
__version__ = '2.0.0'
8 changes: 8 additions & 0 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,12 @@ def repeat_elements(x, rep, axis):
A tensor.
"""
x_shape = x.get_shape().as_list()
if x_shape[axis] is None:
raise ValueError('Axis ' + str(axis) + ' of input tensor '
' should have a defined dimension, but is None. '
'Full tensor shape: ' + str(tuple(x_shape)) + '. '
'Typically you need to pass a fully-defined '
'`input_shape` argument to your first layer.')
# slices along the repeat axis
try:
splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis)
Expand Down Expand Up @@ -2758,6 +2764,8 @@ def conv2d_transpose(x, kernel, output_shape, strides=(1, 1),
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
if isinstance(output_shape, (tuple, list)):
output_shape = tf.stack(output_shape)

x = _preprocess_conv2d_input(x, data_format)
output_shape = _preprocess_deconv_output_shape(x, output_shape, data_format)
Expand Down
111 changes: 54 additions & 57 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,6 @@ def asymmetric_spatial_2d_padding(x, top_pad=1, bottom_pad=1,
input_shape[1] + top_pad + bottom_pad,
input_shape[2] + left_pad + right_pad,
input_shape[3])
print(output_shape)
output = T.zeros(output_shape)
indices = (slice(None),
slice(top_pad, input_shape[1] + top_pad),
Expand All @@ -879,7 +878,7 @@ def asymmetric_spatial_2d_padding(x, top_pad=1, bottom_pad=1,
return T.set_subtensor(output[indices], x)


def spatial_3d_padding(x, padding=(1, 1, 1), data_format=None):
def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
"""Pad the 2nd, 3rd and 4th dimensions of a 5D tensor
with "padding[0]", "padding[1]" and "padding[2]" (resp.) zeros left and right.
"""
Expand All @@ -892,27 +891,27 @@ def spatial_3d_padding(x, padding=(1, 1, 1), data_format=None):
if data_format == 'channels_first':
output_shape = (input_shape[0],
input_shape[1],
input_shape[2] + 2 * padding[0],
input_shape[3] + 2 * padding[1],
input_shape[4] + 2 * padding[2])
input_shape[2] + padding[0][0] + padding[0][1],
input_shape[3] + padding[1][0] + padding[1][1],
input_shape[4] + padding[2][0] + padding[2][1])
output = T.zeros(output_shape)
indices = (slice(None),
slice(None),
slice(padding[0], input_shape[2] + padding[0]),
slice(padding[1], input_shape[3] + padding[1]),
slice(padding[2], input_shape[4] + padding[2]))
slice(padding[0][0], input_shape[2] + padding[0][0]),
slice(padding[1][0], input_shape[3] + padding[1][0]),
slice(padding[2][0], input_shape[4] + padding[2][0]))

elif data_format == 'channels_last':
output_shape = (input_shape[0],
input_shape[1] + 2 * padding[0],
input_shape[2] + 2 * padding[1],
input_shape[3] + 2 * padding[2],
input_shape[1] + padding[0][0] + padding[0][1],
input_shape[2] + padding[1][0] + padding[1][1],
input_shape[3] + padding[2][0] + padding[2][1],
input_shape[4])
output = T.zeros(output_shape)
indices = (slice(None),
slice(padding[0], input_shape[1] + padding[0]),
slice(padding[1], input_shape[2] + padding[1]),
slice(padding[2], input_shape[3] + padding[2]),
slice(padding[0][0], input_shape[1] + padding[0][0]),
slice(padding[1][0], input_shape[2] + padding[1][0]),
slice(padding[2][0], input_shape[3] + padding[2][0]),
slice(None))
else:
raise ValueError('Invalid data_format:', data_format)
Expand Down Expand Up @@ -1424,7 +1423,7 @@ def _preprocess_padding(padding):
return th_padding


def _preprocess_conv2d_image_shape(data_format, image_shape):
def _preprocess_conv2d_image_shape(image_shape, data_format):
# Theano might not accept long type
def int_or_none(value):
try:
Expand All @@ -1440,7 +1439,7 @@ def int_or_none(value):
return image_shape


def _preprocess_conv3d_volume_shape(data_format, volume_shape):
def _preprocess_conv3d_volume_shape(volume_shape, data_format):
# Theano might not accept long type
def int_or_none(value):
try:
Expand All @@ -1456,33 +1455,31 @@ def int_or_none(value):
return volume_shape


def _preprocess_conv2d_filter_shape(data_format, filter_shape):
def _preprocess_conv2d_filter_shape(filter_shape, data_format):
# Theano might not accept long type
def int_or_none(value):
try:
return int(value)
except TypeError:
return None
if data_format == 'channels_last':
if filter_shape:
filter_shape = (filter_shape[3], filter_shape[2],
filter_shape[0], filter_shape[1])
if filter_shape:
filter_shape = (filter_shape[3], filter_shape[2],
filter_shape[0], filter_shape[1])
if filter_shape is not None:
filter_shape = tuple(int_or_none(v) for v in filter_shape)
return filter_shape


def _preprocess_conv3d_filter_shape(data_format, filter_shape):
def _preprocess_conv3d_filter_shape(filter_shape, data_format):
# Theano might not accept long type
def int_or_none(value):
try:
return int(value)
except TypeError:
return None
if data_format == 'channels_last':
if filter_shape:
filter_shape = (filter_shape[4], filter_shape[3],
filter_shape[0], filter_shape[1], filter_shape[2])
if filter_shape:
filter_shape = (filter_shape[4], filter_shape[3],
filter_shape[0], filter_shape[1], filter_shape[2])
if filter_shape is not None:
filter_shape = tuple(int_or_none(v) for v in filter_shape)
return filter_shape
Expand Down Expand Up @@ -1525,34 +1522,36 @@ def conv1d(x, kernel, stride=1, padding='valid',
strides: stride integer.
padding: string, "same" or "valid".
data_format: string, one of "channels_last", "channels_first"
dilate_rate: integer.
dilation_rate: integer.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ', data_format)
if hasattr(x, '_keras_shape'):
shape = x._keras_shape
else:
shape = None
if data_format == 'channels_last':
# original shape: (batch, length, input_dim)
# add dim to x to have (batch, length, 1, input_dim)
x = expand_dims(x, 2)
# update x._keras_shape
if hasattr(x, '_keras_shape'):
shape = x._keras_shape
if shape is not None:
x._keras_shape = (shape[0], shape[1], 1, shape[2])
else:
# original shape: (batch, input_dim, length)
# add dim to x to have (batch, input_dim, length, 1)
x = expand_dims(x, 3)
# update x._keras_shape
if hasattr(x, '_keras_shape'):
shape = x._keras_shape
if shape is not None:
x._keras_shape = (shape[0], shape[1], shape[2], 1)
# update dilation rate, strides
dilation_rate = (dilation_rate, 1)
strides = (stride, 1)
# add dim to kernel (always same format independently of data_format)
# i.e. (rows, 1, input_depth, depth)
kernel = expand_dims(x, 1)
kernel = expand_dims(kernel, 1)
output = conv2d(x, kernel,
strides=strides, padding=padding,
data_format=data_format, dilation_rate=dilation_rate)
Expand Down Expand Up @@ -1581,19 +1580,20 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid',
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ', data_format)

x = _preprocess_conv2d_input(x, data_format)
kernel = _preprocess_conv2d_kernel(kernel, data_format)
th_padding = _preprocess_padding(padding)

if hasattr(x, '_keras_shape'):
image_shape = _preprocess_conv2d_image_shape(int_shape(x), data_format)
else:
image_shape = None
if hasattr(kernel, '_keras_shape'):
kernel_shape = kernel._keras_shape
else:
# Will only work if `kernel` is a shared variable.
kernel_shape = kernel.eval().shape
kernel_shape = _preprocess_conv2d_filter_shape(kernel_shape, data_format)

image_shape = int_shape(x)
image_shape = _preprocess_conv2d_image_shape(data_format, image_shape)
kernel_shape = _preprocess_conv2d_filter_shape(data_format, kernel_shape)
x = _preprocess_conv2d_input(x, data_format)
kernel = _preprocess_conv2d_kernel(kernel, data_format)
th_padding = _preprocess_padding(padding)

conv_out = T.nnet.conv2d(x, kernel,
border_mode=th_padding,
Expand Down Expand Up @@ -1631,22 +1631,18 @@ def conv2d_transpose(x, kernel, output_shape, strides=(1, 1),
output_shape[1],
output_shape[2])

x = _preprocess_conv2d_input(x, data_format)
kernel = _preprocess_conv2d_kernel(kernel, data_format)

kernel = kernel.dimshuffle((1, 0, 2, 3))
th_padding = _preprocess_padding(padding)

if hasattr(kernel, '_keras_shape'):
kernel_shape = kernel._keras_shape
else:
# Will only work if `kernel` is a shared variable.
kernel_shape = kernel.eval().shape
kernel_shape = _preprocess_conv2d_filter_shape(kernel_shape, data_format)

filter_shape = _preprocess_conv2d_filter_shape(data_format, kernel_shape)
filter_shape = tuple(filter_shape[i] for i in (1, 0, 2, 3))
x = _preprocess_conv2d_input(x, data_format)
kernel = _preprocess_conv2d_kernel(kernel, data_format)

op = T.nnet.abstract_conv.AbstractConv2d_gradInputs(imshp=output_shape,
th_padding = _preprocess_padding(padding)
op = T.nnet.abstract_conv.AbstractConv2d_gradInputs(imshp=None,
kshp=kernel_shape,
subsample=strides,
border_mode=th_padding,
Expand Down Expand Up @@ -1680,19 +1676,20 @@ def conv3d(x, kernel, strides=(1, 1, 1),
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format:', data_format)

x = _preprocess_conv3d_input(x, data_format)
kernel = _preprocess_conv3d_kernel(kernel, data_format)
th_padding = _preprocess_padding(padding)

if hasattr(x, '_keras_shape'):
volume_shape = _preprocess_conv3d_volume_shape(int_shape(x), data_format)
else:
volume_shape = None
if hasattr(kernel, '_keras_shape'):
kernel_shape = kernel._keras_shape
else:
# Will only work if `kernel` is a shared variable.
kernel_shape = kernel.eval().shape
kernel_shape = _preprocess_conv3d_filter_shape(kernel_shape, data_format)

volume_shape = int_shape(x)
volume_shape = _preprocess_conv3d_volume_shape(data_format, volume_shape)
kernel_shape = _preprocess_conv3d_filter_shape(data_format, kernel_shape)
x = _preprocess_conv3d_input(x, data_format)
kernel = _preprocess_conv3d_kernel(kernel, data_format)
th_padding = _preprocess_padding(padding)

conv_out = T.nnet.conv3d(x, kernel,
border_mode=th_padding,
Expand Down Expand Up @@ -1815,12 +1812,12 @@ def bias_add(x, bias, data_format=None):
x += reshape(bias, (1, bias.shape[0], 1, 1, 1))
elif data_format == 'channels_last':
x += reshape(bias, (1, 1, 1, 1, bias.shape[0]))
if ndim(x) == 4:
elif ndim(x) == 4:
if data_format == 'channels_first':
x += reshape(bias, (1, bias.shape[0], 1, 1))
elif data_format == 'channels_last':
x += reshape(bias, (1, 1, 1, bias.shape[0]))
if ndim(x) == 3:
elif ndim(x) == 3:
if data_format == 'channels_first':
x += reshape(bias, (1, bias.shape[0], 1))
elif data_format == 'channels_last':
Expand Down
11 changes: 5 additions & 6 deletions keras/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,11 +750,10 @@ def call(self, inputs):
else:
output_shape = (batch_size, out_height, out_width, self.filters)

output_shape_tensor = K.stack(output_shape)
outputs = K.conv2d_transpose(
inputs,
self.kernel,
output_shape_tensor,
output_shape,
self.strides,
padding=self.padding,
data_format=self.data_format)
Expand Down Expand Up @@ -1573,12 +1572,12 @@ def get_output_shape_for(self, input_shape):
if self.data_format == 'channels_first':
return (input_shape[0],
input_shape[1],
input_shape[2] - self.cropping[0][0] - self.cropping[0][1],
input_shape[3] - self.cropping[1][0] - self.cropping[1][1])
input_shape[2] - self.cropping[0][0] - self.cropping[0][1] if input_shape[2] else None,
input_shape[3] - self.cropping[1][0] - self.cropping[1][1] if input_shape[3] else None)
elif self.data_format == 'channels_last':
return (input_shape[0],
input_shape[1] - self.cropping[0][0] - self.cropping[0][1],
input_shape[2] - self.cropping[1][0] - self.cropping[1][1],
input_shape[1] - self.cropping[0][0] - self.cropping[0][1] if input_shape[1] else None,
input_shape[2] - self.cropping[1][0] - self.cropping[1][1] if input_shape[2] else None,
input_shape[3])
else:
raise ValueError('Invalid data_format:', self.data_format)
Expand Down
14 changes: 2 additions & 12 deletions keras/utils/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,35 +61,25 @@ def normalize_padding(value):
return padding


def convert_kernel(kernel, data_format=None):
def convert_kernel(kernel):
"""Converts a Numpy kernel matrix from Theano format to TensorFlow format.
Also works reciprocally, since the transformation is its own inverse.
# Arguments
kernel: Numpy array (4D or 5D).
data_format: the data format.
# Returns
The converted kernel.
# Raises
ValueError: in case of invalid kernel shape or invalid data_format.
"""
if data_format is None:
data_format = K.image_data_format()
if not 4 <= kernel.ndim <= 5:
raise ValueError('Invalid kernel shape:', kernel.shape)

slices = [slice(None, None, -1) for _ in range(kernel.ndim)]
no_flip = (slice(None, None), slice(None, None))
if data_format == 'channels_first': # (out_depth, input_depth, ...)
slices[:2] = no_flip
elif data_format == 'channels_last': # (..., input_depth, out_depth)
slices[-2:] = no_flip
else:
raise ValueError('Invalid data_format:', data_format)

slices[-2:] = no_flip
return np.copy(kernel[slices])


Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@


setup(name='Keras',
version='1.2.1',
version='2.0.0',
description='Deep Learning for Python',
author='Francois Chollet',
author_email='[email protected]',
url='https://github.com/fchollet/keras',
download_url='https://github.com/fchollet/keras/tarball/1.2.1',
download_url='https://github.com/fchollet/keras/tarball/2.0.0',
license='MIT',
install_requires=['theano', 'pyyaml', 'six'],
extras_require={
Expand Down
Loading

0 comments on commit 46649e5

Please sign in to comment.