-
Notifications
You must be signed in to change notification settings - Fork 95
/
hook.py
52 lines (41 loc) · 1.52 KB
/
hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# 仅用作简单的中间变量输出和调用hook,实际上mmdetection集成了较完备的hook系统,如果进一步读懂底层代码无需这样自己写hook直接调用就行
import mmcv
import torch
from mmcv.runner import load_checkpoint
from mmdet.models import build_detector
from mmdet.apis import inference_detector, show_result
import ipdb
def roialign_forward(module,input,output):
print('\n\ninput:')
print(input[0].shape,'\n',input[1].shape)
print('\n\noutput:')
print(output.shape)
# print(type(input))
if __name__ == '__main__':
params=[]
def hook(module,input):
# print('breakpoint')
params.append(input)
# print(input[0].shape)
# data=input
cfg = mmcv.Config.fromfile('configs/faster_rcnn_r50_fpn_1x.py')
cfg.model.pretrained = None
torch.cuda.empty_cache()
# ipdb.set_trace()
# construct the model and load checkpoint
model = build_detector(cfg.model, test_cfg=cfg.test_cfg)
print(model)
handle=model.backbone.conv1.register_forward_pre_hook(hook)
# model.bbox_roi_extractor.roi_layers[0].register_forward_hook(roialign_forward)
_ = load_checkpoint(model, 'weights/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth')
# test a single image
img= mmcv.imread('/py/pic/2.jpg')
result = inference_detector(model, img, cfg)
# print(params)
show_result(img, result)
handle.remove()
# # test a list of images
# imgs = ['/py/pic/4.jpg', '/py/pic/5.jpg']
# for i, result in enumerate(inference_detector(model, imgs, cfg, device='cuda:0')):
# print(i, imgs[i])
# show_result(imgs[i], result)