Skip to content

Commit

Permalink
removed two deprecated function calls, added __name__ check to addres…
Browse files Browse the repository at this point in the history
…s multithreading bug in dataloader (pytorch#414)
  • Loading branch information
jacobaustin123 authored and soumith committed Sep 21, 2018
1 parent 6fd43cd commit 753d086
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()


torch.manual_seed(args.seed)

device = torch.device("cuda" if args.cuda else "cpu")
Expand Down Expand Up @@ -61,7 +60,7 @@ def reparameterize(self, mu, logvar):

def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.sigmoid(self.fc4(h3))
return torch.sigmoid(self.fc4(h3))

def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
Expand All @@ -75,7 +74,7 @@ def forward(self, x):

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
Expand Down Expand Up @@ -125,12 +124,12 @@ def test(epoch):
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))


for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
with torch.no_grad():
sample = torch.randn(64, 20).to(device)
sample = model.decode(sample).cpu()
save_image(sample.view(64, 1, 28, 28),
'results/sample_' + str(epoch) + '.png')
if __name__ == "__main__":
for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
with torch.no_grad():
sample = torch.randn(64, 20).to(device)
sample = model.decode(sample).cpu()
save_image(sample.view(64, 1, 28, 28),
'results/sample_' + str(epoch) + '.png')

0 comments on commit 753d086

Please sign in to comment.