Skip to content

Commit

Permalink
Fixed bug when logging to tensorboard
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed Sep 26, 2017
1 parent 1b7866b commit d092e7e
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions deepchem/models/tensorgraph/tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def create_feed_dict():
self.session.run(tf.global_variables_initializer())
if restore:
self.restore()
avg_loss, n_batches = 0.0, 0.0
avg_loss, n_averaged_batches = 0.0, 0.0
n_samples = 0
n_enqueued = [0]
final_sample = [None]
Expand All @@ -190,7 +190,6 @@ def create_feed_dict():
args=(self, feed_dict_generator, self._get_tf("Graph"),
self.session, n_enqueued, final_sample))
enqueue_thread.start()
fetches = [train_op, self.loss.out_tensor]
for feed_dict in create_feed_dict():
if self.use_queue:
# Don't let this thread get ahead of the enqueue thread, since if
Expand All @@ -202,24 +201,27 @@ def create_feed_dict():
time.sleep(0)
if n_samples == final_sample[0]:
break
n_samples += 1
should_log = (self.tensorboard and
n_samples % self.tensorboard_log_frequency == 0)
fetches = [train_op, self.loss.out_tensor]
if should_log:
fetches.append(self._get_tf("summary_op"))
fetched_values = self.session.run(fetches, feed_dict=feed_dict)
loss = fetched_values[-1]
if should_log:
self._log_tensorboard(fetches[2])
loss = fetched_values[1]
avg_loss += loss
n_batches += 1
n_averaged_batches += 1
self.global_step += 1
n_samples += 1
if self.tensorboard and n_samples % self.tensorboard_log_frequency == 0:
summary = self.session.run(
self._get_tf("summary_op"), feed_dict=feed_dict)
self._log_tensorboard(summary)
if self.global_step % checkpoint_interval == checkpoint_interval - 1:
saver.save(self.session, self.save_file, global_step=self.global_step)
avg_loss = float(avg_loss) / n_batches
avg_loss = float(avg_loss) / n_averaged_batches
print('Ending global_step %d: Average loss %g' % (self.global_step,
avg_loss))
avg_loss, n_batches = 0.0, 0.0
if n_batches > 0:
avg_loss = float(avg_loss) / n_batches
avg_loss, n_averaged_batches = 0.0, 0.0
if n_averaged_batches > 0:
avg_loss = float(avg_loss) / n_averaged_batches
print('Ending global_step %d: Average loss %g' % (self.global_step,
avg_loss))
saver.save(self.session, self.save_file, global_step=self.global_step)
Expand Down

0 comments on commit d092e7e

Please sign in to comment.