forked from zyang1580/PDA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_iterator.py
224 lines (170 loc) · 6.77 KB
/
data_iterator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
@author: Zhongchuan Sun
"""
import numpy as np
class Sampler(object):
"""Base class for all Samplers.
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
def __init__(self):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class SequentialSampler(Sampler):
"""Samples elements sequentially, always in the same order.
"""
def __init__(self, data_source):
"""Initializes a new `SequentialSampler` instance.
Args:
data_source (_Dataset): Dataset to sample from.
"""
super(SequentialSampler, self).__init__()
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
class RandomSampler(Sampler):
"""Samples elements randomly, without replacement.
"""
def __init__(self, data_source):
"""Initializes a new `SequentialSampler` instance.
Args:
data_source (_Dataset): Dataset to sample from.
"""
super(RandomSampler, self).__init__()
self.data_source = data_source
def __iter__(self):
perm = np.random.permutation(len(self.data_source)).tolist()
return iter(perm)
def __len__(self):
return len(self.data_source)
class BatchSampler(Sampler):
"""Wraps another sampler to yield a mini-batch of indices.
"""
def __init__(self, sampler, batch_size, drop_last):
"""Initializes a new `BatchSampler` instance.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If `True`, the sampler will drop the last batch
if its size would be less than `batch_size`.
"""
super(BatchSampler, self).__init__()
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
class _Dataset(object):
"""Pack the given data to one dataset.
Args:
data (list or tuple): a list of 'data'.
"""
def __init__(self, data):
for d in data:
if len(d) != len(data[0]):
raise ValueError("The length of the given data are not equal!")
# assert len(d) == len(data[0])
self.data = data
def __len__(self):
return len(self.data[0])
def __getitem__(self, idx):
return [data[idx] for data in self.data]
class _DataLoaderIter(object):
"""Iterates once over the dataset, as specified by the sampler.
"""
def __init__(self, loader):
self.dataset = loader.dataset
self.batch_sampler = loader.batch_sampler
self.sample_iter = iter(self.batch_sampler)
def __len__(self):
return len(self.batch_sampler)
def __next__(self):
indices = next(self.sample_iter) # may raise StopIteration
batch = [self.dataset[i] for i in indices]
transposed = [list(samples) for samples in zip(*batch)]
if len(transposed) == 1:
transposed = transposed[0]
return transposed
def __iter__(self):
return self
class DataIterator(object):
"""`DataIterator` provides iterators over the dataset.
This class combines some data sets and provides a batch iterator over them.
For example::
users = list(range(10))
items = list(range(10, 20))
labels = list(range(20, 30))
data_iter = DataIterator(users, items, labels, batch_size=4, shuffle=False)
for bat_user, bat_item, bat_label in data_iter:
print(bat_user, bat_item, bat_label)
data_iter = DataIterator(users, items, batch_size=4, shuffle=True, drop_last=True)
for bat_user, bat_item in data_iter:
print(bat_user, bat_item)
"""
def __init__(self, *data, batch_size=1, shuffle=False, drop_last=False):
"""
Args:
*data: Variable length data list.
batch_size (int): How many samples per batch to load. Defaults to `1`.
shuffle (bool): Set to `True` to have the data reshuffled at every
epoch. Defaults to `False`.
drop_last (bool): Set to `True` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size.
If `False` and the size of dataset is not divisible by the
batch size, then the last batch will be smaller.
Defaults to `False`.
Raises:
ValueError: If the length of the given data are not equal.
"""
dataset = _Dataset(list(data))
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
self.batch_sampler = BatchSampler(sampler, batch_size, drop_last)
def __iter__(self):
return _DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
if __name__ == "__main__":
users = list(range(10))
items = list(range(10, 20))
labels = list(range(20, 30))
data_iter = DataIterator(users, items, labels, batch_size=4, shuffle=False)
for bat_user, bat_item, bat_label in data_iter:
print(bat_user, bat_item, bat_label)
data_iter = DataIterator(users, items, batch_size=4, shuffle=True, drop_last=True)
for bat_user, bat_item in data_iter:
print(bat_user, bat_item)