diff --git a/keras/callbacks.py b/keras/callbacks.py index 9038ca837dd..f59436056be 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -592,6 +592,8 @@ class TensorBoard(Callback): write_graph is set to True. write_grads: whether to visualize gradient histograms in TensorBoard. `histogram_freq` must be greater than 0. + batch_size: size of batch of inputs to feed to the network + for histograms computation. write_images: whether to write model weights to visualize as image in TensorBoard. embeddings_freq: frequency (in epochs) at which selected embedding @@ -607,6 +609,7 @@ class TensorBoard(Callback): def __init__(self, log_dir='./logs', histogram_freq=0, + batch_size=32, write_graph=True, write_grads=False, write_images=False, @@ -626,6 +629,7 @@ def __init__(self, log_dir='./logs', self.embeddings_freq = embeddings_freq self.embeddings_layer_names = embeddings_layer_names self.embeddings_metadata = embeddings_metadata or {} + self.batch_size = batch_size def set_model(self, model): self.model = model @@ -725,8 +729,6 @@ def on_epoch_end(self, epoch, logs=None): if self.validation_data and self.histogram_freq: if epoch % self.histogram_freq == 0: - # TODO: implement batched calls to sess.run - # (current call will likely go OOM on GPU) val_data = self.validation_data tensors = (self.model.inputs + @@ -737,10 +739,21 @@ def on_epoch_end(self, epoch, logs=None): tensors += [K.learning_phase()] assert len(val_data) == len(tensors) - feed_dict = dict(zip(tensors, val_data)) - result = self.sess.run([self.merged], feed_dict=feed_dict) - summary_str = result[0] - self.writer.add_summary(summary_str, epoch) + val_size = val_data[0].shape[0] + i = 0 + while i < val_size: + step = min(self.batch_size, val_size - i) + batch_val = [] + batch_val.append(val_data[0][i:i + step]) + batch_val.append(val_data[1][i:i + step]) + batch_val.append(val_data[2][i:i + step]) + if self.model.uses_learning_phase: + batch_val.append(val_data[3]) + feed_dict = dict(zip(tensors, batch_val)) + result = self.sess.run([self.merged], feed_dict=feed_dict) + summary_str = result[0] + self.writer.add_summary(summary_str, epoch) + i += self.batch_size if self.embeddings_freq and self.embeddings_logs: if epoch % self.embeddings_freq == 0: diff --git a/tests/keras/test_callbacks.py b/tests/keras/test_callbacks.py index b11bbe7b4ce..460c8dca244 100644 --- a/tests/keras/test_callbacks.py +++ b/tests/keras/test_callbacks.py @@ -318,7 +318,8 @@ def data_generator_graph(train): metrics=['accuracy']) tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1, - write_images=True, write_grads=True) + write_images=True, write_grads=True, + batch_size=5) cbks = [tsb] # fit with validation data @@ -381,7 +382,8 @@ def test_TensorBoard_convnet(): optimizer='rmsprop', metrics=['accuracy']) tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1, - write_images=True, write_grads=True) + write_images=True, write_grads=True, + batch_size=16) cbks = [tsb] model.summary() history = model.fit(x_train, y_train, epochs=2, batch_size=16,