Skip to content

Commit

Permalink
Turn on download by default
Browse files Browse the repository at this point in the history
  • Loading branch information
kuangliu committed Jul 17, 2017
1 parent e68081e commit 1fc07cb
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
7 changes: 4 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Expand All @@ -65,7 +65,8 @@
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
net = MobileNet()
# net = MobileNet()
net = DPN26()

if use_cuda:
net.cuda()
Expand Down
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .vgg import *
from .dpn import *
from .lenet import *
from .resnet import *
from .resnext import *
Expand Down
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def init_params(net):
_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 86.
TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
Expand Down Expand Up @@ -81,7 +81,7 @@ def progress_bar(current, total, msg=None):
sys.stdout.write(' ')

# Go back to the center of the bar.
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)):
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))

Expand Down

0 comments on commit 1fc07cb

Please sign in to comment.