forked from dailenson/SDT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
dai gang
committed
Apr 8, 2023
1 parent
6891fda
commit e8089c1
Showing
2 changed files
with
376 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
MODEL: | ||
ENCODER_LAYERS: 3 | ||
DECODER_LAYERS: 4 | ||
CLS_LAYERS: 2 | ||
PA_INI_MODE: XAVIER | ||
NUM_IMGS: 15 | ||
NUM_GPUS: 1 # TODO, support multi GPUs | ||
SOLVER: | ||
BASE_LR: 0.0002 | ||
MAX_ITER: 1000000 | ||
WARMUP_ITERS: 20000 | ||
TYPE: Adam # TODO, support optional optimizer | ||
GRAD_L2_CLIP: 5.0 | ||
TRAIN: | ||
ISTRAIN: True | ||
IMS_PER_BATCH: 64 | ||
SNAPSHOT_BEGIN: 2000 | ||
SNAPSHOT_ITERS: 4000 | ||
VALIDATE_ITERS: 2000 | ||
VALIDATE_BEGIN: 1 | ||
SEED: 1001 | ||
IMG_H: 64 | ||
IMG_W: 64 | ||
TEST: | ||
ISTRAIN: False | ||
IMG_H: 64 | ||
IMG_W: 64 | ||
SAMPLE_STEPS: 2000 | ||
DATA_LOADER: | ||
NUM_THREADS: 8 | ||
CONCAT_GRID: True | ||
TYPE: ScriptDataset | ||
PATH: data | ||
DATASET: CHINESE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,342 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
import six | ||
import os | ||
import os.path as osp | ||
import copy | ||
from ast import literal_eval | ||
|
||
import numpy as np | ||
from packaging import version | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn import init | ||
import yaml | ||
from easydict import EasyDict | ||
|
||
class AttrDict(EasyDict): | ||
IMMUTABLE = '__immutable__' | ||
|
||
def __init__(self, *args): | ||
super(EasyDict, self).__init__(*args) | ||
|
||
def immutable(self, is_immutable): | ||
"""Set immutability to is_immutable and recursively apply the setting | ||
to all nested AttrDicts. | ||
""" | ||
self.__dict__[AttrDict.IMMUTABLE] = is_immutable | ||
# Recursively set immutable state | ||
for v in self.__dict__.values(): | ||
if isinstance(v, AttrDict): | ||
v.immutable(is_immutable) | ||
for v in self.values(): | ||
if isinstance(v, AttrDict): | ||
v.immutable(is_immutable) | ||
|
||
def is_immutable(self): | ||
return self.__dict__[AttrDict.IMMUTABLE] | ||
|
||
__C = AttrDict() | ||
# Consumers can get config by: | ||
# from parse_config import cfg | ||
cfg = __C | ||
|
||
|
||
# Random note: avoid using '.ON' as a config key since yaml converts it to True; | ||
# prefer 'ENABLED' instead | ||
|
||
# ---------------------------------------------------------------------------- # | ||
# Training options | ||
# ---------------------------------------------------------------------------- # | ||
__C.TRAIN = AttrDict() | ||
|
||
# Datasets to train on | ||
__C.TRAIN.ISTRAIN = True | ||
|
||
# Height Pixel | ||
__C.TRAIN.IMG_H = 64 | ||
|
||
# Width Pixel | ||
__C.TRAIN.IMG_W = 64 | ||
|
||
# keep image aspect while trainning, didn't support True yet | ||
__C.TRAIN.KEEP_ASPECT = False | ||
|
||
# Images *per GPU* in the training minibatch | ||
# Total images per minibatch = TRAIN.IMS_PER_BATCH * NUM_GPUS | ||
__C.TRAIN.IMS_PER_BATCH = 64 | ||
|
||
# Snapshot (model checkpoint) period | ||
# Divide by NUM_GPUS to determine actual period (e.g., 20000/8 => 2500 iters) | ||
# to allow for linear training schedule scaling | ||
__C.TRAIN.SNAPSHOT_ITERS = 3000 | ||
|
||
__C.TRAIN.SNAPSHOT_BEGIN = 0 | ||
|
||
__C.TRAIN.VALIDATE_ITERS = 0 | ||
|
||
__C.TRAIN.VALIDATE_BEGIN = 0 | ||
|
||
__C.TRAIN.TEST_ITERS = 6000 | ||
|
||
__C.TRAIN.DATASET = '' | ||
|
||
# Dropout probability in dense3 | ||
__C.TRAIN.DROPOUT_P = 0. | ||
|
||
# Set the random seed | ||
__C.TRAIN.SEED = 1001 | ||
|
||
|
||
# ---------------------------------------------------------------------------- # | ||
# Data loader options | ||
# ---------------------------------------------------------------------------- # | ||
__C.DATA_LOADER = AttrDict() | ||
|
||
# Number of Python threads to use for the data loader (warning: using too many | ||
# threads can cause GIL-based interference with Python Ops leading to *slower* | ||
# training; 4 seems to be the sweet spot in our experience) | ||
__C.DATA_LOADER.NUM_THREADS = 8 | ||
|
||
__C.DATA_LOADER.CONCAT_GRID = False | ||
|
||
__C.DATA_LOADER.PATH = 'data' | ||
|
||
__C.DATA_LOADER.TYPE = 'ScriptDataset' | ||
|
||
__C.DATA_LOADER.DATASET = 'CHINESE' | ||
|
||
|
||
# ---------------------------------------------------------------------------- # | ||
# Inference ('test') options | ||
# ---------------------------------------------------------------------------- # | ||
__C.TEST = AttrDict() | ||
|
||
# Datasets to test on | ||
# Available dataset list: datasets.dataset_catalog.DATASETS.keys() | ||
# If multiple datasets are listed, testing is performed on each one sequentially | ||
__C.TEST.ISTRAIN = False | ||
|
||
# Scale to use during testing (can NOT list multiple scales) | ||
# The scale is the pixel size of an image's shortest side | ||
__C.TEST.IMG_H = 64 | ||
|
||
# Max pixel size of the longest side of a scaled input image | ||
__C.TEST.IMG_W = 64 | ||
|
||
__C.TEST.KEEP_ASPECT = False | ||
|
||
__C.TEST.DATASET = '' | ||
|
||
# ---------------------------------------------------------------------------- # | ||
# Model options | ||
# ---------------------------------------------------------------------------- # | ||
__C.MODEL = AttrDict() | ||
|
||
# the number of the encoder layers | ||
__C.MODEL.ENCODER_LAYERS = 3 | ||
|
||
# the total number of the decoder layers | ||
__C.MODEL.DECODER_LAYERS = 4 | ||
|
||
# the number of layers for fusing writer-wise styles | ||
__C.MODEL.CLS_LAYERS = 2 | ||
|
||
# the number of style references | ||
__C.MODEL.NUM_IMGS = 15 | ||
|
||
__C.MODEL.PA_INI_MODE = 'XAVIER' | ||
|
||
# ---------------------------------------------------------------------------- # | ||
# Solver options | ||
# ---------------------------------------------------------------------------- # | ||
__C.SOLVER = AttrDict() | ||
|
||
# support 'SGD', 'Adam', 'Adadelta' and 'Rmsprop' | ||
__C.SOLVER.TYPE = 'Adam' | ||
|
||
# Base learning rate for the specified schedule | ||
__C.SOLVER.BASE_LR = 0.001 | ||
|
||
# Maximum number of trainning iterations | ||
__C.SOLVER.MAX_ITER = 7200000 | ||
|
||
__C.SOLVER.WARMUP_ITERS = 0 | ||
|
||
# CLIP Gradient L2 Nrom | ||
__C.SOLVER.GRAD_L2_CLIP = 1.0 | ||
|
||
# ---------------------------------------------------------------------------- # | ||
# MISC options | ||
# ---------------------------------------------------------------------------- # | ||
|
||
# Number of GPUs to use (applies to both training and testing) | ||
__C.NUM_GPUS = 1 | ||
|
||
# Root directory of project | ||
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__))) | ||
|
||
# Output basedir | ||
__C.OUTPUT_DIR = 'Saved' | ||
|
||
def assert_and_infer_cfg(make_immutable=True): | ||
"""Call this function in your script after you have finished setting all cfg | ||
values that are necessary (e.g., merging a config from a file, merging | ||
command line config options, etc.). By default, this function will also | ||
mark the global cfg as immutable to prevent changing the global cfg settings | ||
during script execution (which can lead to hard to debug errors or code | ||
that's harder to understand than is necessary). | ||
""" | ||
if version.parse(torch.__version__) < version.parse('0.4.0'): | ||
__C.PYTORCH_VERSION_LESS_THAN_040 = True | ||
# create alias for PyTorch version less than 0.4.0 | ||
init.uniform_ = init.uniform | ||
init.normal_ = init.normal | ||
init.constant_ = init.constant | ||
init.kaiming_normal_ = init.kaiming_normal | ||
torch.nn.utils.clip_grad_norm_ = torch.nn.utils.clip_grad_norm | ||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): | ||
tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) | ||
tensor.requires_grad = requires_grad | ||
tensor._backward_hooks = backward_hooks | ||
return tensor | ||
torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 | ||
if make_immutable: | ||
cfg.immutable(True) | ||
|
||
|
||
def merge_cfg_from_file(cfg_filename): | ||
"""Load a yaml config file and merge it into the global config.""" | ||
with open(cfg_filename, 'r') as f: | ||
yaml_cfg = AttrDict(yaml.full_load(f)) | ||
_merge_a_into_b(yaml_cfg, __C) | ||
|
||
cfg_from_file = merge_cfg_from_file | ||
|
||
|
||
def merge_cfg_from_cfg(cfg_other): | ||
"""Merge `cfg_other` into the global config.""" | ||
_merge_a_into_b(cfg_other, __C) | ||
|
||
|
||
def merge_cfg_from_list(cfg_list): | ||
"""Merge config keys, values in a list (e.g., from command line) into the | ||
global config. For example, `cfg_list = ['TEST.NMS', 0.5]`. | ||
""" | ||
assert len(cfg_list) % 2 == 0 | ||
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): | ||
# if _key_is_deprecated(full_key): | ||
# continue | ||
# if _key_is_renamed(full_key): | ||
# _raise_key_rename_error(full_key) | ||
key_list = full_key.split('.') | ||
d = __C | ||
for subkey in key_list[:-1]: | ||
assert subkey in d, 'Non-existent key: {}'.format(full_key) | ||
d = d[subkey] | ||
subkey = key_list[-1] | ||
assert subkey in d, 'Non-existent key: {}'.format(full_key) | ||
value = _decode_cfg_value(v) | ||
value = _check_and_coerce_cfg_value_type( | ||
value, d[subkey], subkey, full_key | ||
) | ||
d[subkey] = value | ||
|
||
cfg_from_list = merge_cfg_from_list | ||
|
||
|
||
def _merge_a_into_b(a, b, stack=None): | ||
"""Merge config dictionary a into config dictionary b, clobbering the | ||
options in b whenever they are also specified in a. | ||
""" | ||
assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict' | ||
assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict' | ||
|
||
for k, v_ in a.items(): | ||
full_key = '.'.join(stack) + '.' + k if stack is not None else k | ||
# a must specify keys that are in b | ||
if k not in b: | ||
# if _key_is_deprecated(full_key): | ||
# continue | ||
# elif _key_is_renamed(full_key): | ||
# _raise_key_rename_error(full_key) | ||
# else: | ||
raise KeyError('Non-existent config key: {}'.format(full_key)) | ||
|
||
v = copy.deepcopy(v_) | ||
v = _decode_cfg_value(v) | ||
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) | ||
|
||
# Recursively merge dicts | ||
if isinstance(v, AttrDict): | ||
try: | ||
stack_push = [k] if stack is None else stack + [k] | ||
_merge_a_into_b(v, b[k], stack=stack_push) | ||
except BaseException: | ||
raise | ||
else: | ||
b[k] = v | ||
|
||
|
||
def _decode_cfg_value(v): | ||
"""Decodes a raw config value (e.g., from a yaml config files or command | ||
line argument) into a Python object. | ||
""" | ||
# Configs parsed from raw yaml will contain dictionary keys that need to be | ||
# converted to AttrDict objects | ||
if isinstance(v, dict): | ||
return AttrDict(v) | ||
# All remaining processing is only applied to strings | ||
if not isinstance(v, six.string_types): | ||
return v | ||
# Try to interpret `v` as a: | ||
# string, number, tuple, list, dict, boolean, or None | ||
try: | ||
v = literal_eval(v) | ||
# The following two excepts allow v to pass through when it represents a | ||
# string. | ||
# | ||
# Longer explanation: | ||
# The type of v is always a string (before calling literal_eval), but | ||
# sometimes it *represents* a string and other times a data structure, like | ||
# a list. In the case that v represents a string, what we got back from the | ||
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is | ||
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other | ||
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval | ||
# will raise a SyntaxError. | ||
except ValueError: | ||
pass | ||
except SyntaxError: | ||
pass | ||
return v | ||
|
||
|
||
def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key): | ||
"""Checks that `value_a`, which is intended to replace `value_b` is of the | ||
right type. The type is correct if it matches exactly or is one of a few | ||
cases in which the type can be easily coerced. | ||
""" | ||
# The types must match (with some exceptions) | ||
type_b = type(value_b) | ||
type_a = type(value_a) | ||
if type_a is type_b: | ||
return value_a | ||
|
||
# Exceptions: numpy arrays, strings, tuple<->list | ||
if isinstance(value_b, np.ndarray): | ||
value_a = np.array(value_a, dtype=value_b.dtype) | ||
elif isinstance(value_b, six.string_types): | ||
value_a = str(value_a) | ||
elif isinstance(value_a, tuple) and isinstance(value_b, list): | ||
value_a = list(value_a) | ||
elif isinstance(value_a, list) and isinstance(value_b, tuple): | ||
value_a = tuple(value_a) | ||
else: | ||
raise ValueError( | ||
'Type mismatch ({} vs. {}) with values ({} vs. {}) for config ' | ||
'key: {}'.format(type_b, type_a, value_b, value_a, full_key) | ||
) | ||
return value_a |