forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_stable_diffusion_controlnet_to_tensorrt.py
121 lines (94 loc) · 4.28 KB
/
convert_stable_diffusion_controlnet_to_tensorrt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import sys
import tensorrt as trt
def convert_models(onnx_path: str, num_controlnet: int, output_path: str, fp16: bool = False, sd_xl: bool = False):
"""
Function to convert models in stable diffusion controlnet pipeline into TensorRT format
Example:
python convert_stable_diffusion_controlnet_to_tensorrt.py
--onnx_path path-to-models-stable_diffusion/RevAnimated-v1-2-2/unet/model.onnx
--output_path path-to-models-stable_diffusion/RevAnimated-v1-2-2/unet/model.engine
--fp16
--num_controlnet 2
Example for SD XL:
python convert_stable_diffusion_controlnet_to_tensorrt.py
--onnx_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.onnx
--output_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.engine
--fp16
--num_controlnet 1
--sd_xl
Returns:
unet/model.engine
run test script in diffusers/examples/community
python test_onnx_controlnet.py
--sd_model danbrown/RevAnimated-v1-2-2
--onnx_model_dir path-to-models-stable_diffusion/RevAnimated-v1-2-2
--unet_engine_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.engine
--qr_img_path path-to-qr-code-image
"""
# UNET
if sd_xl:
batch_size = 1
unet_in_channels = 4
unet_sample_size = 64
num_tokens = 77
text_hidden_size = 2048
img_size = 512
text_embeds_shape = (2 * batch_size, 1280)
time_ids_shape = (2 * batch_size, 6)
else:
batch_size = 1
unet_in_channels = 4
unet_sample_size = 64
num_tokens = 77
text_hidden_size = 768
img_size = 512
batch_size = 1
latents_shape = (2 * batch_size, unet_in_channels, unet_sample_size, unet_sample_size)
embed_shape = (2 * batch_size, num_tokens, text_hidden_size)
controlnet_conds_shape = (num_controlnet, 2 * batch_size, 3, img_size, img_size)
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
TRT_BUILDER = trt.Builder(TRT_LOGGER)
TRT_RUNTIME = trt.Runtime(TRT_LOGGER)
network = TRT_BUILDER.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
onnx_parser = trt.OnnxParser(network, TRT_LOGGER)
parse_success = onnx_parser.parse_from_file(onnx_path)
for idx in range(onnx_parser.num_errors):
print(onnx_parser.get_error(idx))
if not parse_success:
sys.exit("ONNX model parsing failed")
print("Load Onnx model done")
profile = TRT_BUILDER.create_optimization_profile()
profile.set_shape("sample", latents_shape, latents_shape, latents_shape)
profile.set_shape("encoder_hidden_states", embed_shape, embed_shape, embed_shape)
profile.set_shape("controlnet_conds", controlnet_conds_shape, controlnet_conds_shape, controlnet_conds_shape)
if sd_xl:
profile.set_shape("text_embeds", text_embeds_shape, text_embeds_shape, text_embeds_shape)
profile.set_shape("time_ids", time_ids_shape, time_ids_shape, time_ids_shape)
config = TRT_BUILDER.create_builder_config()
config.add_optimization_profile(profile)
config.set_preview_feature(trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805, True)
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
plan = TRT_BUILDER.build_serialized_network(network, config)
if plan is None:
sys.exit("Failed building engine")
print("Succeeded building engine")
engine = TRT_RUNTIME.deserialize_cuda_engine(plan)
## save TRT engine
with open(output_path, "wb") as f:
f.write(engine.serialize())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--sd_xl", action="store_true", default=False, help="SD XL pipeline")
parser.add_argument(
"--onnx_path",
type=str,
required=True,
help="Path to the onnx checkpoint to convert",
)
parser.add_argument("--num_controlnet", type=int)
parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")
args = parser.parse_args()
convert_models(args.onnx_path, args.num_controlnet, args.output_path, args.fp16, args.sd_xl)