Skip to content

Commit

Permalink
Add TBE annotation in Kineto trace (pytorch#3057)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3057

X-link: facebookresearch/FBGEMM#153

This diff adds TBE annotation in a Kineto trace.  It can be enabled
using the `TBE_ANNOTATE_KINETO_TRACE` knob in JustKnob or setting
an environment variable `FBGEMM_TBE_ANNOTATE_KINETO_TRACE=1`.

Reviewed By: q10, dshi7, shintaro-iwasaki, spcyppt

Differential Revision: D61999485

fbshipit-source-id: 39c582ddee7f2b0f3b1e682f2cf401a6d464c93b
  • Loading branch information
sryap authored and facebook-github-bot committed Oct 25, 2024
1 parent 2b5316f commit 9921707
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/script.h>
#include "torch/csrc/autograd/record_function_ops.h"

#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/ops_utils.h"
Expand All @@ -20,6 +21,7 @@
using Tensor = at::Tensor;

using namespace fbgemm_gpu;
namespace profiler = torch::autograd::profiler;

{#/* Module description */#}
{%- set fwd_mdesc = "ssd" if ssd else ("dense" if dense else "split") %}
Expand Down Expand Up @@ -62,55 +64,59 @@ enum SSDTensor {
.findSchemaOrThrow("fbgemm::{{ forward_op }}", "")
.typed<decltype({{ forward_op }})>();

return {
embedding_codegen_forward_op.call(
flatten_dev_weights,
{%- if not dense %}
uvm_weights,
lxu_cache_weights,
weights_placements,
{%- endif %}
weights_offsets,
{%- if nobag %}
D,
{%- else %}
D_offsets,
total_D,
max_D,
{%- endif %}
indices,
offsets,
{%- if not nobag %}
pooling_mode,
{%- endif %} {# /* if not nobag */ #}
{%- if weighted %}
*indice_weights,
{%- endif %}
{%- if not dense %}
{{ "ssd_tensors[SSDTensor::ROW_ADDRS]" if ssd else "lxu_cache_locations" }},
uvm_cache_stats_,
{%- endif %}
output_dtype,
{%- if not nobag %}
{%- if vbe %}
vbe_row_output_offsets,
vbe_b_t_map,
vbe_output_size,
info_B_num_bits,
info_B_mask_int64,
{%- endif %} {# /* if vbe */ #}
{%- if is_gwd %}
hash_size_cumsum,
prev_iter_dev_,
learning_rate,
weight_decay,
iter,
gwd_lower_bound,
{%- endif %} {# /* if is_gwd */ #}
{%- endif %} {# /* if not nobag */ #}
{{ "is_experimental" if has_experimental else "false" }}
)
};
auto output = embedding_codegen_forward_op.call(
flatten_dev_weights,
{%- if not dense %}
uvm_weights,
lxu_cache_weights,
weights_placements,
{%- endif %}
weights_offsets,
{%- if nobag %}
D,
{%- else %}
D_offsets,
total_D,
max_D,
{%- endif %}
indices,
offsets,
{%- if not nobag %}
pooling_mode,
{%- endif %} {# /* if not nobag */ #}
{%- if weighted %}
*indice_weights,
{%- endif %}
{%- if not dense %}
{{ "ssd_tensors[SSDTensor::ROW_ADDRS]" if ssd else "lxu_cache_locations" }},
uvm_cache_stats_,
{%- endif %}
output_dtype,
{%- if not nobag %}
{%- if vbe %}
vbe_row_output_offsets,
vbe_b_t_map,
vbe_output_size,
info_B_num_bits,
info_B_mask_int64,
{%- endif %} {# /* if vbe */ #}
{%- if is_gwd %}
hash_size_cumsum,
prev_iter_dev_,
learning_rate,
weight_decay,
iter,
gwd_lower_bound,
{%- endif %} {# /* if is_gwd */ #}
{%- endif %} {# /* if not nobag */ #}
{{ "is_experimental" if has_experimental else "false" }}
);

if (is_annotate_trace_enabled) {
record_trace->record.end();
}

return {output};
{%- endmacro %}

/* This macro generates a code blob for dispatching corresponding weighted and
Expand Down Expand Up @@ -195,6 +201,11 @@ enum SSDTensor {
/*unused=*/0
{%- endif %}
);

if (is_annotate_trace_enabled) {
record_trace->record.end();
}

return {
{%- if not dense %}
Tensor(), // placeholder autograd tensor
Expand Down Expand Up @@ -630,6 +641,31 @@ class {{ autograd_func }} :
const auto max_B_ = offsets.sym_size(0) / T;
{%- endif %}

// Annotate Kineto trace
const static bool is_annotate_trace_enabled = config::is_feature_enabled(
config::FeatureGateName::TBE_ANNOTATE_KINETO_TRACE);
std::string op_annotation = "";
c10::intrusive_ptr<profiler::PythonRecordFunction> record_trace;
if (is_annotate_trace_enabled) {
std::stringstream ss;
ss << "["
<< "weighted={{ "T" if weighted else "F" }},"
<< "pooled={{ "T" if not nobag else "F" }},"
<< "vbe={{ "T" if vbe else "F" }},"
<< "avg_B=" << ({{ "max_B_" if not vbe else "max_B_ / T" }}) << ","
<< "max_B=" << max_B_ << ","
<< "T=" << T << ","
<< "avg_D=" << ({{ "total_D / T" if not nobag else "D" }}) << ","
<< "max_D=" << {{ "max_D" if not nobag else "D" }} << ","
<< "num_indices=" << indices.numel() << ","
<< "avg_pooling_fac=" << (static_cast<float>(indices.numel()) / T / max_B_)
<< "]";
op_annotation = ss.str();
record_trace = profiler::record_function_enter_new(
"{{ fwd_mdesc }}_tbe_fwd" + op_annotation);
ctx->saved_data["op_annotation"] = op_annotation;
}

{%- if not dense %}
// NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t
// TODO: Hook up with frontend code
Expand Down Expand Up @@ -874,6 +910,15 @@ class {{ autograd_func }} :
{%- endfor %}
{%- endif %}

const static bool is_annotate_trace_enabled = config::is_feature_enabled(
config::FeatureGateName::TBE_ANNOTATE_KINETO_TRACE);
c10::intrusive_ptr<profiler::PythonRecordFunction> record_trace;
if (is_annotate_trace_enabled) {
auto& op_annotation = ctx->saved_data["op_annotation"].toStringRef();
record_trace = profiler::record_function_enter_new(
"{{ bwd_mdesc }}_tbe_bwd" + op_annotation);
}

TORCH_CHECK_EQ(grad_outputs.size(), 1);

#ifdef USE_ROCM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/split_embeddings_utils.h"
#include "fbgemm_gpu/config/feature_gates.h"
#include "torch/csrc/autograd/record_function_ops.h"

{%- if has_vbe_support %}
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
{%- endif %}

using Tensor = at::Tensor;

using namespace fbgemm_gpu;
namespace profiler = torch::autograd::profiler;

{#/* Module description */#}
{%- set fwd_mdesc = "ssd" if ssd else ("dense" if dense else "split") %}
Expand Down Expand Up @@ -129,52 +132,56 @@ enum SSDTensor {
const int64_t /*output_dtype*/
)>();

return {
embedding_codegen_forward_op.call(
weights_host,
flatten_weights_dev,
weights_uvm,
lxu_cache_weights,
weights_placements,
weights_offsets,
{%- if nobag %}
D,
{%- else %}
D_offsets,
total_D,
max_D,
{%- endif %}
hash_size_cumsum,
indices,
offsets,
{%- if not nobag %}
pooling_mode,
indice_weights_value,
{%- endif %} {# /* if not nobag */ #}
{%- if not dense %}
{{ "ssd_tensors[SSDTensor::ROW_ADDRS]" if ssd else "lxu_cache_locations" }},
uvm_cache_stats_,
{%- endif %}
{%- if not nobag %}
{%- if vbe %}
vbe_row_output_offsets,
vbe_b_t_map,
vbe_output_size,
info_B_num_bits,
info_B_mask_int64,
{%- endif %} {# /* if vbe */ #}
{%- if is_gwd %}
prev_iter_dev_,
learning_rate,
weight_decay,
iter,
gwd_lower_bound,
{%- endif %} {# /* if is_gwd */ #}
{%- endif %} {# /* if not nobag */ #}
is_experimental,
output_dtype
)
};
auto output = embedding_codegen_forward_op.call(
weights_host,
flatten_weights_dev,
weights_uvm,
lxu_cache_weights,
weights_placements,
weights_offsets,
{%- if nobag %}
D,
{%- else %}
D_offsets,
total_D,
max_D,
{%- endif %}
hash_size_cumsum,
indices,
offsets,
{%- if not nobag %}
pooling_mode,
indice_weights_value,
{%- endif %} {# /* if not nobag */ #}
{%- if not dense %}
{{ "ssd_tensors[SSDTensor::ROW_ADDRS]" if ssd else "lxu_cache_locations" }},
uvm_cache_stats_,
{%- endif %}
{%- if not nobag %}
{%- if vbe %}
vbe_row_output_offsets,
vbe_b_t_map,
vbe_output_size,
info_B_num_bits,
info_B_mask_int64,
{%- endif %} {# /* if vbe */ #}
{%- if is_gwd %}
prev_iter_dev_,
learning_rate,
weight_decay,
iter,
gwd_lower_bound,
{%- endif %} {# /* if is_gwd */ #}
{%- endif %} {# /* if not nobag */ #}
is_experimental,
output_dtype
);

if (is_annotate_trace_enabled) {
record_trace->record.end();
}

return {output};
{%- endmacro %}

/* This macro generates a code blob for dispatching corresponding weighted and
Expand Down Expand Up @@ -309,6 +316,11 @@ enum SSDTensor {
, output_dtype
{%- endif %}
);

if (is_annotate_trace_enabled) {
record_trace->record.end();
}

return {
{%- if not dense %}
Tensor(), // placeholder autograd tensor
Expand Down Expand Up @@ -585,6 +597,31 @@ class {{ autograd_func }} :
const auto max_B_ = offsets.sym_size(0) / T;
{%- endif %}

// Annotate Kineto trace
const static bool is_annotate_trace_enabled = config::is_feature_enabled(
config::FeatureGateName::TBE_ANNOTATE_KINETO_TRACE);
std::string op_annotation = "";
c10::intrusive_ptr<profiler::PythonRecordFunction> record_trace;
if (is_annotate_trace_enabled) {
std::stringstream ss;
ss << "["
<< "weighted={{ "T" if weighted else "F" }},"
<< "pooled={{ "T" if not nobag else "F" }},"
<< "vbe={{ "T" if vbe else "F" }},"
<< "avg_B=" << ({{ "max_B_" if not vbe else "max_B_ / T" }}) << ","
<< "max_B=" << max_B_ << ","
<< "T=" << T << ","
<< "avg_D=" << ({{ "total_D / T" if not nobag else "D" }}) << ","
<< "max_D=" << {{ "max_D" if not nobag else "D" }} << ","
<< "num_indices=" << indices.numel() << ","
<< "avg_pooling_fac=" << (static_cast<float>(indices.numel()) / T / max_B_)
<< "]";
op_annotation = ss.str();
record_trace = profiler::record_function_enter_new(
"{{ fwd_mdesc }}_tbe_fwd" + op_annotation);
ctx->saved_data["op_annotation"] = op_annotation;
}

// NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t
// TODO: Hook up with frontend code
const auto uvm_cache_stats_ = uvm_cache_stats
Expand Down Expand Up @@ -830,6 +867,15 @@ static torch::autograd::variable_list backward(
auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}();
{%- endfor %}

const static bool is_annotate_trace_enabled = config::is_feature_enabled(
config::FeatureGateName::TBE_ANNOTATE_KINETO_TRACE);
c10::intrusive_ptr<profiler::PythonRecordFunction> record_trace;
if (is_annotate_trace_enabled) {
auto& op_annotation = ctx->saved_data["op_annotation"].toStringRef();
record_trace = profiler::record_function_enter_new(
"{{ bwd_mdesc }}_tbe_bwd" + op_annotation);
}

TORCH_CHECK_EQ(grad_outputs.size(), 1);

#ifdef USE_ROCM
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ namespace fbgemm_gpu::config {
/// For OSS: The environment variable will be evaluated as f"FBGEMM_{ENUM}"
#define ENUMERATE_ALL_FEATURE_FLAGS \
X(TBE_V2) \
X(TBE_ENSEMBLE_ROWWISE_ADAGRAD)
X(TBE_ENSEMBLE_ROWWISE_ADAGRAD) \
X(TBE_ANNOTATE_KINETO_TRACE)
// X(EXAMPLE_FEATURE_FLAG)

/// @ingroup fbgemm-gpu-config
Expand Down

0 comments on commit 9921707

Please sign in to comment.