Skip to content

Commit

Permalink
Unlock InceptionResNetV2 for CNTK (keras-team#7914)
Browse files Browse the repository at this point in the history
* Unlock InceptionResNetV2 for CNTK

* Reduce the number of models created in test

* Use Process() to make sure memory is properly reclaimed

* Eliminate possible deadlock at Queue.get()

* Add comments and TODO about the use of multiprocessing in tests
  • Loading branch information
myutwo150 authored and fchollet committed Sep 21, 2017
1 parent bd5d616 commit 71a791c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 43 deletions.
2 changes: 1 addition & 1 deletion docs/templates/applications.md
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ keras.applications.inception_resnet_v2.InceptionResNetV2(include_top=True, weigh

Inception-ResNet V2 model, with weights pre-trained on ImageNet.

This model is available for both the Theano and TensorFlow backend (but not CNTK), and can be built both
This model is available for Theano, TensorFlow and CNTK backends, and can be built both
with `'channels_first'` data format (channels, height, width) or `'channels_last'` data format (height, width, channels).

The default input size for this model is 299x299.
Expand Down
8 changes: 2 additions & 6 deletions keras/applications/inception_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def InceptionResNetV2(include_top=True,
set `"image_data_format": "channels_last"` in your Keras config
at `~/.keras/keras.json`.
The model and the weights are compatible with both TensorFlow and Theano
backends (but not CNTK). The data format convention used by the model is
The model and the weights are compatible with TensorFlow, Theano and
CNTK backends. The data format convention used by the model is
the one specified in your Keras config file.
Note that the default input image size for this model is 299x299, instead
Expand Down Expand Up @@ -226,11 +226,7 @@ def InceptionResNetV2(include_top=True,
# Raises
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
RuntimeError: If attempting to run this model with an unsupported backend.
"""
if K.backend() in {'cntk'}:
raise RuntimeError(K.backend() + ' backend is currently unsupported for this model.')

if weights not in {'imagenet', None}:
raise ValueError('The `weights` argument should be either '
'`None` (random initialization) or `imagenet` '
Expand Down
100 changes: 64 additions & 36 deletions tests/keras/applications/applications_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from multiprocessing import Process, Queue
from keras.utils.test_utils import keras_test
from keras.utils.test_utils import layer_test
from keras.utils.generic_utils import CustomObjectScope
Expand Down Expand Up @@ -173,61 +174,88 @@ def test_inceptionv3_variable_input_channels():


@keras_test
@pytest.mark.skipif((K.backend() == 'cntk'),
reason='InceptionResNetV2 is not supported on CNTK')
def test_inceptionresnetv2():
model = applications.InceptionResNetV2(weights=None)
assert model.output_shape == (None, 1000)
# Create model in a subprocess so that the memory consumed by InceptionResNetV2 will be
# released back to the system after this test (to deal with OOM error on CNTK backend)
# TODO: remove the use of multiprocessing from these tests once a memory clearing mechanism
# is implemented in the CNTK backend
def target(queue):
model = applications.InceptionResNetV2(weights=None)
queue.put(model.output_shape)
queue = Queue()
p = Process(target=target, args=(queue,))
p.start()
p.join()

# The error in a subprocess won't propagate to the main process, so we check if the model
# is successfully created by checking if the output shape has been put into the queue
assert not queue.empty(), 'Model creation failed.'
model_output_shape = queue.get_nowait()
assert model_output_shape == (None, 1000)


@keras_test
@pytest.mark.skipif((K.backend() == 'cntk'),
reason='InceptionResNetV2 is not supported on CNTK')
def test_inceptionresnetv2_notop():
def target(queue):
model = applications.InceptionResNetV2(weights=None, include_top=False)
queue.put(model.output_shape)

global_image_data_format = K.image_data_format()
queue = Queue()

K.set_image_data_format('channels_first')
model = applications.InceptionResNetV2(weights=None, include_top=False)
assert model.output_shape == (None, 1536, None, None)
p = Process(target=target, args=(queue,))
p.start()
p.join()
K.set_image_data_format(global_image_data_format)
assert not queue.empty(), 'Model creation failed.'
model_output_shape = queue.get_nowait()
assert model_output_shape == (None, 1536, None, None)

K.set_image_data_format('channels_last')
model = applications.InceptionResNetV2(weights=None, include_top=False)
assert model.output_shape == (None, None, None, 1536)

p = Process(target=target, args=(queue,))
p.start()
p.join()
K.set_image_data_format(global_image_data_format)
assert not queue.empty(), 'Model creation failed.'
model_output_shape = queue.get_nowait()
assert model_output_shape == (None, None, None, 1536)


@keras_test
@pytest.mark.skipif((K.backend() == 'cntk'),
reason='InceptionResNetV2 is not supported on CNTK')
def test_inceptionresnetv2_pooling():
model = applications.InceptionResNetV2(weights=None, include_top=False, pooling='avg')
assert model.output_shape == (None, 1536)
def target(queue):
model = applications.InceptionResNetV2(weights=None, include_top=False, pooling='avg')
queue.put(model.output_shape)
queue = Queue()
p = Process(target=target, args=(queue,))
p.start()
p.join()
assert not queue.empty(), 'Model creation failed.'
model_output_shape = queue.get_nowait()
assert model_output_shape == (None, 1536)


@keras_test
@pytest.mark.skipif((K.backend() == 'cntk'),
reason='InceptionResNetV2 is not supported on CNTK')
def test_inceptionresnetv2_variable_input_channels():
global_image_data_format = K.image_data_format()

K.set_image_data_format('channels_first')
input_shape = (1, None, None)
model = applications.InceptionResNetV2(weights=None, include_top=False, input_shape=input_shape)
assert model.output_shape == (None, 1536, None, None)
input_shape = (4, None, None)
model = applications.InceptionResNetV2(weights=None, include_top=False, input_shape=input_shape)
assert model.output_shape == (None, 1536, None, None)

K.set_image_data_format('channels_last')
input_shape = (None, None, 1)
model = applications.InceptionResNetV2(weights=None, include_top=False, input_shape=input_shape)
assert model.output_shape == (None, None, None, 1536)
input_shape = (None, None, 4)
model = applications.InceptionResNetV2(weights=None, include_top=False, input_shape=input_shape)
assert model.output_shape == (None, None, None, 1536)

K.set_image_data_format(global_image_data_format)
def target(queue, input_shape):
model = applications.InceptionResNetV2(weights=None, include_top=False, input_shape=input_shape)
queue.put(model.output_shape)

queue = Queue()
p = Process(target=target, args=(queue, (None, None, 1)))
p.start()
p.join()
assert not queue.empty(), 'Model creation failed.'
model_output_shape = queue.get_nowait()
assert model_output_shape == (None, None, None, 1536)

p = Process(target=target, args=(queue, (None, None, 4)))
p.start()
p.join()
assert not queue.empty(), 'Model creation failed.'
model_output_shape = queue.get_nowait()
assert model_output_shape == (None, None, None, 1536)


@keras_test
Expand Down

0 comments on commit 71a791c

Please sign in to comment.