Skip to content

Commit

Permalink
Fixed merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
himat committed Apr 21, 2017
2 parents d48665a + e7e048b commit 4a079aa
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
10 changes: 8 additions & 2 deletions discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def denseW(input, output):
return tf.Variable(tf.random_normal([input, output]))


def conv_net(x):
def conv_weights():
weights = {
'c1': filters(4, 64),
'c2': filters(64, 128),
Expand All @@ -58,7 +58,11 @@ def conv_net(x):
'c6': bias(512),
'c7': bias(1),
}
return (weights, biases)


def conv_net(x, vars):
weights, biases = vars
x = tf.reshape(x, shape=[-1, 128, 128, 4])

c1 = conv(x, weights['c1'], biases['c1'], strides=2)
Expand All @@ -76,7 +80,9 @@ def conv_net(x):
c7 = conv(c6, weights['c7'], biases['c7'], strides=2)

res = tf.reshape(c7, [-1, 1])
return tf.nn.sigmoid(res)
vars = list(weights.values())
vars.extend(biases.values())
return (tf.nn.sigmoid(res), vars)


def test_conv_net():
Expand Down
38 changes: 29 additions & 9 deletions net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,22 @@
import matplotlib.pyplot as plt

from generator import u_net
from discriminator import conv_net
from discriminator import conv_net, conv_weights

minib_size = 128
X_dim = 128*128
y_dim = 128*128
Z_dim = 100

X = tf.placeholder(tf.float32, shape=[None, X_dim])
y = tf.placeholder(tf.float32, shape=[None, y_dim])
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])

""" Discriminator """
D_theta = []

""" Generator """
G_theta = []

epochs = 10
minib_size = 14
Expand All @@ -19,11 +34,11 @@
def generator(x):
return u_net(x)

def discriminator(x, g):
def discriminator(x, g, W, b):
x = tf.reshape(x, [-1, 128, 128, 1])
g = tf.reshape(g, [-1, 128, 128, 3])
y = tf.concat([x, g], 3)
return conv_net(y)
return conv_net(y, (W, b))

def next_data_batch(minibatch_size):
pass
Expand Down Expand Up @@ -65,15 +80,20 @@ def next_data_batch(minibatch_size):

# --> Add conditional stuff
G_sample = generator(X_sketch) # add conditional parameter
D_real = discriminator(X_ground_truth, X_sketch)
D_fake = discriminator(X_ground_truth, G_sample)

D_W, D_b = conv_weights()
D_theta.extend(D_W.values())
D_theta.extend(D_b.values())

D_real = discriminator(X_ground_truth, X_sketch, (D_W, D_b))
D_fake = discriminator(X_ground_truth, G_sample, (D_W, D_b))

D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake)) + tf.reduce_mean(X_sketch - D_fake)

# Apply an optimizer here to minimize the above loss functions
D_solver = tf.train.AdamOptimizer().minimize(D_loss) # --> add var_list
G_solver = tf.train.AdamOptimizer().minimize(G_loss) # --> add var_list
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list = D_theta)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list = G_theta)

theta_D = [] ### FILL THIS IN
theta_G = [] ### FILL THIS IN
Expand All @@ -97,9 +117,9 @@ def next_data_batch(minibatch_size):
X_ground_truth = 0 ## Figure out how to get sketches

_, D_loss_curr = sess.run([D_solver, D_loss],
feed_dict={X_sketch: X_sketches, X_ground_truth = X_true})
feed_dict={X_sketch: X_sketches, X_ground_truth: X_true})
_, G_loss_curr = sess.run([G_solver, G_loss],
feed_dict={X_ground_truth = X_true})
feed_dict={X_ground_truth: X_true})

# Stops background threads
coord.request_stop()
Expand Down

0 comments on commit 4a079aa

Please sign in to comment.