Skip to content

Commit 530eff6

Browse files
AnishShahfchollet
authored andcommitted
[issue keras-team#3942] Add GlobalMaxPooling3D and GlobalAveragePooling3D (keras-team#3983)
1 parent 4de7eaa commit 530eff6

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

keras/layers/pooling.py

+80
Original file line numberDiff line numberDiff line change
@@ -519,3 +519,83 @@ def call(self, x, mask=None):
519519
return K.max(x, axis=[1, 2])
520520
else:
521521
return K.max(x, axis=[2, 3])
522+
523+
524+
class _GlobalPooling3D(Layer):
525+
526+
def __init__(self, dim_ordering='default', **kwargs):
527+
super(_GlobalPooling3D, self).__init__(**kwargs)
528+
if dim_ordering == 'default':
529+
dim_ordering = K.image_dim_ordering()
530+
self.dim_ordering = dim_ordering
531+
self.input_spec = [InputSpec(ndim=5)]
532+
533+
def get_output_shape_for(self, input_shape):
534+
if self.dim_ordering == 'tf':
535+
return (input_shape[0], input_shape[4])
536+
else:
537+
return (input_shape[0], input_shape[1])
538+
539+
def call(self, x, mask=None):
540+
raise NotImplementedError
541+
542+
def get_config(self):
543+
config = {'dim_ordering': self.dim_ordering}
544+
base_config = super(_GlobalPooling3D, self).get_config()
545+
return dict(list(base_config.items()) + list(config.items()))
546+
547+
548+
class GlobalAveragePooling3D(_GlobalPooling3D):
549+
'''Global Average pooling operation for 3D data.
550+
551+
# Arguments
552+
dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension
553+
(the depth) is at index 1, in 'tf' mode is it at index 4.
554+
It defaults to the `image_dim_ordering` value found in your
555+
Keras config file at `~/.keras/keras.json`.
556+
If you never set it, then it will be "tf".
557+
558+
# Input shape
559+
5D tensor with shape:
560+
`(samples, channels, len_pool_dim1, len_pool_dim2, len_pool_dim3)` if dim_ordering='th'
561+
or 5D tensor with shape:
562+
`(samples, len_pool_dim1, len_pool_dim2, len_pool_dim3, channels)` if dim_ordering='tf'.
563+
564+
# Output shape
565+
2D tensor with shape:
566+
`(nb_samples, channels)`
567+
'''
568+
569+
def call(self, x, mask=None):
570+
if self.dim_ordering == 'tf':
571+
return K.mean(x, axis=[1, 2, 3])
572+
else:
573+
return K.mean(x, axis=[2, 3, 4])
574+
575+
576+
class GlobalMaxPooling3D(_GlobalPooling3D):
577+
'''Global Max pooling operation for 3D data.
578+
579+
# Arguments
580+
dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension
581+
(the depth) is at index 1, in 'tf' mode is it at index 4.
582+
It defaults to the `image_dim_ordering` value found in your
583+
Keras config file at `~/.keras/keras.json`.
584+
If you never set it, then it will be "tf".
585+
586+
# Input shape
587+
5D tensor with shape:
588+
`(samples, channels, len_pool_dim1, len_pool_dim2, len_pool_dim3)` if dim_ordering='th'
589+
or 5D tensor with shape:
590+
`(samples, len_pool_dim1, len_pool_dim2, len_pool_dim3, channels)` if dim_ordering='tf'.
591+
592+
# Output shape
593+
2D tensor with shape:
594+
`(nb_samples, channels)`
595+
'''
596+
597+
def call(self, x, mask=None):
598+
if self.dim_ordering == 'tf':
599+
return K.max(x, axis=[1, 2, 3])
600+
else:
601+
return K.max(x, axis=[2, 3, 4])

tests/keras/layers/test_convolutional.py

+16
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,22 @@ def test_globalpooling_2d():
269269
input_shape=(3, 5, 6, 4))
270270

271271

272+
@keras_test
273+
def test_globalpooling_3d():
274+
layer_test(pooling.GlobalMaxPooling3D,
275+
kwargs={'dim_ordering': 'th'},
276+
input_shape=(3, 4, 3, 4, 3))
277+
layer_test(pooling.GlobalMaxPooling3D,
278+
kwargs={'dim_ordering': 'tf'},
279+
input_shape=(3, 4, 3, 4, 3))
280+
layer_test(pooling.GlobalAveragePooling3D,
281+
kwargs={'dim_ordering': 'th'},
282+
input_shape=(3, 4, 3, 4, 3))
283+
layer_test(pooling.GlobalAveragePooling3D,
284+
kwargs={'dim_ordering': 'tf'},
285+
input_shape=(3, 4, 3, 4, 3))
286+
287+
272288
@keras_test
273289
def test_maxpooling_2d():
274290
pool_size = (3, 3)

0 commit comments

Comments
 (0)