Skip to content

Commit

Permalink
changed G Loss
Browse files Browse the repository at this point in the history
  • Loading branch information
irjaved committed Apr 25, 2017
1 parent 727039c commit f175aee
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions net.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from discriminator import conv_net, conv_weights
from util import plot_single, plot_save_batch

epochs = 30
epochs = 50
mb_size = 4

train_path = os.path.join("data", "train") # "data/"
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_edges_file(f):

# Calculate CGAN (classic) losses
# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake)) + tf.reduce_mean(X_sketch - D_fake)
# G_loss = -tf.reduce_mean(tf.log(D_fake)) #+ tf.reduce_mean(X_ground_truth - G_sample)


# Calculate CGAN (alternative) losses
Expand All @@ -129,12 +129,12 @@ def get_edges_file(f):
tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
lmbda = 0.2 # fix scaling
lmbda = 1 # fix scaling
G_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_logit_fake,
labels=tf.ones_like(D_logit_fake))) + lmbda*tf.reduce_mean(
X_ground_truth - G_sample)
labels=tf.ones_like(D_logit_fake))) #+ lmbda*tf.reduce_mean(
#X_ground_truth - G_sample)

# Apply an optimizer to minimize the above loss functions
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
Expand All @@ -160,6 +160,7 @@ def get_edges_file(f):
# Get next batch
[X_truth_batch, X_edges_batch] = sess.run([truth_images_batch,
edges_images_batch])
# print(sess.run((D_fake, D_logit_fake)))

if i % iter_to_print == 0:
produced_image = sess.run(G_sample,
Expand All @@ -170,10 +171,10 @@ def get_edges_file(f):




_, D_loss_curr = sess.run([D_solver, D_loss],
feed_dict={X_ground_truth: X_truth_batch,
X_sketch: X_edges_batch})
for j in range(3):
_, D_loss_curr = sess.run([D_solver, D_loss],
feed_dict={X_ground_truth: X_truth_batch,
X_sketch: X_edges_batch})
_, G_loss_curr = sess.run([G_solver, G_loss],
feed_dict={X_ground_truth: X_truth_batch,
X_sketch: X_edges_batch})
Expand Down

0 comments on commit f175aee

Please sign in to comment.