Skip to content

Commit

Permalink
Add support for MNIST
Browse files Browse the repository at this point in the history
  • Loading branch information
tkarras committed Feb 3, 2021
1 parent f0a4246 commit 1d25833
Showing 1 changed file with 44 additions and 11 deletions.
55 changes: 44 additions & 11 deletions dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pickle
import sys
import tarfile
import gzip
import zipfile
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
Expand Down Expand Up @@ -165,6 +166,36 @@ def iterate_images():

#----------------------------------------------------------------------------

def open_mnist(images_gz: str, *, max_images: Optional[int]):
labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
assert labels_gz != images_gz
images = []
labels = []

with gzip.open(images_gz, 'rb') as f:
images = np.frombuffer(f.read(), np.uint8, offset=16)
with gzip.open(labels_gz, 'rb') as f:
labels = np.frombuffer(f.read(), np.uint8, offset=8)

images = images.reshape(-1, 28, 28)
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
assert labels.shape == (60000,) and labels.dtype == np.uint8
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 9

max_idx = maybe_min(len(images), max_images)

def iterate_images():
for idx, img in enumerate(images):
yield dict(img=img, label=int(labels[idx]))
if idx >= max_idx-1:
break

return max_idx, iterate_images()

#----------------------------------------------------------------------------

def make_transform(
transform: Optional[str],
output_width: Optional[int],
Expand Down Expand Up @@ -225,10 +256,11 @@ def open_dataset(source, *, max_images: Optional[int]):
else:
return open_image_folder(source, max_images=max_images)
elif os.path.isfile(source):
if source.endswith('cifar-10-python.tar.gz'):
if os.path.basename(source) == 'cifar-10-python.tar.gz':
return open_cifar10(source, max_images=max_images)
ext = file_ext(source)
if ext == 'zip':
elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
return open_mnist(source, max_images=max_images)
elif file_ext(source) == 'zip':
return open_image_zip(source, max_images=max_images)
else:
assert False, 'unknown archive type'
Expand Down Expand Up @@ -293,17 +325,18 @@ def convert_dataset(
The input dataset format is guessed from the --source argument:
\b
--source *_lmdb/ - Load LSUN dataset
--source cifar-10-python.tar.gz - Load CIFAR-10 dataset
--source path/ - Recursively load all images from path/
--source dataset.zip - Recursively load all images from dataset.zip
--source *_lmdb/ Load LSUN dataset
--source cifar-10-python.tar.gz Load CIFAR-10 dataset
--source train-images-idx3-ubyte.gz Load MNIST dataset
--source path/ Recursively load all images from path/
--source dataset.zip Recursively load all images from dataset.zip
The output dataset format can be either an image folder or a zip archive. Specifying
the output format and path:
The output dataset format can be either an image folder or a zip archive.
Specifying the output format and path:
\b
--dest /path/to/dir - Save output files under /path/to/dir
--dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive
--dest /path/to/dir Save output files under /path/to/dir
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
Images within the dataset archive will be stored as uncompressed PNG.
Expand Down

0 comments on commit 1d25833

Please sign in to comment.