forked from open-mmlab/mmdeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_rewrite.py
115 lines (96 loc) · 3.88 KB
/
demo_rewrite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import os
import shutil
import torch
from pyppeteer import launch
from torchvision.models import resnet18
from mmdeploy.core import FUNCTION_REWRITER, RewriterContext, patch_model
from mmdeploy.utils import get_root_logger
@FUNCTION_REWRITER.register_rewriter(
func_name='torchvision.models.ResNet._forward_impl')
def forward_of_resnet(self, x):
"""Rewrite the forward implementation of resnet.
Early return the feature map after two down-sampling steps.
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
return x
def rewrite_resnet18(original_path: str, rewritten_path: str):
# prepare inputs and original model
inputs = torch.rand(1, 3, 224, 224)
original_model = resnet18(pretrained=False)
# export original model
torch.onnx.export(original_model, inputs, original_path)
# patch model
patched_model = patch_model(original_model, cfg={}, backend='default')
# export rewritten onnx under a rewriter context manager
with RewriterContext(cfg={}, backend='default'), torch.no_grad():
torch.onnx.export(patched_model, inputs, rewritten_path)
def screen_size():
"""Get windows size through tkinter."""
import tkinter
tk = tkinter.Tk()
width = tk.winfo_screenwidth()
height = tk.winfo_screenheight()
tk.quit()
return width, height
async def visualize(original_path: str, rewritten_path: str):
# launch a web browser
browser = await launch(headless=False, args=['--start-maximized'])
# create two new pages
page2 = await browser.newPage()
page1 = await browser.newPage()
# go to netron.app
width, height = screen_size()
await page1.setViewport({'width': width, 'height': height})
await page2.setViewport({'width': width, 'height': height})
await page1.goto('https://netron.app/')
await page2.goto('https://netron.app/')
await asyncio.sleep(2)
# open local two onnx files
mupinput1 = await page1.querySelector("input[type='file']")
mupinput2 = await page2.querySelector("input[type='file']")
await mupinput1.uploadFile(original_file_path)
await mupinput2.uploadFile(rewritten_file_path)
await asyncio.sleep(4)
for _ in range(6):
await page1.click('#zoom-out-button')
await asyncio.sleep(0.3)
await asyncio.sleep(1)
await page1.screenshot({'path': original_path.replace('.onnx', '.png')},
clip={
'x': width / 4,
'y': 0,
'width': width / 2,
'height': height
})
await page2.screenshot({'path': rewritten_path.replace('.onnx', '.png')},
clip={
'x': width / 4,
'y': 0,
'width': width / 2,
'height': height
})
await browser.close()
if __name__ == '__main__':
tmp_dir = os.getcwd() + '/tmp'
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
original_file_path = os.path.join(tmp_dir, 'original.onnx')
rewritten_file_path = os.path.join(tmp_dir, 'rewritten.onnx')
logger = get_root_logger()
logger.info('Generating resnet18 and its rewritten model...')
rewrite_resnet18(original_file_path, rewritten_file_path)
logger.info('Visualizing models through netron...')
asyncio.get_event_loop().run_until_complete(
visualize(original_file_path, rewritten_file_path))
import mmcv
image1 = mmcv.imread(original_file_path.replace('.onnx', '.png'))
image2 = mmcv.imread(rewritten_file_path.replace('.onnx', '.png'))
mmcv.imshow(image1, win_name='original')
mmcv.imshow(image2, win_name='rewritten')
shutil.rmtree(tmp_dir)