Skip to content

Commit

Permalink
[TorchTidy] Add Pattern to detect Synchronous Data Loader (pytorch#81740
Browse files Browse the repository at this point in the history
)

Summary: By setting num_workers > 0 in DataLoader, we can achieve async data loading, which is non blocking to the computation. This helps speed up the training process. By matching the call structure, we can detect if we are using Synchronous Data Loader.

Test Plan:
Added test in test.profiler.py

Differential Revision: [D38082644](https://our.internmc.facebook.com/intern/diff/D38082644)
Pull Request resolved: pytorch#81740
Approved by: https://github.com/robieta
  • Loading branch information
davidchencsl authored and pytorchmergebot committed Jul 27, 2022
1 parent df1b7c2 commit d537f86
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 7 deletions.
13 changes: 12 additions & 1 deletion test/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
ExtraCUDACopyPattern,
ForLoopIndexingPattern,
FP32MatMulPattern,
OptimizerSingleTensorPattern)
OptimizerSingleTensorPattern,
SynchronizedDataLoaderPattern)
from torch.testing._internal.common_device_type import skipCUDAVersionIn

try:
Expand Down Expand Up @@ -1655,6 +1656,16 @@ def test_profiler_optimizer_single_tensor_pattern(self):
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])

def test_profiler_synchronized_dataloader_pattern(self):
dataset = torch.rand((100, 100))
sync_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
async_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=4)
with profile(with_stack=True) as prof:
next(iter(sync_dataloader))
next(iter(async_dataloader))
pattern = SynchronizedDataLoaderPattern(prof)
num_matched = len(pattern.matched_events())
self.assertEqual(num_matched, 1)

if __name__ == '__main__':
run_tests()
56 changes: 50 additions & 6 deletions torch/profiler/_pattern_matcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
import os
import re
from typing import Dict, List, Set

Expand Down Expand Up @@ -340,12 +341,11 @@ class OptimizerSingleTensorPattern(Pattern):
def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "Optimizer Single Tensor Pattern"
self.optimizers_with_foreach = [
"adam", "sgd", "adamw"
]
self.optimizers_with_foreach = ["adam", "sgd", "adamw"]
self.description = (
"Deteced optimizer running with single tensor implementation. "
"Please enable multi tensor implementation by passing 'foreach=True' into optimizer.")
"Please enable multi tensor implementation by passing 'foreach=True' into optimizer."
)

def match(self, event: _ProfilerEvent):
for optimizer in self.optimizers_with_foreach:
Expand All @@ -354,14 +354,57 @@ def match(self, event: _ProfilerEvent):
return False


class SynchronizedDataLoaderPattern(Pattern):
'''
This pattern identifies if we are using num_workers=0 in DataLoader.
example:
torch.utils.data.DataLoader(dataset, batch_size=batch_size)
Add num_workers=N to the arguments. N depends on system configuration.
Pattern:
dataloader.py(...): __iter__
dataloader.py(...): _get_iterator
NOT dataloader.py(...): check_worker_number_rationality
Algorithm:
If we don't see check_worker_number_rationality call in the dataloader __iter__,
It is not an asynchronous dataloader.
'''

def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "Synchronized DataLoader Pattern"
self.description = (
"Detected DataLoader running with synchronized implementation. "
"Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader."
)

def match(self, event: _ProfilerEvent):
def is_dataloader_function(name: str, function_name: str):
return name.startswith(os.path.join("torch", "utils", "data", "dataloader.py")) and name.endswith(function_name)
if not is_dataloader_function(event.name(), "__iter__"):
return False
if not event.children:
return False
event = event.children[0]
if not is_dataloader_function(event.name(), "_get_iterator"):
return False
if not event.children:
return False
event = event.children[0]
return not is_dataloader_function(event.name(), "check_worker_number_rationality")
# TODO: We should also check if the loader is bottleneck.


def source_code_location(event: _ProfilerEvent):
while event:
if event_type(event) == _EventType.PyCall or event_type(
event) == _EventType.PyCCall:
assert isinstance(event.extra_fields,
_ExtraFields_PyCall) or isinstance(
event.extra_fields, _ExtraFields_PyCCall)
if not event.extra_fields.caller.file_name.startswith("torch/"):
if not event.extra_fields.caller.file_name.startswith("torch" + os.sep):
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
event = event.parent
return "No source code location found"
Expand All @@ -377,7 +420,8 @@ def report_all_anti_patterns(prof, should_benchmark: bool = False):
ExtraCUDACopyPattern(prof, should_benchmark),
ForLoopIndexingPattern(prof, should_benchmark),
FP32MatMulPattern(prof, should_benchmark),
OptimizerSingleTensorPattern(prof, should_benchmark)
OptimizerSingleTensorPattern(prof, should_benchmark),
SynchronizedDataLoaderPattern(prof, should_benchmark)
]
reported = set()
summaries = []
Expand Down

0 comments on commit d537f86

Please sign in to comment.