diff --git a/discriminator.py b/discriminator.py index 0e1c9a5..98688ce 100644 --- a/discriminator.py +++ b/discriminator.py @@ -38,7 +38,7 @@ def denseW(input, output): return tf.Variable(tf.random_normal([input, output])) -def conv_net(x): +def conv_weights(): weights = { 'c1': filters(4, 64), 'c2': filters(64, 128), @@ -58,7 +58,11 @@ def conv_net(x): 'c6': bias(512), 'c7': bias(1), } + return (weights, biases) + +def conv_net(x, vars): + weights, biases = vars x = tf.reshape(x, shape=[-1, 128, 128, 4]) c1 = conv(x, weights['c1'], biases['c1'], strides=2) @@ -76,7 +80,9 @@ def conv_net(x): c7 = conv(c6, weights['c7'], biases['c7'], strides=2) res = tf.reshape(c7, [-1, 1]) - return tf.nn.sigmoid(res) + vars = list(weights.values()) + vars.extend(biases.values()) + return (tf.nn.sigmoid(res), vars) def test_conv_net(): diff --git a/net.py b/net.py index 785ea75..8fbf01b 100644 --- a/net.py +++ b/net.py @@ -4,7 +4,22 @@ import matplotlib.pyplot as plt from generator import u_net -from discriminator import conv_net +from discriminator import conv_net, conv_weights + +minib_size = 128 +X_dim = 128*128 +y_dim = 128*128 +Z_dim = 100 + +X = tf.placeholder(tf.float32, shape=[None, X_dim]) +y = tf.placeholder(tf.float32, shape=[None, y_dim]) +Z = tf.placeholder(tf.float32, shape=[None, Z_dim]) + +""" Discriminator """ +D_theta = [] + +""" Generator """ +G_theta = [] epochs = 10 minib_size = 14 @@ -19,11 +34,11 @@ def generator(x): return u_net(x) -def discriminator(x, g): +def discriminator(x, g, W, b): x = tf.reshape(x, [-1, 128, 128, 1]) g = tf.reshape(g, [-1, 128, 128, 3]) y = tf.concat([x, g], 3) - return conv_net(y) + return conv_net(y, (W, b)) def next_data_batch(minibatch_size): pass @@ -65,15 +80,20 @@ def next_data_batch(minibatch_size): # --> Add conditional stuff G_sample = generator(X_sketch) # add conditional parameter -D_real = discriminator(X_ground_truth, X_sketch) -D_fake = discriminator(X_ground_truth, G_sample) + +D_W, D_b = conv_weights() +D_theta.extend(D_W.values()) +D_theta.extend(D_b.values()) + +D_real = discriminator(X_ground_truth, X_sketch, (D_W, D_b)) +D_fake = discriminator(X_ground_truth, G_sample, (D_W, D_b)) D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) G_loss = -tf.reduce_mean(tf.log(D_fake)) + tf.reduce_mean(X_sketch - D_fake) # Apply an optimizer here to minimize the above loss functions -D_solver = tf.train.AdamOptimizer().minimize(D_loss) # --> add var_list -G_solver = tf.train.AdamOptimizer().minimize(G_loss) # --> add var_list +D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list = D_theta) +G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list = G_theta) theta_D = [] ### FILL THIS IN theta_G = [] ### FILL THIS IN @@ -97,9 +117,9 @@ def next_data_batch(minibatch_size): X_ground_truth = 0 ## Figure out how to get sketches _, D_loss_curr = sess.run([D_solver, D_loss], - feed_dict={X_sketch: X_sketches, X_ground_truth = X_true}) + feed_dict={X_sketch: X_sketches, X_ground_truth: X_true}) _, G_loss_curr = sess.run([G_solver, G_loss], - feed_dict={X_ground_truth = X_true}) + feed_dict={X_ground_truth: X_true}) # Stops background threads coord.request_stop()