Skip to content

Commit

Permalink
Set up TensorRT configurations for external use, and add a test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 183347199
  • Loading branch information
aaroey authored and tensorflower-gardener committed Jan 26, 2018
1 parent ffda607 commit 76f6938
Show file tree
Hide file tree
Showing 11 changed files with 701 additions and 51 deletions.
119 changes: 119 additions & 0 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu'
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
Expand Down Expand Up @@ -959,6 +960,119 @@ def set_tf_cudnn_version(environ_cp):
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)


def set_tf_tensorrt_install_path(environ_cp):
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
Adapted from code contributed by Sami Kama (https://github.com/samikama).
Args:
environ_cp: copy of the os.environ.
Raises:
ValueError: if this method was called under non-Linux platform.
UserInputError: if user has provided invalid input multiple times.
"""
if not is_linux():
raise ValueError('Currently TensorRT is only supported on Linux platform.')

# Ask user whether to add TensorRT support.
if str(int(get_var(
environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1':
return

for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
ask_tensorrt_path = (r'Please specify the location where TensorRT is '
'installed. [Default is %s]:') % (
_DEFAULT_TENSORRT_PATH_LINUX)
trt_install_path = get_from_env_or_user_or_default(
environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path,
_DEFAULT_TENSORRT_PATH_LINUX)

# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
trt_install_path = os.path.realpath(
os.path.expanduser(trt_install_path))

def find_libs(search_path):
"""Search for libnvinfer.so in "search_path"."""
fl = set()
if os.path.exists(search_path) and os.path.isdir(search_path):
fl.update([os.path.realpath(os.path.join(search_path, x))
for x in os.listdir(search_path) if 'libnvinfer.so' in x])
return fl

possible_files = find_libs(trt_install_path)
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))

def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver):
"""Check the compatibility between tensorrt and cudnn/cudart libraries."""
ldd_bin = which('ldd') or '/usr/bin/ldd'
ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep)
cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
cudnn = None
cudart = None
for line in ldd_out:
if 'libcudnn.so' in line:
cudnn = cudnn_pattern.search(line)
elif 'libcudart.so' in line:
cudart = cuda_pattern.search(line)
if cudnn and len(cudnn.group(1)):
cudnn = convert_version_to_int(cudnn.group(1))
if cudart and len(cudart.group(1)):
cudart = convert_version_to_int(cudart.group(1))
return (cudnn == cudnn_ver) and (cudart == cuda_ver)

cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
highest_ver = [0, None, None]

for lib_file in possible_files:
if is_compatible(lib_file, cuda_ver, cudnn_ver):
ver_str = nvinfer_pattern.search(lib_file).group(1)
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
if ver > highest_ver[0]:
highest_ver = [ver, ver_str, lib_file]
if highest_ver[1] is not None:
trt_install_path = os.path.dirname(highest_ver[2])
tf_tensorrt_version = highest_ver[1]
break

# Try another alternative from ldconfig.
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
ldconfig_output = run_shell([ldconfig_bin, '-p'])
search_result = re.search(
'.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output)
if search_result:
libnvinfer_path_from_ldconfig = search_result.group(2)
if os.path.exists(libnvinfer_path_from_ldconfig):
if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver):
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
tf_tensorrt_version = search_result.group(1)
break

# Reset and Retry
print('Invalid path to TensorRT. None of the following files can be found:')
print(trt_install_path)
print(os.path.join(trt_install_path, 'lib'))
print(os.path.join(trt_install_path, 'lib64'))
if search_result:
print(libnvinfer_path_from_ldconfig)

else:
raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)

# Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION
environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)


def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
Expand Down Expand Up @@ -1244,9 +1358,11 @@ def main():
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
environ_cp['TF_CUDA_CLANG'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'

if is_macos():
environ_cp['TF_NEED_JEMALLOC'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'

set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
'with_jemalloc', True)
Expand Down Expand Up @@ -1278,6 +1394,8 @@ def main():
'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_tf_cuda_version(environ_cp)
set_tf_cudnn_version(environ_cp)
if is_linux():
set_tf_tensorrt_install_path(environ_cp)
set_tf_cuda_compute_capabilities(environ_cp)

set_tf_cuda_clang(environ_cp)
Expand Down Expand Up @@ -1332,6 +1450,7 @@ def main():
'more details.')
config_info_line('mkl', 'Build with MKL support.')
config_info_line('monolithic', 'Config for mostly static monolithic build.')
config_info_line('tensorrt', 'Build with TensorRT support.')

if __name__ == '__main__':
main()
9 changes: 9 additions & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,14 @@ config_setting(
visibility = ["//visibility:public"],
)

# TODO(laigd): consider removing this option and make TensorRT enabled
# automatically when CUDA is enabled.
config_setting(
name = "with_tensorrt_support",
values = {"define": "with_tensorrt_support=true"},
visibility = ["//visibility:public"],
)

package_group(
name = "internal",
packages = [
Expand Down Expand Up @@ -558,6 +566,7 @@ filegroup(
"//tensorflow/contrib/tensor_forest/proto:all_files",
"//tensorflow/contrib/tensorboard:all_files",
"//tensorflow/contrib/tensorboard/db:all_files",
"//tensorflow/contrib/tensorrt:all_files",
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/text:all_files",
"//tensorflow/contrib/tfprof:all_files",
Expand Down
45 changes: 45 additions & 0 deletions tensorflow/contrib/tensorrt/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Description:
# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow.
# APIs are meant to change over time.

package(default_visibility = ["//tensorflow:__subpackages__"])

licenses(["notice"]) # Apache 2.0

exports_files(["LICENSE"])

load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load(
"@local_config_tensorrt//:build_defs.bzl",
"if_tensorrt",
)

tf_cuda_cc_test(
name = "tensorrt_test_cc",
size = "small",
srcs = ["tensorrt_test.cc"],
tags = [
"manual",
"notap",
],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
] + if_tensorrt([
"@local_config_cuda//cuda:cuda_headers",
"@local_config_tensorrt//:nv_infer",
]),
)

filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
159 changes: 159 additions & 0 deletions tensorflow/contrib/tensorrt/tensorrt_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"

#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda.h"
#include "cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"

namespace tensorflow {
namespace {

class Logger : public nvinfer1::ILogger {
public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
switch (severity) {
case Severity::kINFO:
LOG(INFO) << msg;
break;
case Severity::kWARNING:
LOG(WARNING) << msg;
break;
case Severity::kINTERNAL_ERROR:
case Severity::kERROR:
LOG(ERROR) << msg;
break;
default:
break;
}
}
};

class ScopedWeights {
public:
ScopedWeights(float value) : value_(value) {
w.type = nvinfer1::DataType::kFLOAT;
w.values = &value_;
w.count = 1;
}
const nvinfer1::Weights& get() { return w; }

private:
float value_;
nvinfer1::Weights w;
};

const char* kInputTensor = "input";
const char* kOutputTensor = "output";

// Creates a network to compute y=2x+3.
nvinfer1::IHostMemory* CreateNetwork() {
Logger logger;
nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
ScopedWeights weights(2.0);
ScopedWeights bias(3.0);

nvinfer1::INetworkDefinition* network = builder->createNetwork();
// Add the input.
auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
EXPECT_NE(input, nullptr);
// Add the hidden layer.
auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get());
EXPECT_NE(layer, nullptr);
// Mark the output.
auto output = layer->getOutput(0);
output->setName(kOutputTensor);
network->markOutput(*output);
// Build the engine
builder->setMaxBatchSize(1);
builder->setMaxWorkspaceSize(1 << 10);
auto engine = builder->buildCudaEngine(*network);
EXPECT_NE(engine, nullptr);
// Serialize the engine to create a model, then close everything.
nvinfer1::IHostMemory* model = engine->serialize();
network->destroy();
engine->destroy();
builder->destroy();
return model;
}

// Executes the network.
void Execute(nvinfer1::IExecutionContext& context, const float* input,
float* output) {
const nvinfer1::ICudaEngine& engine = context.getEngine();

// We have two bindings: input and output.
ASSERT_EQ(engine.getNbBindings(), 2);
const int input_index = engine.getBindingIndex(kInputTensor);
const int output_index = engine.getBindingIndex(kOutputTensor);

// Create GPU buffers and a stream
void* buffers[2];
ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float)));
ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float)));
cudaStream_t stream;
ASSERT_EQ(0, cudaStreamCreate(&stream));

// Copy the input to the GPU, execute the network, and copy the output back.
//
// Note that since the host buffer was not created as pinned memory, these
// async copies are turned into sync copies. So the following synchronization
// could be removed.
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
cudaMemcpyHostToDevice, stream));
context.enqueue(1, buffers, stream, nullptr);
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
cudaMemcpyDeviceToHost, stream));
cudaStreamSynchronize(stream);

// Release the stream and the buffers
cudaStreamDestroy(stream);
ASSERT_EQ(0, cudaFree(buffers[input_index]));
ASSERT_EQ(0, cudaFree(buffers[output_index]));
}

TEST(TensorrtTest, BasicFunctions) {
// Create the network model.
nvinfer1::IHostMemory* model = CreateNetwork();
// Use the model to create an engine and then an execution context.
Logger logger;
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
nvinfer1::ICudaEngine* engine =
runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
model->destroy();
nvinfer1::IExecutionContext* context = engine->createExecutionContext();

// Execute the network.
float input = 1234;
float output;
Execute(*context, &input, &output);
EXPECT_EQ(output, input * 2 + 3);

// Destroy the engine.
context->destroy();
engine->destroy();
runtime->destroy();
}

} // namespace
} // namespace tensorflow

#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA
Loading

0 comments on commit 76f6938

Please sign in to comment.