Skip to content

Commit

Permalink
extract kernel for easier debugging (pytorch#587)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#587

If we have an exception from the code inside AT_DISPATCH_... we can't see which line the exception is from. If we extract the code as a separate function, we can see the exact line making easier to debug

Reviewed By: jianyuh

Differential Revision: D27556686

fbshipit-source-id: cc722a470141d9254939b6f06bf79798bc64b69b
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Apr 5, 2021
1 parent 228cbac commit 62c7209
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 184 deletions.
75 changes: 49 additions & 26 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,40 +47,50 @@ def _arg_constructor(
return (
f"{name}.packed_accessor{precision}<{type}, 1, RestrictPtrTraits>()"
if gpu
else f"auto {name}_accessor = {name}.accessor<{type}, 1>()"
else f"{name}.accessor<{type}, 1>()"
)


def _arg(type: str, name: str, precision: int = 32) -> str:
return f"PackedTensorAccessor{precision}<{type}, 1, RestrictPtrTraits> {name}"
def _arg(type: str, name: str, gpu: bool = True, precision: int = 32) -> str:
return (
f"PackedTensorAccessor{precision}<{type}, 1, RestrictPtrTraits> {name}"
if gpu
else f"TensorAccessor<{type}, 1> {name}"
)


def acc_cache_tensor_arg_constructor(name: str) -> str:
return _arg_constructor("acc_type<cache_t, true>", name, precision=64)
def acc_cache_tensor_arg_constructor(name: str, gpu: bool = True) -> str:
return _arg_constructor(
"acc_type<" + ("cache_t" if gpu else "scalar_t") + ", true>",
name,
gpu=gpu,
precision=64,
)


def acc_cache_tensor_arg(name: str) -> str:
return _arg("acc_type<cache_t, true>", name, precision=64)
def acc_cache_tensor_arg(name: str, gpu: bool = True) -> str:
return _arg(
"acc_type<" + ("cache_t" if gpu else "scalar_t") + ", true>",
name,
gpu=gpu,
precision=64,
)


def long_tensor_arg_constructor(name: str, gpu: bool = True) -> str:
return _arg_constructor("int64_t", name, gpu=gpu)


def long_tensor_arg(name: str) -> str:
return _arg("int64_t", name)
def long_tensor_arg(name: str, gpu: bool = True) -> str:
return _arg("int64_t", name, gpu=gpu)


def int_tensor_arg_constructor(name: str, gpu: bool = True) -> str:
return _arg_constructor("int32_t", name, gpu=gpu)


def int_tensor_arg(name: str) -> str:
return _arg("int32_t", name)


def host_accessor_constructor(name: str) -> str:
return _arg_constructor("acc_type<scalar_t, true>", name, gpu=False)
def int_tensor_arg(name: str, gpu: bool = True) -> str:
return _arg("int32_t", name, gpu=gpu)


def tensor_arg(name: str) -> str:
Expand Down Expand Up @@ -158,7 +168,8 @@ def generate(**kwargs: Any) -> None:
class Args:
split_kernel_args: List[str]
split_kernel_arg_constructors: List[str]
split_host_accessor_constructors: List[str]
split_cpu_kernel_args: List[str]
split_cpu_kernel_arg_constructors: List[str]
split_function_args: List[str]
split_saved_tensors: List[str]
split_tensors: List[str]
Expand Down Expand Up @@ -190,13 +201,22 @@ def make_kernel_arg_constructor(ty: int, name: str) -> str:
FLOAT: lambda x: x,
}[ty](name)

def make_host_accessor_constructor(ty: int, name: str) -> str:
def make_cpu_kernel_arg(ty: int, name: str) -> str:
return {
TENSOR: host_accessor_constructor,
TENSOR: lambda x: acc_cache_tensor_arg(x, gpu=False),
INT_TENSOR: lambda x: int_tensor_arg(x, gpu=False),
LONG_TENSOR: lambda x: long_tensor_arg(x, gpu=False),
INT: int64_arg,
FLOAT: float_arg,
}[ty](name)

def make_cpu_kernel_arg_constructor(ty: int, name: str) -> str:
return {
TENSOR: lambda x: acc_cache_tensor_arg_constructor(x, gpu=False),
INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False),
LONG_TENSOR: lambda x: long_tensor_arg_constructor(x, gpu=False),
INT: lambda x: "",
FLOAT: lambda x: "",
INT: lambda x: x,
FLOAT: lambda x: x,
}[ty](name)

def make_function_arg(ty: int, name: str) -> str:
Expand Down Expand Up @@ -228,8 +248,11 @@ def make_args_for_compute_device(split_arg_spec: List[Tuple[int, str]]) -> Args:
split_kernel_arg_constructors=[
make_kernel_arg_constructor(ty, name) for (ty, name) in split_arg_spec
],
split_host_accessor_constructors=[
make_host_accessor_constructor(ty, name)
split_cpu_kernel_args=[
make_cpu_kernel_arg(ty, name) for (ty, name) in split_arg_spec
],
split_cpu_kernel_arg_constructors=[
make_cpu_kernel_arg_constructor(ty, name)
for (ty, name) in split_arg_spec
],
split_function_args=[
Expand Down Expand Up @@ -301,11 +324,11 @@ def adagrad() -> None:
"""
split_weight_update_cpu = """
for (int64_t d = 0; d < D; ++d) {
momentum1_host_accessor[embedding_begin + d] +=
momentum1_host[embedding_begin + d] +=
grad_buffer[d] * grad_buffer[d];
host_weights_data[embedding_begin + d] -=
learning_rate * grad_buffer[d] /
(sqrt(momentum1_host_accessor[embedding_begin + d]) + eps);
(sqrt(momentum1_host[embedding_begin + d]) + eps);
}
"""

Expand Down Expand Up @@ -374,8 +397,8 @@ def rowwise_adagrad() -> None:
g_local_sum_square += grad_buffer[d] * grad_buffer[d];
}
auto g_avg_square = g_local_sum_square / D;
acc_type<scalar_t, true> new_sum_square_grads = momentum1_host_accessor[momentum1_offsets_data[feature_begin] + idx] + g_avg_square;
momentum1_host_accessor[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads;
acc_type<scalar_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square;
momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads;
acc_type<scalar_t, true> multiplier;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
for (int64_t d = 0; d < D; ++d) {
Expand Down
137 changes: 86 additions & 51 deletions fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,68 @@

using namespace at;

namespace {
template <typename scalar_t, typename grad_t>
void split_embedding_backward_approx_cpu_kernel(
Tensor grad_output,
Tensor host_weights,
const TensorAccessor<int64_t, 1> weights_offsets_data,
const TensorAccessor<int, 1> D_offsets_data,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
int T,
int B,
{% if "momentum1_offsets" in args.split_function_arg_names %}
const TensorAccessor<int64_t, 1> momentum1_offsets_data,
{% endif %}
{% if "momentum2_offsets" in args.split_function_arg_names %}
const TensorAccessor<int64_t, 1> momentum2_offsets_data,
{% endif %}
{{ args.split_cpu_kernel_args | join(", ") }}) {
auto grad_output_data = grad_output.accessor<grad_t, 2>();
auto host_weights_data = host_weights.accessor<scalar_t, 1>();
const auto indices_data = indices.accessor<int64_t, 1>();
const auto offsets_data = offsets.accessor<int64_t, 1>();
// If indice_weights are not defined, then this accessor won't be used
auto indice_weights_data = indice_weights.defined()
? indice_weights.accessor<grad_t, 1>()
: TensorAccessor<grad_t, 1>(nullptr, nullptr, nullptr);

for (int64_t t = 0; t < T; ++t) {
int feature_begin = t; // to conform interface with exact
const auto D_begin = D_offsets_data[t];
const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
const auto table_begin = weights_offsets_data[t];
at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
for (int64_t b = b_begin; b < b_end; ++b) {
const auto pool_begin = offsets_data[t * B + b];
const auto pool_end = offsets_data[t * B + b + 1];
const auto L = pool_end - pool_begin;
const double scale_factor =
// NOTE: MEAN pooling will not work with indice_weights!
(pooling_mode == MEAN && !indice_weights.defined() && L > 0)
? 1.0 / L
: 1.0;
for (auto p = pool_begin; p < pool_end; ++p) {
auto idx = indices_data[p];
const int64_t embedding_begin = table_begin + idx * D;
scalar_t grad_buffer[D];
for (int64_t d = 0; d < D; ++d) {
grad_buffer[d] = scale_factor *
(indice_weights.defined()
? grad_output_data[b][D_begin + d] * indice_weights_data[p]
: grad_output_data[b][D_begin + d]);
}
{{ split_weight_update_cpu }};
} // for each p
} // for each b
}); // parallel for B
} // for each t
}
} // namespace

// The template for approximate optimizers
{{ "void" if not dense else "Tensor" }}
split_embedding_backward_codegen_{{ optimizer }}_cpu(
Expand Down Expand Up @@ -43,9 +105,8 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(
int64_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

const auto D_offsets_data = D_offsets.accessor<int, 1>();
const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
const auto D_offsets_data = D_offsets.accessor<int, 1>();
{%if "momentum1_offsets" in args.split_function_arg_names %}
const auto momentum1_offsets_data = momentum1_offsets.accessor<int64_t, 1>();
{% endif %}
Expand All @@ -66,11 +127,12 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(

if (use_fbgemm) {
auto grad_stride = grad_output.size(1);
float* host_weights_data = host_weights.data_ptr<float>();
float* momentum1_data = momentum1_host.data_ptr<float>();
const float* grad_output_data = grad_output.data_ptr<float>();
const int64_t* offsets_data = offsets.data_ptr<int64_t>();
float* host_weights_data = host_weights.data_ptr<float>();
const int64_t* indices_data = indices.data_ptr<int64_t>();
const int64_t* offsets_data = offsets.data_ptr<int64_t>();
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
float* momentum1_data = momentum1_host.data_ptr<float>();

at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
int t_begin = tb_begin / B;
Expand Down Expand Up @@ -121,58 +183,31 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(

{% endif %}

const auto offsets_data = offsets.accessor<int64_t, 1>();
const auto indices_data = indices.accessor<int64_t, 1>();

AT_DISPATCH_FLOATING_TYPES(
grad_output.scalar_type(), "split_embedding_backward_cpu", [&]() {
// If indice_weights are not defined, then this accessor won't be
// used
auto indice_weights_data = indice_weights.defined()
? indice_weights.accessor<scalar_t, 1>()
: TensorAccessor<scalar_t, 1>(nullptr, nullptr, nullptr);

auto grad_output_data = grad_output.accessor<scalar_t, 2>();
using grad_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
host_weights.scalar_type(),
"split_embedding_backward_cpu_inner",
[&]() {
{{ args.split_host_accessor_constructors | join("; ") }}

auto host_weights_data = host_weights.accessor<scalar_t, 1>();

for (int64_t t = 0; t < T; ++t) {
int feature_begin = t; // to conform interface with exact
const auto D_begin = D_offsets_data[t];
const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
const auto table_begin = weights_offsets_data[t];
at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
for (int64_t b = b_begin; b < b_end; ++b) {
const auto pool_begin = offsets_data[t * B + b];
const auto pool_end = offsets_data[t * B + b + 1];
const auto L = pool_end - pool_begin;
const double scale_factor =
// NOTE: MEAN pooling will not work with indice_weights!
(pooling_mode == MEAN && !indice_weights.defined() &&
L > 0)
? 1.0 / L
: 1.0;
for (auto p = pool_begin; p < pool_end; ++p) {
auto idx = indices_data[p];
const int64_t embedding_begin = table_begin + idx * D;
scalar_t grad_buffer[D];
for (int64_t d = 0; d < D; ++d) {
grad_buffer[d] = scale_factor *
(indice_weights.defined()
? grad_output_data[b][D_begin + d] *
indice_weights_data[p]
: grad_output_data[b][D_begin + d]);
}
{{ split_weight_update_cpu }};
} // for each p
} // for each b
}); // parallel for B
} // for each t
split_embedding_backward_approx_cpu_kernel<scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
D_offsets_data,
indices,
offsets,
pooling_mode,
indice_weights,
T,
B,
{% if "momentum1_offsets" in args.split_function_arg_names %}
momentum1_offsets_data,
{% endif %}
{% if "momentum2_offsets" in args.split_function_arg_names %}
momentum2_offsets_data,
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
}); // dispatch host_weights.scalar_type()
}); // dispatch grad_output.scalar_type()

Expand Down
Loading

0 comments on commit 62c7209

Please sign in to comment.