Skip to content

Commit

Permalink
Change reusing of Variables (pytorch#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
bartolsthoorn authored and soumith committed Jun 6, 2017
1 parent dc10cd8 commit 1c6d9d2
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,6 @@ def forward(self, input):
input, label = input.cuda(), label.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

input = Variable(input)
label = Variable(label)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)

# setup optimizer
Expand All @@ -216,21 +213,25 @@ def forward(self, input):
netD.zero_grad()
real_cpu, _ = data
batch_size = real_cpu.size(0)
input.data.resize_(real_cpu.size()).copy_(real_cpu)
label.data.resize_(batch_size).fill_(real_label)

output = netD(input)
errD_real = criterion(output, label)
if opt.cuda:
real_cpu = real_cpu.cuda()
input.resize_as_(real_cpu).copy_(real_cpu)
label.resize_(batch_size).fill_(real_label)
inputv = Variable(input)
labelv = Variable(label)

output = netD(inputv)
errD_real = criterion(output, labelv)
errD_real.backward()
D_x = output.data.mean()

# train with fake
noise.data.resize_(batch_size, nz, 1, 1)
noise.data.normal_(0, 1)
fake = netG(noise)
label.data.fill_(fake_label)
noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
noisev = Variable(noise)
fake = netG(noisev)
labelv = Variable(label.fill_(fake_label))
output = netD(fake.detach())
errD_fake = criterion(output, label)
errD_fake = criterion(output, labelv)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
Expand All @@ -240,9 +241,9 @@ def forward(self, input):
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.data.fill_(real_label) # fake labels are real for generator cost
labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, label)
errG = criterion(output, labelv)
errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step()
Expand Down

0 comments on commit 1c6d9d2

Please sign in to comment.