Skip to content

Commit

Permalink
deprecate some more caffe2 stuff
Browse files Browse the repository at this point in the history
Summary:
1. make 'import caffe2 ' optional because pytorch is removing it from OSS. this should help with
	 https://fb.workplace.com/groups/277527419809135/posts/955879138640623/

2. remove two functions `export_{caffe2,onnx}_model` that are marked deprecated a long time ago

3. make `add_export_config` a no-op and remove all callsites (there is no callsite of it in d2go).

Reviewed By: wat3rBro

Differential Revision: D32556033

fbshipit-source-id: 6254562ee892c0b57c8455003db17bb00998f5ed
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Nov 22, 2021
1 parent ac31916 commit d29378a
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 53 deletions.
10 changes: 9 additions & 1 deletion detectron2/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# -*- coding: utf-8 -*-

from .api import *
try:
from caffe2.proto import caffe2_pb2 as _tmp

# caffe2 is optional
except ImportError:
pass
else:
from .api import *

from .flatten import TracingAdapter
from .torchscript import scripting_with_instances, dump_torchscript_IR

Expand Down
40 changes: 1 addition & 39 deletions detectron2/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,12 @@

__all__ = [
"add_export_config",
"export_caffe2_model",
"Caffe2Model",
"export_onnx_model",
"Caffe2Tracer",
]


def add_export_config(cfg):
"""
Add options needed by caffe2 export.
Args:
cfg (CfgNode): a detectron2 config
Returns:
CfgNode:
an updated config with new options that will be used by :class:`Caffe2Tracer`.
"""
is_frozen = cfg.is_frozen()
cfg.defrost()
cfg.EXPORT_CAFFE2 = CfgNode()
cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT = False
if is_frozen:
cfg.freeze()
return cfg


Expand Down Expand Up @@ -68,9 +50,7 @@ class Caffe2Tracer:
def __init__(self, cfg: CfgNode, model: nn.Module, inputs):
"""
Args:
cfg (CfgNode): a detectron2 config, with extra export-related options
added by :func:`add_export_config`. It's used to construct
caffe2-compatible model.
cfg (CfgNode): a detectron2 config used to construct caffe2-compatible model.
model (nn.Module): An original pytorch model. Must be among a few official models
in detectron2 that can be converted to become caffe2-compatible automatically.
Weights have to be already loaded to this model.
Expand All @@ -81,8 +61,6 @@ def __init__(self, cfg: CfgNode, model: nn.Module, inputs):
assert isinstance(cfg, CfgNode), cfg
assert isinstance(model, torch.nn.Module), type(model)

if "EXPORT_CAFFE2" not in cfg:
cfg = add_export_config(cfg) # will just the defaults
# TODO make it support custom models, by passing in c2 model directly
C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
self.traceable_model = C2MetaArch(cfg, copy.deepcopy(model))
Expand Down Expand Up @@ -255,19 +233,3 @@ def __call__(self, inputs):
if self._predictor is None:
self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net)
return self._predictor(inputs)


def export_caffe2_model(cfg, model, inputs):
logger = logging.getLogger(__name__)
logger.warning(
"export_caffe2_model() is deprecated. Please use `Caffe2Tracer().export_caffe2() instead."
)
return Caffe2Tracer(cfg, model, inputs).export_caffe2()


def export_onnx_model(cfg, model, inputs):
logger = logging.getLogger(__name__)
logger.warning(
"export_caffe2_model() is deprecated. Please use `Caffe2Tracer().export_onnx() instead."
)
return Caffe2Tracer(cfg, model, inputs).export_onnx()
6 changes: 5 additions & 1 deletion detectron2/export/caffe2_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,12 @@ def __init__(self, cfg, torch_model):
torch_model = patch_generalized_rcnn(torch_model)
super().__init__(cfg, torch_model)

try:
use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT
except AttributeError:
use_heatmap_max_keypoint = False
self.roi_heads_patcher = ROIHeadsPatcher(
self._wrapped_model.roi_heads, cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT
self._wrapped_model.roi_heads, use_heatmap_max_keypoint
)

def encode_additional_info(self, predict_net, init_net):
Expand Down
3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
"ResNetBlockBase",
"GroupedBatchSampler",
"build_transform_gen",
"export_caffe2_model",
"export_onnx_model",
"apply_transform_gens",
"TransformGen",
"apply_augmentations",
Expand All @@ -294,6 +292,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
"WarmupMultiStepLR",
"downgrade_config",
"upgrade_config",
"add_export_config",
}
try:
if name in HIDDEN or (
Expand Down
3 changes: 1 addition & 2 deletions tests/test_export_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from detectron2 import model_zoo
from detectron2.export import Caffe2Model, Caffe2Tracer, add_export_config
from detectron2.export import Caffe2Model, Caffe2Tracer
from detectron2.utils.logger import setup_logger
from detectron2.utils.testing import get_sample_coco_image

Expand All @@ -22,7 +22,6 @@ def setUp(self):

def _test_model(self, config_path, device="cpu"):
cfg = model_zoo.get_config(config_path)
add_export_config(cfg)
cfg.MODEL.DEVICE = device
model = model_zoo.get(config_path, trained=True, device=device)

Expand Down
11 changes: 3 additions & 8 deletions tools/deploy/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,7 @@
from detectron2.config import get_cfg
from detectron2.data import build_detection_test_loader, detection_utils
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format
from detectron2.export import (
Caffe2Tracer,
TracingAdapter,
add_export_config,
dump_torchscript_IR,
scripting_with_instances,
)
from detectron2.export import TracingAdapter, dump_torchscript_IR, scripting_with_instances
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import add_pointrend_config
Expand All @@ -31,7 +25,6 @@ def setup_cfg(args):
cfg = get_cfg()
# cuda context is initialized before creating dataloader, so we don't fork anymore
cfg.DATALOADER.NUM_WORKERS = 0
cfg = add_export_config(cfg)
add_pointrend_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
Expand All @@ -40,6 +33,8 @@ def setup_cfg(args):


def export_caffe2_tracing(cfg, torch_model, inputs):
from detectron2.export import Caffe2Tracer

tracer = Caffe2Tracer(cfg, torch_model, inputs)
if args.format == "caffe2":
caffe2_model = tracer.export_caffe2()
Expand Down

0 comments on commit d29378a

Please sign in to comment.