Skip to content

Commit

Permalink
Add more comments for infer TBE kernel for vendor optimizations (pyto…
Browse files Browse the repository at this point in the history
…rch#948)

Summary:
Pull Request resolved: pytorch#948

As title: just adding more comments in infer TBE kernel

Reviewed By: jspark1105

Differential Revision: D34430782

fbshipit-source-id: 1601f9c6232eff4bea9f4f517cb888656e6d30a3
  • Loading branch information
jianyuh authored and facebook-github-bot committed Feb 25, 2022
1 parent c3a26e1 commit 631bb8d
Showing 1 changed file with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i
continue;
}
const uint32_t* row = reinterpret_cast<const uint32_t*>(&buffers[warp_idx][i][input_row_idx][0]);
// scale and bias are at the beginning of each row.
// rationale: have scale/shift at start since these get loaded first
// and then broadcasted around so it might speed up the first cache miss.
{% if bit_width in [8, 4, 2] %}
half2 shift_scale = reinterpret_cast<const half2*>(row)[0];
{% endif %}
Expand All @@ -347,6 +350,10 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i
if (std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value) {
#pragma unroll MaxNum128BRows
for (uint32_t j = 0; j < MaxNum128BRows; ++j) {
// Read the uint8/4/2 values: note that first 4 Bytes will be ditched later:
// We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to
// the scale/shift handling).
// Reason: to avoid divergence the first thread in the warp computes garbage.
int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
if (output_d >= 0 && output_d < D) {
Expand Down

0 comments on commit 631bb8d

Please sign in to comment.