Skip to content

Commit

Permalink
fix filenames again, fix non-gzip loader, add gzip test
Browse files Browse the repository at this point in the history
  • Loading branch information
sorki committed Jan 9, 2018
1 parent 316c5b0 commit 766faa3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
10 changes: 5 additions & 5 deletions mnist/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def __init__(self, path='.', mode='vanilla', return_type='lists', gz=False):

self._return_type = return_type

self.test_img_fname = 't10k-images.idx3-ubyte'
self.test_lbl_fname = 't10k-labels.idx1-ubyte'
self.test_img_fname = 't10k-images-idx3-ubyte'
self.test_lbl_fname = 't10k-labels-idx1-ubyte'

self.train_img_fname = 'train-images.idx3-ubyte'
self.train_lbl_fname = 'train-labels.idx1-ubyte'
self.train_img_fname = 'train-images-idx3-ubyte'
self.train_lbl_fname = 'train-labels-idx1-ubyte'

self.gz = gz

Expand Down Expand Up @@ -167,7 +167,7 @@ def opener(self, path_fn, *args, **kwargs):
if self.gz:
return gzip.open(path_fn + '.gz', *args, **kwargs)
else:
return open(*args, **kwargs)
return open(path_fn, *args, **kwargs)

def load(self, path_img, path_lbl):
with self.opener(path_lbl, 'rb') as file:
Expand Down
10 changes: 10 additions & 0 deletions tests/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ def test_dataset_lengths(self):
self.assertEqual(len(train_img), len(train_label))
self.assertEqual(len(train_img), 60000)

def test_gzip(self):
mn = mnist.MNIST(DATA_PATH, gz=True)

test_img, test_label = mn.load_testing()
train_img, train_label = mn.load_training()
self.assertEqual(len(test_img), len(test_label))
self.assertEqual(len(test_img), 10000)
self.assertEqual(len(train_img), len(train_label))
self.assertEqual(len(train_img), 60000)

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
unittest.main()

0 comments on commit 766faa3

Please sign in to comment.