forked from ShenghaiRong/BECO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase_runner.py
158 lines (126 loc) · 4.64 KB
/
base_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List
import torch
from utils.distributed import get_dist_info
import hooks
from utils import distributed
from utils import buffer
class BaseRunner(metaclass=ABCMeta):
"""
Base Runner class
Args:
args: the CLI arguments
model: The model object
optimizer: optimizer object
scheduler: scheduler object
dataloaders: A dict of dataloaders
samplers: A dict of samplers for dataloaders
workflow: workflow control sequence
"""
def __init__(
self,
args,
model,
optimizer,
scheduler,
dataloaders: Dict,
samplers: Dict,
workflow: Dict,
**kwargs
) -> None:
self.args = args
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.dataloaders = dataloaders
self.samplers = samplers
self.net_dict = distributed.skip_module_for_dict(self.model.nets)
self.rank = get_dist_info()[0]
self.epoch = 0
self.iter = 0
self.inner_iter = 0
self._workflow = workflow
# Buffer to store runtime information. Usually a hook exchange data with
# other hooks or objects asynchronously using this buffer
self.buffer = buffer.SimpleBuffer()
# The attributes name listed in this array will be saved to ckpt. A children
# class can append more elements to this list
self.stat_dict_keys = ['epoch', 'iter', 'inner_iter']
# A logger to output runtime information. In this framework, `TqdmLoggerHook`
# can be used for better experience. Logger should be initialized in children class
self.logger = None
#self.register_hook(config.hooks)
self._hook_dict = hooks.get_hooks(self.args.logging_path,
self.args.ckpt,
self.args.amp)
self._init(**kwargs)
@abstractmethod
def _init(self):
"""Further init operations for children classes"""
pass
# Main functions
@abstractmethod
def train(self):
pass
@abstractmethod
def val(self):
pass
@abstractmethod
def run(self):
"""Entrance to start runner"""
pass
@torch.no_grad()
def test(self):
"""Perform full testing on datasets['test']"""
self.change_net_val()
self.is_train = False
self.logger.write("Start testing...")
#self.call_hook('before_test')
#self._hook_dict['RestoreCkptHook'].load_ckpt(self)
self._hook_dict['CheckpointLoadHook'].load_ckpt(self)
if self._hook_dict['TqdmLoggerHook'] is not None:
self._hook_dict['TqdmLoggerHook'].init_bar_iter_test(self)
for i, data in enumerate(self.dataloaders['test']):
#self.call_hook('before_test_iter')
self.test_iter(data)
self.inner_iter = i
#self.call_hook('after_test_iter')
if self._hook_dict['TqdmLoggerHook'] is not None:
self._hook_dict['TqdmLoggerHook'].update_bar_iter(self)
#self.call_hook('after_test')
self._hook_dict['MetricHook'].get_test_metric(self)
if self._hook_dict['TqdmLoggerHook'] is not None:
self._hook_dict['TqdmLoggerHook'].log_test(self)
if self._hook_dict['TBLoggerHook'] is not None:
self._hook_dict['TBLoggerHook'].log_test(self)
def test_iter(self, data):
func_step = getattr(self.model, "test_step", self.model.val_step)
if callable(func_step):
output = func_step(data)
if output is not None:
self.buffer.update_from_dict(output)
def close(self):
"""close all registered hook before exit"""
for _, hook in self._hook_dict.items():
# Skip master_only hooks
if hook is not None:
hook.close()
# ckpt related funcs
def state_dict(self) -> Dict[str, Any]:
state_dict = dict()
for k in self.stat_dict_keys:
state_dict[k] = getattr(self, k)
state_dict['buffer'] = self.buffer.state_dict()
return state_dict
def load_state_dict(self, state_dict: Dict):
buffer_dict = state_dict.pop('buffer')
self.buffer.load_state_dict(buffer_dict)
for k, v in state_dict.items():
setattr(self, k, v)
# Misc
def change_net_val(self):
for _, net in self.model.nets.items():
net.eval()
def change_net_train(self):
for _, net in self.model.nets.items():
net.train()