-
Notifications
You must be signed in to change notification settings - Fork 38
/
common.py
105 lines (93 loc) · 4.15 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import time
import os
import argparse
import numpy as np
import pycuda.driver as cuda
import tensorrt as trt
try:
# Sometimes python2 does not understand FileNotFoundError
FileNotFoundError
except NameError:
FileNotFoundError = IOError
def GiB(val):
return val * 1 << 30
def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
'''
Parses sample arguments.
Args:
description (str): Description of the sample.
subfolder (str): The subfolder containing data relevant to this sample
find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
Returns:
str: Path of data directory.
Raises:
FileNotFoundError
'''
kDEFAULT_DATA_ROOT = os.path.abspath("/usr/src/tensorrt/data")
# Standard command-line arguments for all samples.
parser = argparse.ArgumentParser(description=description)
parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory.")
args, unknown_args = parser.parse_known_args()
# If data directory is not specified, use the default.
data_root = args.datadir if args.datadir else kDEFAULT_DATA_ROOT
# If the subfolder exists, append it to the path, otherwise use the provided path as-is.
subfolder_path = os.path.join(data_root, subfolder)
if not os.path.exists(subfolder_path):
print("WARNING: " + subfolder_path + " does not exist. Using " + data_root + " instead.")
data_path = subfolder_path if os.path.exists(subfolder_path) else data_root
# Make sure data directory exists.
if not (os.path.exists(data_path)):
raise FileNotFoundError(data_path + " does not exist. Please provide the correct data path with the -d option.")
# Find all requested files.
for index, f in enumerate(find_files):
find_files[index] = os.path.abspath(os.path.join(data_path, f))
if not os.path.exists(find_files[index]):
raise FileNotFoundError(find_files[index] + " does not exist. Please provide the correct data path with the -d option.")
if find_files:
return data_path, find_files
else:
return data_path
# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
def allocate_buffers(engine):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in engine:
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding))
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream
# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
start = time.time()
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
print("=> time: %.4f" %(time.time()-start))
return [out.host for out in outputs]