From 8d9c8897eddf53b01ce760ff9ae9ec9e4c2c08b5 Mon Sep 17 00:00:00 2001 From: Brian Coutinho Date: Wed, 26 Jul 2023 03:44:57 +0000 Subject: [PATCH] [profiler] add option for kineto synchronization events in the trace (#105187) Summary: ## About Sync Events For CUDA profiling mode, we can enable tracing CUDA synchronization events. * This feature captures synchronization events in CUDA including 1) context/device sync, 2) stream sync, 3) CUDA event sync, 4) CUDA stream wait event (inter stream synchronization). Read more * We add this flag using the profiler's experimental config option. * This PR relies on https://github.com/pytorch/kineto/commit/7b003638c6d65537ac9678c101f6e9adf013b157 change in pytorch/kineto ## Usage Just set the `enable_cuda_sync_events` option in `_ExperimentalConfig` ``` from torch.autograd.profiler import profile, _ExperimentalConfig with profile(use_kineto=True, use_cuda=True, experimental_config=_ExperimentalConfig(enable_cuda_sync_events=True), ) as prof: workload() ``` **Please wait for PyTorch github repo to point to https://github.com/pytorch/kineto/commit/7b003638c6d65537ac9678c101f6e9adf013b157 or later commit in Kineto** Test Plan: ## Unit Test Added a unit test buck2 test mode/dev-nosan caffe2/test:profiler --local-only -- test_profiler_cuda_sync_events Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0 ttps://www.internalfb.com/intern/testinfra/testrun/281475298097379 Reviewed By: davidberard98 Differential Revision: D46244591 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105187 Approved by: https://github.com/aaronenyeshi --- test/profiler/test_profiler.py | 38 ++++++++++++++++++- test/profiler/test_profiler_tree.py | 4 ++ torch/_C/_profiler.pyi | 2 + torch/csrc/autograd/profiler_kineto.cpp | 2 +- torch/csrc/profiler/collection.cpp | 19 ++++++++++ torch/csrc/profiler/collection.h | 4 ++ torch/csrc/profiler/kineto_shim.cpp | 6 +++ .../csrc/profiler/orchestration/observer.cpp | 2 + torch/csrc/profiler/orchestration/observer.h | 7 ++++ torch/csrc/profiler/python/init.cpp | 26 +++++++++---- 10 files changed, 99 insertions(+), 11 deletions(-) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 1a2071bbf28a36..c84e801dd23706 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -1485,10 +1485,44 @@ def test_profiler_disable_fwd_bwd_link(self): events = j["traceEvents"] for e in events: - self.assertNotEqual(getattr(e, "cat", None), "fwdbwd") + self.assertNotEqual(e.get("cat", None), "fwdbwd") finally: torch._C._profiler._set_fwd_bwd_enabled_val(True) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_profiler_cuda_sync_events(self): + device = torch.device("cuda:0") + t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device) + + def workload() -> None: + torch.add(t1, t2) + torch.cuda.synchronize() + torch.add(t1, t2) + + def trace_and_check(exp_config: Optional[_ExperimentalConfig]) -> None: + with _profile(use_kineto=True, use_cuda=True, + experimental_config=exp_config, + ) as prof: + workload() + + with TemporaryFileName(mode="w+") as fname: + # fname = "/tmp/kineto_out.json" + prof.export_chrome_trace(fname) + with open(fname) as f: + j = json.load(f) + cats = {e.get("cat", None) for e in j["traceEvents"]} + self.assertTrue("cuda_sync" in cats, "Expected to find cuda_sync event" + f" found = {cats}") + + print("Testing enable_cuda_sync_events in _ExperimentalConfig") + trace_and_check(exp_config=_ExperimentalConfig(enable_cuda_sync_events=True)) + + print("Testing _profiler._set_cuda_sync_enabled_val()") + try: + torch._C._profiler._set_cuda_sync_enabled_val(True) + trace_and_check(exp_config=None) + finally: + torch._C._profiler._set_cuda_sync_enabled_val(False) def test_profiler_type(self): profiler_type = torch._C._autograd._profiler_type @@ -1522,7 +1556,7 @@ def test_profiler_correlation_id(self): model(inputs) for event in prof.profiler.kineto_results.events(): corr_id = event.correlation_id() - if (corr_id): + if (corr_id) and event.device_type() == DeviceType.CPU: self.assertTrue(corr_id not in id_uniqueness_set) id_uniqueness_set.add(corr_id) self.assertTrue(corr_id < uint32_max) diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 1c87193c92cf69..267953a282149e 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -128,6 +128,10 @@ def flatten(nodes, depth=0, out=None): flat_nodes = flatten(profiler.kineto_results.experimental_event_tree()) # Profiler inserts a `cudaDeviceSynchronize` at the end of profiling. + # and may also insert 'Context Sync' CUDA synchronization event. + if flat_nodes and flat_nodes[-2][1] == "cudaDeviceSynchronize": + flat_nodes = flat_nodes[:-2] + if flat_nodes and flat_nodes[-1][1] == "cudaDeviceSynchronize": flat_nodes = flat_nodes[:-1] diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index 8dbb20867f3d7e..c90312660266ab 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -54,6 +54,7 @@ class _ExperimentalConfig: profiler_metrics: List[str] = ..., profiler_measure_per_kernel: bool = ..., verbose: bool = ..., + enable_cuda_sync_events: bool = ..., ) -> None: ... class ProfilerConfig: @@ -216,3 +217,4 @@ def _enable_execution_trace_observer() -> None: ... def _disable_execution_trace_observer() -> None: ... def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ... def _set_fwd_bwd_enabled_val(val: bool) -> None: ... +def _set_cuda_sync_enabled_val(val: bool) -> None: ... diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 30c530bed3ce4e..e59abc859f7153 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -378,7 +378,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase { void materializeOpEvents(std::vector>& events) { for (auto& e : events) { - if (e->parent_.expired()) { + if (e->parent_.expired() && e->deviceType() == c10::DeviceType::CPU) { event_tree_.push_back(e); } diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index 0da4213c33e292..c3820893418f50 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -1460,6 +1460,25 @@ void set_fwd_bwd_enabled_val(bool val) { fwd_bwd_enabled_fn() = [val]() { return val; }; } +namespace { +std::function& cuda_sync_enabled_fn() { + static std::function fn = []() { return true; }; + return fn; +} +} // namespace + +bool get_cuda_sync_enabled() { + return cuda_sync_enabled_fn()(); +} + +void set_cuda_sync_enabled_fn(std::function fn) { + cuda_sync_enabled_fn() = std::move(fn); +} + +void set_cuda_sync_enabled_val(bool val) { + cuda_sync_enabled_fn() = [val]() { return val; }; +} + } // namespace impl } // namespace profiler } // namespace torch diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index 0f0e58f410db0b..20ef0c85836e94 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -637,6 +637,10 @@ TORCH_API bool get_fwd_bwd_enabled(); TORCH_API void set_fwd_bwd_enabled_fn(std::function); TORCH_API void set_fwd_bwd_enabled_val(bool); +TORCH_API bool get_cuda_sync_enabled(); +TORCH_API void set_cuda_sync_enabled_fn(std::function); +TORCH_API void set_cuda_sync_enabled_val(bool); + } // namespace impl } // namespace profiler } // namespace torch diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index e21a0c89860bea..59e721eec41192 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -235,6 +236,10 @@ void prepareTrace( } if (activities.count(torch::autograd::profiler::ActivityType::CUDA)) { k_activities.insert(kCudaTypes.begin(), kCudaTypes.end()); + if (config.enable_cuda_sync_events || get_cuda_sync_enabled()) { + LOG(INFO) << "Enabling CUDA Sync Events"; + k_activities.insert(libkineto::ActivityType::CUDA_SYNC); + } } ExperimentalConfigWrapper configWrap(config); @@ -320,6 +325,7 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { case libkineto::ActivityType::GPU_MEMCPY: case libkineto::ActivityType::GPU_MEMSET: case libkineto::ActivityType::CONCURRENT_KERNEL: + case libkineto::ActivityType::CUDA_SYNC: case libkineto::ActivityType::GPU_USER_ANNOTATION: case libkineto::ActivityType::CUDA_PROFILER_RANGE: // TODO: T151322015 diff --git a/torch/csrc/profiler/orchestration/observer.cpp b/torch/csrc/profiler/orchestration/observer.cpp index d0cb0823668147..f027ed7feac338 100644 --- a/torch/csrc/profiler/orchestration/observer.cpp +++ b/torch/csrc/profiler/orchestration/observer.cpp @@ -18,11 +18,13 @@ ExperimentalConfig::ExperimentalConfig( bool profiler_measure_per_kernel, bool verbose, std::vector performance_events, + bool enable_cuda_sync_events, bool adjust_timestamps) : profiler_metrics{std::move(profiler_metrics)}, profiler_measure_per_kernel{profiler_measure_per_kernel}, verbose{verbose}, performance_events(std::move(performance_events)), + enable_cuda_sync_events{enable_cuda_sync_events}, adjust_timestamps{adjust_timestamps} {} /*explicit*/ ExperimentalConfig::operator bool() const { diff --git a/torch/csrc/profiler/orchestration/observer.h b/torch/csrc/profiler/orchestration/observer.h index f3945acdb9d3df..5d42f9234c381f 100644 --- a/torch/csrc/profiler/orchestration/observer.h +++ b/torch/csrc/profiler/orchestration/observer.h @@ -47,6 +47,7 @@ struct TORCH_API ExperimentalConfig { bool profiler_measure_per_kernel = false, bool verbose = false, std::vector performance_events = {}, + bool enable_cuda_sync_events = false, bool adjust_timestamps = false); ~ExperimentalConfig() = default; explicit operator bool() const; @@ -59,6 +60,12 @@ struct TORCH_API ExperimentalConfig { * An empty list will disable performance event based profiling altogether. */ std::vector performance_events; + /* + * For CUDA profiling mode, enable adding CUDA synchronization events + * that expose CUDA device, stream and event synchronization activities. + * This feature is new and currently disabled by default. + */ + bool enable_cuda_sync_events; /* * Controls whether or not timestamp adjustment occurs after profiling. * The purpose of this is to adjust Vulkan event timelines to align with those diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 9f31a1b6f66fc1..be3bee9d74de24 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -62,7 +62,8 @@ void initPythonBindings(PyObject* module) { std::vector /* profiler_metrics */, bool /* profiler_measure_per_kernel */, bool /* verbose */, - std::vector /* performance_events */ + std::vector /* performance_events */, + bool /* enable_cuda_sync_events */ >(), "An experimental config for Kineto features. Please note that" "backward compatibility is not guaranteed.\n" @@ -72,11 +73,15 @@ void initPythonBindings(PyObject* module) { " profiler_measure_per_kernel (bool) : whether to profile metrics per kernel\n" " or for the entire measurement duration.\n" " verbose (bool) : whether the trace file has `Call stack` field or not.\n" - " performance_events : a list of profiler events to be used for measurement", + " performance_events : a list of profiler events to be used for measurement.\n" + " enable_cuda_sync_events : for CUDA profiling mode, enable adding CUDA synchronization events\n" + " that expose CUDA device, stream and event synchronization activities. This feature is new\n" + " and currently disabled by default.\n", py::arg("profiler_metrics") = std::vector(), py::arg("profiler_measure_per_kernel") = false, py::arg("verbose") = false, - py::arg("performance_events") = std::vector()) + py::arg("performance_events") = std::vector(), + py::arg("enable_cuda_sync_events") = false) .def(py::pickle( [](const ExperimentalConfig& p) { // __getstate__ py::list py_metrics; @@ -94,11 +99,12 @@ void initPythonBindings(PyObject* module) { py_metrics, p.profiler_measure_per_kernel, p.verbose, + p.enable_cuda_sync_events, p.performance_events); }, [](py::tuple t) { // __setstate__ - if (t.size() >= 3) { - throw std::runtime_error("Expected atleast 3 values in state"); + if (t.size() >= 4) { + throw std::runtime_error("Expected atleast 4 values in state"); } py::list py_metrics = t[0].cast(); @@ -109,8 +115,8 @@ void initPythonBindings(PyObject* module) { } std::vector performance_events; - if (t.size() == 4) { - py::list py_perf_events = t[3].cast(); + if (t.size() == 5) { + py::list py_perf_events = t[4].cast(); performance_events.resize(py_perf_events.size()); for (const auto& py_perf_event : py_perf_events) { performance_events.push_back(py::str(py_perf_event)); @@ -121,7 +127,8 @@ void initPythonBindings(PyObject* module) { std::move(metrics), t[1].cast(), t[2].cast(), - std::move(performance_events)); + std::move(performance_events), + t[3].cast()); })); py::class_(m, "ProfilerConfig") @@ -303,6 +310,9 @@ void initPythonBindings(PyObject* module) { m.def( "_set_fwd_bwd_enabled_val", &torch::profiler::impl::set_fwd_bwd_enabled_val); + m.def( + "_set_cuda_sync_enabled_val", + &torch::profiler::impl::set_cuda_sync_enabled_val); py::class_>( m, "CapturedTraceback");