From d092e7ef7dddba02e7bb381e74f6aea622d7f3c4 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Tue, 26 Sep 2017 13:09:36 -0700 Subject: [PATCH] Fixed bug when logging to tensorboard --- deepchem/models/tensorgraph/tensor_graph.py | 28 +++++++++++---------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/deepchem/models/tensorgraph/tensor_graph.py b/deepchem/models/tensorgraph/tensor_graph.py index 82e7032911..917e34e451 100644 --- a/deepchem/models/tensorgraph/tensor_graph.py +++ b/deepchem/models/tensorgraph/tensor_graph.py @@ -180,7 +180,7 @@ def create_feed_dict(): self.session.run(tf.global_variables_initializer()) if restore: self.restore() - avg_loss, n_batches = 0.0, 0.0 + avg_loss, n_averaged_batches = 0.0, 0.0 n_samples = 0 n_enqueued = [0] final_sample = [None] @@ -190,7 +190,6 @@ def create_feed_dict(): args=(self, feed_dict_generator, self._get_tf("Graph"), self.session, n_enqueued, final_sample)) enqueue_thread.start() - fetches = [train_op, self.loss.out_tensor] for feed_dict in create_feed_dict(): if self.use_queue: # Don't let this thread get ahead of the enqueue thread, since if @@ -202,24 +201,27 @@ def create_feed_dict(): time.sleep(0) if n_samples == final_sample[0]: break + n_samples += 1 + should_log = (self.tensorboard and + n_samples % self.tensorboard_log_frequency == 0) + fetches = [train_op, self.loss.out_tensor] + if should_log: + fetches.append(self._get_tf("summary_op")) fetched_values = self.session.run(fetches, feed_dict=feed_dict) - loss = fetched_values[-1] + if should_log: + self._log_tensorboard(fetches[2]) + loss = fetched_values[1] avg_loss += loss - n_batches += 1 + n_averaged_batches += 1 self.global_step += 1 - n_samples += 1 - if self.tensorboard and n_samples % self.tensorboard_log_frequency == 0: - summary = self.session.run( - self._get_tf("summary_op"), feed_dict=feed_dict) - self._log_tensorboard(summary) if self.global_step % checkpoint_interval == checkpoint_interval - 1: saver.save(self.session, self.save_file, global_step=self.global_step) - avg_loss = float(avg_loss) / n_batches + avg_loss = float(avg_loss) / n_averaged_batches print('Ending global_step %d: Average loss %g' % (self.global_step, avg_loss)) - avg_loss, n_batches = 0.0, 0.0 - if n_batches > 0: - avg_loss = float(avg_loss) / n_batches + avg_loss, n_averaged_batches = 0.0, 0.0 + if n_averaged_batches > 0: + avg_loss = float(avg_loss) / n_averaged_batches print('Ending global_step %d: Average loss %g' % (self.global_step, avg_loss)) saver.save(self.session, self.save_file, global_step=self.global_step)