forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_resampling_dataset.py
103 lines (76 loc) · 3.29 KB
/
test_resampling_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import collections
import unittest
import numpy as np
from fairseq.data import ListDataset, ResamplingDataset
class TestResamplingDataset(unittest.TestCase):
def setUp(self):
self.strings = ["ab", "c", "def", "ghij"]
self.weights = [4.0, 2.0, 7.0, 1.5]
self.size_ratio = 2
self.dataset = ListDataset(
self.strings, np.array([len(s) for s in self.strings])
)
def _test_common(self, resampling_dataset, iters):
assert len(self.dataset) == len(self.strings) == len(self.weights)
assert len(resampling_dataset) == self.size_ratio * len(self.strings)
results = {"ordered_by_size": True, "max_distribution_diff": 0.0}
totalfreqs = 0
freqs = collections.defaultdict(int)
for epoch_num in range(iters):
resampling_dataset.set_epoch(epoch_num)
indices = resampling_dataset.ordered_indices()
assert len(indices) == len(resampling_dataset)
prev_size = -1
for i in indices:
cur_size = resampling_dataset.size(i)
# Make sure indices map to same sequences within an epoch
assert resampling_dataset[i] == resampling_dataset[i]
# Make sure length of sequence is correct
assert cur_size == len(resampling_dataset[i])
freqs[resampling_dataset[i]] += 1
totalfreqs += 1
if prev_size > cur_size:
results["ordered_by_size"] = False
prev_size = cur_size
assert set(freqs.keys()) == set(self.strings)
for s, weight in zip(self.strings, self.weights):
freq = freqs[s] / totalfreqs
expected_freq = weight / sum(self.weights)
results["max_distribution_diff"] = max(
results["max_distribution_diff"], abs(expected_freq - freq)
)
return results
def test_resampling_dataset_batch_by_size_false(self):
resampling_dataset = ResamplingDataset(
self.dataset,
self.weights,
size_ratio=self.size_ratio,
batch_by_size=False,
seed=0,
)
results = self._test_common(resampling_dataset, iters=1000)
# For batch_by_size = False, the batches should be returned in
# arbitrary order of size.
assert not results["ordered_by_size"]
# Allow tolerance in distribution error of 2%.
assert results["max_distribution_diff"] < 0.02
def test_resampling_dataset_batch_by_size_true(self):
resampling_dataset = ResamplingDataset(
self.dataset,
self.weights,
size_ratio=self.size_ratio,
batch_by_size=True,
seed=0,
)
results = self._test_common(resampling_dataset, iters=1000)
# For batch_by_size = True, the batches should be returned in
# increasing order of size.
assert results["ordered_by_size"]
# Allow tolerance in distribution error of 2%.
assert results["max_distribution_diff"] < 0.02
if __name__ == "__main__":
unittest.main()