forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdownload_mnist.py
93 lines (79 loc) · 2.77 KB
/
download_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import gzip
import os
import sys
from urllib.error import URLError
from urllib.request import urlretrieve
MIRRORS = [
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",
]
RESOURCES = [
"train-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
]
def report_download_progress(
chunk_number: int,
chunk_size: int,
file_size: int,
) -> None:
if file_size != -1:
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = "#" * int(64 * percent)
sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))
def download(destination_path: str, resource: str, quiet: bool) -> None:
if os.path.exists(destination_path):
if not quiet:
print("{} already exists, skipping ...".format(destination_path))
else:
for mirror in MIRRORS:
url = mirror + resource
print("Downloading {} ...".format(url))
try:
hook = None if quiet else report_download_progress
urlretrieve(url, destination_path, reporthook=hook)
except (URLError, ConnectionError) as e:
print("Failed to download (trying next):\n{}".format(e))
continue
finally:
if not quiet:
# Just a newline.
print()
break
else:
raise RuntimeError("Error downloading resource!")
def unzip(zipped_path: str, quiet: bool) -> None:
unzipped_path = os.path.splitext(zipped_path)[0]
if os.path.exists(unzipped_path):
if not quiet:
print("{} already exists, skipping ... ".format(unzipped_path))
return
with gzip.open(zipped_path, "rb") as zipped_file:
with open(unzipped_path, "wb") as unzipped_file:
unzipped_file.write(zipped_file.read())
if not quiet:
print("Unzipped {} ...".format(zipped_path))
def main() -> None:
parser = argparse.ArgumentParser(
description="Download the MNIST dataset from the internet"
)
parser.add_argument(
"-d", "--destination", default=".", help="Destination directory"
)
parser.add_argument(
"-q", "--quiet", action="store_true", help="Don't report about progress"
)
options = parser.parse_args()
if not os.path.exists(options.destination):
os.makedirs(options.destination)
try:
for resource in RESOURCES:
path = os.path.join(options.destination, resource)
download(path, resource, options.quiet)
unzip(path, options.quiet)
except KeyboardInterrupt:
print("Interrupted")
if __name__ == "__main__":
main()