Skip to content

Commit

Permalink
Refactor codegen/embedding_forward_quantized_cpu_template.cpp (pytorc…
Browse files Browse the repository at this point in the history
…h#2196)

Summary:
Pull Request resolved: pytorch#2196

Consolidate the function calls using Jinja macro

Reviewed By: jspark1105

Differential Revision: D51922356

fbshipit-source-id: 8aea70a43f09a980c12eb89932d8d25bac2d7ef2
  • Loading branch information
sryap authored and facebook-github-bot committed Dec 8, 2023
1 parent f8bd441 commit 90e81f5
Showing 1 changed file with 68 additions and 106 deletions.
174 changes: 68 additions & 106 deletions fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,93 +259,73 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
{% if weighted %}
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
{% endif %}

{% macro generate_and_exec_kernel(weight_type, use_base, use_nbit, use_fp8) %}
{% set has_asmjit = use_base or use_nbit %}
{% set kernel_name = "GenerateEmbeddingSpMDMWithStrides"
if use_base else ("GenerateEmbeddingSpMDMNBitWithStrides"
if use_nbit else "GenerateEmbeddingSpMDMFP8WithStrides")
%}
const auto kernel = fbgemm::{{ kernel_name }}<
{% if use_base %}
{{ weight_type }},
{% endif %}
index_t,
index_t,
{% if has_asmjit %}
fbgemm_out_t,
/*THREAD_LOCAL=*/true
{% else %}
fbgemm_out_t
{% endif %}
>(
{% if use_nbit %}
/*bit_rate=*/bit_rate,
{% endif %}
D,
{% if has_asmjit %}
has_weight,
{% endif %}
normalize_by_lengths,
{% if has_asmjit %}
/*prefetch=*/16,
{% endif %}
/*is_weight_positional=*/false,
/*use_offsets=*/true,
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof({{ weight_type }}),
{% if use_fp8 %}
/*exponent_bits=*/fp8_exponent_bits,
/*exponent_bias=*/fp8_exponent_bias,
{% endif %}
{% if has_asmjit %}
/*scale_bias_last=*/false,
{% endif %}
{% if use_base %}
/*no_bag=*/false,
{% endif %}
/*is_bf16_out=*/output_is_bf16
);
success = kernel(
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
reinterpret_cast<const {{ weight_type }}*>(weights),
indices_acc + *offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
{% endmacro %}

if (weight_ty == SparseType::FP32) {
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<float, index_t, index_t, fbgemm_out_t, /*THREAD_LOCAL=*/true>(
D,
has_weight,
normalize_by_lengths,
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(float),
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
reinterpret_cast<const float*>(weights),
indices_acc + *offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
{{ generate_and_exec_kernel("float", True, False, False) }}
} else if (weight_ty == SparseType::FP16) {
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<float16, index_t, index_t, fbgemm_out_t, /*THREAD_LOCAL=*/true>(
D,
has_weight,
normalize_by_lengths,
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(float16),
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
reinterpret_cast<const float16*>(weights),
indices_acc + *offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
{{ generate_and_exec_kernel("float16", True, False, False) }}
} else if (weight_ty == SparseType::INT8) {
{{ generate_and_exec_kernel("uint8_t", True, False, False) }}
} else if (weight_ty == SparseType::FP8) {
assert(fp8_exponent_bits > 0 && fp8_exponent_bias > 0);
auto kernel = fbgemm::GenerateEmbeddingSpMDMFP8WithStrides<index_t, index_t, fbgemm_out_t>(
D,
normalize_by_lengths,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*exponent_bits=*/fp8_exponent_bits,
/*exponent_bias=*/fp8_exponent_bias,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::INT8) {
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<uint8_t, index_t, index_t, fbgemm_out_t, /*THREAD_LOCAL=*/true>(
D,
has_weight,
normalize_by_lengths,
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
{{ generate_and_exec_kernel("uint8_t", False, False, True) }}
} else if (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2) {
int bit_rate;
switch (weight_ty) {
Expand All @@ -356,31 +336,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
bit_rate = 2;
break;
default:
throw std::logic_error("Unsupported SparseType: " + std::to_string(static_cast<int>(weight_ty)));
throw std::logic_error(
"Unsupported SparseType: " + std::to_string(static_cast<int>(weight_ty)));
}
auto kernel = fbgemm::GenerateEmbeddingSpMDMNBitWithStrides<index_t, index_t, fbgemm_out_t, /*THREAD_LOCAL=*/true>(
/*bit_rate=*/bit_rate,
D,
has_weight,
normalize_by_lengths,
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*scale_bias_last=*/false,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
{{ generate_and_exec_kernel("uint8_t", False, True, False) }}
} else {
throw std::logic_error("Unsupported SparseType: " + std::to_string(static_cast<int>(weight_ty)));
throw std::logic_error(
"Unsupported SparseType: " + std::to_string(static_cast<int>(weight_ty)));
}
if (!success) {
fbgemm_gpu::report_embedding_error(
Expand Down

0 comments on commit 90e81f5

Please sign in to comment.