forked from facebookresearch/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reviewed By: wat3rBro Differential Revision: D19255042 fbshipit-source-id: 6b12bba55c1a86fa65868d235aead7f393133c19
- Loading branch information
1 parent
8cab00c
commit 1a780e0
Showing
10 changed files
with
224 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from .api import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import logging | ||
import os | ||
import torch | ||
from caffe2.proto import caffe2_pb2 | ||
from torch import nn | ||
|
||
from detectron2.config import CfgNode as CN | ||
|
||
from .caffe2_export import export_caffe2_detection_model, run_and_save_graph | ||
from .caffe2_inference import ProtobufDetectionModel | ||
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format | ||
from .shared import get_pb_arg_vali, save_graph | ||
|
||
__all__ = ["add_export_config", "export_caffe2_model", "Caffe2Model"] | ||
|
||
|
||
def add_export_config(cfg): | ||
""" | ||
Args: | ||
cfg (CfgNode): a detectron2 config | ||
Returns: | ||
CfgNode: an updated config with new options that :func:`export_caffe2_model` will need. | ||
""" | ||
is_frozen = cfg.is_frozen() | ||
cfg.defrost() | ||
cfg.EXPORT_CAFFE2 = CN() | ||
cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT = False | ||
if is_frozen: | ||
cfg.freeze() | ||
return cfg | ||
|
||
|
||
def export_caffe2_model(cfg, model, inputs): | ||
""" | ||
Export a detectron2 model to caffe2 format. | ||
Args: | ||
cfg (CfgNode): a detectron2 config, with extra export-related options | ||
added by :func:`add_export_config`. | ||
model (nn.Module): a model built by | ||
:func:`detectron2.modeling.build_model`. | ||
It will be modified by this function. | ||
inputs: sample inputs that the given model takes for inference. | ||
Will be used to trace the model. | ||
Returns: | ||
Caffe2Model | ||
""" | ||
assert isinstance(cfg, CN), cfg | ||
C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE] | ||
c2_compatible_model = C2MetaArch(cfg, model) | ||
c2_format_input = c2_compatible_model.get_tensors_input(inputs) | ||
predict_net, init_net = export_caffe2_detection_model(c2_compatible_model, c2_format_input) | ||
return Caffe2Model(predict_net, init_net) | ||
|
||
|
||
class Caffe2Model(nn.Module): | ||
def __init__(self, predict_net, init_net): | ||
super().__init__() | ||
self.eval() # always in eval mode | ||
self._predict_net = predict_net | ||
self._init_net = init_net | ||
self._predictor = None | ||
|
||
@property | ||
def predict_net(self): | ||
""" | ||
Returns: | ||
core.Net: the underlying caffe2 predict net | ||
""" | ||
return self._predict_net | ||
|
||
@property | ||
def init_net(self): | ||
""" | ||
Returns: | ||
core.Net: the underlying caffe2 init net | ||
""" | ||
return self._init_net | ||
|
||
__init__.__HIDE_SPHINX_DOC__ = True | ||
|
||
def save_protobuf(self, output_dir): | ||
""" | ||
Save the model as caffe2's protobuf format. | ||
Args: | ||
output_dir (str): the output directory to save protobuf files. | ||
""" | ||
logger = logging.getLogger(__name__) | ||
logger.info("Saving model to {} ...".format(output_dir)) | ||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
with open(os.path.join(output_dir, "model.pb"), "wb") as f: | ||
f.write(self._predict_net.SerializeToString()) | ||
with open(os.path.join(output_dir, "model.pbtxt"), "w") as f: | ||
f.write(str(self._predict_net)) | ||
with open(os.path.join(output_dir, "model_init.pb"), "wb") as f: | ||
f.write(self._init_net.SerializeToString()) | ||
|
||
def save_graph(self, output_file, inputs=None): | ||
""" | ||
Save the graph as SVG format. | ||
Args: | ||
output_file (str): a SVG file | ||
inputs: optional inputs given to the model. | ||
If given, the inputs will be used to run the graph to record | ||
shape of every tensor. The shape information will be | ||
saved together with the graph. | ||
""" | ||
if inputs is None: | ||
save_graph(self._predict_net, output_file, op_only=False) | ||
else: | ||
size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0) | ||
inputs = convert_batched_inputs_to_c2_format( | ||
inputs, size_divisibility, torch.device("cpu") | ||
) | ||
inputs = [x.numpy() for x in inputs] | ||
run_and_save_graph(self._predict_net, self._init_net, inputs, output_file) | ||
|
||
@staticmethod | ||
def load_protobuf(dir): | ||
""" | ||
Args: | ||
dir (str): a directory used to save Caffe2Model with | ||
:meth:`save_protobuf`. | ||
The files "model.pb" and "model_init.pb" are needed. | ||
Returns: | ||
Caffe2Model: the caffe2 model loaded from this directory. | ||
""" | ||
predict_net = caffe2_pb2.NetDef() | ||
with open(os.path.join(dir, "model.pb"), "rb") as f: | ||
predict_net.ParseFromString(f.read()) | ||
|
||
init_net = caffe2_pb2.NetDef() | ||
with open(os.path.join(dir, "model_init.pb"), "rb") as f: | ||
init_net.ParseFromString(f.read()) | ||
|
||
return Caffe2Model(predict_net, init_net) | ||
|
||
def __call__(self, inputs): | ||
""" | ||
An interface that wraps around a caffe2 model and mimics detectron2's models' | ||
input & output format. This is used to compare the caffe2 model | ||
with its original torch model. | ||
""" | ||
if self._predictor is None: | ||
self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net) | ||
return self._predictor(inputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import argparse | ||
import os | ||
|
||
from detectron2.checkpoint import DetectionCheckpointer | ||
from detectron2.config import get_cfg | ||
from detectron2.data import build_detection_test_loader | ||
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format | ||
from detectron2.export import add_export_config, export_caffe2_model | ||
from detectron2.modeling import build_model | ||
from detectron2.utils.logger import setup_logger | ||
|
||
|
||
def setup_cfg(args): | ||
cfg = get_cfg() | ||
cfg = add_export_config(cfg) | ||
cfg.merge_from_file(args.config_file) | ||
cfg.merge_from_list(args.opts) | ||
cfg.freeze() | ||
return cfg | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Convert a model to Caffe2") | ||
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") | ||
parser.add_argument("--run-eval", action="store_true") | ||
parser.add_argument("--output", help="output directory for the converted caffe2 model") | ||
parser.add_argument( | ||
"opts", | ||
help="Modify config options using the command-line", | ||
default=None, | ||
nargs=argparse.REMAINDER, | ||
) | ||
args = parser.parse_args() | ||
logger = setup_logger() | ||
logger.info("Command line arguments: " + str(args)) | ||
|
||
cfg = setup_cfg(args) | ||
|
||
# create a torch model | ||
torch_model = build_model(cfg) | ||
DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS) | ||
|
||
# get a sample data | ||
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) | ||
first_batch = next(iter(data_loader)) | ||
|
||
# convert and save caffe2 model | ||
caffe2_model = export_caffe2_model(cfg, torch_model, first_batch) | ||
caffe2_model.save_protobuf(args.output) | ||
# draw the caffe2 graph | ||
caffe2_model.save_graph(os.path.join(args.output, "model_def.svg"), inputs=first_batch) | ||
|
||
# run evaluation with the converted model | ||
if args.run_eval: | ||
dataset = cfg.DATASETS.TEST[0] | ||
data_loader = build_detection_test_loader(cfg, dataset) | ||
# NOTE: hard-coded evaluator. change to the evaluator for your dataset | ||
evaluator = COCOEvaluator(dataset, cfg, True, args.output) | ||
metrics = inference_on_dataset(caffe2_model, data_loader, evaluator) | ||
print_csv_format(metrics) |