forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
unet: fix bug and coding style (wang-xinyu#1171)
* unet: fix bug * update readme * remove useless code
- Loading branch information
1 parent
f98182e
commit cb9efbd
Showing
5 changed files
with
298 additions
and
440 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,61 @@ | ||
# UNet | ||
This is a TensorRT version UNet, inspired by [tensorrtx](https://github.com/wang-xinyu/tensorrtx) and [pytorch-unet](https://github.com/milesial/Pytorch-UNet).<br> | ||
You can generate TensorRT engine file using this script and customize some params and network structure based on network you trained (FP32/16 precision, input size, different conv, activation function...)<br> | ||
|
||
# Requirements | ||
Pytorch model from [Pytorch-UNet](https://github.com/milesial/Pytorch-UNet). | ||
|
||
TensorRT 7.x or 8.x (you need to install tensorrt first)<br> | ||
Python<br> | ||
opencv<br> | ||
cmake<br> | ||
## Contributors | ||
|
||
# Train .pth file and convert .wts | ||
<a href="https://github.com/YuzhouPeng"><img src="https://avatars.githubusercontent.com/u/13601004?v=4?s=48" width="40px;" alt=""/></a> | ||
<a href="https://github.com/East-Face"><img src="https://avatars.githubusercontent.com/u/35283869?v=4s=48" width="40px;" alt=""/></a> | ||
<a href="https://github.com/irvingzhang0512"><img src="https://avatars.githubusercontent.com/u/22089207?s=48&v=4" width="40px;" alt=""/></a> | ||
<a href="https://github.com/wang-xinyu"><img src="https://avatars.githubusercontent.com/u/15235574?s=48&v=4" width="40px;" alt=""/></a> | ||
|
||
## Create env | ||
## Requirements | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Train .pth file | ||
|
||
Train your dataset by following [Pytorch-UNet](https://github.com/milesial/Pytorch-UNet) and generate .pth file.<br> | ||
Please use TensorRT 7.x. | ||
|
||
Please set bilinear=False, i.e. `UNet(n_channels=3, n_classes=1, bilinear=False)`, because TensorRT doesn't support Upsample layer. | ||
There is a bug with TensorRT 8.x, we are working on it. | ||
|
||
## Convert .pth to .wts | ||
## Build and Run | ||
|
||
1. Generate .wts | ||
``` | ||
cp tensorrtx/unet/gen_wts.py Pytorch-UNet/ | ||
cp {path-of-tensorrtx}/unet/gen_wts.py Pytorch-UNet/ | ||
cd Pytorch-UNet/ | ||
python gen_wts.py | ||
wget https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth | ||
python gen_wts.py unet_carvana_scale0.5_epoch2.pth | ||
``` | ||
|
||
# Generate engine file and infer | ||
|
||
Build: | ||
2. Generate TensorRT engine | ||
``` | ||
cd tensorrtx/unet/ | ||
mkdir build | ||
cd build | ||
cmake .. | ||
make | ||
cp {path-of-Pytorch-UNet}/unet.wts . | ||
./unet -s | ||
``` | ||
|
||
Generate TensorRT engine file: | ||
``` | ||
unet -s | ||
3. Run inference | ||
``` | ||
Inference on images in a folder: | ||
``` | ||
unet -d ../samples | ||
wget https://raw.githubusercontent.com/wang-xinyu/tensorrtx/f60dcc7bec28846cd973fc95ac829c4e57a11395/unet/samples/0cdf5b5d0ce1_01.jpg | ||
./unet -d 0cdf5b5d0ce1_01.jpg | ||
``` | ||
|
||
4. Check result.jpg | ||
|
||
<p align="center"> | ||
<img src="https://user-images.githubusercontent.com/15235574/207358769-dacf908e-f65d-4b2e-bc53-4fa2a9114c2a.jpg" height="360px;"> | ||
</p> | ||
|
||
# Benchmark | ||
the speed of tensorRT engine is much faster | ||
|
||
pytorch | TensorRT FP32 | TensorRT FP16 | ||
---- | ----- | ------ | ||
816x672 | 816x672 | 816x672 | ||
58ms | 43ms (batchsize 8) | 14ms (batchsize 8) | ||
# test img | ||
``` | ||
wget https://raw.githubusercontent.com/wang-xinyu/tensorrtx/f60dcc7bec28846cd973fc95ac829c4e57a11395/unet/samples/0cdf5b5d0ce1_01.jpg | ||
``` | ||
# Further development | ||
Pytorch | TensorRT FP32 | TensorRT FP16 | ||
---- | ----- | ------ | ||
816x672 | 816x672 | 816x672 | ||
58ms | 43ms (batchsize 8) | 14ms (batchsize 8) | ||
|
||
## More Information | ||
|
||
See the readme in [home page.](https://github.com/wang-xinyu/tensorrtx) | ||
|
||
1. add INT8 calibrator<br> | ||
2. add custom plugin<br> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,24 @@ | ||
import torch | ||
from torch import nn | ||
import torchvision | ||
import os | ||
import sys | ||
import struct | ||
from torchsummary import summary | ||
|
||
def main(): | ||
print('cuda device count: ', torch.cuda.device_count()) | ||
net = torch.load('ori_unet.pth') | ||
net = net.to('cuda:0') | ||
net = net.eval() | ||
print('model: ', net) | ||
#print('state dict: ', net.state_dict().keys()) | ||
tmp = torch.ones(1, 3, 224, 224).to('cuda:0') | ||
print('input: ', tmp) | ||
out = net(tmp) | ||
device = torch.device('cpu') | ||
state_dict = torch.load(sys.argv[1], map_location=device) | ||
|
||
print('output:', out) | ||
|
||
summary(net, (3, 224, 224)) | ||
#return | ||
f = open("unet.wts", 'w') | ||
f.write("{}\n".format(len(net.state_dict().keys()))) | ||
for k,v in net.state_dict().items(): | ||
print('key: ', k) | ||
print('value: ', v.shape) | ||
vr = v.reshape(-1).cpu().numpy() | ||
f.write("{} {}".format(k, len(vr))) | ||
for vv in vr: | ||
f.write(" ") | ||
f.write(struct.pack(">f", float(vv)).hex()) | ||
f.write("\n") | ||
f = open("unet.wts", 'w') | ||
f.write("{}\n".format(len(state_dict.keys()))) | ||
for k, v in state_dict.items(): | ||
print('key: ', k) | ||
print('value: ', v.shape) | ||
vr = v.reshape(-1).cpu().numpy() | ||
f.write("{} {}".format(k, len(vr))) | ||
for vv in vr: | ||
f.write(" ") | ||
f.write(struct.pack(">f", float(vv)).hex()) | ||
f.write("\n") | ||
f.close() | ||
|
||
if __name__ == '__main__': | ||
main() | ||
main() | ||
|
Oops, something went wrong.