Skip to content

Commit

Permalink
update ptq
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobo-y committed Oct 6, 2023
1 parent 0ced3b5 commit cc51d9d
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 114 deletions.
9 changes: 3 additions & 6 deletions scripts/trt_quant/README
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
yolov5 tensorrt int8 or fp16 量化, 代码参考自 [nanodet_tensorrt_int8_tools](https://github.com/Wulingtian/nanodet_tensorrt_int8_tools)

将onnx 模型进行float16 或者int8 量化

int8 engine
```shell script
python scripts/trt_quant/convert_trt_quant.py --img_dir /XXXX/train/ --img_size 640 --batch_size 6 --batch 200 --onnx_model runs/train/exp1/weights/bast.onnx --mode int8
python scripts/trt_quant/generate_int8_engine.py --onnx path --images-dir img_path --save-engine engine_path
```

this scripts run in tensorrt 7

62 changes: 47 additions & 15 deletions scripts/trt_quant/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pycuda.driver as cuda
import ctypes
import logging
import pycuda.autoinit



Expand All @@ -12,34 +13,65 @@
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]


class Calibrator(trt.IInt8MinMaxCalibrator):
def __init__(self, stream, cache_file=""):
trt.IInt8MinMaxCalibrator.__init__(self)
self.stream = stream
self.d_input = cuda.mem_alloc(self.stream.calibration_data.nbytes)
class EntropyCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, dataloader, cache_file):
trt.IInt8EntropyCalibrator2.__init__(self)
self.cache_file = cache_file
stream.reset()
self.data = iter(dataloader)
self.batch_size = dataloader.batch_size
self.num_image = len(dataloader.dataset)
self.current_index = 0
self.nbytes = dataloader.dataset.nbytes
self.device_input = cuda.mem_alloc(self.nbytes*self.batch_size)

def get_batch_size(self):
return self.stream.batch_size
return self.batch_size

def get_batch(self, names):
batch = self.stream.next_batch()
if not batch.size:
def get_batch(self, name):
if self.current_index + self.batch_size > self.num_image:
return None
batch = next(self.data).numpy().astype("float32").ravel()
cuda.memcpy_htod(self.device_input, batch)
self.current_index += self.batch_size
return [self.device_input]

def read_calibration_cache(self):
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
return f.read()

def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
f.write(cache)


class MinMaxCalibrator(trt.IInt8MinMaxCalibrator):
def __init__(self, dataloader, cache_file):
trt.IInt8MinMaxCalibrator.__init__(self)
self.cache_file = cache_file
self.data = iter(dataloader)
self.batch_size = dataloader.batch_size
self.num_image = len(dataloader.dataset)
self.current_index = 0
self.nbytes = dataloader.dataset.nbytes
self.device_input = cuda.mem_alloc(self.nbytes*self.batch_size)

cuda.memcpy_htod(self.d_input, batch)
def get_batch_size(self):
return self.batch_size

return [int(self.d_input)]
def get_batch(self, name):
if self.current_index + self.batch_size > self.num_image:
return None
batch = next(self.data).numpy().astype("float32").ravel()
cuda.memcpy_htod(self.device_input, batch)
self.current_index += self.batch_size
return [self.device_input]

def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
logger.info("Using calibration cache to save time: {:}".format(self.cache_file))
return f.read()

def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
logger.info("Caching calibration data for future use: {:}".format(self.cache_file))
f.write(cache)
93 changes: 0 additions & 93 deletions scripts/trt_quant/convert_trt_quant.py

This file was deleted.

29 changes: 29 additions & 0 deletions scripts/trt_quant/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import cv2
from glob import glob
import random


class Dataset():
def __init__(self, image_path, num=-1, transform=None):
self.imgs_path = []
for ext in ["*.png", "*.jpg", "*.jpeg"]:
self.imgs_path.extend(glob(os.path.join(image_path, ext)))
random.shuffle(self.imgs_path)
if num > 0:
self.imgs_path = self.imgs_path[:num]
self.trans = transform

def __len__(self):
return len(self.imgs_path)

def __getitem__(self, idx):
img = cv2.imread(self.imgs_path[idx])
if self.trans:
img = self.trans(img)
return img

@property
def nbytes(self):
size = self[0].nbytes
return size
130 changes: 130 additions & 0 deletions scripts/trt_quant/generate_int8_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import argparse
from dataclasses import dataclass
import cv2
import os
from glob import glob
import tensorrt as trt
from data import Dataset
from torch.utils.data import DataLoader
from calibrator import EntropyCalibrator, MinMaxCalibrator
import ctypes


EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
EXPLICIT_PRECISION = 1 << (int)(
trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)


def GiB(val):
return val * 1 << 30

def MiB(val):
return val * 1 << 20

@dataclass
class BuildConfig:
min_timing_iterations: int = None
avg_timing_iterations: int = None
int8_calibrator: trt.IInt8Calibrator = None
max_workspace_size: int = MiB(1024)
flags: int = None
profile_stream: int = None
num_optimization_profiles: int = None
default_device_type: trt.DeviceType = trt.DeviceType.GPU
DLA_core: int = None
profiling_verbosity: int = None
engine_capability: int = None


class Transform():
def __init__(self, h=640, w=640):
self.h = h
self.w = w

def __call__(self, img):
img = img.astype("float32")
img = (img - 128.0)/128.0
img = cv2.resize(img, (self.w, self.h))
img = img.transpose(2, 0, 1)
return img

def build_int8_engine(trt_logger, onnx_path, build_params={}):
builder = trt.Builder(trt_logger)
network = builder.create_network(EXPLICIT_BATCH)
config = builder.create_builder_config()
parser = trt.OnnxParser(network, trt_logger)
build_config = build_params.get("build_config", None)
if build_config:
for key, val in build_config.__dict__.items():
if val is not None:
setattr(config, key, val)
with open(onnx_path, 'rb') as f:
parser.parse(f.read())
for index in range(parser.num_errors):
print(parser.get_error(index))
if builder.platform_has_tf32:
config.clear_flag(trt.BuilderFlag.TF32)
engine = builder.build_serialized_network(network, config)
return engine


if __name__ == '__main__':
def parser_arg():
parser = argparse.ArgumentParser(
description="calibrate int8 model and generate model")
parser.add_argument("--onnx", type=str, required=True)
parser.add_argument("--images-dir", type=str, required=True)
parser.add_argument("--save-engine", type=str, required=True)
parser.add_argument('--verbose', action="store_true",
default=False, required=False)
parser.add_argument("--w", type=int, default=640)
parser.add_argument('--h', type=int, default=640)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--images-num', type=int, default=1000)
parser.add_argument('--calibrator', type=str, default='kl', help='kl or minmax')
parser.add_argument('--plugin-dir', type=str, default=None, required=False, help='plugin dir')
parser.add_argument('--cache-file', type=str, default='sample.cache', required=False)
args = parser.parse_args()
return args

args = parser_arg()
samples_imgs = args.images_dir
onnx_model = args.onnx
save_engine = args.save_engine
h = args.h
w = args.w
bs = args.batch_size
num = args.images_num
calibrator = args.calibrator
plugin_dir = args.plugin_dir
cache_file = args.cache_file

if plugin_dir is not None:
paths = glob(os.path.join(plugin_dir, "*.so"))
for path in paths:
ctypes.cdll.LoadLibrary(path)

build_config = BuildConfig()
build_config.flags = 1 << int(trt.BuilderFlag.INT8)
build_config.max_workspace_size = MiB(2048)

transform = Transform(h=h, w=w)
dataset = Dataset(samples_imgs, num=num, transform=transform)
dataloader = DataLoader(dataset, batch_size=bs, num_workers=8, drop_last=True, shuffle=True, prefetch_factor=2)
if calibrator == 'kl':
calibr = EntropyCalibrator(dataloader=dataloader, cache_file=cache_file)
elif calibrator == 'minmax':
calibr = MinMaxCalibrator(dataloader=dataloader, cache_file=cache_file)
else:
assert False, "not support calibrator"

build_config.int8_calibrator = calibr
if args.verbose is True:
logger = trt.Logger(trt.Logger.VERBOSE)
else:
logger = trt.Logger(trt.Logger.INFO)
build_params = {"build_config": build_config}
engine = build_int8_engine(
logger, onnx_path=onnx_model, build_params=build_params)
with open(save_engine, 'wb') as f:
f.write(engine)

0 comments on commit cc51d9d

Please sign in to comment.