Skip to content

Commit

Permalink
change flops printing style
Browse files Browse the repository at this point in the history
Reviewed By: alexander-kirillov

Differential Revision: D27675478

fbshipit-source-id: 7259563f0eed17f52ac883fa96551af10382ff6a
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Apr 9, 2021
1 parent 652529e commit 086dfe4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 34 deletions.
23 changes: 11 additions & 12 deletions fvcore/nn/print_model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,24 +348,24 @@ def flop_count_str(
>>> inputs = torch.randn((1,10))
>>> print(flop_count_str(FlopCountAnalysis(model, inputs)))
TestNet(
n_params: 0.44K, n_flops: 0.4K
#params: 0.44K, #flops: 0.4K
(fc1): Linear(
in_features=10, out_features=10, bias=True
n_params: 0.11K, n_flops: 100
#params: 0.11K, #flops: 100
)
(fc2): Linear(
in_features=10, out_features=10, bias=True
n_params: 0.11K, n_flops: 100
#params: 0.11K, #flops: 100
)
(inner): InnerNet(
n_params: 0.22K, n_flops: 0.2K
#params: 0.22K, #flops: 0.2K
(fc1): Linear(
in_features=10, out_features=10, bias=True
n_params: 0.11K, n_flops: 100
#params: 0.11K, #flops: 100
)
(fc2): Linear(
in_features=10, out_features=10, bias=True
n_params: 0.11K, n_flops: 100
#params: 0.11K, #flops: 100
)
)
)
Expand All @@ -388,13 +388,13 @@ def flop_count_str(
flops.unsupported_ops_warnings(False)
flops.uncalled_modules_warnings(False)
flops.tracer_warnings("none")
stats = {"n_params": params, "n_flops": flops.by_module()}
stats = {"#params": params, "#flops": flops.by_module()}

if activations is not None:
activations.unsupported_ops_warnings(False)
activations.uncalled_modules_warnings(False)
activations.tracer_warnings("none")
stats["n_acts"] = activations.by_module()
stats["#acts"] = activations.by_module()

all_uncalled = flops.uncalled_modules() | (
activations.uncalled_modules() if activations is not None else set()
Expand All @@ -403,14 +403,13 @@ def flop_count_str(
stats = _group_by_module(stats)
stats = _remove_zero_statistics(stats, force_keep=all_uncalled)
stats = _pretty_statistics(stats, sig_figs=2)
stats = _indicate_uncalled_modules(stats, "n_flops", flops.uncalled_modules())
stats = _indicate_uncalled_modules(stats, "#flops", flops.uncalled_modules())
if activations is not None:
stats = _indicate_uncalled_modules(
stats, "n_acts", activations.uncalled_modules()
stats, "#acts", activations.uncalled_modules()
)

input_sizes = _get_input_sizes(flops._inputs)
model_string = "Input sizes (torch.Tensor only): {}\n".format(input_sizes)
model_string = ""
if all_uncalled:
model_string += (
"N/A indicates a possibly missing statistic due to how "
Expand Down
44 changes: 22 additions & 22 deletions tests/test_print_model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,9 @@ def test_flop_count_str(self) -> None:
)

self.assertTrue("N/A indicates a possibly missing statistic" in model_str)
self.assertTrue("n_params: 0.11K, n_flops: 100" in model_str)
self.assertTrue("#params: 0.11K, #flops: 100" in model_str)
self.assertTrue("ReLU()" in model_str) # Suppress trivial statistics
self.assertTrue("n_params: 0.11K, n_flops: N/A" in model_str) # Uncalled stats
self.assertTrue("#params: 0.11K, #flops: N/A" in model_str) # Uncalled stats
self.assertTrue("[[1, 10]]") # Input sizes

# Expected:
Expand All @@ -458,32 +458,32 @@ def test_flop_count_str(self) -> None:
# "module was called. Missing values are still included in the "
# "parent's total.\n"
# "TestNet(\n"
# " n_params: 0.33K, n_flops: 0.3K\n"
# " #params: 0.33K, #flops: 0.3K\n"
# " (a1): A1(\n"
# " n_params: 0.11K, n_flops: 100\n"
# " #params: 0.11K, #flops: 100\n"
# " (b1): A1B1(\n"
# " n_params: 0.11K, n_flops: 100\n"
# " #params: 0.11K, #flops: 100\n"
# " (c1): A1B1C1(\n"
# " n_params: 0.11K, n_flops: N/A\n"
# " #params: 0.11K, #flops: N/A\n"
# " (d1): Linear(\n"
# " in_features=10, out_features=10, bias=True\n"
# " n_params: 0.11K, n_flops: 100\n"
# " #params: 0.11K, #flops: 100\n"
# " )\n"
# " (d2): ReLU()\n"
# " )\n"
# " )\n"
# " )\n"
# " (a2): A2(\n"
# " n_params: 0.22K, n_flops: 0.2K\n"
# " #params: 0.22K, #flops: 0.2K\n"
# " (b1): A2B1(\n"
# " n_params: 0.22K, n_flops: 0.2K\n"
# " #params: 0.22K, #flops: 0.2K\n"
# " (c1): Linear(\n"
# " in_features=10, out_features=10, bias=True\n"
# " n_params: 0.11K, n_flops: 100\n"
# " #params: 0.11K, #flops: 100\n"
# " )\n"
# " (c2): Linear(\n"
# " in_features=10, out_features=10, bias=True\n"
# " n_params: 0.11K, n_flops: 100\n"
# " #params: 0.11K, #flops: 100\n"
# " )\n"
# " )\n"
# " )\n"
Expand All @@ -495,8 +495,8 @@ def test_flop_count_str(self) -> None:
activations=ActivationCountAnalysis(model, inputs).ancestor_mode("caller"),
)

self.assertTrue("n_params: 0.33K, n_flops: 0.3K, n_acts: 30" in model_str)
self.assertTrue("n_params: 0.11K, n_flops: N/A, n_acts: N/A" in model_str)
self.assertTrue("#params: 0.33K, #flops: 0.3K, #acts: 30" in model_str)
self.assertTrue("#params: 0.11K, #flops: N/A, #acts: N/A" in model_str)

# Expected:

Expand All @@ -505,32 +505,32 @@ def test_flop_count_str(self) -> None:
# "module was called. Missing values are still included in the "
# "parent's total.\n"
# "TestNet(\n"
# " n_params: 0.33K, n_flops: 0.3K, n_acts: 30\n"
# " #params: 0.33K, #flops: 0.3K, #acts: 30\n"
# " (a1): A1(\n"
# " n_params: 0.11K, n_flops: 100, n_acts: 10\n"
# " #params: 0.11K, #flops: 100, #acts: 10\n"
# " (b1): A1B1(\n"
# " n_params: 0.11K, n_flops: 100, n_acts: 10\n"
# " #params: 0.11K, #flops: 100, #acts: 10\n"
# " (c1): A1B1C1(\n"
# " n_params: 0.11K, n_flops: N/A, n_acts: N/A\n"
# " #params: 0.11K, #flops: N/A, #acts: N/A\n"
# " (d1): Linear(\n"
# " in_features=10, out_features=10, bias=True\n"
# " n_params: 0.11K, n_flops: 100, n_acts: 10\n"
# " #params: 0.11K, #flops: 100, #acts: 10\n"
# " )\n"
# " (d2): ReLU()\n"
# " )\n"
# " )\n"
# " )\n"
# " (a2): A2(\n"
# " n_params: 0.22K, n_flops: 0.2K, n_acts: 20\n"
# " #params: 0.22K, #flops: 0.2K, #acts: 20\n"
# " (b1): A2B1(\n"
# " n_params: 0.22K, n_flops: 0.2K, n_acts: 20\n"
# " #params: 0.22K, #flops: 0.2K, #acts: 20\n"
# " (c1): Linear(\n"
# " in_features=10, out_features=10, bias=True\n"
# " n_params: 0.11K, n_flops: 100, n_acts: 10\n"
# " #params: 0.11K, #flops: 100, #acts: 10\n"
# " )\n"
# " (c2): Linear(\n"
# " in_features=10, out_features=10, bias=True\n"
# " n_params: 0.11K, n_flops: 100, n_acts: 10\n"
# " #params: 0.11K, #flops: 100, #acts: 10\n"
# " )\n"
# " )\n"
# " )\n"
Expand Down

0 comments on commit 086dfe4

Please sign in to comment.