Skip to content

Commit

Permalink
Add tests for CPU kernel (pytorch#1112)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1112

Add a unit test for CPU kernel. Build with ASAN to test memory errors in both weighted and unweighted case.

Reviewed By: jspark1105

Differential Revision: D36253233

fbshipit-source-id: 9bca32533e7db5fd5b96bddcd83a5acaf4b48cb2
  • Loading branch information
chowarfb authored and facebook-github-bot committed Jun 8, 2022
1 parent 5123493 commit 70c66b3
Showing 1 changed file with 109 additions and 0 deletions.
109 changes: 109 additions & 0 deletions fbgemm_gpu/test/cpu_kernel_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include <gtest/gtest.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>

#include "deeplearning/fbgemm/fbgemm_gpu/codegen/embedding_forward_split_cpu.h"
#include "fbgemm_gpu/cpu_utils.h"
#include "fbgemm_gpu/embedding_common.h"
#include "torch/types.h" // @manual=//caffe2:torch-cpp-cpu

TEST(cpu_kernel_test, radix_sort_parallel_test) {
std::array<int, 8> keys = {1, 2, 4, 5, 4, 3, 2, 9};
std::array<int, 8> values = {0, 0, 0, 0, 1, 1, 1, 1};

int* sorted_keys;
int* sorted_values;

std::array<int, 8> keys_tmp;
std::array<int, 8> values_tmp;

std::tie(sorted_keys, sorted_values) = fbgemm_gpu::radix_sort_parallel(
keys.data(),
values.data(),
keys_tmp.data(),
values_tmp.data(),
keys.size(),
10);

std::array<int, 8> expect_keys_tmp = {1, 2, 2, 3, 4, 4, 5, 9};
std::array<int, 8> expect_values_tmp = {0, 0, 1, 1, 0, 1, 0, 1};
EXPECT_EQ(sorted_keys, keys_tmp.data());
EXPECT_EQ(sorted_values, values_tmp.data());
EXPECT_EQ(keys_tmp, expect_keys_tmp);
EXPECT_EQ(values_tmp, expect_values_tmp);
}

TEST(cpu_kernel_test, csr2csc_test) {
internal::HyperCompressedSparseColumn csc;
int B = 2;
at::Tensor offsets = torch::tensor({0, 4, 8});
at::Tensor indices = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9});
int64_t pooling_mode = (int64_t)fbgemm_gpu::PoolingMode::SUM;
int table_to_feature_offset[2] = {0, 1};
int num_embeddings = 10;

::internal::csr2csc(
csc,
B,
offsets.accessor<int64_t, 1>(),
indices.accessor<int64_t, 1>(),
at::TensorAccessor<at::acc_type<float, true>, 1>(
nullptr, nullptr, nullptr), // no weights
pooling_mode,
table_to_feature_offset,
num_embeddings);

// sorted list of unique elements in indices
std::array<int, 6> expect_cs_indices = {1, 2, 3, 4, 5, 9};
for (int i = 0; i < expect_cs_indices.size(); ++i) {
EXPECT_EQ(expect_cs_indices[i], csc.column_segment_indices[i]);
}

// column_segment_ptr[i+1]-column_segment_ptr[i] gives the count of
// column_segment_indices[i] in indices
std::array<int, 7> expect_cs_ptr = {0, 1, 3, 4, 6, 7, 8};
for (int i = 0; i < expect_cs_ptr.size(); ++i) {
EXPECT_EQ(expect_cs_ptr[i], csc.column_segment_ptr[i]);
}

// gives the bag of the ith lowest value in indices, where the bag is
// determined according to offsets
std::array<int, 8> expect_row_indices = {0, 0, 1, 1, 0, 1, 0, 1};
for (int i = 0; i < expect_row_indices.size(); ++i) {
EXPECT_EQ(expect_row_indices[i], csc.row_indices[i]);
}

internal::HyperCompressedSparseColumn csc_weighted;
at::Tensor indice_weights = torch::tensor(
{1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f}, torch::kFloat32);
::internal::csr2csc(
csc_weighted,
B,
offsets.accessor<int64_t, 1>(),
indices.accessor<int64_t, 1>(),
indice_weights.accessor<at::acc_type<float, true>, 1>(),
pooling_mode,
table_to_feature_offset,
num_embeddings);

for (int i = 0; i < expect_cs_indices.size(); ++i) {
EXPECT_EQ(expect_cs_indices[i], csc_weighted.column_segment_indices[i]);
}

for (int i = 0; i < expect_cs_ptr.size(); ++i) {
EXPECT_EQ(expect_cs_ptr[i], csc_weighted.column_segment_ptr[i]);
}

for (int i = 0; i < expect_row_indices.size(); ++i) {
EXPECT_EQ(expect_row_indices[i], csc_weighted.row_indices[i]);
}

// sorting should be exact, no arithmetic needed. check for strict equality
// of floats, not relative error
std::array<float, 8> expect_weights = {
1.0f, 1.1f, 1.6f, 1.5f, 1.2f, 1.4f, 1.3f, 1.7f};
for (int i = 0; i < expect_weights.size(); ++i) {
EXPECT_EQ(expect_weights[i], csc_weighted.weights[i]);
}
}

0 comments on commit 70c66b3

Please sign in to comment.