forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_valid_subset_checks.py
138 lines (119 loc) · 4.78 KB
/
test_valid_subset_checks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import shutil
import tempfile
import unittest
from fairseq import options
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored
from .utils import create_dummy_data, preprocess_lm_data, train_language_model
def make_lm_config(
data_dir=None,
extra_flags=None,
task="language_modeling",
arch="transformer_lm_gpt2_tiny",
):
task_args = [task]
if data_dir is not None:
task_args += [data_dir]
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
"--task",
*task_args,
"--arch",
arch,
"--optimizer",
"adam",
"--lr",
"0.0001",
"--max-tokens",
"500",
"--tokens-per-sample",
"500",
"--save-dir",
data_dir,
"--max-epoch",
"1",
]
+ (extra_flags or []),
)
cfg = convert_namespace_to_omegaconf(train_args)
return cfg
def write_empty_file(path):
with open(path, "w"):
pass
assert os.path.exists(path)
class TestValidSubsetsErrors(unittest.TestCase):
"""Test various filesystem, clarg combinations and ensure that error raising happens as expected"""
def _test_case(self, paths, extra_flags):
with tempfile.TemporaryDirectory() as data_dir:
[
write_empty_file(os.path.join(data_dir, f"{p}.bin"))
for p in paths + ["train"]
]
cfg = make_lm_config(data_dir, extra_flags=extra_flags)
raise_if_valid_subsets_unintentionally_ignored(cfg)
def test_default_raises(self):
with self.assertRaises(ValueError):
self._test_case(["valid", "valid1"], [])
with self.assertRaises(ValueError):
self._test_case(
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
)
def partially_specified_valid_subsets(self):
with self.assertRaises(ValueError):
self._test_case(
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
)
# Fix with ignore unused
self._test_case(
["valid", "valid1", "valid2"],
["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"],
)
def test_legal_configs(self):
self._test_case(["valid"], [])
self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"])
self._test_case(["valid", "valid1"], ["--combine-val"])
self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"])
self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"])
self._test_case(
["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"]
)
self._test_case(
["valid1"], ["--valid-subset", "valid1"]
) # valid.bin doesn't need to be ignored.
def test_disable_validation(self):
self._test_case([], ["--disable-validation"])
self._test_case(["valid", "valid1"], ["--disable-validation"])
def test_dummy_task(self):
cfg = make_lm_config(task="dummy_lm")
raise_if_valid_subsets_unintentionally_ignored(cfg)
def test_masked_dummy_task(self):
cfg = make_lm_config(task="dummy_masked_lm")
raise_if_valid_subsets_unintentionally_ignored(cfg)
class TestCombineValidSubsets(unittest.TestCase):
def _train(self, extra_flags):
with self.assertLogs() as logs:
with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
create_dummy_data(data_dir, num_examples=20)
preprocess_lm_data(data_dir)
shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin")
shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx")
train_language_model(
data_dir,
"transformer_lm",
["--max-update", "0", "--log-format", "json"] + extra_flags,
run_validation=False,
)
return [x.message for x in logs.records]
def test_combined(self):
flags = ["--combine-valid-subsets"]
logs = self._train(flags)
assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1
assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined
def test_subsets(self):
flags = ["--valid-subset", "valid,valid1"]
logs = self._train(flags)
assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1
assert any(["valid1_ppl" in x for x in logs]) # metrics are combined