forked from Sleepychord/ImprovedGAN-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDatasets.py
43 lines (41 loc) · 1.56 KB
/
Datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch.utils.data import TensorDataset
from torchvision import datasets, transforms
import numpy as np
def MnistLabel(class_num):
raw_dataset = datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0.1307,), (0.3081,))
]))
class_tot = [0] * 10
data = []
labels = []
positive_tot = 0
tot = 0
perm = np.random.permutation(raw_dataset.__len__())
for i in range(raw_dataset.__len__()):
datum, label = raw_dataset.__getitem__(perm[i])
if class_tot[label] < class_num:
data.append(datum.numpy())
labels.append(label)
class_tot[label] += 1
tot += 1
if tot >= 10 * class_num:
break
return TensorDataset(torch.FloatTensor(np.array(data)), torch.LongTensor(np.array(labels)))
def MnistUnlabel():
raw_dataset = datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0.1307,), (0.3081,))
]))
return raw_dataset
def MnistTest():
return datasets.MNIST('../data', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0.1307,), (0.3081,))
]))
if __name__ == '__main__':
print dir(MnistTest())