Skip to content

Commit 40c48dd

Browse files
committed
Run non-optimization updates on predict for stateful RNN's
1 parent 81787dd commit 40c48dd

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ keras/datasets/temp/*
99
docs/site/*
1010
docs/theme/*
1111
tags
12+
Keras.egg-info
1213

1314
# test-related
1415
.coverage

keras/layers/containers.py

+5
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def set_weights(self, weights):
113113
self.layers[i].set_weights(weights[:nb_param])
114114
weights = weights[nb_param:]
115115

116+
def reset_states(self):
117+
for layer in self.layers:
118+
if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
119+
layer.reset_states()
120+
116121
def get_config(self):
117122
return {"name": self.__class__.__name__,
118123
"layers": [layer.get_config() for layer in self.layers]}

keras/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def compile(self, optimizer, loss,
419419

420420
self._train = K.function(train_ins, [train_loss], updates=updates)
421421
self._train_with_acc = K.function(train_ins, [train_loss, train_accuracy], updates=updates)
422-
self._predict = K.function(predict_ins, [self.y_test])
422+
self._predict = K.function(predict_ins, [self.y_test], updates=self.updates)
423423
self._test = K.function(test_ins, [test_loss])
424424
self._test_with_acc = K.function(test_ins, [test_loss, test_accuracy])
425425

0 commit comments

Comments
 (0)