Skip to content

Commit

Permalink
[TorchTidy] Add option to generate json report (pytorch#82261)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#82261
Approved by: https://github.com/robieta
  • Loading branch information
davidchencsl authored and pytorchmergebot committed Aug 3, 2022
1 parent 1a74fd1 commit 7922bbe
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 9 deletions.
29 changes: 28 additions & 1 deletion test/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
SynchronizedDataLoaderPattern,
GradNotSetToNonePattern,
Conv2dBiasFollowedByBatchNorm2dPattern,
MatMulDimInFP16Pattern)
MatMulDimInFP16Pattern,
report_all_anti_patterns)
from torch.testing._internal.common_device_type import skipCUDAVersionIn

try:
Expand Down Expand Up @@ -1728,6 +1729,32 @@ def test_profiler_matmul_dim_fp16_pattern(self):
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])

def test_profiler_pattern_matcher_json_report(self):
x = torch.ones((100, 100))
model = nn.Sequential(
nn.Linear(100, 100),
nn.ReLU(),
nn.Linear(100, 10),
)
optimizer = torch.optim.Adam(model.parameters())
with profile(with_stack=True, record_shapes=True) as prof:
y_hat = model(x)
loss = torch.nn.functional.cross_entropy(y_hat, torch.randint(0, 10, (100,)))
loss.backward()
optimizer.step()
optimizer.zero_grad()
report_all_anti_patterns(prof, json_report_dir=".", print_enable=False)
try:
with open("./torchtidy_report.json") as f:
report = json.load(f)
self.assertTrue("test_profiler.py" in report)
self.assertTrue(len(report["test_profiler.py"]) > 0)
expected_fields = sorted(["line_number", "name", "url", "message"])
for event in report["test_profiler.py"]:
actual_fields = sorted(event.keys())
self.assertEqual(expected_fields, actual_fields)
finally:
os.remove("torchtidy_report.json")

if __name__ == '__main__':
run_tests()
63 changes: 55 additions & 8 deletions torch/profiler/_pattern_matcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
import json
import math
import os
import re
Expand Down Expand Up @@ -26,6 +27,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
self.should_benchmark = should_benchmark
self.name = "Please specify a name for pattern"
self.description = "Please specify a description for pattern"
self.url = ""
assert prof.profiler is not None and prof.profiler.kineto_results is not None
self.event_tree = prof.profiler.kineto_results.experimental_event_tree(
)
Expand All @@ -52,11 +54,13 @@ def summary(self, events: List[_ProfilerEvent]):
default_summary = f"{self.name}: {len(events)} events matched."
if self.should_benchmark:
# If benchmark summary is not empty, use it.
return self.benchmark_summary(events) if hasattr( # type: ignore[attr-defined]
self, 'benchmark') else default_summary
return self.benchmark_summary(
events) if hasattr( # type: ignore[attr-defined]
self, 'benchmark') else default_summary
return default_summary

def benchmark_summary(self, events: List[_ProfilerEvent]):

def format_time(time_ns: int):
unit_lst = ["ns", "us", "ms"]
for unit in unit_lst:
Expand All @@ -66,7 +70,8 @@ def format_time(time_ns: int):
return f"{time_ns:.2f} s"

assert hasattr(self, 'benchmark'), 'Please implement benchmark()'
shapes_factor_map = self.benchmark(events) # type: ignore[attr-defined]
shapes_factor_map = self.benchmark( # type: ignore[attr-defined]
events)
original_time = sum(event.duration_time_ns for event in events)
new_time = sum(shapes_factor_map[input_shapes(event)] *
event.duration_time_ns for event in events)
Expand Down Expand Up @@ -158,6 +163,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "Extra CUDA Copy Pattern"
self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initalize it on GPU."
self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device"
self.init_ops = {
"aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_"
}
Expand Down Expand Up @@ -273,6 +279,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
"You are currently using GPU that supports TF32. "
"Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
)
self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"

@property
def skip(self):
Expand Down Expand Up @@ -341,6 +348,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
"Deteced optimizer running with single tensor implementation. "
"Please enable multi tensor implementation by passing 'foreach=True' into optimizer."
)
self.url = ""

def match(self, event: _ProfilerEvent):
for optimizer in self.optimizers_with_foreach:
Expand Down Expand Up @@ -374,10 +382,17 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
"Detected DataLoader running with synchronized implementation. "
"Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader."
)
self.url = (
"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
"#enable-async-data-loading-and-augmentation")

def match(self, event: _ProfilerEvent):

def is_dataloader_function(name: str, function_name: str):
return name.startswith(os.path.join("torch", "utils", "data", "dataloader.py")) and name.endswith(function_name)
return name.startswith(
os.path.join("torch", "utils", "data",
"dataloader.py")) and name.endswith(function_name)

if not is_dataloader_function(event.name(), "__iter__"):
return False
if not event.children:
Expand All @@ -388,7 +403,8 @@ def is_dataloader_function(name: str, function_name: str):
if not event.children:
return False
event = event.children[0]
return not is_dataloader_function(event.name(), "check_worker_number_rationality")
return not is_dataloader_function(event.name(),
"check_worker_number_rationality")
# TODO: We should also check if the loader is bottleneck.


Expand Down Expand Up @@ -417,6 +433,9 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
self.description = (
"Detected gradient set to zero instead of None. "
"Please add 'set_to_none=True' when calling zero_grad().")
self.url = (
"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
"#disable-gradient-calculation-for-validation-or-inference")

def match(self, event: _ProfilerEvent):
if not event.name().endswith(": zero_grad"):
Expand Down Expand Up @@ -449,6 +468,9 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
self.url = (
"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
"#disable-bias-for-convolutions-directly-followed-by-a-batch-norm")

@property
def skip(self):
Expand Down Expand Up @@ -476,6 +498,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "Matrix Multiplication Dimension Not Aligned Pattern"
self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension."
self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp"

@property
def skip(self):
Expand Down Expand Up @@ -538,7 +561,8 @@ def source_code_location(event: _ProfilerEvent):
assert isinstance(event.extra_fields,
_ExtraFields_PyCall) or isinstance(
event.extra_fields, _ExtraFields_PyCCall)
if not event.extra_fields.caller.file_name.startswith("torch" + os.sep):
if not event.extra_fields.caller.file_name.startswith("torch" +
os.sep):
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
event = event.parent
return "No source code location found"
Expand Down Expand Up @@ -578,7 +602,11 @@ def eventTreeBFS(event_tree: List[_ProfilerEvent]):
stack.append(child_event)


def report_all_anti_patterns(prof, should_benchmark: bool = False):
def report_all_anti_patterns(prof,
should_benchmark: bool = False,
print_enable: bool = True,
json_report_dir: str = None):
report_dict: Dict = {}
anti_patterns = [
ExtraCUDACopyPattern(prof, should_benchmark),
ForLoopIndexingPattern(prof, should_benchmark),
Expand All @@ -604,8 +632,27 @@ def report_all_anti_patterns(prof, should_benchmark: bool = False):
if report_msg not in reported:
message_list.append(report_msg)
reported.add(report_msg)
src_location, line_no = source_code_location(event).split(":")
report_dict.setdefault(src_location, []).append({
"line_number": int(line_no),
"name": anti_pattern.name,
"url": anti_pattern.url,
"message": anti_pattern.description,
})

if json_report_dir is not None:
json_report_path = os.path.join(json_report_dir,
"torchtidy_report.json")
if os.path.exists(json_report_path):
with open(json_report_path, "r") as f:
exisiting_report = json.load(f)
exisiting_report.update(report_dict)
report_dict = exisiting_report
with open(json_report_path, "w") as f:
json.dump(report_dict, f, indent=4)

message_list.append("Summary:")
message_list += summaries
message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}")
print("\n".join(message_list))
if print_enable:
print("\n".join(message_list))

0 comments on commit 7922bbe

Please sign in to comment.