Skip to content

Commit

Permalink
[Refactor] Add 'to_backend' in BackendManager (open-mmlab#1522)
Browse files Browse the repository at this point in the history
* Refactor to backend

* export_postprocess_mask = False as defailt

* update zh_cn docs

* solve comment

* fix comment
  • Loading branch information
grimoire authored Dec 21, 2022
1 parent 26d71ce commit 5285caf
Show file tree
Hide file tree
Showing 28 changed files with 712 additions and 500 deletions.
32 changes: 10 additions & 22 deletions docs/en/07-developer-guide/support_new_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,32 +123,20 @@ The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" fi
__all__ += ['onnx2ncnn', 'get_output_model_file']
```

Then add the codes about conversion to `tools/deploy.py` using these APIs if necessary.
Create a backend manager class which derive from `BackendManager`, implement its `to_backend` static method.

**Example:**

```Python
# tools/deploy.py
# ...
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import is_available as is_available_ncnn

if not is_available_ncnn():
logging.error('ncnn support is not available.')
exit(-1)

from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file

backend_files = []
for onnx_path in onnx_files:
create_process(
f'onnx2ncnn with {onnx_path}',
target=onnx2ncnn,
args=(onnx_path, args.work_dir),
kwargs=dict(),
ret_value=ret_value)
backend_files += get_output_model_file(onnx_path, args.work_dir)
# ...
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
deploy_cfg: Any,
work_dir: str,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
return ir_files
```

6. Convert the models of OpenMMLab to backends (if necessary) and inference on backend engine. If you find some incompatible operators when testing, you can try to rewrite the original model for the backend following the [rewriter tutorial](support_new_model.md) or add custom operators.
Expand Down
36 changes: 12 additions & 24 deletions docs/zh_cn/07-developer-guide/support_new_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”
call([onnx2ncnn_path, onnx_path, save_param, save_bin])\
```

5. 在 `mmdeploy/apis` 中创建新后端库并声明对应 APIs
从 BackendManager 派生类,实现 `to_backend` 类方法。

**例子**

Expand All @@ -128,32 +128,20 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”
**例子**

```Python
# tools/deploy.py
# ...
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import is_available as is_available_ncnn

if not is_available_ncnn():
logging.error('ncnn support is not available.')
exit(-1)

from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file

backend_files = []
for onnx_path in onnx_files:
create_process(
f'mmdeploy_onnx2ncnn with {onnx_path}',
target=onnx2ncnn,
args=(onnx_path, args.work_dir),
kwargs=dict(),
ret_value=ret_value)
backend_files += get_output_model_file(onnx_path, args.work_dir)
# ...
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
deploy_cfg: Any,
work_dir: str,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
return ir_files
```

6. 将 OpenMMLab 的模型转换后(如有必要)并在后端引擎上进行推理。如果在测试时发现一些不兼容的算子,可以尝试按照[重写器教程](support_new_model.md)为后端重写原始模型或添加自定义算子。
5. 将 OpenMMLab 的模型转换后(如有必要)并在后端引擎上进行推理。如果在测试时发现一些不兼容的算子,可以尝试按照[重写器教程](support_new_model.md)为后端重写原始模型或添加自定义算子。

7. 为新后端引擎代码添加相关注释和单元测试:).
6. 为新后端引擎代码添加相关注释和单元测试:).

## 支持后端推理

Expand Down
5 changes: 3 additions & 2 deletions mmdeploy/apis/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .calibration import create_calib_input_data
from .utils import build_task_processor, get_predefined_partition_cfg
from .utils import (build_task_processor, get_predefined_partition_cfg,
to_backend)

__all__ = [
'create_calib_input_data', 'build_task_processor',
'get_predefined_partition_cfg'
'get_predefined_partition_cfg', 'to_backend'
]
37 changes: 37 additions & 0 deletions mmdeploy/apis/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Any, Optional, Sequence

import mmcv

from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
from mmdeploy.utils import (get_backend, get_codebase, get_task_type,
parse_device_id)
from ..core import PIPELINE_MANAGER


def check_backend_device(deploy_cfg: mmcv.Config, device: str):
Expand Down Expand Up @@ -62,3 +66,36 @@ def get_predefined_partition_cfg(deploy_cfg: mmcv.Config, partition_type: str):
codebase = get_codebase_class(codebase_type)
task_processor_class = codebase.get_task_class(task)
return task_processor_class.get_partition_cfg(partition_type)


@PIPELINE_MANAGER.register_pipeline()
def to_backend(backend_name: str,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Optional[Any] = None,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
backend_name (str): The name of the backend.
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be save in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Seqeuence[str]: Backend files.
"""
from mmdeploy.backend.base import get_backend_manager
backend_mgr = get_backend_manager(backend_name)
return backend_mgr.to_backend(
ir_files=ir_files,
work_dir=work_dir,
deploy_cfg=deploy_cfg,
log_level=log_level,
device=device,
**kwargs)
37 changes: 37 additions & 0 deletions mmdeploy/backend/ascend/backend_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from typing import Any, Optional, Sequence

from ..base import BACKEND_MANAGERS, BaseBackendManager
Expand Down Expand Up @@ -29,3 +31,38 @@ def build_wrapper(cls,
"""
from .wrapper import AscendWrapper
return AscendWrapper(model=backend_files[0], device=device)

@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be save in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Seqeuence[str]: Backend files.
"""
from mmdeploy.utils import get_model_inputs
from .onnx2ascend import from_onnx

model_inputs = get_model_inputs(deploy_cfg)

om_files = []
for model_id, onnx_path in enumerate(ir_files):
om_path = osp.splitext(onnx_path)[0] + '.om'
from_onnx(onnx_path, work_dir, model_inputs[model_id])
om_files.append(om_path)
backend_files = om_files

return backend_files
8 changes: 5 additions & 3 deletions mmdeploy/backend/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backend_manager import BACKEND_MANAGERS, BaseBackendManager
from .backend_manager import (BACKEND_MANAGERS, BaseBackendManager,
get_backend_manager)
from .backend_wrapper_registry import (BACKEND_WRAPPER, get_backend_file_count,
get_backend_wrapper_class)
from .base_wrapper import BaseWrapper

__all__ = [
'BACKEND_MANAGERS', 'BaseBackendManager', 'BaseWrapper', 'BACKEND_WRAPPER',
'get_backend_wrapper_class', 'get_backend_file_count'
'BACKEND_MANAGERS', 'BaseBackendManager', 'get_backend_manager',
'BaseWrapper', 'BACKEND_WRAPPER', 'get_backend_wrapper_class',
'get_backend_file_count'
]
42 changes: 41 additions & 1 deletion mmdeploy/backend/base/backend_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import logging
from abc import ABCMeta
from typing import Any, Optional, Sequence

Expand Down Expand Up @@ -28,7 +29,31 @@ def build_wrapper(cls,
to None.
"""
raise NotImplementedError(
f'build_wrapper has not been implemented for {cls}')
f'build_wrapper has not been implemented for `{cls.__name__}`')

@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be save in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Seqeuence[str]: Backend files.
"""
raise NotImplementedError(
f'to_backend has not been implemented for `{cls.__name__}`')


class BackendManagerRegistry:
Expand Down Expand Up @@ -89,3 +114,18 @@ def find(self, name: str) -> BaseBackendManager:


BACKEND_MANAGERS = BackendManagerRegistry()


def get_backend_manager(name: str) -> BaseBackendManager:
"""Get backend manager.
Args:
name (str): name of the backend.
Returns:
BaseBackendManager: The backend manager of given name
"""
from enum import Enum
if isinstance(name, Enum):
name = name.value
return BACKEND_MANAGERS.find(name)
35 changes: 35 additions & 0 deletions mmdeploy/backend/coreml/backend_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from typing import Any, Optional, Sequence

from ..base import BACKEND_MANAGERS, BaseBackendManager
Expand Down Expand Up @@ -29,3 +31,36 @@ def build_wrapper(cls,
"""
from .wrapper import CoreMLWrapper
return CoreMLWrapper(model_file=backend_files[0])

@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be save in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Seqeuence[str]: Backend files.
"""
from .torchscript2coreml import from_torchscript

coreml_files = []
for model_id, torchscript_path in enumerate(ir_files):
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0]
output_file_prefix = osp.join(work_dir, torchscript_name)

from_torchscript(model_id, torchscript_path, output_file_prefix,
deploy_cfg, coreml_files)

return coreml_files
49 changes: 47 additions & 2 deletions mmdeploy/backend/ncnn/backend_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.

import logging
import os.path as osp
import sys
from typing import Any, Optional, Sequence

from mmdeploy.utils import get_backend_config
from mmdeploy.utils import get_backend_config, get_root_logger
from ..base import BACKEND_MANAGERS, BaseBackendManager


Expand Down Expand Up @@ -43,3 +45,46 @@ def build_wrapper(cls,
bin_file=backend_files[1],
output_names=output_names,
use_vulkan=use_vulkan)

@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be save in this directory.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Seqeuence[str]: Backend files.
"""
logger = get_root_logger()

from . import is_available

if not is_available():
logger.error('ncnn support is not available, please make sure:\n'
'1) `mmdeploy_onnx2ncnn` existed in `PATH`\n'
'2) python import ncnn success')
sys.exit(1)

from mmdeploy.apis.ncnn import get_output_model_file
from .onnx2ncnn import from_onnx

backend_files = []
for onnx_path in ir_files:
model_param_path, model_bin_path = get_output_model_file(
onnx_path, work_dir)
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
from_onnx(onnx_path, osp.join(work_dir, onnx_name))

backend_files += [model_param_path, model_bin_path]

return backend_files
Loading

0 comments on commit 5285caf

Please sign in to comment.