Skip to content

Commit

Permalink
[Refactor] add config checker (open-mmlab#2153)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored and Tau-J committed Apr 6, 2023
1 parent c08d272 commit 3aa5ef0
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 1 deletion.
2 changes: 2 additions & 0 deletions demo/MMPose_Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@
"!git clone https://github.com/open-mmlab/mmpose.git -b 1.x\n",
"# \"-b 1.x\" means checkout to the `1.x` branch.\n",
"%cd mmpose\n",
"# for better Colab compatibility, install xtcocotools from source",
"%pip install git+https://github.com/jin-s13/xtcocoapi",
"%pip install -r requirements.txt\n",
"%pip install -v -e .\n",
"# \"-v\" means verbose, or more output\n",
Expand Down
7 changes: 7 additions & 0 deletions mmpose/models/pose_estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor

from mmpose.datasets.datasets.utils import parse_pose_metainfo
from mmpose.models.utils import check_and_update_config
from mmpose.registry import MODELS
from mmpose.utils.typing import (ConfigType, ForwardResults, OptConfigType,
Optional, OptMultiConfig, OptSampleList,
Expand Down Expand Up @@ -45,6 +46,12 @@ def __init__(self,

self.backbone = MODELS.build(backbone)

# the PR #2108 and #2126 modified the interface of neck and head.
# The following function automatically detects outdated
# configurations and updates them accordingly, while also providing
# clear and concise information on the changes made.
neck, head = check_and_update_config(neck, head)

if neck is not None:
self.neck = MODELS.build(neck)

Expand Down
3 changes: 2 additions & 1 deletion mmpose/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .check_and_update_config import check_and_update_config
from .ckpt_convert import pvt_convert
from .rtmcc_block import RTMCCBlock, rope
from .transformer import PatchEmbed, nchw_to_nlc, nlc_to_nchw

__all__ = [
'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'pvt_convert', 'RTMCCBlock',
'rope'
'rope', 'check_and_update_config'
]
230 changes: 230 additions & 0 deletions mmpose/models/utils/check_and_update_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union

from mmengine.config import Config, ConfigDict
from mmengine.dist import master_only
from mmengine.logging import MMLogger

ConfigType = Union[Config, ConfigDict]


def process_input_transform(input_transform: str, head: Dict, head_new: Dict,
head_deleted_dict: Dict, head_append_dict: Dict,
neck_new: Dict, input_index: Tuple[int],
align_corners: bool) -> None:
"""Process the input_transform field and update head and neck
dictionaries."""
if input_transform == 'resize_concat':
in_channels = head_new.pop('in_channels')
head_deleted_dict['in_channels'] = str(in_channels)
in_channels = sum([in_channels[i] for i in input_index])
head_new['in_channels'] = in_channels
head_append_dict['in_channels'] = str(in_channels)

neck_new.update(
dict(
type='FeatureMapProcessor',
concat=True,
select_index=input_index,
))
if align_corners:
neck_new['align_corners'] = align_corners

elif input_transform == 'select':
if input_index != (-1, ):
neck_new.update(
dict(type='FeatureMapProcessor', select_index=input_index))
if isinstance(head['in_channels'], tuple):
in_channels = head_new.pop('in_channels')
head_deleted_dict['in_channels'] = str(in_channels)
if isinstance(input_index, int):
in_channels = in_channels[input_index]
else:
in_channels = tuple([in_channels[i] for i in input_index])
head_new['in_channels'] = in_channels
head_append_dict['in_channels'] = str(in_channels)
if align_corners:
neck_new['align_corners'] = align_corners

else:
raise ValueError(f'model.head get invalid value for argument '
f'input_transform: {input_transform}')


def process_extra_field(extra: Dict, head_new: Dict, head_deleted_dict: Dict,
head_append_dict: Dict, neck_new: Dict) -> None:
"""Process the extra field and update head and neck dictionaries."""
head_deleted_dict['extra'] = 'dict('
for key, value in extra.items():
head_deleted_dict['extra'] += f'{key}={value},'
head_deleted_dict['extra'] = head_deleted_dict['extra'][:-1] + ')'
if 'final_conv_kernel' in extra:
kernel_size = extra['final_conv_kernel']
if kernel_size > 1:
padding = kernel_size // 2
head_new['final_layer'] = dict(
kernel_size=kernel_size, padding=padding)
head_append_dict[
'final_layer'] = f'dict(kernel_size={kernel_size}, ' \
f'padding={padding})'
else:
head_new['final_layer'] = dict(kernel_size=kernel_size)
head_append_dict[
'final_layer'] = f'dict(kernel_size={kernel_size})'
if 'upsample' in extra:
neck_new.update(
dict(
type='FeatureMapProcessor',
scale_factor=float(extra['upsample']),
apply_relu=True,
))


def process_has_final_layer(has_final_layer: bool, head_new: Dict,
head_deleted_dict: Dict,
head_append_dict: Dict) -> None:
"""Process the has_final_layer field and update the head dictionary."""
head_deleted_dict['has_final_layer'] = str(has_final_layer)
if not has_final_layer:
if 'final_layer' not in head_new:
head_new['final_layer'] = None
head_append_dict['final_layer'] = 'None'


def check_and_update_config(neck: Optional[ConfigType],
head: ConfigType) -> Tuple[Optional[Dict], Dict]:
"""Check and update the configuration of the head and neck components.
Args:
neck (Optional[ConfigType]): Configuration for the neck component.
head (ConfigType): Configuration for the head component.
Returns:
Tuple[Optional[Dict], Dict]: Updated configurations for the neck
and head components.
"""
head_new, neck_new = head.copy(), neck.copy() if isinstance(neck,
dict) else {}
head_deleted_dict, head_append_dict = {}, {}

if 'input_transform' in head:
input_transform = head_new.pop('input_transform')
head_deleted_dict['input_transform'] = f'\'{input_transform}\''
else:
input_transform = 'select'

if 'input_index' in head:
input_index = head_new.pop('input_index')
head_deleted_dict['input_index'] = str(input_index)
else:
input_index = (-1, )

if 'align_corners' in head:
align_corners = head_new.pop('align_corners')
head_deleted_dict['align_corners'] = str(align_corners)
else:
align_corners = False

process_input_transform(input_transform, head, head_new, head_deleted_dict,
head_append_dict, neck_new, input_index,
align_corners)

if 'extra' in head:
extra = head_new.pop('extra')
process_extra_field(extra, head_new, head_deleted_dict,
head_append_dict, neck_new)

if 'has_final_layer' in head:
has_final_layer = head_new.pop('has_final_layer')
process_has_final_layer(has_final_layer, head_new, head_deleted_dict,
head_append_dict)

display_modifications(head_deleted_dict, head_append_dict, neck_new)

neck_new = neck_new if len(neck_new) else None
return neck_new, head_new


@master_only
def display_modifications(head_deleted_dict: Dict, head_append_dict: Dict,
neck: Dict) -> None:
"""Display the modifications made to the head and neck configurations.
Args:
head_deleted_dict (Dict): Dictionary of deleted fields in the head.
head_append_dict (Dict): Dictionary of appended fields in the head.
neck (Dict): Updated neck configuration.
"""
if len(head_deleted_dict) + len(head_append_dict) == 0:
return

old_model_info, new_model_info = build_model_info(head_deleted_dict,
head_append_dict, neck)

total_info = '\nThe config you are using is outdated. '\
'The following section of the config:\n```\n'
total_info += old_model_info
total_info += '```\nshould be updated to\n```\n'
total_info += new_model_info
total_info += '```\nFor more information, please refer to '\
'https://mmpose.readthedocs.io/en/dev-1.x/' \
'guide_to_framework.html#step3-model'

logger: MMLogger = MMLogger.get_current_instance()
logger.warning(total_info)


def build_model_info(head_deleted_dict: Dict, head_append_dict: Dict,
neck: Dict) -> Tuple[str, str]:
"""Build the old and new model information strings.
Args:
head_deleted_dict (Dict): Dictionary of deleted fields in the head.
head_append_dict (Dict): Dictionary of appended fields in the head.
neck (Dict): Updated neck configuration.
Returns:
Tuple[str, str]: Old and new model information strings.
"""
old_head_info = build_head_info(head_deleted_dict)
new_head_info = build_head_info(head_append_dict)
neck_info = build_neck_info(neck)

old_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' + old_head_info
new_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' \
+ neck_info + new_head_info

return old_model_info, new_model_info


def build_head_info(head_dict: Dict) -> str:
"""Build the head information string.
Args:
head_dict (Dict): Dictionary of fields in the head configuration.
Returns:
str: Head information string.
"""
head_info = ' ' * 4 + 'head=dict(\n'
for key, value in head_dict.items():
head_info += ' ' * 8 + f'{key}={value},\n'
head_info += ' ' * 8 + '...),\n'
return head_info


def build_neck_info(neck: Dict) -> str:
"""Build the neck information string.
Args:
neck (Dict): Updated neck configuration.
Returns:
str: Neck information string.
"""
if len(neck) > 0:
neck = neck.copy()
neck_info = ' ' * 4 + 'neck=dict(\n' + ' ' * 8 + \
f'type=\'{neck.pop("type")}\',\n'
for key, value in neck.items():
neck_info += ' ' * 8 + f'{key}={str(value)},\n'
neck_info += ' ' * 4 + '),\n'
else:
neck_info = ''
return neck_info
115 changes: 115 additions & 0 deletions tests/test_models/test_utils/test_check_and_update_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

from mmpose.models.utils import check_and_update_config


class TestCheckAndUpdateConfig(unittest.TestCase):

def test_case_1(self):
neck = None
head = dict(
type='HeatmapHead',
in_channels=768,
out_channels=17,
deconv_out_channels=[],
deconv_kernel_sizes=[],
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder='codec',
align_corners=False,
extra=dict(upsample=4, final_conv_kernel=3))
neck, head = check_and_update_config(neck, head)

self.assertDictEqual(
neck,
dict(
type='FeatureMapProcessor',
scale_factor=4.0,
apply_relu=True,
))
self.assertIn('final_layer', head)
self.assertDictEqual(head['final_layer'],
dict(kernel_size=3, padding=1))
self.assertNotIn('extra', head)
self.assertNotIn('input_transform', head)
self.assertNotIn('input_index', head)
self.assertNotIn('align_corners', head)

def test_case_2(self):
neck = None
head = dict(
type='CIDHead',
in_channels=(48, 96, 192, 384),
num_keypoints=17,
gfd_channels=48,
input_transform='resize_concat',
input_index=(0, 1, 2, 3),
coupled_heatmap_loss=dict(
type='FocalHeatmapLoss', loss_weight=1.0),
decoupled_heatmap_loss=dict(
type='FocalHeatmapLoss', loss_weight=4.0),
)
neck, head = check_and_update_config(neck, head)

self.assertDictEqual(
neck,
dict(
type='FeatureMapProcessor',
concat=True,
select_index=(0, 1, 2, 3),
))
self.assertEqual(head['in_channels'], 720)
self.assertNotIn('input_transform', head)
self.assertNotIn('input_index', head)
self.assertNotIn('align_corners', head)

def test_case_3(self):
neck = None
head = dict(
type='HeatmapHead',
in_channels=(64, 128, 320, 512),
out_channels=17,
input_index=3,
has_final_layer=False,
loss=dict(type='KeypointMSELoss', use_target_weight=True))
neck, head = check_and_update_config(neck, head)

self.assertDictEqual(
neck, dict(
type='FeatureMapProcessor',
select_index=3,
))
self.assertEqual(head['in_channels'], 512)
self.assertIn('final_layer', head)
self.assertIsNone(head['final_layer'])
self.assertNotIn('input_transform', head)
self.assertNotIn('input_index', head)
self.assertNotIn('align_corners', head)

def test_case_4(self):
neck = None
head = dict(
type='RTMCCHead',
in_channels=768,
out_channels=17,
input_size='input_size',
in_featuremap_size=(9, 12),
simcc_split_ratio='simcc_split_ratio',
final_layer_kernel_size=7,
gau_cfg=dict(
hidden_dims=256,
s=128,
expansion_factor=2,
dropout_rate=0.,
drop_path=0.,
act_fn='SiLU',
use_rel_bias=False,
pos_enc=False))
neck, head_new = check_and_update_config(neck, head)

self.assertIsNone(neck)
self.assertDictEqual(head, head_new)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3aa5ef0

Please sign in to comment.