Skip to content

Commit

Permalink
4-bit SLS with emb dim not a multiple of 2 (pytorch#243)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#243

As title

Reviewed By: dskhudia

Differential Revision: D19330217

fbshipit-source-id: fefb48fc5e40f892784850cd4fccef46822d5852
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Jan 9, 2020
1 parent b48bdff commit b6e1ce6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/EmbeddingSpMDM4BitTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ static vector<vector<int>> GetInputs_() {
{10, 4000, 56, 100},
{10, 4000, 2, 100},
{10, 4000, 4, 100},
{10, 4000, 7, 100},
// These were from C2 tests
{10, 40, 16, 10},
{10, 40, 86, 10},
Expand Down Expand Up @@ -98,12 +99,13 @@ TEST_P(Fused4BitRowwiseEmbeddingLookupTest, basicTest) {
uint8_t* fused_embedding_table =
new uint8_t[num_rows * fused_embedding_dim];
for (int i = 0; i < num_rows; i++) {
for (int ii = 0; ii < embedding_dim / 2; ii++) {
for (int ii = 0; ii < (embedding_dim + 1) / 2; ii++) {
fused_embedding_table[i * fused_embedding_dim + ii] =
entries(generator);
}
float16* scale_bias = reinterpret_cast<float16*>(
fused_embedding_table + i * fused_embedding_dim + embedding_dim / 2);
fused_embedding_table + i * fused_embedding_dim +
(embedding_dim + 1) / 2);
float scale = embedding_distribution(generator);
float bias = embedding_distribution(generator);
FloatToFloat16_ref(&scale, scale_bias, 1, true /* clip */);
Expand Down

0 comments on commit b6e1ce6

Please sign in to comment.