@@ -538,7 +538,8 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
538
538
verbose = 1 , callbacks = [],
539
539
validation_data = None , nb_val_samples = None ,
540
540
class_weight = {},
541
- max_q_size = 10 , ** kwargs ):
541
+ max_q_size = 10 , nb_worker = 1 ,
542
+ pickle_safe = False , ** kwargs ):
542
543
'''Fits a model on data generated batch-by-batch by a Python generator.
543
544
The generator is run in parallel to the model, for efficiency.
544
545
For instance, this allows you to do real-time data augmentation
@@ -599,10 +600,6 @@ def generate_arrays_from_file(path):
599
600
'the model at compile time:\n '
600
601
'`model.compile(optimizer, loss, '
601
602
'metrics=["accuracy"])`' )
602
- if 'nb_worker' in kwargs :
603
- kwargs .pop ('nb_worker' )
604
- warnings .warn ('The "nb_worker" argument is deprecated, '
605
- 'please remove it from your code.' )
606
603
if 'nb_val_worker' in kwargs :
607
604
kwargs .pop ('nb_val_worker' )
608
605
warnings .warn ('The "nb_val_worker" argument is deprecated, '
@@ -647,13 +644,16 @@ def fixed_generator():
647
644
validation_data = validation_data ,
648
645
nb_val_samples = nb_val_samples ,
649
646
class_weight = class_weight ,
650
- max_q_size = max_q_size )
647
+ max_q_size = max_q_size ,
648
+ nb_worker = nb_worker ,
649
+ pickle_safe = pickle_safe )
651
650
self .train_on_batch = self ._train_on_batch
652
651
self .evaluate = self ._evaluate
653
652
return history
654
653
655
654
def evaluate_generator (self , generator , val_samples ,
656
- verbose = 1 , max_q_size = 10 , ** kwargs ):
655
+ verbose = 1 , max_q_size = 10 , nb_worker = 1 ,
656
+ pickle_safe = False , ** kwargs ):
657
657
'''Evaluates the model on a generator. The generator should
658
658
return the same kind of data with every yield as accepted
659
659
by `evaluate`.
@@ -707,7 +707,9 @@ def fixed_generator():
707
707
generator = fixed_generator ()
708
708
history = super (Graph , self ).evaluate_generator (generator ,
709
709
val_samples ,
710
- max_q_size = max_q_size )
710
+ max_q_size = max_q_size ,
711
+ nb_worker = nb_worker ,
712
+ pickle_safe = pickle_safe )
711
713
self .test_on_batch = self ._test_on_batch
712
714
return history
713
715
0 commit comments