Skip to content

Commit

Permalink
Added batch histogram computation (keras-team#6065)
Browse files Browse the repository at this point in the history
* Added batch histogram computation

* batch_size_histogram renamed to batch_size, default set to 32, added spaces around operators

* PEP8 fix

* Added batch_size in tests/keras/test_callbacks.py::test_TensorBoard_convnet

* PEP8 fix

* Batch size reduced in tests, targets and sample_weights sliced
  • Loading branch information
cattaneod authored and fchollet committed May 10, 2017
1 parent 737ae88 commit 08aa6ae
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
25 changes: 19 additions & 6 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ class TensorBoard(Callback):
write_graph is set to True.
write_grads: whether to visualize gradient histograms in TensorBoard.
`histogram_freq` must be greater than 0.
batch_size: size of batch of inputs to feed to the network
for histograms computation.
write_images: whether to write model weights to visualize as
image in TensorBoard.
embeddings_freq: frequency (in epochs) at which selected embedding
Expand All @@ -607,6 +609,7 @@ class TensorBoard(Callback):

def __init__(self, log_dir='./logs',
histogram_freq=0,
batch_size=32,
write_graph=True,
write_grads=False,
write_images=False,
Expand All @@ -626,6 +629,7 @@ def __init__(self, log_dir='./logs',
self.embeddings_freq = embeddings_freq
self.embeddings_layer_names = embeddings_layer_names
self.embeddings_metadata = embeddings_metadata or {}
self.batch_size = batch_size

def set_model(self, model):
self.model = model
Expand Down Expand Up @@ -725,8 +729,6 @@ def on_epoch_end(self, epoch, logs=None):

if self.validation_data and self.histogram_freq:
if epoch % self.histogram_freq == 0:
# TODO: implement batched calls to sess.run
# (current call will likely go OOM on GPU)

val_data = self.validation_data
tensors = (self.model.inputs +
Expand All @@ -737,10 +739,21 @@ def on_epoch_end(self, epoch, logs=None):
tensors += [K.learning_phase()]

assert len(val_data) == len(tensors)
feed_dict = dict(zip(tensors, val_data))
result = self.sess.run([self.merged], feed_dict=feed_dict)
summary_str = result[0]
self.writer.add_summary(summary_str, epoch)
val_size = val_data[0].shape[0]
i = 0
while i < val_size:
step = min(self.batch_size, val_size - i)
batch_val = []
batch_val.append(val_data[0][i:i + step])
batch_val.append(val_data[1][i:i + step])
batch_val.append(val_data[2][i:i + step])
if self.model.uses_learning_phase:
batch_val.append(val_data[3])
feed_dict = dict(zip(tensors, batch_val))
result = self.sess.run([self.merged], feed_dict=feed_dict)
summary_str = result[0]
self.writer.add_summary(summary_str, epoch)
i += self.batch_size

if self.embeddings_freq and self.embeddings_logs:
if epoch % self.embeddings_freq == 0:
Expand Down
6 changes: 4 additions & 2 deletions tests/keras/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ def data_generator_graph(train):
metrics=['accuracy'])

tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1,
write_images=True, write_grads=True)
write_images=True, write_grads=True,
batch_size=5)
cbks = [tsb]

# fit with validation data
Expand Down Expand Up @@ -381,7 +382,8 @@ def test_TensorBoard_convnet():
optimizer='rmsprop',
metrics=['accuracy'])
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1,
write_images=True, write_grads=True)
write_images=True, write_grads=True,
batch_size=16)
cbks = [tsb]
model.summary()
history = model.fit(x_train, y_train, epochs=2, batch_size=16,
Expand Down

0 comments on commit 08aa6ae

Please sign in to comment.