Skip to content

Commit

Permalink
add configs
Browse files Browse the repository at this point in the history
  • Loading branch information
dai gang committed Apr 8, 2023
1 parent 6891fda commit e8089c1
Show file tree
Hide file tree
Showing 2 changed files with 376 additions and 0 deletions.
34 changes: 34 additions & 0 deletions CHINESE_CASIA.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
MODEL:
ENCODER_LAYERS: 3
DECODER_LAYERS: 4
CLS_LAYERS: 2
PA_INI_MODE: XAVIER
NUM_IMGS: 15
NUM_GPUS: 1 # TODO, support multi GPUs
SOLVER:
BASE_LR: 0.0002
MAX_ITER: 1000000
WARMUP_ITERS: 20000
TYPE: Adam # TODO, support optional optimizer
GRAD_L2_CLIP: 5.0
TRAIN:
ISTRAIN: True
IMS_PER_BATCH: 64
SNAPSHOT_BEGIN: 2000
SNAPSHOT_ITERS: 4000
VALIDATE_ITERS: 2000
VALIDATE_BEGIN: 1
SEED: 1001
IMG_H: 64
IMG_W: 64
TEST:
ISTRAIN: False
IMG_H: 64
IMG_W: 64
SAMPLE_STEPS: 2000
DATA_LOADER:
NUM_THREADS: 8
CONCAT_GRID: True
TYPE: ScriptDataset
PATH: data
DATASET: CHINESE
342 changes: 342 additions & 0 deletions parse_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import six
import os
import os.path as osp
import copy
from ast import literal_eval

import numpy as np
from packaging import version
import torch
import torch.nn as nn
from torch.nn import init
import yaml
from easydict import EasyDict

class AttrDict(EasyDict):
IMMUTABLE = '__immutable__'

def __init__(self, *args):
super(EasyDict, self).__init__(*args)

def immutable(self, is_immutable):
"""Set immutability to is_immutable and recursively apply the setting
to all nested AttrDicts.
"""
self.__dict__[AttrDict.IMMUTABLE] = is_immutable
# Recursively set immutable state
for v in self.__dict__.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)
for v in self.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)

def is_immutable(self):
return self.__dict__[AttrDict.IMMUTABLE]

__C = AttrDict()
# Consumers can get config by:
# from parse_config import cfg
cfg = __C


# Random note: avoid using '.ON' as a config key since yaml converts it to True;
# prefer 'ENABLED' instead

# ---------------------------------------------------------------------------- #
# Training options
# ---------------------------------------------------------------------------- #
__C.TRAIN = AttrDict()

# Datasets to train on
__C.TRAIN.ISTRAIN = True

# Height Pixel
__C.TRAIN.IMG_H = 64

# Width Pixel
__C.TRAIN.IMG_W = 64

# keep image aspect while trainning, didn't support True yet
__C.TRAIN.KEEP_ASPECT = False

# Images *per GPU* in the training minibatch
# Total images per minibatch = TRAIN.IMS_PER_BATCH * NUM_GPUS
__C.TRAIN.IMS_PER_BATCH = 64

# Snapshot (model checkpoint) period
# Divide by NUM_GPUS to determine actual period (e.g., 20000/8 => 2500 iters)
# to allow for linear training schedule scaling
__C.TRAIN.SNAPSHOT_ITERS = 3000

__C.TRAIN.SNAPSHOT_BEGIN = 0

__C.TRAIN.VALIDATE_ITERS = 0

__C.TRAIN.VALIDATE_BEGIN = 0

__C.TRAIN.TEST_ITERS = 6000

__C.TRAIN.DATASET = ''

# Dropout probability in dense3
__C.TRAIN.DROPOUT_P = 0.

# Set the random seed
__C.TRAIN.SEED = 1001


# ---------------------------------------------------------------------------- #
# Data loader options
# ---------------------------------------------------------------------------- #
__C.DATA_LOADER = AttrDict()

# Number of Python threads to use for the data loader (warning: using too many
# threads can cause GIL-based interference with Python Ops leading to *slower*
# training; 4 seems to be the sweet spot in our experience)
__C.DATA_LOADER.NUM_THREADS = 8

__C.DATA_LOADER.CONCAT_GRID = False

__C.DATA_LOADER.PATH = 'data'

__C.DATA_LOADER.TYPE = 'ScriptDataset'

__C.DATA_LOADER.DATASET = 'CHINESE'


# ---------------------------------------------------------------------------- #
# Inference ('test') options
# ---------------------------------------------------------------------------- #
__C.TEST = AttrDict()

# Datasets to test on
# Available dataset list: datasets.dataset_catalog.DATASETS.keys()
# If multiple datasets are listed, testing is performed on each one sequentially
__C.TEST.ISTRAIN = False

# Scale to use during testing (can NOT list multiple scales)
# The scale is the pixel size of an image's shortest side
__C.TEST.IMG_H = 64

# Max pixel size of the longest side of a scaled input image
__C.TEST.IMG_W = 64

__C.TEST.KEEP_ASPECT = False

__C.TEST.DATASET = ''

# ---------------------------------------------------------------------------- #
# Model options
# ---------------------------------------------------------------------------- #
__C.MODEL = AttrDict()

# the number of the encoder layers
__C.MODEL.ENCODER_LAYERS = 3

# the total number of the decoder layers
__C.MODEL.DECODER_LAYERS = 4

# the number of layers for fusing writer-wise styles
__C.MODEL.CLS_LAYERS = 2

# the number of style references
__C.MODEL.NUM_IMGS = 15

__C.MODEL.PA_INI_MODE = 'XAVIER'

# ---------------------------------------------------------------------------- #
# Solver options
# ---------------------------------------------------------------------------- #
__C.SOLVER = AttrDict()

# support 'SGD', 'Adam', 'Adadelta' and 'Rmsprop'
__C.SOLVER.TYPE = 'Adam'

# Base learning rate for the specified schedule
__C.SOLVER.BASE_LR = 0.001

# Maximum number of trainning iterations
__C.SOLVER.MAX_ITER = 7200000

__C.SOLVER.WARMUP_ITERS = 0

# CLIP Gradient L2 Nrom
__C.SOLVER.GRAD_L2_CLIP = 1.0

# ---------------------------------------------------------------------------- #
# MISC options
# ---------------------------------------------------------------------------- #

# Number of GPUs to use (applies to both training and testing)
__C.NUM_GPUS = 1

# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__)))

# Output basedir
__C.OUTPUT_DIR = 'Saved'

def assert_and_infer_cfg(make_immutable=True):
"""Call this function in your script after you have finished setting all cfg
values that are necessary (e.g., merging a config from a file, merging
command line config options, etc.). By default, this function will also
mark the global cfg as immutable to prevent changing the global cfg settings
during script execution (which can lead to hard to debug errors or code
that's harder to understand than is necessary).
"""
if version.parse(torch.__version__) < version.parse('0.4.0'):
__C.PYTORCH_VERSION_LESS_THAN_040 = True
# create alias for PyTorch version less than 0.4.0
init.uniform_ = init.uniform
init.normal_ = init.normal
init.constant_ = init.constant
init.kaiming_normal_ = init.kaiming_normal
torch.nn.utils.clip_grad_norm_ = torch.nn.utils.clip_grad_norm
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
tensor._backward_hooks = backward_hooks
return tensor
torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
if make_immutable:
cfg.immutable(True)


def merge_cfg_from_file(cfg_filename):
"""Load a yaml config file and merge it into the global config."""
with open(cfg_filename, 'r') as f:
yaml_cfg = AttrDict(yaml.full_load(f))
_merge_a_into_b(yaml_cfg, __C)

cfg_from_file = merge_cfg_from_file


def merge_cfg_from_cfg(cfg_other):
"""Merge `cfg_other` into the global config."""
_merge_a_into_b(cfg_other, __C)


def merge_cfg_from_list(cfg_list):
"""Merge config keys, values in a list (e.g., from command line) into the
global config. For example, `cfg_list = ['TEST.NMS', 0.5]`.
"""
assert len(cfg_list) % 2 == 0
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
# if _key_is_deprecated(full_key):
# continue
# if _key_is_renamed(full_key):
# _raise_key_rename_error(full_key)
key_list = full_key.split('.')
d = __C
for subkey in key_list[:-1]:
assert subkey in d, 'Non-existent key: {}'.format(full_key)
d = d[subkey]
subkey = key_list[-1]
assert subkey in d, 'Non-existent key: {}'.format(full_key)
value = _decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(
value, d[subkey], subkey, full_key
)
d[subkey] = value

cfg_from_list = merge_cfg_from_list


def _merge_a_into_b(a, b, stack=None):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict'
assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict'

for k, v_ in a.items():
full_key = '.'.join(stack) + '.' + k if stack is not None else k
# a must specify keys that are in b
if k not in b:
# if _key_is_deprecated(full_key):
# continue
# elif _key_is_renamed(full_key):
# _raise_key_rename_error(full_key)
# else:
raise KeyError('Non-existent config key: {}'.format(full_key))

v = copy.deepcopy(v_)
v = _decode_cfg_value(v)
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)

# Recursively merge dicts
if isinstance(v, AttrDict):
try:
stack_push = [k] if stack is None else stack + [k]
_merge_a_into_b(v, b[k], stack=stack_push)
except BaseException:
raise
else:
b[k] = v


def _decode_cfg_value(v):
"""Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object.
"""
# Configs parsed from raw yaml will contain dictionary keys that need to be
# converted to AttrDict objects
if isinstance(v, dict):
return AttrDict(v)
# All remaining processing is only applied to strings
if not isinstance(v, six.string_types):
return v
# Try to interpret `v` as a:
# string, number, tuple, list, dict, boolean, or None
try:
v = literal_eval(v)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return v


def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key):
"""Checks that `value_a`, which is intended to replace `value_b` is of the
right type. The type is correct if it matches exactly or is one of a few
cases in which the type can be easily coerced.
"""
# The types must match (with some exceptions)
type_b = type(value_b)
type_a = type(value_a)
if type_a is type_b:
return value_a

# Exceptions: numpy arrays, strings, tuple<->list
if isinstance(value_b, np.ndarray):
value_a = np.array(value_a, dtype=value_b.dtype)
elif isinstance(value_b, six.string_types):
value_a = str(value_a)
elif isinstance(value_a, tuple) and isinstance(value_b, list):
value_a = list(value_a)
elif isinstance(value_a, list) and isinstance(value_b, tuple):
value_a = tuple(value_a)
else:
raise ValueError(
'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
'key: {}'.format(type_b, type_a, value_b, value_a, full_key)
)
return value_a

0 comments on commit e8089c1

Please sign in to comment.