forked from open-mmlab/mmengine
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_test_aug_time.py
117 lines (88 loc) · 3.63 KB
/
test_test_aug_time.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
from torch.utils.data import DataLoader, Dataset
from mmengine.dataset.utils import pseudo_collate
from mmengine.model import BaseModel, BaseTTAModel
from mmengine.registry import DATASETS, MODELS, TRANSFORMS
from mmengine.testing import RunnerTestCase
class ToyTTAPipeline:
def __call__(self, result):
return {key: [value] for key, value in result.items()}
class ToyTestTimeAugModel(BaseTTAModel):
def merge_preds(self, data_samples_list):
result = [sum(x) for x in data_samples_list]
return result
class ToyModel(BaseModel):
def __init__(self):
super().__init__()
# DDPWrapper requires at least one parameter.
self.linear = torch.nn.Linear(1, 1)
def forward(self, inputs, data_samples, mode='tensor'):
return data_samples
class ToyDatasetTTA(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
def __init__(self, pipeline=None):
self.pipeline = TRANSFORMS.build(pipeline)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
result = dict(inputs=self.data[index], data_samples=self.label[index])
result = self.pipeline(result)
return result
class TestBaseTTAModel(RunnerTestCase):
def setUp(self) -> None:
super().setUp()
DATASETS.register_module(module=ToyDatasetTTA, force=True)
MODELS.register_module(module=ToyTestTimeAugModel, force=True)
MODELS.register_module(module=ToyModel, force=True)
TRANSFORMS.register_module(module=ToyTTAPipeline, force=True)
def tearDown(self):
super().tearDown()
DATASETS.module_dict.pop('ToyDatasetTTA', None)
MODELS.module_dict.pop('ToyTestTimeAugModel', None)
MODELS.module_dict.pop('ToyModel', None)
TRANSFORMS.module_dict.pop('ToyTTAPipeline', None)
def test_test_step(self):
model = ToyModel()
tta_model = ToyTestTimeAugModel(model)
dict_dataset = [
dict(inputs=[1, 2], data_samples=[3, 4]) for _ in range(10)
]
tuple_dataset = [([1, 2], [3, 4]) for _ in range(10)]
dict_dataloader = DataLoader(
dict_dataset, batch_size=2, collate_fn=pseudo_collate)
tuple_dataloader = DataLoader(
tuple_dataset, batch_size=2, collate_fn=pseudo_collate)
for data in dict_dataloader:
result = tta_model.test_step(data)
self.assertEqual(result, [7, 7])
for data in tuple_dataloader:
result = tta_model.test_step(data)
self.assertEqual(result, [7, 7])
def test_init(self):
model = ToyModel()
tta_model = ToyTestTimeAugModel(model)
self.assertIs(tta_model.module, model)
# Test build from cfg.
model = dict(type='ToyModel')
tta_model = ToyTestTimeAugModel(model)
self.assertIsInstance(tta_model.module, ToyModel)
def test_with_runner(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.model = dict(
type='ToyTestTimeAugModel', module=dict(type='ToyModel'))
cfg.test_dataloader.dataset = dict(type='ToyDatasetTTA')
cfg.test_dataloader.dataset['pipeline'] = dict(type='ToyTTAPipeline')
runner = self.build_runner(cfg)
runner.test()
if torch.cuda.is_available() and torch.distributed.is_nccl_available():
cfg.launcher = 'pytorch'
self.setup_dist_env()
runner = self.build_runner(cfg)
runner.test()