Skip to content

Commit

Permalink
fix for 0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Aug 6, 2017
1 parent 5c2b513 commit 9012fae
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | fake')
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
Expand Down Expand Up @@ -77,8 +77,10 @@
transforms.Scale(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
)
]))
elif opt.dataset == 'fake':
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
transform=transforms.ToTensor())
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
Expand Down Expand Up @@ -173,7 +175,7 @@ def forward(self, input):
else:
output = self.main(input)

return output.view(-1, 1)
return output.view(-1, 1).squeeze(1)


netD = _netD(ngpu)
Expand Down

0 comments on commit 9012fae

Please sign in to comment.