Skip to content

Commit 3feca20

Browse files
Thomas Boquetfchollet
Thomas Boquet
authored andcommitted
+ multiprocessing in legacy - unused imports (keras-team#4139)
1 parent f1bc3c0 commit 3feca20

File tree

5 files changed

+14
-12
lines changed

5 files changed

+14
-12
lines changed

keras/legacy/models.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,8 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
538538
verbose=1, callbacks=[],
539539
validation_data=None, nb_val_samples=None,
540540
class_weight={},
541-
max_q_size=10, **kwargs):
541+
max_q_size=10, nb_worker=1,
542+
pickle_safe=False, **kwargs):
542543
'''Fits a model on data generated batch-by-batch by a Python generator.
543544
The generator is run in parallel to the model, for efficiency.
544545
For instance, this allows you to do real-time data augmentation
@@ -599,10 +600,6 @@ def generate_arrays_from_file(path):
599600
'the model at compile time:\n'
600601
'`model.compile(optimizer, loss, '
601602
'metrics=["accuracy"])`')
602-
if 'nb_worker' in kwargs:
603-
kwargs.pop('nb_worker')
604-
warnings.warn('The "nb_worker" argument is deprecated, '
605-
'please remove it from your code.')
606603
if 'nb_val_worker' in kwargs:
607604
kwargs.pop('nb_val_worker')
608605
warnings.warn('The "nb_val_worker" argument is deprecated, '
@@ -647,13 +644,16 @@ def fixed_generator():
647644
validation_data=validation_data,
648645
nb_val_samples=nb_val_samples,
649646
class_weight=class_weight,
650-
max_q_size=max_q_size)
647+
max_q_size=max_q_size,
648+
nb_worker=nb_worker,
649+
pickle_safe=pickle_safe)
651650
self.train_on_batch = self._train_on_batch
652651
self.evaluate = self._evaluate
653652
return history
654653

655654
def evaluate_generator(self, generator, val_samples,
656-
verbose=1, max_q_size=10, **kwargs):
655+
verbose=1, max_q_size=10, nb_worker=1,
656+
pickle_safe=False, **kwargs):
657657
'''Evaluates the model on a generator. The generator should
658658
return the same kind of data with every yield as accepted
659659
by `evaluate`.
@@ -707,7 +707,9 @@ def fixed_generator():
707707
generator = fixed_generator()
708708
history = super(Graph, self).evaluate_generator(generator,
709709
val_samples,
710-
max_q_size=max_q_size)
710+
max_q_size=max_q_size,
711+
nb_worker=nb_worker,
712+
pickle_safe=pickle_safe)
711713
self.test_on_batch = self._test_on_batch
712714
return history
713715

tests/keras/engine/test_training.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from keras.layers import Dense, Dropout
66
from keras.engine.topology import merge, Input
77
from keras.engine.training import Model
8-
from keras.models import Sequential, Graph
8+
from keras.models import Sequential
99
from keras import backend as K
1010
from keras.utils.test_utils import keras_test
1111

tests/keras/layers/test_normalization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from keras.layers.core import Dense, Activation
66
from keras.utils.test_utils import layer_test, keras_test
77
from keras.layers import normalization
8-
from keras.models import Sequential, Graph
8+
from keras.models import Sequential
99
from keras import backend as K
1010

1111
input_1 = np.arange(10)

tests/keras/test_sequential_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
np.random.seed(1337)
77

88
from keras import backend as K
9-
from keras.models import Graph, Sequential
9+
from keras.models import Sequential
1010
from keras.layers.core import Dense, Activation, Merge, Lambda
1111
from keras.utils import np_utils
1212
from keras.utils.test_utils import get_test_data, keras_test

tests/test_loss_weighting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
np.random.seed(1337)
66

77
from keras.utils.test_utils import get_test_data
8-
from keras.models import Sequential, Graph
8+
from keras.models import Sequential
99
from keras.layers import Dense, Activation, RepeatVector, TimeDistributedDense, GRU
1010
from keras.utils import np_utils
1111
from keras.utils.test_utils import keras_test

0 commit comments

Comments
 (0)