Skip to content

Commit

Permalink
Merge pull request pytorch#397 from guoruoqian/aten_size_fix_bug
Browse files Browse the repository at this point in the history
fix bug, when dim of aten::size.int(Tensor self, int dim) -> (int) is…
  • Loading branch information
narendasan authored Apr 30, 2021
2 parents 4da65b3 + efc8202 commit 24a780f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
11 changes: 10 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,18 @@ auto aten_registrations TRTORCH_UNUSED =
auto dim = args.at(n->input(1)).unwrapToInt();
if (tensor_var.isITensor()) {
auto tensor = tensor_var.ITensor();
return util::toVec(tensor->getDimensions())[dim];
auto dims = util::toVec(tensor->getDimensions());
auto nbDims = tensor->getDimensions().nbDims;
if (dim < 0) {
dim += nbDims;
}
return dims[dim];
} else {
auto tensor = tensor_var.unwrapToTensor();
auto nbDims = tensor.sizes().size();
if (dim < 0) {
dim += nbDims;
}
return tensor.sizes()[dim];
}
}
Expand Down
25 changes: 25 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,31 @@ TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
}

TEST(Evaluators, ATenSizeNegativeConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=-1]()
%2 : int = prim::Constant[value=-2]()
%3 : int = aten::size(%0, %1)
%4 : int = aten::size(%0, %2)
%5 : int[] = prim::ListConstruct(%3, %4)
%6 : Tensor = aten::view(%0, %5)
return (%6))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {3, 3}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Evaluators, FloorIntIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
Expand Down

0 comments on commit 24a780f

Please sign in to comment.