Skip to content

Commit

Permalink
Finish the gan_toy.py
Browse files Browse the repository at this point in the history
  • Loading branch information
caogang committed May 8, 2017
1 parent 64f73e3 commit 2d467c8
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions gan_toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ITERS = 100000 # how many generator iterations to train for
use_cuda = True

# ==================Definition Start======================

class Generator(nn.Module):

Expand Down Expand Up @@ -103,11 +104,17 @@ def generate_image(true_dist):
points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
points = points.reshape((-1, 2))
samples, disc_map = session.run(
[fake_data, disc_real],
feed_dict={real_data: points}
)
disc_map = session.run(disc_real, feed_dict={real_data: points})

noise = torch.randn(BATCH_SIZE, 2)
if use_cuda:
noise = noise.cuda()
noisev = autograd.Variable(noise, volatile=True)
samples = netG(noisev).cpu().data.numpy()

points_v = autograd.Variable(torch.Tensor(points), volatile=True)
if use_cuda:
points_v = points_v.cuda()
disc_map = netD(points_v).cpu().data.numpy()

plt.clf()

Expand Down Expand Up @@ -183,31 +190,23 @@ def calc_gradient_penalty(netD, real_data, fake_data):
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda() if use_cuda else alpha

# print real_data[0]
# print fake_data[0]
# print alpha[0]
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
# print interpolates[0]

if use_cuda:
interpolates = interpolates.cuda()
interpolates = autograd.Variable(interpolates, requires_grad=True)

disc_interpolates = netD(interpolates)
# print disc_interpolates
# print disc_interpolates.size(), interpolates.size()

gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]

for p in netD.parameters():
print p.grad
# print gradients
# print (gradients.norm(2, dim=1) - 1) ** 2
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty

# ==================Definition End======================

netG = Generator()
netD = Discriminator()
Expand Down Expand Up @@ -266,8 +265,6 @@ def calc_gradient_penalty(netD, real_data, fake_data):
# train with gradient penalty
gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
gradient_penalty.backward()
for p in netD.parameters():
print p.grad

D = D_fake - D_real + gradient_penalty
optimizerD.step()
Expand All @@ -289,12 +286,10 @@ def calc_gradient_penalty(netD, real_data, fake_data):
G.backward(mone)
optimizerG.step()

# print D, G

# Write logs and save samples
lib.plot.plot('disc cost', D)
lib.plot.plot('gen cost', G)
lib.plot.plot('disc cost', D.cpu().data.numpy())
lib.plot.plot('gen cost', G.cpu().data.numpy())
if iteration % 100 == 99:
lib.plot.flush()
# generate_image(_data)
generate_image(_data)
lib.plot.tick()

0 comments on commit 2d467c8

Please sign in to comment.