diff --git a/gan.py b/gan.py index 0873784..e37ba07 100644 --- a/gan.py +++ b/gan.py @@ -138,7 +138,7 @@ def _create_model(self): # This defines the generator network - it takes samples from a noise # distribution as input, and passes them through an MLP. - with tf.variable_scope('G'): + with tf.variable_scope('Gen'): self.z = tf.placeholder(tf.float32, shape=(self.batch_size, 1)) self.G = generator(self.z, self.mlp_hidden_size) @@ -147,7 +147,7 @@ def _create_model(self): # # Here we create two copies of the discriminator network (that share parameters), # as you cannot use the same network with different inputs in TensorFlow. - with tf.variable_scope('D') as scope: + with tf.variable_scope('Disc') as scope: self.x = tf.placeholder(tf.float32, shape=(self.batch_size, 1)) self.D1 = discriminator(self.x, self.mlp_hidden_size, self.minibatch) scope.reuse_variables() @@ -158,10 +158,9 @@ def _create_model(self): self.loss_d = tf.reduce_mean(-tf.log(self.D1) - tf.log(1 - self.D2)) self.loss_g = tf.reduce_mean(-tf.log(self.D2)) - vars = tf.trainable_variables() - self.d_pre_params = [v for v in vars if v.name.startswith('D_pre/')] - self.d_params = [v for v in vars if v.name.startswith('D/')] - self.g_params = [v for v in vars if v.name.startswith('G/')] + self.d_pre_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D_pre') + self.d_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Disc') + self.g_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Gen') self.opt_d = optimizer(self.loss_d, self.d_params, self.learning_rate) self.opt_g = optimizer(self.loss_g, self.g_params, self.learning_rate)