Skip to content

Commit

Permalink
support FlopCountAnalysis from fvcore
Browse files Browse the repository at this point in the history
Summary: Print flop table

Reviewed By: theschnitz

Differential Revision: D26810342

fbshipit-source-id: 2f958f53d6b988bdcdd40ffdc84c0c1eab6dbde0
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Mar 20, 2021
1 parent a27cd63 commit c08150c
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 30 deletions.
6 changes: 4 additions & 2 deletions .github/ISSUE_TEMPLATE/documentation.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
---
name: "\U0001F4DA Documentation Issues"
about: Docs and comments are missing, incorrect, or not clear enough
about: Docs or comments are missing or inaccurate
labels: documentation

---

## 📚 Documentation

* Links to the relevant documentation/comment:
* Link to the relevant documentation/comment:

* What is missing or inaccurate?
6 changes: 3 additions & 3 deletions detectron2/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ class Caffe2Model(nn.Module):
Examples:
::
model = Caffe2Model.load_protobuf("dir/with/pb/files")
c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2()
inputs = [{"image": img_tensor_CHW}]
outputs = model(inputs)
outputs = c2_model(inputs)
orig_outputs = torch_model(inputs)
"""

def __init__(self, predict_net, init_net):
Expand Down
75 changes: 62 additions & 13 deletions detectron2/export/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,13 @@ class TracingAdapter(nn.Module):
Schema of the output produced by calling the given model with inputs.
"""

def __init__(self, model: nn.Module, inputs, inference_func: Optional[Callable] = None):
def __init__(
self,
model: nn.Module,
inputs,
inference_func: Optional[Callable] = None,
allow_non_tensor: bool = False,
):
"""
Args:
model: an nn.Module
Expand All @@ -231,6 +237,13 @@ def __init__(self, model: nn.Module, inputs, inference_func: Optional[Callable]
model with inputs, and return outputs. By default it
is ``lambda model, *inputs: model(*inputs)``. Can be override
if you need to call the model differently.
allow_non_tensor: allow inputs/outputs to contain non-tensor objects.
This option will filter out non-tensor objects to make the
model traceable, but ``inputs_schema``/``outputs_schema`` cannot be
used anymore because inputs/outputs cannot be rebuilt from pure tensors.
This is useful when you're only interested in the single trace of
execution (e.g. for flop count), but not interested in
generalizing the traced graph to new inputs.
"""
super().__init__()
if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
Expand All @@ -239,29 +252,65 @@ def __init__(self, model: nn.Module, inputs, inference_func: Optional[Callable]
if not isinstance(inputs, tuple):
inputs = (inputs,)
self.inputs = inputs
self.allow_non_tensor = allow_non_tensor

if inference_func is None:
inference_func = lambda model, *inputs: model(*inputs) # noqa
self.inference_func = inference_func

self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs)
for input in self.flattened_inputs:
if not isinstance(input, torch.Tensor):
raise ValueError(
f"Inputs for tracing must only contain tensors. Got a {type(input)} instead."
)

if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs):
return
if self.allow_non_tensor:
self.flattened_inputs = tuple(
[x for x in self.flattened_inputs if isinstance(x, torch.Tensor)]
)
self.inputs_schema = None
else:
for input in self.flattened_inputs:
if not isinstance(input, torch.Tensor):
raise ValueError(
"Inputs for tracing must only contain tensors. "
f"Got a {type(input)} instead."
)

def forward(self, *args: torch.Tensor):
with torch.no_grad(), patch_builtin_len():
inputs_orig_format = self.inputs_schema(args)
if self.inputs_schema is not None:
inputs_orig_format = self.inputs_schema(args)
else:
if args != self.flattened_inputs:
raise ValueError(
"TracingAdapter does not contain valid inputs_schema."
" So it cannot generalize to other inputs and must be"
" traced with `.flattened_inputs`."
)
inputs_orig_format = self.inputs

outputs = self.inference_func(self.model, *inputs_orig_format)
flattened_outputs, schema = flatten_to_tuple(outputs)
if self.outputs_schema is None:
self.outputs_schema = schema
else:
assert (
self.outputs_schema == schema
), "Model should always return outputs with the same structure so it can be traced!"

flattened_output_tensors = tuple(
[x for x in flattened_outputs if isinstance(x, torch.Tensor)]
)
if len(flattened_output_tensors) < len(flattened_outputs):
if self.allow_non_tensor:
flattened_outputs = flattened_output_tensors
self.outputs_schema = None
else:
raise ValueError(
"Model cannot be traced because some model outputs "
"cannot flatten to tensors."
)
else: # schema is valid
if self.outputs_schema is None:
self.outputs_schema = schema
else:
assert self.outputs_schema == schema, (
"Model should always return outputs with the same "
"structure so it can be traced!"
)
return flattened_outputs

def _create_wrapper(self, traced_model):
Expand Down
2 changes: 2 additions & 0 deletions detectron2/modeling/proposal_generator/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def losses(
normalizer = self.batch_size_per_image * num_images
losses = {
"loss_rpn_cls": objectness_loss / normalizer,
# The original Faster R-CNN paper uses a slightly different normalizer
# for loc loss. But it doesn't matter in practice
"loss_rpn_loc": localization_loss / normalizer,
}
losses = {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
Expand Down
44 changes: 39 additions & 5 deletions detectron2/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
# -*- coding: utf-8 -*-

import typing
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
import fvcore
from fvcore.nn import (
activation_count,
flop_count,
flop_count_table,
parameter_count,
parameter_count_table,
)
from torch import nn

from detectron2.export import TracingAdapter

__all__ = [
"activation_count_operators",
"flop_count_operators",
"flop_count_table",
"parameter_count_table",
"parameter_count",
]
Expand Down Expand Up @@ -48,13 +56,28 @@
}


def flop_count_operators(
model: nn.Module, inputs: list, **kwargs
) -> typing.DefaultDict[str, float]:
class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis):
"""
Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models.
"""

def __init__(self, model, inputs):
"""
Args:
model (nn.Module):
inputs (Any): inputs of the given model. Does not have to be tuple of tensors.
"""
wrapper = TracingAdapter(model, inputs, allow_non_tensor=True)
super().__init__(wrapper, wrapper.flattened_inputs)
self.set_op_handle(**{k: None for k in _IGNORED_OPS})


def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]:
"""
Implement operator-level flops counting using jit.
This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard
detection models in detectron2.
Please use :class:`FlopCountAnalysis` for more advanced functionalities.
Note:
The function runs the input through the model to compute flops.
Expand All @@ -69,8 +92,16 @@ def flop_count_operators(
model: a detectron2 model that takes `list[dict]` as input.
inputs (list[dict]): inputs to model, in detectron2's standard format.
Only "image" key will be used.
supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count`
Returns:
Counter: Gflop count per operator
"""
return _wrapper_count_operators(model=model, inputs=inputs, mode=FLOPS_MODE, **kwargs)
old_train = model.training
model.eval()
ret = FlopCountAnalysis(model, inputs).by_operator()
model.train(old_train)
return {k: v / 1e9 for k, v in ret.items()}


def activation_count_operators(
Expand All @@ -91,6 +122,9 @@ def activation_count_operators(
model: a detectron2 model that takes `list[dict]` as input.
inputs (list[dict]): inputs to model, in detectron2's standard format.
Only "image" key will be used.
Returns:
Counter: activation count per operator
"""
return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions docs/modules/fvcore.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ fvcore.common

.. automodule:: fvcore.common.param_scheduler
:members:
:inherited-members:
:undoc-members:
:show-inheritance:

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def get_model_zoo_configs() -> List[str]:
"matplotlib",
"tqdm>4.29.0",
"tensorboard",
"fvcore>=0.1.3,<0.1.4", # required like this to make it pip installable
"fvcore>=0.1.4,<0.1.5", # required like this to make it pip installable
"iopath>=0.1.2",
"pycocotools>=2.0.2", # corresponds to https://github.com/ppwwyyxx/cocoapi
"future", # used by caffe2
Expand Down
22 changes: 16 additions & 6 deletions tools/analyze_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from detectron2.engine import default_argument_parser
from detectron2.modeling import build_model
from detectron2.utils.analysis import (
FlopCountAnalysis,
activation_count_operators,
flop_count_operators,
flop_count_table,
parameter_count_table,
)
from detectron2.utils.logger import setup_logger
Expand All @@ -27,6 +28,7 @@ def setup(args):
cfg.DATALOADER.NUM_WORKERS = 0
cfg.merge_from_list(args.opts)
cfg.freeze()
setup_logger(name="fvcore")
setup_logger()
return cfg

Expand All @@ -40,13 +42,20 @@ def do_flop(cfg):
counts = Counter()
total_flops = []
for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa
count = flop_count_operators(model, data)
counts += count
total_flops.append(sum(count.values()))
flops = FlopCountAnalysis(model, data)
if idx > 0:
flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False)
counts += flops.by_operator()
total_flops.append(flops.total())

logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops))
logger.info(
"Average GFlops for each type of operators:\n"
+ str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()])
)
logger.info(
"(G)Flops for Each Type of Operators:\n" + str([(k, v / idx) for k, v in counts.items()])
"Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9)
)
logger.info("Total (G)Flops: {}±{}".format(np.mean(total_flops), np.std(total_flops)))


def do_activation(cfg):
Expand Down Expand Up @@ -106,6 +115,7 @@ def do_structure(cfg):
nargs="+",
)
parser.add_argument(
"-n",
"--num-inputs",
default=100,
type=int,
Expand Down

0 comments on commit c08150c

Please sign in to comment.