Skip to content

Commit

Permalink
[profiler] add option for kineto synchronization events in the trace (p…
Browse files Browse the repository at this point in the history
…ytorch#105187)

Summary:
## About Sync Events
For CUDA profiling mode, we can enable tracing CUDA synchronization events.
* This feature captures synchronization events in CUDA including 1) context/device sync, 2) stream sync, 3) CUDA event sync, 4) CUDA stream wait event (inter stream synchronization). Read more
* We add this flag using the profiler's experimental config option.
* This PR relies on pytorch/kineto@7b00363 change in pytorch/kineto

## Usage
Just set the `enable_cuda_sync_events` option in `_ExperimentalConfig`
```
from torch.autograd.profiler import profile, _ExperimentalConfig
with profile(use_kineto=True, use_cuda=True,
   experimental_config=_ExperimentalConfig(enable_cuda_sync_events=True),
) as prof:
   workload()
```

**Please wait for PyTorch github repo to point to pytorch/kineto@7b00363 or later commit in Kineto**

Test Plan:
## Unit Test
Added a unit test

  buck2 test mode/dev-nosan caffe2/test:profiler --local-only -- test_profiler_cuda_sync_events
  Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
ttps://www.internalfb.com/intern/testinfra/testrun/281475298097379

Reviewed By: davidberard98

Differential Revision: D46244591

Pull Request resolved: pytorch#105187
Approved by: https://github.com/aaronenyeshi
  • Loading branch information
briancoutinho authored and pytorchmergebot committed Jul 26, 2023
1 parent a770295 commit 8d9c889
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 11 deletions.
38 changes: 36 additions & 2 deletions test/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,10 +1485,44 @@ def test_profiler_disable_fwd_bwd_link(self):
events = j["traceEvents"]

for e in events:
self.assertNotEqual(getattr(e, "cat", None), "fwdbwd")
self.assertNotEqual(e.get("cat", None), "fwdbwd")
finally:
torch._C._profiler._set_fwd_bwd_enabled_val(True)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_profiler_cuda_sync_events(self):
device = torch.device("cuda:0")
t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)

def workload() -> None:
torch.add(t1, t2)
torch.cuda.synchronize()
torch.add(t1, t2)

def trace_and_check(exp_config: Optional[_ExperimentalConfig]) -> None:
with _profile(use_kineto=True, use_cuda=True,
experimental_config=exp_config,
) as prof:
workload()

with TemporaryFileName(mode="w+") as fname:
# fname = "/tmp/kineto_out.json"
prof.export_chrome_trace(fname)
with open(fname) as f:
j = json.load(f)
cats = {e.get("cat", None) for e in j["traceEvents"]}
self.assertTrue("cuda_sync" in cats, "Expected to find cuda_sync event"
f" found = {cats}")

print("Testing enable_cuda_sync_events in _ExperimentalConfig")
trace_and_check(exp_config=_ExperimentalConfig(enable_cuda_sync_events=True))

print("Testing _profiler._set_cuda_sync_enabled_val()")
try:
torch._C._profiler._set_cuda_sync_enabled_val(True)
trace_and_check(exp_config=None)
finally:
torch._C._profiler._set_cuda_sync_enabled_val(False)

def test_profiler_type(self):
profiler_type = torch._C._autograd._profiler_type
Expand Down Expand Up @@ -1522,7 +1556,7 @@ def test_profiler_correlation_id(self):
model(inputs)
for event in prof.profiler.kineto_results.events():
corr_id = event.correlation_id()
if (corr_id):
if (corr_id) and event.device_type() == DeviceType.CPU:
self.assertTrue(corr_id not in id_uniqueness_set)
id_uniqueness_set.add(corr_id)
self.assertTrue(corr_id < uint32_max)
Expand Down
4 changes: 4 additions & 0 deletions test/profiler/test_profiler_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def flatten(nodes, depth=0, out=None):
flat_nodes = flatten(profiler.kineto_results.experimental_event_tree())

# Profiler inserts a `cudaDeviceSynchronize` at the end of profiling.
# and may also insert 'Context Sync' CUDA synchronization event.
if flat_nodes and flat_nodes[-2][1] == "cudaDeviceSynchronize":
flat_nodes = flat_nodes[:-2]

if flat_nodes and flat_nodes[-1][1] == "cudaDeviceSynchronize":
flat_nodes = flat_nodes[:-1]

Expand Down
2 changes: 2 additions & 0 deletions torch/_C/_profiler.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class _ExperimentalConfig:
profiler_metrics: List[str] = ...,
profiler_measure_per_kernel: bool = ...,
verbose: bool = ...,
enable_cuda_sync_events: bool = ...,
) -> None: ...

class ProfilerConfig:
Expand Down Expand Up @@ -216,3 +217,4 @@ def _enable_execution_trace_observer() -> None: ...
def _disable_execution_trace_observer() -> None: ...
def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
def _set_fwd_bwd_enabled_val(val: bool) -> None: ...
def _set_cuda_sync_enabled_val(val: bool) -> None: ...
2 changes: 1 addition & 1 deletion torch/csrc/autograd/profiler_kineto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase {

void materializeOpEvents(std::vector<std::shared_ptr<Result>>& events) {
for (auto& e : events) {
if (e->parent_.expired()) {
if (e->parent_.expired() && e->deviceType() == c10::DeviceType::CPU) {
event_tree_.push_back(e);
}

Expand Down
19 changes: 19 additions & 0 deletions torch/csrc/profiler/collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,25 @@ void set_fwd_bwd_enabled_val(bool val) {
fwd_bwd_enabled_fn() = [val]() { return val; };
}

namespace {
std::function<bool()>& cuda_sync_enabled_fn() {
static std::function<bool()> fn = []() { return true; };
return fn;
}
} // namespace

bool get_cuda_sync_enabled() {
return cuda_sync_enabled_fn()();
}

void set_cuda_sync_enabled_fn(std::function<bool()> fn) {
cuda_sync_enabled_fn() = std::move(fn);
}

void set_cuda_sync_enabled_val(bool val) {
cuda_sync_enabled_fn() = [val]() { return val; };
}

} // namespace impl
} // namespace profiler
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/profiler/collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,10 @@ TORCH_API bool get_fwd_bwd_enabled();
TORCH_API void set_fwd_bwd_enabled_fn(std::function<bool()>);
TORCH_API void set_fwd_bwd_enabled_val(bool);

TORCH_API bool get_cuda_sync_enabled();
TORCH_API void set_cuda_sync_enabled_fn(std::function<bool()>);
TORCH_API void set_cuda_sync_enabled_val(bool);

} // namespace impl
} // namespace profiler
} // namespace torch
6 changes: 6 additions & 0 deletions torch/csrc/profiler/kineto_shim.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/kineto_shim.h>

#include <type_traits>
Expand Down Expand Up @@ -235,6 +236,10 @@ void prepareTrace(
}
if (activities.count(torch::autograd::profiler::ActivityType::CUDA)) {
k_activities.insert(kCudaTypes.begin(), kCudaTypes.end());
if (config.enable_cuda_sync_events || get_cuda_sync_enabled()) {
LOG(INFO) << "Enabling CUDA Sync Events";
k_activities.insert(libkineto::ActivityType::CUDA_SYNC);
}
}

ExperimentalConfigWrapper configWrap(config);
Expand Down Expand Up @@ -320,6 +325,7 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) {
case libkineto::ActivityType::GPU_MEMCPY:
case libkineto::ActivityType::GPU_MEMSET:
case libkineto::ActivityType::CONCURRENT_KERNEL:
case libkineto::ActivityType::CUDA_SYNC:
case libkineto::ActivityType::GPU_USER_ANNOTATION:
case libkineto::ActivityType::CUDA_PROFILER_RANGE:
// TODO: T151322015
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/profiler/orchestration/observer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ ExperimentalConfig::ExperimentalConfig(
bool profiler_measure_per_kernel,
bool verbose,
std::vector<std::string> performance_events,
bool enable_cuda_sync_events,
bool adjust_timestamps)
: profiler_metrics{std::move(profiler_metrics)},
profiler_measure_per_kernel{profiler_measure_per_kernel},
verbose{verbose},
performance_events(std::move(performance_events)),
enable_cuda_sync_events{enable_cuda_sync_events},
adjust_timestamps{adjust_timestamps} {}

/*explicit*/ ExperimentalConfig::operator bool() const {
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/profiler/orchestration/observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct TORCH_API ExperimentalConfig {
bool profiler_measure_per_kernel = false,
bool verbose = false,
std::vector<std::string> performance_events = {},
bool enable_cuda_sync_events = false,
bool adjust_timestamps = false);
~ExperimentalConfig() = default;
explicit operator bool() const;
Expand All @@ -59,6 +60,12 @@ struct TORCH_API ExperimentalConfig {
* An empty list will disable performance event based profiling altogether.
*/
std::vector<std::string> performance_events;
/*
* For CUDA profiling mode, enable adding CUDA synchronization events
* that expose CUDA device, stream and event synchronization activities.
* This feature is new and currently disabled by default.
*/
bool enable_cuda_sync_events;
/*
* Controls whether or not timestamp adjustment occurs after profiling.
* The purpose of this is to adjust Vulkan event timelines to align with those
Expand Down
26 changes: 18 additions & 8 deletions torch/csrc/profiler/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ void initPythonBindings(PyObject* module) {
std::vector<std::string> /* profiler_metrics */,
bool /* profiler_measure_per_kernel */,
bool /* verbose */,
std::vector<std::string> /* performance_events */
std::vector<std::string> /* performance_events */,
bool /* enable_cuda_sync_events */
>(),
"An experimental config for Kineto features. Please note that"
"backward compatibility is not guaranteed.\n"
Expand All @@ -72,11 +73,15 @@ void initPythonBindings(PyObject* module) {
" profiler_measure_per_kernel (bool) : whether to profile metrics per kernel\n"
" or for the entire measurement duration.\n"
" verbose (bool) : whether the trace file has `Call stack` field or not.\n"
" performance_events : a list of profiler events to be used for measurement",
" performance_events : a list of profiler events to be used for measurement.\n"
" enable_cuda_sync_events : for CUDA profiling mode, enable adding CUDA synchronization events\n"
" that expose CUDA device, stream and event synchronization activities. This feature is new\n"
" and currently disabled by default.\n",
py::arg("profiler_metrics") = std::vector<std::string>(),
py::arg("profiler_measure_per_kernel") = false,
py::arg("verbose") = false,
py::arg("performance_events") = std::vector<std::string>())
py::arg("performance_events") = std::vector<std::string>(),
py::arg("enable_cuda_sync_events") = false)
.def(py::pickle(
[](const ExperimentalConfig& p) { // __getstate__
py::list py_metrics;
Expand All @@ -94,11 +99,12 @@ void initPythonBindings(PyObject* module) {
py_metrics,
p.profiler_measure_per_kernel,
p.verbose,
p.enable_cuda_sync_events,
p.performance_events);
},
[](py::tuple t) { // __setstate__
if (t.size() >= 3) {
throw std::runtime_error("Expected atleast 3 values in state");
if (t.size() >= 4) {
throw std::runtime_error("Expected atleast 4 values in state");
}

py::list py_metrics = t[0].cast<py::list>();
Expand All @@ -109,8 +115,8 @@ void initPythonBindings(PyObject* module) {
}

std::vector<std::string> performance_events;
if (t.size() == 4) {
py::list py_perf_events = t[3].cast<py::list>();
if (t.size() == 5) {
py::list py_perf_events = t[4].cast<py::list>();
performance_events.resize(py_perf_events.size());
for (const auto& py_perf_event : py_perf_events) {
performance_events.push_back(py::str(py_perf_event));
Expand All @@ -121,7 +127,8 @@ void initPythonBindings(PyObject* module) {
std::move(metrics),
t[1].cast<bool>(),
t[2].cast<bool>(),
std::move(performance_events));
std::move(performance_events),
t[3].cast<bool>());
}));

py::class_<ProfilerConfig>(m, "ProfilerConfig")
Expand Down Expand Up @@ -303,6 +310,9 @@ void initPythonBindings(PyObject* module) {
m.def(
"_set_fwd_bwd_enabled_val",
&torch::profiler::impl::set_fwd_bwd_enabled_val);
m.def(
"_set_cuda_sync_enabled_val",
&torch::profiler::impl::set_cuda_sync_enabled_val);

py::class_<CapturedTraceback, std::shared_ptr<CapturedTraceback>>(
m, "CapturedTraceback");
Expand Down

0 comments on commit 8d9c889

Please sign in to comment.