Skip to content

Commit

Permalink
[Habana] Enable gather and move test from fbcode. (pytorch#2867)
Browse files Browse the repository at this point in the history
  • Loading branch information
rdzhabarov authored May 7, 2019
1 parent 664ed42 commit e2f1c54
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
1 change: 1 addition & 0 deletions lib/Backends/Habana/Habana.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,7 @@ bool HabanaBackend::isOpSupported(const NodeInfo &NI) const {
case Kinded::Kind::ConcatNodeKind:
case Kinded::Kind::DequantizeNodeKind:
case Kinded::Kind::DivNodeKind:
case Kinded::Kind::GatherNodeKind:
case Kinded::Kind::FullyConnectedNodeKind:
case Kinded::Kind::HabanaFullyConnectedNodeKind:
case Kinded::Kind::HabanaReshapeNodeKind:
Expand Down
52 changes: 50 additions & 2 deletions tests/unittests/HabanaBackendTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1665,8 +1665,7 @@ TEST_F(HabanaBackendTest, DISABLED_GatherPerf) {
printf("GBps: %lf\n", 10 * 3000 * 4.0 / elapsed_secs / 1000 / 1000 / 1000);
}

/// Disable gather op as it does not produce consistent results.
TEST_F(HabanaBackendTest, DISABLED_BatchedGather) {
TEST_F(HabanaBackendTest, BatchedGather) {
/*
DATA = [
[1.0, 1.2, 2.4, 4.5],
Expand Down Expand Up @@ -1711,3 +1710,52 @@ TEST_F(HabanaBackendTest, DISABLED_BatchedGather) {
EXPECT_FLOAT_EQ(H.at({2, 0}), 4.5);
EXPECT_FLOAT_EQ(H.at({2, 1}), 1.2);
}

TEST_F(HabanaBackendTest, BatchedGatherMultipleRuns) {
const unsigned M = 1000;
const unsigned N = 1;

// Fill out the array with random data
std::vector<float> inputData;
inputData.resize(M);
for (unsigned int i = 0; i < M; i++) {
inputData[i] = float(rand() % 1000) / 100;
}

// ID list, to be filled up
unsigned idLen = 10000;
std::vector<int> inputIds;
inputIds.resize(idLen);

// Create placeholder for data
auto *data = mod_.createPlaceholder(ElemKind::FloatTy, {N, M}, "data", false);
ctx_.allocate(data)->getHandle() = inputData;

auto *indices =
mod_.createPlaceholder(ElemKind::Int32ITy, {idLen}, "indices", false);
auto indicesH = ctx_.allocate(indices)->getHandle<int32_t>();
indicesH = inputIds;

// create the net
auto *R = F_->createGather("gather", data, indices, 1);
auto *result = F_->createSave("save", R);
ctx_.allocate(result->getPlaceholder());

EE_.compile(CompilationMode::Infer, F_);

// run this multiple times
for (auto ntimes = 0; ntimes < 10; ntimes++) {
// fill up the ID list with random data
for (unsigned int i = 0; i < inputIds.size(); i++) {
inputIds[i] = rand() % M;
}
indicesH = inputIds;

EE_.run(ctx_);

auto H = ctx_.get(result->getPlaceholder())->getHandle();
for (unsigned i = 0; i < idLen; i++) {
EXPECT_FLOAT_EQ(inputData[inputIds[i]], H.at({0, i}));
}
}
}

0 comments on commit e2f1c54

Please sign in to comment.