forked from Zardinality/TF_Deformable_Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
138 lines (113 loc) · 4.3 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
import cv2
import matplotlib.pyplot as plt
from ..utils.blob import im_list_to_blob
from ..utils.timer import Timer
# TODO: make fast_rcnn irrelevant
# >>>> obsolete, because it depends on sth outside of this project
from ..fast_rcnn.config import cfg
# <<<< obsolete
def _vis_proposals(im, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
class_name = 'obj'
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
ax.set_title(('{} detections with '
'p({} | box) >= {:.1f}').format(class_name, class_name,
thresh),
fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.draw()
def _get_image_blob(im):
"""Converts an image into a network input.
Arguments:
im (ndarray): a color image in BGR order
Returns:
blob (ndarray): a data blob holding an image pyramid
im_scale_factors (list): list of image scales (relative to im) used
in the image pyramid
"""
im_orig = im.astype(np.float32, copy=True)
im_orig -= cfg.PIXEL_MEANS
processed_ims = []
assert len(cfg.TEST.SCALES_BASE) == 1
im_scale = cfg.TRAIN.SCALES_BASE[0]
im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_LINEAR)
im_info = np.hstack((im.shape[:2], im_scale))[np.newaxis, :]
processed_ims.append(im)
# Create a blob to hold the input images
blob = im_list_to_blob(processed_ims)
return blob, im_info
def im_proposals(net, im):
"""Generate RPN proposals on a single image."""
blobs = {}
blobs['data'], blobs['im_info'] = _get_image_blob(im)
net.blobs['data'].reshape(*(blobs['data'].shape))
net.blobs['im_info'].reshape(*(blobs['im_info'].shape))
blobs_out = net.forward(
data=blobs['data'].astype(np.float32, copy=False),
im_info=blobs['im_info'].astype(np.float32, copy=False))
scale = blobs['im_info'][0, 2]
boxes = blobs_out['rois'][:, 1:].copy() / scale
scores = blobs_out['scores'].copy()
return boxes, scores
def imdb_proposals(net, imdb):
"""Generate RPN proposals on all images in an imdb."""
_t = Timer()
imdb_boxes = [[] for _ in range(imdb.num_images)]
for i in range(imdb.num_images):
im = cv2.imread(imdb.image_path_at(i))
_t.tic()
imdb_boxes[i], scores = im_proposals(net, im)
_t.toc()
print('im_proposals: {:d}/{:d} {:.3f}s' \
.format(i + 1, imdb.num_images, _t.average_time))
if 0:
dets = np.hstack((imdb_boxes[i], scores))
# from IPython import embed; embed()
_vis_proposals(im, dets[:3, :], thresh=0.9)
plt.show()
return imdb_boxes
def imdb_proposals_det(net, imdb):
"""Generate RPN proposals on all images in an imdb."""
_t = Timer()
imdb_boxes = [[] for _ in range(imdb.num_images)]
for i in range(imdb.num_images):
im = cv2.imread(imdb.image_path_at(i))
_t.tic()
boxes, scores = im_proposals(net, im)
_t.toc()
print('im_proposals: {:d}/{:d} {:.3f}s' \
.format(i + 1, imdb.num_images, _t.average_time))
dets = np.hstack((boxes, scores))
imdb_boxes[i] = dets
if 0:
# from IPython import embed; embed()
_vis_proposals(im, dets[:3, :], thresh=0.9)
plt.show()
return imdb_boxes