Skip to content

Commit

Permalink
[PyTorch] Optimize no input NVTX collection (pytorch#70133)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#70133

we were creating `sstream` + string concats via `getNvtxStr` even when there were no inputs and wasting precious time. this diff avoids `stringstream` when there is no input to squeeze performance. 60% reduction in overhead

Test Plan:
Before
```
I1214 22:48:07.964118 2971180 bench.cpp:154] Mean 0.970494
I1214 22:48:07.964139 2971180 bench.cpp:155] Median 0.969054
I1214 22:48:07.964144 2971180 bench.cpp:156] Min 0.962247
I1214 22:48:07.964148 2971180 bench.cpp:157] stddev 0.00774841
I1214 22:48:07.964154 2971180 bench.cpp:158] stddev / mean 0.00798398
```

After
```
I1214 22:59:00.039872 3437853 bench.cpp:154] Mean 0.384333
I1214 22:59:00.039896 3437853 bench.cpp:155] Median 0.384886
I1214 22:59:00.039899 3437853 bench.cpp:156] Min 0.370235
I1214 22:59:00.039902 3437853 bench.cpp:157] stddev 0.00435907
I1214 22:59:00.039907 3437853 bench.cpp:158] stddev / mean 0.0113419
```

Reviewed By: aaronenyeshi, robieta

Differential Revision: D33137501

fbshipit-source-id: ce0e8cf9aef7ea22fd8aed927e76be4ca375efc3
  • Loading branch information
chaekit authored and facebook-github-bot committed Jan 5, 2022
1 parent 44283c2 commit 12653be
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions torch/csrc/profiler/util.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/profiler/util.h>

#include <c10/util/ArrayRef.h>
#include <fmt/format.h>

#ifdef USE_KINETO
#include <libkineto.h>
Expand Down Expand Up @@ -31,22 +32,20 @@ std::string getNvtxStr(
int64_t sequence_nr,
const std::vector<std::vector<int64_t>>& shapes) {
if (sequence_nr >= -1 || shapes.size() > 0) {
std::stringstream s;
#if defined(USE_ROCM)
s << name;
#endif
std::string str;
if (sequence_nr >= 0) {
#if defined(USE_ROCM)
s << ", seq = " << sequence_nr;
#else
s << name << ", seq = " << sequence_nr;
#endif
str = fmt::format("{}, seq = {}", name, sequence_nr);
} else if (sequence_nr == -1) {
#if !defined(USE_ROCM)
s << name;
str = name;
} else {
#if defined(USE_ROCM)
// Only ROCM supports < -1 sequence_nr
str = name;
#endif
}
if (shapes.size() > 0) {
std::stringstream s;
s << str;
s << ", sizes = [";
for (const auto idx : c10::irange(shapes.size())) {
if (shapes[idx].size() > 0) {
Expand All @@ -66,8 +65,10 @@ std::string getNvtxStr(
}
}
s << "]";
return s.str();
}
return s.str();

return str;
} else {
return name;
}
Expand Down

0 comments on commit 12653be

Please sign in to comment.