forked from dailenson/SDT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
parse_config.py
343 lines (271 loc) · 11 KB
/
parse_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import six
import os
import os.path as osp
import copy
from ast import literal_eval
import numpy as np
from packaging import version
import torch
import torch.nn as nn
from torch.nn import init
import yaml
from easydict import EasyDict
class AttrDict(EasyDict):
IMMUTABLE = '__immutable__'
def __init__(self, *args):
super(EasyDict, self).__init__(*args)
def immutable(self, is_immutable):
"""Set immutability to is_immutable and recursively apply the setting
to all nested AttrDicts.
"""
self.__dict__[AttrDict.IMMUTABLE] = is_immutable
# Recursively set immutable state
for v in self.__dict__.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)
for v in self.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)
def is_immutable(self):
return self.__dict__[AttrDict.IMMUTABLE]
__C = AttrDict()
# Consumers can get config by:
# from parse_config import cfg
cfg = __C
# Random note: avoid using '.ON' as a config key since yaml converts it to True;
# prefer 'ENABLED' instead
# ---------------------------------------------------------------------------- #
# Training options
# ---------------------------------------------------------------------------- #
__C.TRAIN = AttrDict()
# Datasets to train on
__C.TRAIN.ISTRAIN = True
# Height Pixel
__C.TRAIN.IMG_H = 64
# Width Pixel
__C.TRAIN.IMG_W = 64
# keep image aspect while trainning, didn't support True yet
__C.TRAIN.KEEP_ASPECT = False
# Images *per GPU* in the training minibatch
# Total images per minibatch = TRAIN.IMS_PER_BATCH * NUM_GPUS
__C.TRAIN.IMS_PER_BATCH = 64
# Snapshot (model checkpoint) period
# Divide by NUM_GPUS to determine actual period (e.g., 20000/8 => 2500 iters)
# to allow for linear training schedule scaling
__C.TRAIN.SNAPSHOT_ITERS = 3000
__C.TRAIN.SNAPSHOT_BEGIN = 0
__C.TRAIN.VALIDATE_ITERS = 0
__C.TRAIN.VALIDATE_BEGIN = 0
__C.TRAIN.TEST_ITERS = 6000
__C.TRAIN.DATASET = ''
# Dropout probability in dense3
__C.TRAIN.DROPOUT_P = 0.
# Set the random seed
__C.TRAIN.SEED = 1001
# ---------------------------------------------------------------------------- #
# Data loader options
# ---------------------------------------------------------------------------- #
__C.DATA_LOADER = AttrDict()
# Number of Python threads to use for the data loader (warning: using too many
# threads can cause GIL-based interference with Python Ops leading to *slower*
# training; 4 seems to be the sweet spot in our experience)
__C.DATA_LOADER.NUM_THREADS = 8
__C.DATA_LOADER.CONCAT_GRID = False
__C.DATA_LOADER.PATH = 'data'
__C.DATA_LOADER.TYPE = 'ScriptDataset'
__C.DATA_LOADER.DATASET = 'CHINESE'
# ---------------------------------------------------------------------------- #
# Inference ('test') options
# ---------------------------------------------------------------------------- #
__C.TEST = AttrDict()
# Datasets to test on
# Available dataset list: datasets.dataset_catalog.DATASETS.keys()
# If multiple datasets are listed, testing is performed on each one sequentially
__C.TEST.ISTRAIN = False
# Scale to use during testing (can NOT list multiple scales)
# The scale is the pixel size of an image's shortest side
__C.TEST.IMG_H = 64
# Max pixel size of the longest side of a scaled input image
__C.TEST.IMG_W = 64
__C.TEST.KEEP_ASPECT = False
__C.TEST.DATASET = ''
# ---------------------------------------------------------------------------- #
# Model options
# ---------------------------------------------------------------------------- #
__C.MODEL = AttrDict()
# the number of the encoder layers
__C.MODEL.ENCODER_LAYERS = 3
# the number of layers for fusing writer-wise styles
__C.MODEL.WRI_DEC_LAYERS = 2
# the number of layers for fusing character-wise styles
__C.MODEL.GLY_DEC_LAYERS = 2
# the number of layers for each style head
__C.MODEL.NUM_HEAD_LAYERS = 1
# the number of style references
__C.MODEL.NUM_IMGS = 15
# ---------------------------------------------------------------------------- #
# Solver options
# ---------------------------------------------------------------------------- #
__C.SOLVER = AttrDict()
# support 'SGD', 'Adam', 'Adadelta' and 'Rmsprop'
__C.SOLVER.TYPE = 'Adam'
# Base learning rate for the specified schedule
__C.SOLVER.BASE_LR = 0.001
# Maximum number of trainning iterations
__C.SOLVER.MAX_ITER = 7200000
__C.SOLVER.WARMUP_ITERS = 0
# CLIP Gradient L2 Nrom
__C.SOLVER.GRAD_L2_CLIP = 1.0
# ---------------------------------------------------------------------------- #
# MISC options
# ---------------------------------------------------------------------------- #
# Number of GPUs to use (applies to both training and testing)
__C.NUM_GPUS = 1
# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__)))
# Output basedir
__C.OUTPUT_DIR = 'Saved'
def assert_and_infer_cfg(make_immutable=True):
"""Call this function in your script after you have finished setting all cfg
values that are necessary (e.g., merging a config from a file, merging
command line config options, etc.). By default, this function will also
mark the global cfg as immutable to prevent changing the global cfg settings
during script execution (which can lead to hard to debug errors or code
that's harder to understand than is necessary).
"""
if version.parse(torch.__version__) < version.parse('0.4.0'):
__C.PYTORCH_VERSION_LESS_THAN_040 = True
# create alias for PyTorch version less than 0.4.0
init.uniform_ = init.uniform
init.normal_ = init.normal
init.constant_ = init.constant
init.kaiming_normal_ = init.kaiming_normal
torch.nn.utils.clip_grad_norm_ = torch.nn.utils.clip_grad_norm
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
tensor._backward_hooks = backward_hooks
return tensor
torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
if make_immutable:
cfg.immutable(True)
def merge_cfg_from_file(cfg_filename):
"""Load a yaml config file and merge it into the global config."""
with open(cfg_filename, 'r') as f:
yaml_cfg = AttrDict(yaml.full_load(f))
_merge_a_into_b(yaml_cfg, __C)
cfg_from_file = merge_cfg_from_file
def merge_cfg_from_cfg(cfg_other):
"""Merge `cfg_other` into the global config."""
_merge_a_into_b(cfg_other, __C)
def merge_cfg_from_list(cfg_list):
"""Merge config keys, values in a list (e.g., from command line) into the
global config. For example, `cfg_list = ['TEST.NMS', 0.5]`.
"""
assert len(cfg_list) % 2 == 0
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
# if _key_is_deprecated(full_key):
# continue
# if _key_is_renamed(full_key):
# _raise_key_rename_error(full_key)
key_list = full_key.split('.')
d = __C
for subkey in key_list[:-1]:
assert subkey in d, 'Non-existent key: {}'.format(full_key)
d = d[subkey]
subkey = key_list[-1]
assert subkey in d, 'Non-existent key: {}'.format(full_key)
value = _decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(
value, d[subkey], subkey, full_key
)
d[subkey] = value
cfg_from_list = merge_cfg_from_list
def _merge_a_into_b(a, b, stack=None):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict'
assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict'
for k, v_ in a.items():
full_key = '.'.join(stack) + '.' + k if stack is not None else k
# a must specify keys that are in b
if k not in b:
# if _key_is_deprecated(full_key):
# continue
# elif _key_is_renamed(full_key):
# _raise_key_rename_error(full_key)
# else:
raise KeyError('Non-existent config key: {}'.format(full_key))
v = copy.deepcopy(v_)
v = _decode_cfg_value(v)
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
# Recursively merge dicts
if isinstance(v, AttrDict):
try:
stack_push = [k] if stack is None else stack + [k]
_merge_a_into_b(v, b[k], stack=stack_push)
except BaseException:
raise
else:
b[k] = v
def _decode_cfg_value(v):
"""Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object.
"""
# Configs parsed from raw yaml will contain dictionary keys that need to be
# converted to AttrDict objects
if isinstance(v, dict):
return AttrDict(v)
# All remaining processing is only applied to strings
if not isinstance(v, six.string_types):
return v
# Try to interpret `v` as a:
# string, number, tuple, list, dict, boolean, or None
try:
v = literal_eval(v)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return v
def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key):
"""Checks that `value_a`, which is intended to replace `value_b` is of the
right type. The type is correct if it matches exactly or is one of a few
cases in which the type can be easily coerced.
"""
# The types must match (with some exceptions)
type_b = type(value_b)
type_a = type(value_a)
if type_a is type_b:
return value_a
# Exceptions: numpy arrays, strings, tuple<->list
if isinstance(value_b, np.ndarray):
value_a = np.array(value_a, dtype=value_b.dtype)
elif isinstance(value_b, six.string_types):
value_a = str(value_a)
elif isinstance(value_a, tuple) and isinstance(value_b, list):
value_a = list(value_a)
elif isinstance(value_a, list) and isinstance(value_b, tuple):
value_a = tuple(value_a)
else:
raise ValueError(
'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
'key: {}'.format(type_b, type_a, value_b, value_a, full_key)
)
return value_a