Skip to content

Commit

Permalink
added graph reset to the fit() method so that subsequent calls to fit…
Browse files Browse the repository at this point in the history
…() doesn't throw errors
  • Loading branch information
gabrieleangeletti committed May 17, 2016
1 parent b0667a9 commit 526b971
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion command_line/run_stacked_autoencoder_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def load_from_np(dataset_path):
finetune_opt=FLAGS.finetune_opt, finetune_batch_size=FLAGS.finetune_batch_size,
finetune_dropout=FLAGS.finetune_dropout,
dae_enc_act_func=dae_enc_act_func, dae_dec_act_func=dae_dec_act_func,
dae_corr_type=FLAGS.dae_corr_type, dae_corr_frac=FLAGS.dae_corr_frac, dae_l2reg=FLAGS.dae_l2reg,
dae_corr_type=dae_corr_type, dae_corr_frac=dae_corr_frac, dae_l2reg=dae_l2reg,
dataset=FLAGS.dataset, dae_loss_func=dae_loss_func, main_dir=FLAGS.main_dir,
dae_opt=dae_opt, tied_weights=FLAGS.tied_weights,
dae_learning_rate=dae_learning_rate, momentum=FLAGS.momentum, verbose=FLAGS.verbose,
Expand Down
3 changes: 3 additions & 0 deletions yadlt/models/autoencoder_models/denoising_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tensorflow.python.framework import ops
import tensorflow as tf
import numpy as np
import os
Expand Down Expand Up @@ -59,6 +60,8 @@ def fit(self, train_set, validation_set=None, restore_previous_model=False):
"""

with tf.Session() as self.tf_session:
# Reset tensorflow's default graph
ops.reset_default_graph()
self._initialize_tf_utilities_and_ops(restore_previous_model)
self._train_model(train_set, validation_set)
self.tf_saver.save(self.tf_session, self.model_path)
Expand Down
2 changes: 2 additions & 0 deletions yadlt/models/autoencoder_models/stacked_deep_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def fit(self, train_set, train_ref, validation_set=None, validation_ref=None, re
print('Starting Reconstruction finetuning...')

with tf.Session() as self.tf_session:
# Reset tensorflow's default graph
ops.reset_default_graph()
self.build_model(train_set.shape[1])
self._initialize_tf_utilities_and_ops(restore_previous_model)
self._train_model(train_set, train_ref, validation_set, validation_ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def fit(self, train_set, train_labels, validation_set=None, validation_labels=No
print('Starting Supervised finetuning...')

with tf.Session() as self.tf_session:
# Reset tensorflow's default graph
ops.reset_default_graph()
self.build_model(train_set.shape[1], train_labels.shape[1])
self._initialize_tf_utilities_and_ops(restore_previous_model)
self._train_model(train_set, train_labels, validation_set, validation_labels)
Expand Down
3 changes: 3 additions & 0 deletions yadlt/models/convolutional_models/conv_net.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tensorflow.python.framework import ops
import tensorflow as tf
import numpy as np

Expand Down Expand Up @@ -58,6 +59,8 @@ def fit(self, train_set, train_labels, original_shape, validation_set=None, vali
print('Starting training...')

with tf.Session() as self.tf_session:
# Reset tensorflow's default graph
ops.reset_default_graph()
self.build_model(train_set.shape[1], train_labels.shape[1], original_shape)
self._initialize_tf_utilities_and_ops(restore_previous_model)
self._train_model(train_set, train_labels, validation_set, validation_labels)
Expand Down
3 changes: 3 additions & 0 deletions yadlt/models/misc_models/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tensorflow.python.framework import ops
import tensorflow as tf
import numpy as np

Expand Down Expand Up @@ -86,6 +87,8 @@ def fit(self, train_set, train_labels, validation_set=None, validation_labels=No
"""

with tf.Session() as self.tf_session:
# Reset tensorflow's default graph
ops.reset_default_graph()
self.build_model(train_set.shape[1], train_labels.shape[1])
self._initialize_tf_utilities_and_ops(restore_previous_model)
self._train_model(train_set, train_labels, validation_set, validation_labels)
Expand Down
2 changes: 2 additions & 0 deletions yadlt/models/rbm_models/dbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def fit(self, train_set, train_labels, validation_set=None, validation_labels=No
print('Starting Supervised finetuning...')

with tf.Session() as self.tf_session:
# Reset tensorflow's default graph
ops.reset_default_graph()
self.build_model(train_set.shape[1], train_labels.shape[1])
self._initialize_tf_utilities_and_ops(restore_previous_model)
self._train_model(train_set, train_labels, validation_set, validation_labels)
Expand Down
2 changes: 2 additions & 0 deletions yadlt/models/rbm_models/deep_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def fit(self, train_set, train_ref, validation_set=None, validation_ref=None, re
print('Starting Reconstruction finetuning...')

with tf.Session() as self.tf_session:
# Reset tensorflow's default graph
ops.reset_default_graph()
self.build_model(train_set.shape[1])
self._initialize_tf_utilities_and_ops(restore_previous_model)
self._train_model(train_set, train_ref, validation_set, validation_ref)
Expand Down
3 changes: 3 additions & 0 deletions yadlt/models/rbm_models/rbm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tensorflow.python.framework import ops
import tensorflow as tf
import numpy as np

Expand Down Expand Up @@ -59,6 +60,8 @@ def fit(self, train_set, validation_set=None, restore_previous_model=False):
"""

with tf.Session() as self.tf_session:
# Reset tensorflow's default graph
ops.reset_default_graph()
self.build_model(train_set.shape[1])
self._initialize_tf_utilities_and_ops(restore_previous_model)
self._train_model(train_set, validation_set)
Expand Down

0 comments on commit 526b971

Please sign in to comment.