Skip to content

Commit

Permalink
use tf.get_collection
Browse files Browse the repository at this point in the history
I have also changed the name of the scopes since 'D' and 'D_Pre' where not going to work with tf.get_collection.
  • Loading branch information
yassersouri authored Jan 22, 2017
1 parent 6fb4cf9 commit 93bb948
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _create_model(self):

# This defines the generator network - it takes samples from a noise
# distribution as input, and passes them through an MLP.
with tf.variable_scope('G'):
with tf.variable_scope('Gen'):
self.z = tf.placeholder(tf.float32, shape=(self.batch_size, 1))
self.G = generator(self.z, self.mlp_hidden_size)

Expand All @@ -147,7 +147,7 @@ def _create_model(self):
#
# Here we create two copies of the discriminator network (that share parameters),
# as you cannot use the same network with different inputs in TensorFlow.
with tf.variable_scope('D') as scope:
with tf.variable_scope('Disc') as scope:
self.x = tf.placeholder(tf.float32, shape=(self.batch_size, 1))
self.D1 = discriminator(self.x, self.mlp_hidden_size, self.minibatch)
scope.reuse_variables()
Expand All @@ -158,10 +158,9 @@ def _create_model(self):
self.loss_d = tf.reduce_mean(-tf.log(self.D1) - tf.log(1 - self.D2))
self.loss_g = tf.reduce_mean(-tf.log(self.D2))

vars = tf.trainable_variables()
self.d_pre_params = [v for v in vars if v.name.startswith('D_pre/')]
self.d_params = [v for v in vars if v.name.startswith('D/')]
self.g_params = [v for v in vars if v.name.startswith('G/')]
self.d_pre_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D_pre')
self.d_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Disc')
self.g_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Gen')

self.opt_d = optimizer(self.loss_d, self.d_params, self.learning_rate)
self.opt_g = optimizer(self.loss_g, self.g_params, self.learning_rate)
Expand Down

0 comments on commit 93bb948

Please sign in to comment.