Skip to content

Commit

Permalink
New conv ops (keras-team#3134)
Browse files Browse the repository at this point in the history
* New function signature for conv2d in backend

* Clean up stuff

* Touch-up TF deconv op

* More cleanup

* Support for TF 3D conv/pool

* Move pooling layers to their own file

* Update TF version in Travis config

* Fix conv3d tests
  • Loading branch information
fchollet authored Jul 4, 2016
1 parent 229f13a commit ee8ff00
Show file tree
Hide file tree
Showing 9 changed files with 733 additions and 491 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ install:

# install TensorFlow
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.7.1-cp27-none-linux_x86_64.whl;
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl;
elif [[ "$TRAVIS_PYTHON_VERSION" == "3.4" ]]; then
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.7.1-cp34-none-linux_x86_64.whl;
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl;
fi
# command to run tests
script:
Expand Down
250 changes: 191 additions & 59 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import copy
import warnings
from .common import _FLOATX, _EPSILON
from .common import _FLOATX, _EPSILON, _IMAGE_DIM_ORDERING

# INTERNAL UTILS

Expand Down Expand Up @@ -998,55 +998,177 @@ def l2_normalize(x, axis):

# CONVOLUTIONS

def _preprocess_conv2d_input(x, dim_ordering):
if _FLOATX == 'float64':
x = tf.cast(x, 'float32')
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
x = tf.transpose(x, (0, 2, 3, 1))
return x

def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th',
image_shape=None, filter_shape=None):
'''2D convolution.

# Arguments
kernel: kernel tensor.
strides: strides tuple.
border_mode: string, "same" or "valid".
dim_ordering: "tf" or "th". Whether to use Theano or TensorFlow dimension ordering
in inputs/kernels/ouputs.
'''
if border_mode == 'same':
padding = 'SAME'
elif border_mode == 'valid':
padding = 'VALID'
else:
raise Exception('Invalid border mode: ' + str(border_mode))
def _preprocess_conv3d_input(x, dim_ordering):
if _FLOATX == 'float64':
x = tf.cast(x, 'float32')
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, conv_dim1, conv_dim2, conv_dim3)
# TF input shape: (samples, conv_dim1, conv_dim2, conv_dim3, input_depth)
x = tf.transpose(x, (0, 2, 3, 4, 1))
return x

strides = (1,) + strides + (1,)

def _preprocess_conv2d_kernel(kernel, dim_ordering):
if _FLOATX == 'float64':
# tf conv2d only supports float32
x = tf.cast(x, 'float32')
kernel = tf.cast(kernel, 'float32')

if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
# TH kernel shape: (depth, input_depth, rows, cols)
# TF kernel shape: (rows, cols, input_depth, depth)
x = tf.transpose(x, (0, 2, 3, 1))
kernel = tf.transpose(kernel, (2, 3, 1, 0))
x = tf.nn.conv2d(x, kernel, strides, padding=padding)
x = tf.transpose(x, (0, 3, 1, 2))
elif dim_ordering == 'tf':
x = tf.nn.conv2d(x, kernel, strides, padding=padding)
return kernel


def _preprocess_conv3d_kernel(kernel, dim_ordering):
if _FLOATX == 'float64':
kernel = tf.cast(kernel, 'float32')
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH kernel shape: (out_depth, input_depth, kernel_dim1, kernel_dim2, kernel_dim3)
# TF kernel shape: (kernel_dim1, kernel_dim2, kernel_dim3, input_depth, out_depth)
kernel = tf.transpose(kernel, (2, 3, 4, 1, 0))
return kernel


def _preprocess_border_mode(border_mode):
if border_mode == 'same':
padding = 'SAME'
elif border_mode == 'valid':
padding = 'VALID'
else:
raise Exception('Unknown dim_ordering: ' + str(dim_ordering))
raise Exception('Invalid border mode: ' + str(border_mode))
return padding


def _postprocess_conv2d_output(x, dim_ordering):
if dim_ordering == 'th':
x = tf.transpose(x, (0, 3, 1, 2))

if _FLOATX == 'float64':
x = tf.cast(x, 'float64')
return x


def _postprocess_conv3d_output(x, dim_ordering):
if dim_ordering == 'th':
x = tf.transpose(x, (0, 4, 1, 2, 3))

if _FLOATX == 'float64':
x = tf.cast(x, 'float64')
return x


def conv2d(x, kernel, strides=(1, 1), border_mode='valid',
dim_ordering=_IMAGE_DIM_ORDERING,
image_shape=None, filter_shape=None):
'''2D convolution.
# Arguments
kernel: kernel tensor.
strides: strides tuple.
border_mode: string, "same" or "valid".
dim_ordering: "tf" or "th".
Whether to use Theano or TensorFlow dimension ordering
in inputs/kernels/ouputs.
'''
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))

x = _preprocess_conv2d_input(x, dim_ordering)
kernel = _preprocess_conv2d_kernel(kernel, dim_ordering)
padding = _preprocess_border_mode(border_mode)
strides = (1,) + strides + (1,)

x = tf.nn.conv2d(x, kernel, strides, padding=padding)
return _postprocess_conv2d_output(x, dim_ordering)


def deconv2d(x, kernel, output_shape, strides=(1, 1),
border_mode='valid',
dim_ordering=_IMAGE_DIM_ORDERING,
image_shape=None, filter_shape=None):
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))

x = _preprocess_conv2d_input(x, dim_ordering)
kernel = _preprocess_conv2d_kernel(kernel, dim_ordering)
padding = _preprocess_border_mode(border_mode)
strides = (1,) + strides + (1,)

# TODO: pre-process output_shape if dim_ordering == th
x = tf.nn.conv2d_transpose(x, kernel, output_shape, strides,
padding=padding)
return _postprocess_conv2d_output(x, dim_ordering)


def atrous_conv2d(x, kernel, rate=1,
border_mode='valid',
dim_ordering=_IMAGE_DIM_ORDERING,
image_shape=None, filter_shape=None):
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
if rate == 1:
return conv2d(x, kernel, strides=(1, 1), border_mode=border_mode,
dim_ordering=dim_ordering)

x = _preprocess_conv2d_input(x, dim_ordering)
kernel = _preprocess_conv2d_kernel(kernel, dim_ordering)
padding = _preprocess_border_mode(border_mode)

x = tf.nn.atrous_conv2d(x, kernel, rate, padding)
return _postprocess_conv2d_output(x, dim_ordering)


def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
border_mode='valid', dim_ordering=_IMAGE_DIM_ORDERING):
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))

x = _preprocess_conv2d_input(x, dim_ordering)
depthwise_kernel = _preprocess_conv2d_kernel(depthwise_kernel, dim_ordering)
pointwise_kernel = _preprocess_conv2d_kernel(pointwise_kernel, dim_ordering)
padding = _preprocess_border_mode(border_mode)
strides = (1,) + strides + (1,)

tf.nn.separable_conv2d(x, depthwise_kernel, pointwise_kernel,
strides, padding)
return _postprocess_conv2d_output(x, dim_ordering)


def conv3d(x, kernel, strides=(1, 1, 1),
border_mode='valid', dim_ordering=_IMAGE_DIM_ORDERING,
volume_shape=None, filter_shape=None):
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))

x = _preprocess_conv3d_input(x, dim_ordering)
kernel = _preprocess_conv3d_kernel(kernel, dim_ordering)
padding = _preprocess_border_mode(border_mode)
strides = (1,) + strides + (1,)

x = tf.nn.conv3d(x, kernel, strides, padding)
return _postprocess_conv3d_output(x, dim_ordering)


def pool2d(x, pool_size, strides=(1, 1),
border_mode='valid', dim_ordering='th', pool_mode='max'):
border_mode='valid', dim_ordering=_IMAGE_DIM_ORDERING,
pool_mode='max'):
'''2D Pooling.
# Arguments
Expand All @@ -1056,43 +1178,53 @@ def pool2d(x, pool_size, strides=(1, 1),
dim_ordering: one of "th", "tf".
pool_mode: one of "max", "avg".
'''
if border_mode == 'same':
padding = 'SAME'
elif border_mode == 'valid':
padding = 'VALID'
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))

padding = _preprocess_border_mode(border_mode)
strides = (1,) + strides + (1,)
pool_size = (1,) + pool_size + (1,)

x = _preprocess_conv2d_input(x, dim_ordering)

if pool_mode == 'max':
x = tf.nn.max_pool(x, pool_size, strides, padding=padding)
elif pool_mode == 'avg':
x = tf.nn.avg_pool(x, pool_size, strides, padding=padding)
else:
raise Exception('Invalid border mode: ' + str(border_mode))
raise Exception('Invalid pooling mode: ' + str(pool_mode))

return _postprocess_conv2d_output(x, dim_ordering)


def pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid',
dim_ordering=_IMAGE_DIM_ORDERING, pool_mode='max'):
'''3D Pooling.
# Arguments
pool_size: tuple of 3 integers.
strides: tuple of 3 integers.
border_mode: one of "valid", "same".
dim_ordering: one of "th", "tf".
pool_mode: one of "max", "avg".
'''
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))

padding = _preprocess_border_mode(border_mode)
strides = (1,) + strides + (1,)
pool_size = (1,) + pool_size + (1,)

if _FLOATX == 'float64':
# tf max_pool only supports float32
x = tf.cast(x, 'float32')
x = _preprocess_conv3d_input(x, dim_ordering)

if dim_ordering in {'tf', 'th'}:
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
# TH kernel shape: (depth, input_depth, rows, cols)
# TF kernel shape: (rows, cols, input_depth, depth)
x = tf.transpose(x, (0, 2, 3, 1))
if pool_mode == 'max':
x = tf.nn.max_pool(x, pool_size, strides, padding=padding)
elif pool_mode == 'avg':
x = tf.nn.avg_pool(x, pool_size, strides, padding=padding)
else:
raise Exception('Invalid pooling mode: ' + str(pool_mode))
if dim_ordering == 'th':
x = tf.transpose(x, (0, 3, 1, 2))
if pool_mode == 'max':
x = tf.nn.max_pool3d(x, pool_size, strides, padding=padding)
elif pool_mode == 'avg':
x = tf.nn.avg_pool3d(x, pool_size, strides, padding=padding)
else:
raise Exception('Unknown dim_ordering: ' + str(dim_ordering))
raise Exception('Invalid pooling mode: ' + str(pool_mode))

if _FLOATX == 'float64':
x = tf.cast(x, 'float64')
return x
return _postprocess_conv3d_output(x, dim_ordering)


# RANDOMNESS
Expand Down
35 changes: 31 additions & 4 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from theano.sandbox.softsign import softsign as T_softsign
import inspect
import numpy as np
from .common import _FLOATX, _EPSILON
from .common import _FLOATX, _EPSILON, _IMAGE_DIM_ORDERING


# INTERNAL UTILS
Expand Down Expand Up @@ -810,10 +810,18 @@ def l2_normalize(x, axis):

# CONVOLUTIONS

def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th',
def conv2d(x, kernel, strides=(1, 1), border_mode='valid',
dim_ordering=_IMAGE_DIM_ORDERING,
image_shape=None, filter_shape=None):
'''
border_mode: string, "same" or "valid".
'''2D convolution.
# Arguments
kernel: kernel tensor.
strides: strides tuple.
border_mode: string, "same" or "valid".
dim_ordering: "tf" or "th".
Whether to use Theano or TensorFlow dimension ordering
in inputs/kernels/ouputs.
'''
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
Expand Down Expand Up @@ -872,6 +880,25 @@ def int_or_none(value):
return conv_out


def deconv2d(x, kernel, output_shape, strides=(1, 1),
border_mode='valid',
dim_ordering=_IMAGE_DIM_ORDERING,
image_shape=None, filter_shape=None):
raise NotImplementedError


def atrous_conv2d(x, kernel, rate=1,
border_mode='valid',
dim_ordering=_IMAGE_DIM_ORDERING,
image_shape=None, filter_shape=None):
raise NotImplementedError


def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
border_mode='valid', dim_ordering=_IMAGE_DIM_ORDERING):
raise NotImplementedError


def conv3d(x, kernel, strides=(1, 1, 1),
border_mode='valid', dim_ordering='th',
volume_shape=None, filter_shape=None):
Expand Down
1 change: 1 addition & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ..engine import Layer, Input, InputLayer, Merge, merge, InputSpec
from .core import *
from .convolutional import *
from .pooling import *
from .recurrent import *
from .normalization import *
from .embeddings import *
Expand Down
Loading

0 comments on commit ee8ff00

Please sign in to comment.