Skip to content

Commit

Permalink
Merge pull request pytorch#425 from NVIDIA/plugins
Browse files Browse the repository at this point in the history
feat(core/plugins): Plugins redesign
  • Loading branch information
narendasan authored May 3, 2021
2 parents f053d32 + eb6ed7b commit 5b8b819
Show file tree
Hide file tree
Showing 28 changed files with 1,113 additions and 245 deletions.
7 changes: 4 additions & 3 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ pkg_tar(
"//core/conversion:include",
"//core/conversion/conversionctx:include",
"//core/conversion/converters:include",
"//core/conversion/converters/impl/plugins:include",
"//core/conversion/evaluators:include",
"//core/conversion/tensorcontainer:include",
"//core/conversion/var:include",
"//core/conversion/tensorcontainer:include",
"//core/conversion/evaluators:include",
"//core/plugins:include",
"//core/lowering:include",
"//core/lowering/passes:include",
"//core/runtime:include",
Expand All @@ -42,6 +42,7 @@ pkg_tar(
"//conditions:default": [
"//cpp/api/lib:libtrtorch.so",
"//cpp/api/lib:libtrtorchrt.so",
"//cpp/api/lib:libtrtorch_plugins.so",
],
}),
mode = "0755",
Expand Down
3 changes: 1 addition & 2 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#include "core/conversion/conversionctx/ConversionCtx.h"
#include <iostream>
#include <sstream>
#include <utility>

#include "core/conversion/conversionctx/ConversionCtx.h"

namespace trtorch {
namespace core {
namespace conversion {
Expand Down
3 changes: 2 additions & 1 deletion core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cc_library(
"impl/linear.cpp",
"impl/lstm_cell.cpp",
"impl/matrix_multiply.cpp",
"impl/normalize.cpp",
"impl/pooling.cpp",
"impl/reduce.cpp",
"impl/replication_pad.cpp",
Expand All @@ -65,7 +66,7 @@ cc_library(
"//core/conversion/var",
"//core/conversion/tensorcontainer",
"//core/conversion/conversionctx",
"//core/conversion/converters/impl/plugins",
"//core/plugins:trtorch_plugins",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand Down
30 changes: 29 additions & 1 deletion core/conversion/converters/impl/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ auto acthardtanh TRTORCH_UNUSED =

auto new_layer = ctx->net->addActivation(*self, nvinfer1::ActivationType::kLEAKY_RELU);
new_layer->setAlpha(negative_slopeScalar);

new_layer->setName(util::node_info(n).c_str());
auto out_tensor = new_layer->getOutput(0);
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
Expand All @@ -167,6 +166,35 @@ auto acthardtanh TRTORCH_UNUSED =
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}})
.pattern({"aten::gelu(Tensor self) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
nvinfer1::DataType type = in->getType();
TRTORCH_CHECK(
type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF,
"gelu only supports kFLOAT and kHALF");
std::string pluginName = "CustomGeluPluginDynamic";
nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;
int type_id = ctx->settings.op_precision == nvinfer1::DataType::kFLOAT
? 0
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
fc.nbFields = f.size();
fc.fields = f.data();

auto creator = getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1", "");
auto gelu_plugin = creator->createPlugin("gelu", &fc);

TRTORCH_CHECK(gelu_plugin, "Unable to create gelu plugin from TensorRT plugin registry" << *n);
auto new_layer =
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *gelu_plugin);
new_layer->setName("gelu");
auto out_tensor = new_layer->getOutput(0);
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}});

} // namespace
Expand Down
118 changes: 0 additions & 118 deletions core/conversion/converters/impl/instance_norm.cpp

This file was deleted.

34 changes: 29 additions & 5 deletions core/conversion/converters/impl/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include "NvInferRuntimeCommon.h"
#include "core/conversion/converters/converters.h"
#include "core/util/prelude.h"
#include "plugins/interpolate_plugin.h"
#include "torch/torch.h"

namespace trtorch {
Expand All @@ -28,11 +27,36 @@ void create_plugin(
bool align_corners,
bool use_scales = false) {
LOG_WARNING("Interpolation layer will be run through ATen, not TensorRT. Performance may be lower than expected");
nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;

auto creator = new plugins::InterpolatePluginCreator();
auto plugin = creator->createPlugin(name, in_shape, out_shape, out_size, scales, mode, align_corners, use_scales);
std::vector<int32_t> in_shape_casted(in_shape.begin(), in_shape.end());
f.emplace_back(
nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size()));

auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
std::vector<int32_t> out_shape_casted(out_shape.begin(), out_shape.end());
f.emplace_back(
nvinfer1::PluginField("out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size()));

std::vector<int32_t> out_size_casted(out_size.begin(), out_size.end());
f.emplace_back(
nvinfer1::PluginField("out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size.size()));

f.emplace_back(nvinfer1::PluginField("scales", scales.data(), nvinfer1::PluginFieldType::kFLOAT64, scales.size()));
f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1));

int32_t align_corners_casted = static_cast<int32_t>(align_corners);
f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1));

int32_t use_scales_casted = static_cast<int32_t>(use_scales);
f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1));

fc.nbFields = f.size();
fc.fields = f.data();
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
auto interpolate_plugin = creator->createPlugin(name, &fc);

auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
TRTORCH_CHECK(resize_layer, "Unable to create interpolation plugin from node" << *n);

resize_layer->setName(util::node_info(n).c_str());
Expand Down Expand Up @@ -779,4 +803,4 @@ auto interpolate_registrations TRTORCH_UNUSED =
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
} // namespace trtorch
79 changes: 79 additions & 0 deletions core/conversion/converters/impl/normalize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "NvInfer.h"
#include "NvInferRuntimeCommon.h"
#include "core/conversion/converters/converters.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace {

/*
* Helper functions
*/
void create_plugin(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* in,
int64_t order,
std::vector<int32_t> axes,
bool keep_dims,
const char* name) {
LOG_WARNING("Normalize layer will be run through ATen, not TensorRT. Performance may be lower than expected");
nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;
f.emplace_back(nvinfer1::PluginField("order", &order, nvinfer1::PluginFieldType::kINT32, 1));
f.emplace_back(nvinfer1::PluginField("axes", axes.data(), nvinfer1::PluginFieldType::kINT32, axes.size()));
f.emplace_back(nvinfer1::PluginField("keep_dims", &keep_dims, nvinfer1::PluginFieldType::kINT32, 1));
fc.nbFields = f.size();
fc.fields = f.data();

auto inputnbDims = in->getDimensions().nbDims;
for (int64_t i = 0; i < (int64_t)axes.size(); i++) {
if (axes[i] < 0) {
axes[i] += inputnbDims;
}
if (axes[i] > inputnbDims - 1) {
TRTORCH_THROW_ERROR("Axis of normalization layer cannot exceed input rank");
}
}

auto creator = getPluginRegistry()->getPluginCreator("NormalizePlugin", "1", "trtorch");
auto plugin = creator->createPlugin(name, &fc);
auto normalize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
TRTORCH_CHECK(normalize_layer, "Unable to create normalization plugin from node" << *n);

normalize_layer->setName(util::node_info(n).c_str());

auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], normalize_layer->getOutput(0));

LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions());
}

auto normalize_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());
auto order = args[1].unwrapToScalar().to<int32_t>();
auto axes_values = args[2].unwrapToIntList().vec();
std::vector<int32_t> axes(axes_values.begin(), axes_values.end());
auto keep_dims = (int32_t)args[3].unwrapToBool();
LOG_DEBUG("Order of normalize_plugin: " << order);
LOG_DEBUG("Axis: " << axes);
LOG_DEBUG("keep_dims: " << keep_dims);
create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePlugintrtorch");
return true;
}

});

} // namespace
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
41 changes: 0 additions & 41 deletions core/conversion/converters/impl/plugins/BUILD

This file was deleted.

Loading

0 comments on commit 5b8b819

Please sign in to comment.