Skip to content

Commit

Permalink
unet: fix bug and coding style (wang-xinyu#1171)
Browse files Browse the repository at this point in the history
* unet: fix bug

* update readme

* remove useless code
  • Loading branch information
wang-xinyu authored Dec 13, 2022
1 parent f98182e commit cb9efbd
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 440 deletions.
4 changes: 2 additions & 2 deletions unet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ include_directories(/usr/local/cuda/include/)
link_directories(/usr/local/cuda/lib64/)

# tensorrt
include_directories(/workspace/TensorRT-8.4.1.5/include/)
link_directories(/workspace/TensorRT-8.4.1.5/lib/)
include_directories(/workspace/TensorRT-7.2.3.4/include/)
link_directories(/workspace/TensorRT-7.2.3.4/lib/)

# opencv library
find_package(OpenCV)
Expand Down
75 changes: 34 additions & 41 deletions unet/README.md
Original file line number Diff line number Diff line change
@@ -1,68 +1,61 @@
# UNet
This is a TensorRT version UNet, inspired by [tensorrtx](https://github.com/wang-xinyu/tensorrtx) and [pytorch-unet](https://github.com/milesial/Pytorch-UNet).<br>
You can generate TensorRT engine file using this script and customize some params and network structure based on network you trained (FP32/16 precision, input size, different conv, activation function...)<br>

# Requirements
Pytorch model from [Pytorch-UNet](https://github.com/milesial/Pytorch-UNet).

TensorRT 7.x or 8.x (you need to install tensorrt first)<br>
Python<br>
opencv<br>
cmake<br>
## Contributors

# Train .pth file and convert .wts
<a href="https://github.com/YuzhouPeng"><img src="https://avatars.githubusercontent.com/u/13601004?v=4?s=48" width="40px;" alt=""/></a>
<a href="https://github.com/East-Face"><img src="https://avatars.githubusercontent.com/u/35283869?v=4s=48" width="40px;" alt=""/></a>
<a href="https://github.com/irvingzhang0512"><img src="https://avatars.githubusercontent.com/u/22089207?s=48&v=4" width="40px;" alt=""/></a>
<a href="https://github.com/wang-xinyu"><img src="https://avatars.githubusercontent.com/u/15235574?s=48&v=4" width="40px;" alt=""/></a>

## Create env
## Requirements

```
pip install -r requirements.txt
```

## Train .pth file

Train your dataset by following [Pytorch-UNet](https://github.com/milesial/Pytorch-UNet) and generate .pth file.<br>
Please use TensorRT 7.x.

Please set bilinear=False, i.e. `UNet(n_channels=3, n_classes=1, bilinear=False)`, because TensorRT doesn't support Upsample layer.
There is a bug with TensorRT 8.x, we are working on it.

## Convert .pth to .wts
## Build and Run

1. Generate .wts
```
cp tensorrtx/unet/gen_wts.py Pytorch-UNet/
cp {path-of-tensorrtx}/unet/gen_wts.py Pytorch-UNet/
cd Pytorch-UNet/
python gen_wts.py
wget https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth
python gen_wts.py unet_carvana_scale0.5_epoch2.pth
```

# Generate engine file and infer

Build:
2. Generate TensorRT engine
```
cd tensorrtx/unet/
mkdir build
cd build
cmake ..
make
cp {path-of-Pytorch-UNet}/unet.wts .
./unet -s
```

Generate TensorRT engine file:
```
unet -s
3. Run inference
```
Inference on images in a folder:
```
unet -d ../samples
wget https://raw.githubusercontent.com/wang-xinyu/tensorrtx/f60dcc7bec28846cd973fc95ac829c4e57a11395/unet/samples/0cdf5b5d0ce1_01.jpg
./unet -d 0cdf5b5d0ce1_01.jpg
```

4. Check result.jpg

<p align="center">
<img src="https://user-images.githubusercontent.com/15235574/207358769-dacf908e-f65d-4b2e-bc53-4fa2a9114c2a.jpg" height="360px;">
</p>

# Benchmark
the speed of tensorRT engine is much faster

pytorch | TensorRT FP32 | TensorRT FP16
---- | ----- | ------
816x672 | 816x672 | 816x672
58ms | 43ms (batchsize 8) | 14ms (batchsize 8)
# test img
```
wget https://raw.githubusercontent.com/wang-xinyu/tensorrtx/f60dcc7bec28846cd973fc95ac829c4e57a11395/unet/samples/0cdf5b5d0ce1_01.jpg
```
# Further development
Pytorch | TensorRT FP32 | TensorRT FP16
---- | ----- | ------
816x672 | 816x672 | 816x672
58ms | 43ms (batchsize 8) | 14ms (batchsize 8)

## More Information

See the readme in [home page.](https://github.com/wang-xinyu/tensorrtx)

1. add INT8 calibrator<br>
2. add custom plugin<br>
44 changes: 1 addition & 43 deletions unet/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <sstream>
#include <vector>
#include <opencv2/opencv.hpp>
#include <dirent.h>
#include "NvInfer.h"

#define CHECK(status) \
Expand Down Expand Up @@ -95,46 +94,5 @@ IScaleLayer* addBatchNorm2d(INetworkDefinition *network, std::map<std::string, W
return scale_1;
}

ILayer* convBlock(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, int outch, int ksize, int s, int g, std::string lname) {
Weights emptywts{DataType::kFLOAT, nullptr, 0};
int p = ksize / 2;
IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap[lname + ".conv.weight"], emptywts);
assert(conv1);
conv1->setStrideNd(DimsHW{s, s});
conv1->setPaddingNd(DimsHW{p, p});
conv1->setNbGroups(g);
IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + ".bn", 1e-3);

// hard_swish = x * hard_sigmoid
auto hsig = network->addActivation(*bn1->getOutput(0), ActivationType::kHARD_SIGMOID);
assert(hsig);
hsig->setAlpha(1.0 / 6.0);
hsig->setBeta(0.5);
auto ew = network->addElementWise(*bn1->getOutput(0), *hsig->getOutput(0), ElementWiseOperation::kPROD);
assert(ew);
return ew;
}

int read_files_in_dir(const char *p_dir_name, std::vector<std::string> &file_names) {
DIR *p_dir = opendir(p_dir_name);
if (p_dir == nullptr) {
return -1;
}

struct dirent* p_file = nullptr;
while ((p_file = readdir(p_dir)) != nullptr) {
if (strcmp(p_file->d_name, ".") != 0 &&
strcmp(p_file->d_name, "..") != 0) {
//std::string cur_file_name(p_dir_name);
//cur_file_name += "/";
//cur_file_name += p_file->d_name;
std::string cur_file_name(p_file->d_name);
file_names.push_back(cur_file_name);
}
}

closedir(p_dir);
return 0;
}

#endif

46 changes: 17 additions & 29 deletions unet/gen_wts.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,24 @@
import torch
from torch import nn
import torchvision
import os
import sys
import struct
from torchsummary import summary

def main():
print('cuda device count: ', torch.cuda.device_count())
net = torch.load('ori_unet.pth')
net = net.to('cuda:0')
net = net.eval()
print('model: ', net)
#print('state dict: ', net.state_dict().keys())
tmp = torch.ones(1, 3, 224, 224).to('cuda:0')
print('input: ', tmp)
out = net(tmp)
device = torch.device('cpu')
state_dict = torch.load(sys.argv[1], map_location=device)

print('output:', out)

summary(net, (3, 224, 224))
#return
f = open("unet.wts", 'w')
f.write("{}\n".format(len(net.state_dict().keys())))
for k,v in net.state_dict().items():
print('key: ', k)
print('value: ', v.shape)
vr = v.reshape(-1).cpu().numpy()
f.write("{} {}".format(k, len(vr)))
for vv in vr:
f.write(" ")
f.write(struct.pack(">f", float(vv)).hex())
f.write("\n")
f = open("unet.wts", 'w')
f.write("{}\n".format(len(state_dict.keys())))
for k, v in state_dict.items():
print('key: ', k)
print('value: ', v.shape)
vr = v.reshape(-1).cpu().numpy()
f.write("{} {}".format(k, len(vr)))
for vv in vr:
f.write(" ")
f.write(struct.pack(">f", float(vv)).hex())
f.write("\n")
f.close()

if __name__ == '__main__':
main()
main()

Loading

0 comments on commit cb9efbd

Please sign in to comment.