Skip to content

Commit

Permalink
Small refactors on the keras.utils module (keras-team#13388)
Browse files Browse the repository at this point in the history
* Use .format calls for string interpolation on utils

* Use generators over listcomps whenever possible to save memory
  • Loading branch information
eltonvs authored and gabrieldemarmiesse committed Oct 8, 2019
1 parent c8f66d1 commit b75b2f7
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 27 deletions.
18 changes: 9 additions & 9 deletions keras/utils/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ def normalize_tuple(value, n, name):
try:
value_tuple = tuple(value)
except TypeError:
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value))
raise ValueError('The `{}` argument must be a tuple of {} '
'integers. Received: {}'.format(name, n, value))
if len(value_tuple) != n:
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value))
raise ValueError('The `{}` argument must be a tuple of {} '
'integers. Received: {}'.format(name, n, value))
for single_value in value_tuple:
try:
int(single_value)
except ValueError:
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value) + ' '
'including element ' + str(single_value) + ' of '
'type ' + str(type(single_value)))
raise ValueError('The `{}` argument must be a tuple of {} '
'integers. Received: {} including element {} '
'of type {}'.format(name, n, value, single_value,
type(single_value)))
return value_tuple


Expand All @@ -55,7 +55,7 @@ def normalize_padding(value):
allowed.add('full')
if padding not in allowed:
raise ValueError('The `padding` argument must be one of "valid", "same" '
'(or "causal" for Conv1D). Received: ' + str(padding))
'(or "causal" for Conv1D). Received: {}'.format(padding))
return padding


Expand Down
16 changes: 8 additions & 8 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def chunk_read(response, chunk_size=8192, reporthook=None):
break

with closing(urlopen(url, data)) as response, open(filename, 'wb') as fd:
for chunk in chunk_read(response, reporthook=reporthook):
fd.write(chunk)
for chunk in chunk_read(response, reporthook=reporthook):
fd.write(chunk)
else:
from six.moves.urllib.request import urlretrieve

Expand Down Expand Up @@ -195,10 +195,10 @@ def get_file(fname,
# File found; verify integrity if a hash was provided.
if file_hash is not None:
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
print('A local file was found, but it seems to be '
'incomplete or outdated because the ' + hash_algorithm +
' file hash does not match the original value of ' +
file_hash + ' so we will re-download the data.')
print('A local file was found, but it seems to be incomplete'
' or outdated because the {} file hash does not match '
'the original value of {} so we will re-download the '
'data.'.format(hash_algorithm, file_hash))
download = True
else:
download = True
Expand Down Expand Up @@ -725,9 +725,9 @@ def get(self):
while self.queue.qsize() > 0:
last_ones.append(self.queue.get(block=True))
# Wait for them to complete
list(map(lambda f: f.wait(), last_ones))
[f.wait() for f in last_ones]
# Keep the good ones
last_ones = [future.get() for future in last_ones if future.successful()]
last_ones = (future.get() for future in last_ones if future.successful())
for inputs in last_ones:
if inputs is not None:
yield inputs
Expand Down
16 changes: 8 additions & 8 deletions keras/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def deserialize_keras_object(identifier, module_objects=None,
# In this case we are dealing with a Keras config dictionary.
config = identifier
if 'class_name' not in config or 'config' not in config:
raise ValueError('Improper config format: ' + str(config))
raise ValueError('Improper config format: {}'.format(config))
class_name = config['class_name']
if custom_objects and class_name in custom_objects:
cls = custom_objects[class_name]
Expand All @@ -136,8 +136,8 @@ def deserialize_keras_object(identifier, module_objects=None,
module_objects = module_objects or {}
cls = module_objects.get(class_name)
if cls is None:
raise ValueError('Unknown ' + printable_module_name +
': ' + class_name)
raise ValueError('Unknown {}: {}'.format(printable_module_name,
class_name))
if hasattr(cls, 'from_config'):
custom_objects = custom_objects or {}
if has_arg(cls.from_config, 'custom_objects'):
Expand All @@ -163,12 +163,12 @@ def deserialize_keras_object(identifier, module_objects=None,
else:
fn = module_objects.get(function_name)
if fn is None:
raise ValueError('Unknown ' + printable_module_name +
':' + function_name)
raise ValueError('Unknown {}: {}'.format(printable_module_name,
function_name))
return fn
else:
raise ValueError('Could not interpret serialized ' +
printable_module_name + ': ' + identifier)
raise ValueError('Could not interpret serialized '
'{}: {}'.format(printable_module_name, identifier))


def func_dump(func):
Expand Down Expand Up @@ -514,7 +514,7 @@ def unpack_singleton(x):

def object_list_uid(object_list):
object_list = to_list(object_list)
return ', '.join([str(abs(id(x))) for x in object_list])
return ', '.join((str(abs(id(x))) for x in object_list))


def is_all_none(iterable_or_element):
Expand Down
2 changes: 1 addition & 1 deletion keras/utils/multi_gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def multi_gpu_model(model, gpus=None, cpu_merge=True, cpu_relocation=False):
if not gpus:
# Using all visible GPUs when not specifying `gpus`
# e.g. CUDA_VISIBLE_DEVICES=0,2 python keras_mgpu.py
gpus = len([x for x in available_devices if '/gpu:' in x])
gpus = len((x for x in available_devices if '/gpu:' in x))

if isinstance(gpus, (list, tuple)):
if len(gpus) <= 1:
Expand Down
2 changes: 1 addition & 1 deletion keras/utils/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def model_to_dot(model,
inputlabels = str(layer.input_shape)
elif hasattr(layer, 'input_shapes'):
inputlabels = ', '.join(
[str(ishape) for ishape in layer.input_shapes])
(str(ishape) for ishape in layer.input_shapes))
else:
inputlabels = 'multiple'
label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
Expand Down

0 comments on commit b75b2f7

Please sign in to comment.