Skip to content

Commit

Permalink
Fix random seeding in DCGAN (pytorch#108)
Browse files Browse the repository at this point in the history
* fix seeding

* fix typo

* fixing some codestyle issues found by flake8
  • Loading branch information
DmitryUlyanov authored and soumith committed Mar 20, 2017
1 parent b8cacb0 commit a60bd4e
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda' , action='store_true', help='enables cuda')
parser.add_argument('--ngpu' , type=int, default=1, help='number of GPUs to use')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')

opt = parser.parse_args()
print(opt)
Expand All @@ -39,10 +40,14 @@
os.makedirs(opt.outf)
except OSError:
pass
opt.manualSeed = random.randint(1, 10000) # fix seed

if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
torch.cuda.manual_seed_all(opt.manualSeed)

cudnn.benchmark = True

Expand Down Expand Up @@ -84,6 +89,7 @@
ndf = int(opt.ndf)
nc = 3


# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
Expand All @@ -93,6 +99,7 @@ def weights_init(m):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)


class _netG(nn.Module):
def __init__(self, ngpu):
super(_netG, self).__init__()
Expand All @@ -119,18 +126,21 @@ def __init__(self, ngpu):
nn.Tanh()
# state size. (nc) x 64 x 64
)

def forward(self, input):
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu)
return nn.parallel.data_parallel(self.main, input, gpu_ids)


netG = _netG(ngpu)
netG.apply(weights_init)
if opt.netG != '':
netG.load_state_dict(torch.load(opt.netG))
print(netG)


class _netD(nn.Module):
def __init__(self, ngpu):
super(_netD, self).__init__()
Expand All @@ -155,13 +165,15 @@ def __init__(self, ngpu):
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)

def forward(self, input):
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu)
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
return output.view(-1, 1)


netD = _netD(ngpu)
netD.apply(weights_init)
if opt.netD != '':
Expand Down Expand Up @@ -190,8 +202,8 @@ def forward(self, input):
fixed_noise = Variable(fixed_noise)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

for epoch in range(opt.niter):
for i, data in enumerate(dataloader, 0):
Expand Down Expand Up @@ -226,7 +238,7 @@ def forward(self, input):
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.data.fill_(real_label) # fake labels are real for generator cost
label.data.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, label)
errG.backward()
Expand Down

0 comments on commit a60bd4e

Please sign in to comment.