Skip to content

Commit

Permalink
initial codebase for tensorrt backend
Browse files Browse the repository at this point in the history
  • Loading branch information
shashikg committed Jan 27, 2024
1 parent 1f00776 commit fb62b31
Show file tree
Hide file tree
Showing 11 changed files with 1,642 additions and 1 deletion.
14 changes: 14 additions & 0 deletions install_tensorrt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
echo "----------------[ Installing OpenMPI ]----------------"
apt-get update && apt-get -y install openmpi-bin libopenmpi-dev || sudo apt-get update && sudo apt-get -y install openmpi-bin libopenmpi-dev

echo "----------------[ Installing MPI4PY ]----------------"
MPI4PY_VERSION="3.1.5"
RELEASE_URL="https://github.com/mpi4py/mpi4py/archive/refs/tags/${MPI4PY_VERSION}.tar.gz"
curl -L ${RELEASE_URL} | tar -zx -C /tmp
# Bypassing compatibility issues with higher versions (>= 69) of setuptools.
sed -i 's/>= 40\.9\.0/>= 40.9.0, < 69/g' /tmp/mpi4py-${MPI4PY_VERSION}/pyproject.toml
pip3 install /tmp/mpi4py-${MPI4PY_VERSION}
rm -rf /tmp/mpi4py*

echo "----------------[ Installing TensorRT-LLM ]----------------"
pip3 install tensorrt_llm==0.8.0.dev2024012301 -U --extra-index-url https://pypi.nvidia.com
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ huggingface-hub
accelerate
optimum
transformers
openai-whisper
openai-whisper
nvidia-ml-py
Empty file.
134 changes: 134 additions & 0 deletions whisper_s2t/backends/tensorrt/engine_builder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import json
import hashlib
import os
from pynvml import *

from .download_utils import SAVE_DIR, download_model


class TRTBuilderConfig:
def __init__(self,
max_batch_size=16,
max_beam_width=1,
max_input_len=4,
max_output_len=448,
world_size=1,
dtype='float16',
quantize_dir='quantize/1-gpu',
use_gpt_attention_plugin='float16',
use_bert_attention_plugin='float16',
use_gemm_plugin='float16',
use_layernorm_plugin=None,
remove_input_padding=False,
use_weight_only_enc=False,
use_weight_only_dec=False,
weight_only_precision='int8',
int8_kv_cache=False,
debug_mode=False,
**kwargs,
):

self.max_batch_size = max_batch_size
self.max_beam_width = max_beam_width
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.world_size = world_size
self.dtype = dtype
self.quantize_dir = quantize_dir
self.use_gpt_attention_plugin = use_gpt_attention_plugin
self.use_bert_attention_plugin = use_bert_attention_plugin
self.use_gemm_plugin = use_gemm_plugin
self.use_layernorm_plugin = use_layernorm_plugin
self.remove_input_padding = remove_input_padding
self.use_weight_only_enc = use_weight_only_enc
self.use_weight_only_dec = use_weight_only_dec
self.weight_only_precision = weight_only_precision
self.int8_kv_cache = int8_kv_cache
self.debug_mode = debug_mode

nvmlInit()
self.cuda_compute_capability = list(nvmlDeviceGetCudaComputeCapability(nvmlDeviceGetHandleByIndex(0)))
nvmlShutdown()

plugins_args = [
'use_gemm_plugin',
'use_gpt_attention_plugin',
'use_bert_attention_plugin'
]

for plugin_arg in plugins_args:
if getattr(self, plugin_arg) is None:
print(
f"{plugin_arg} is None, setting it as {self.dtype} automatically."
)
setattr(self, plugin_arg, self.dtype)


def identifier(self):
params = vars(self)
return hashlib.md5(json.dumps(params).encode()).hexdigest()


def save_trt_build_configs(trt_build_args):
with open(f'{trt_build_args.output_dir}/trt_build_args.json', 'w') as f:
f.write(json.dumps(vars(trt_build_args)))


def load_trt_build_config(output_dir):
"""
[TODO]: Add cuda_compute_capability verification check
"""

with open(f'{output_dir}/trt_build_args.json', 'r') as f:
trt_build_configs = json.load(f)

trt_build_args = TRTBuilderConfig(**trt_build_configs)
trt_build_args.output_dir = trt_build_configs['output_dir']
trt_build_args.model_path = trt_build_configs['model_path']

return trt_build_args


def build_trt_engine(model_name='large-v2', args=None, force=False, log_level='error'):

if args is None:
print(f"args is None, using default configs.")
args = TRTBuilderConfig()

args.output_dir = os.path.join(SAVE_DIR, model_name, args.identifier())
args.model_path, tokenizer_path = download_model(model_name)

if force:
print(f"'force' flag is 'True'. Removing previous build.")
os.system(f"rm -rf {args.output_dir}")

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
else:
_files = os.listdir(args.output_dir)

_failed_export = False
for _req_files in ['tokenizer.json',
'trt_build_args.json',
'encoder_config.json',
'decoder_config.json',
'encoder.engine',
'decoder.engine']:

if _req_files not in _files:
_failed_export = True
break

if _failed_export:
print(f"Export directory exists but seems like a failed export, regenerating the engine files.")
os.system(f"rm -rf {args.output_dir}")
os.makedirs(args.output_dir)
else:
return args.output_dir

os.system(f"cp {tokenizer_path} {args.output_dir}/tokenizer.json")
save_trt_build_configs(args)

os.system(f"python3 -m whisper_s2t.backends.tensorrt.engine_builder.builder --output_dir='{args.output_dir}' --model_name='{model_name}' --log_level='{log_level}'")

return args.output_dir
Loading

0 comments on commit fb62b31

Please sign in to comment.