Skip to content

Commit

Permalink
Merge pull request pytorch#402 from guoruoqian/split_fix_bug
Browse files Browse the repository at this point in the history
fix bugs in split converter
  • Loading branch information
narendasan authored Mar 22, 2021
2 parents 0360198 + da15d9a commit c0b9ec0
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 15 deletions.
13 changes: 8 additions & 5 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
auto in = args[0].ITensor();
auto axis = args[2].unwrapToInt();
auto inDimSize = in->getDimensions().d[axis];
auto numOutputs = 1;
auto numOutputs = 1, numRemainder = 0;
std::vector<int64_t> sizes;

if (split_list) {
Expand All @@ -27,10 +27,13 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
} else {
auto split_size = args[1].unwrapToInt();
numOutputs = inDimSize / split_size;
if (numOutputs == 1) {
numRemainder = inDimSize % split_size;
for (int64_t i = 0; i < numOutputs; i++) {
sizes.push_back(split_size);
} else {
sizes = std::vector<int64_t>(numOutputs, 1);
}
if (numRemainder) {
numOutputs += 1;
sizes.push_back(numRemainder);
}
}

Expand All @@ -42,7 +45,7 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
list.reserve(numOutputs);

int start_idx = 0;
for (int i = 0; i < numOutputs; i++) {
for (int64_t i = 0; i < numOutputs; i++) {
at::Tensor indices = torch::arange(start_idx, start_idx + sizes[i], 1).to(torch::kI32);
auto indicesTensor = tensor_to_const(ctx, indices);

Expand Down
3 changes: 2 additions & 1 deletion core/conversion/var/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ cc_library(
deps = [
"@tensorrt//:nvinfer",
"//core/util:prelude",
"//core/conversion/converters:weights"
"//core/conversion/converters:weights",
"//core/conversion/tensorcontainer:tensorcontainer"
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand Down
24 changes: 15 additions & 9 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,30 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
LOG_DEBUG(ctx->logger, "Found IValue containing object of type " << *(ptr_.ivalue->type()));
}
TRTORCH_CHECK(
isITensor() || (isIValue() && ptr_.ivalue->isTensor()),
isITensor() || (isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isCustomClass())),
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());

nvinfer1::ITensor* out;

if (isIValue()) {
auto weights = converters::Weights(ctx, ptr_.ivalue->toTensor());
if (ptr_.ivalue->isTensor()) {
auto weights = converters::Weights(ctx, ptr_.ivalue->toTensor());

auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");

out = const_layer->getOutput(0);
out = const_layer->getOutput(0);

std::ostringstream tensor_id;
tensor_id << reinterpret_cast<int*>(out);
std::ostringstream tensor_id;
tensor_id << reinterpret_cast<int*>(out);

LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer");
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str());
LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer");
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str());
} else {
// Split converter generates c10::IValue which hold TensorContainer.
auto output_container = ptr_.ivalue->toCustomClass<TensorContainer>();
out = output_container.get()->tensor();
}
} else {
out = ptr_.tensor;
}
Expand Down
1 change: 1 addition & 0 deletions core/conversion/var/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "core/conversion/conversionctx/ConversionCtx.h"
#include "core/conversion/converters/Weights.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/util/prelude.h"
#include "torch/csrc/jit/ir/ir.h"

Expand Down
57 changes: 57 additions & 0 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,60 @@ TEST(Converters, ATenSplitFixedConvertsCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

TEST(Converters, ATenSplitFixedHasRemainderConvertsCorrectly) {
const auto graph = R"IR(
graph(%argument_1.1 : Tensor):
%2 : int = prim::Constant[value=2]()
%2.1 : int = prim::Constant[value=1]()
%3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1)
%4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3)
return (%4, %5, %6))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {1, 5, 4, 4}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

for (size_t i = 0; i < jit_results.size(); i++) {
auto trt = trt_results[i].reshape(jit_results[i].sizes());
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

TEST(Converters, ATenSplitAndAddConvertsCorrectly) {
const auto graph = R"IR(
graph(%argument_1.1 : Tensor):
%2 : int = prim::Constant[value=2]()
%2.1 : int = prim::Constant[value=1]()
%3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1)
%4 : Tensor, %5 : Tensor = prim::ListUnpack(%3)
%6 : Tensor = aten::add(%4, %5, %2.1)
return (%6))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

for (size_t i = 0; i < jit_results.size(); i++) {
auto trt = trt_results[i].reshape(jit_results[i].sizes());
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

0 comments on commit c0b9ec0

Please sign in to comment.