forked from zengwb-lx/Yolov5-Deepsort-Fastreid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
person_count.py
244 lines (199 loc) · 9.73 KB
/
person_count.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/1/18
# @Author : zengwb
import os
import cv2
import torch
import warnings
import argparse
import numpy as np
import onnxruntime as ort
from utils.datasets import LoadStreams, LoadImages
from utils.draw import draw_boxes
from utils.general import check_img_size
from utils.torch_utils import time_synchronized
from person_detect_yolov5 import Person_detect
from deep_sort import build_tracker
from utils.parser import get_config
from utils.log import get_logger
from utils.torch_utils import select_device, load_classifier, time_synchronized
# count
from collections import Counter
from collections import deque
import math
from PIL import Image, ImageDraw, ImageFont
def tlbr_midpoint(box):
minX, minY, maxX, maxY = box
midpoint = (int((minX + maxX) / 2), int((minY + maxY) / 2)) # minus y coordinates to get proper xy format
return midpoint
def intersect(A, B, C, D):
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
def ccw(A, B, C):
return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])
def vector_angle(midpoint, previous_midpoint):
x = midpoint[0] - previous_midpoint[0]
y = midpoint[1] - previous_midpoint[1]
return math.degrees(math.atan2(y, x))
def get_size_with_pil(label,size=25):
font = ImageFont.truetype("./configs/simkai.ttf", size, encoding="utf-8") # simhei.ttf
return font.getsize(label)
#为了支持中文,用pil
def put_text_to_cv2_img_with_pil(cv2_img,label,pt,color):
pil_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB) # cv2和PIL中颜色的hex码的储存顺序不同,需转RGB模式
pilimg = Image.fromarray(pil_img) # Image.fromarray()将数组类型转成图片格式,与np.array()相反
draw = ImageDraw.Draw(pilimg) # PIL图片上打印汉字
font = ImageFont.truetype("./configs/simkai.ttf", 25, encoding="utf-8") #simhei.ttf
draw.text(pt, label, color,font=font)
return cv2.cvtColor(np.array(pilimg), cv2.COLOR_RGB2BGR) # 将图片转成cv2.imshow()可以显示的数组格式
colors = np.array([
[1,0,1],
[0,0,1],
[0,1,1],
[0,1,0],
[1,1,0],
[1,0,0]
]);
def get_color(c, x, max):
ratio = (x / max) * 5;
i = math.floor(ratio);
j = math.ceil(ratio);
ratio -= i;
r = (1 - ratio) * colors[i][c] + ratio * colors[j][c];
return r;
def compute_color_for_labels(class_id,class_total=80):
offset = (class_id + 0) * 123457 % class_total;
red = get_color(2, offset, class_total);
green = get_color(1, offset, class_total);
blue = get_color(0, offset, class_total);
return (int(red*256),int(green*256),int(blue*256))
class yolo_reid():
def __init__(self, cfg, args, path):
self.logger = get_logger("root")
self.args = args
self.video_path = path
use_cuda = args.use_cuda and torch.cuda.is_available()
if not use_cuda:
warnings.warn("Running in cpu mode which maybe very slow!", UserWarning)
self.person_detect = Person_detect(self.args, self.video_path)
imgsz = check_img_size(args.img_size, s=32) # self.model.stride.max()) # check img_size
self.dataset = LoadImages(self.video_path, img_size=imgsz)
self.deepsort = build_tracker(cfg, args.sort, use_cuda=use_cuda)
def deep_sort(self):
idx_frame = 0
results = []
paths = {}
track_cls = 0
last_track_id = -1
total_track = 0
angle = -1
total_counter = 0
up_count = 0
down_count = 0
class_counter = Counter() # store counts of each detected class
already_counted = deque(maxlen=50) # temporary memory for storing counted IDs
for video_path, img, ori_img, vid_cap in self.dataset:
idx_frame += 1
# print('aaaaaaaa', video_path, img.shape, im0s.shape, vid_cap)
t1 = time_synchronized()
# yolo detection
bbox_xywh, cls_conf, cls_ids, xy = self.person_detect.detect(video_path, img, ori_img, vid_cap)
# do tracking
outputs = self.deepsort.update(bbox_xywh, cls_conf, ori_img)
# 1.视频中间画行黄线
line = [(0, int(0.48 * ori_img.shape[0])), (int(ori_img.shape[1]), int(0.48 * ori_img.shape[0]))]
cv2.line(ori_img, line[0], line[1], (0, 255, 255), 4)
# 2. 统计人数
for track in outputs:
bbox = track[:4]
track_id = track[-1]
midpoint = tlbr_midpoint(bbox)
origin_midpoint = (midpoint[0], ori_img.shape[0] - midpoint[1]) # get midpoint respective to botton-left
if track_id not in paths:
paths[track_id] = deque(maxlen=2)
total_track = track_id
paths[track_id].append(midpoint)
previous_midpoint = paths[track_id][0]
origin_previous_midpoint = (previous_midpoint[0], ori_img.shape[0] - previous_midpoint[1])
if intersect(midpoint, previous_midpoint, line[0], line[1]) and track_id not in already_counted:
class_counter[track_cls] += 1
total_counter += 1
last_track_id = track_id;
# draw red line
cv2.line(ori_img, line[0], line[1], (0, 0, 255), 10)
already_counted.append(track_id) # Set already counted for ID to true.
angle = vector_angle(origin_midpoint, origin_previous_midpoint)
if angle > 0:
up_count += 1
if angle < 0:
down_count += 1
if len(paths) > 50:
del paths[list(paths)[0]]
# 3. 绘制人员
if len(outputs) > 0:
bbox_tlwh = []
bbox_xyxy = outputs[:, :4]
identities = outputs[:, -1]
ori_img = draw_boxes(ori_img, bbox_xyxy, identities)
for bb_xyxy in bbox_xyxy:
bbox_tlwh.append(self.deepsort._xyxy_to_tlwh(bb_xyxy))
# results.append((idx_frame - 1, bbox_tlwh, identities))
print("yolo+deepsort:", time_synchronized() - t1)
# 4. 绘制统计信息
label = "客流总数: {}".format(str(total_track))
t_size = get_size_with_pil(label, 25)
x1 = 20
y1 = 50
color = compute_color_for_labels(2)
cv2.rectangle(ori_img, (x1 - 1, y1), (x1 + t_size[0] + 10, y1 - t_size[1]), color, 2)
ori_img = put_text_to_cv2_img_with_pil(ori_img, label, (x1 + 5, y1 - t_size[1] - 2), (0, 0, 0))
label = "穿过黄线人数: {} ({} 向上, {} 向下)".format(str(total_counter), str(up_count), str(down_count))
t_size = get_size_with_pil(label, 25)
x1 = 20
y1 = 100
color = compute_color_for_labels(2)
cv2.rectangle(ori_img, (x1 - 1, y1), (x1 + t_size[0] + 10, y1 - t_size[1]), color, 2)
ori_img = put_text_to_cv2_img_with_pil(ori_img, label, (x1 + 5, y1 - t_size[1] - 2), (0, 0, 0))
if last_track_id >= 0:
label = "最新: 行人{}号{}穿过黄线".format(str(last_track_id), str("向上") if angle >= 0 else str('向下'))
t_size = get_size_with_pil(label, 25)
x1 = 20
y1 = 150
color = compute_color_for_labels(2)
cv2.rectangle(ori_img, (x1 - 1, y1), (x1 + t_size[0] + 10, y1 - t_size[1]), color, 2)
ori_img = put_text_to_cv2_img_with_pil(ori_img, label, (x1 + 5, y1 - t_size[1] - 2), (255, 0, 0))
end = time_synchronized()
if self.args.display:
cv2.imshow("test", ori_img)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
self.logger.info("{}/time: {:.03f}s, fps: {:.03f}, detection numbers: {}, tracking numbers: {}" \
.format(idx_frame, end - t1, 1 / (end - t1),
bbox_xywh.shape[0], len(outputs)))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--video_path", default='./MOT16-03.mp4', type=str)
parser.add_argument("--camera", action="store", dest="cam", type=int, default="-1")
parser.add_argument('--device', default='cuda:0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
# yolov5
parser.add_argument('--weights', nargs='+', type=str, default='./weights/yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--img-size', type=int, default=960, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
parser.add_argument('--classes', default=[0], type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
# deep_sort
parser.add_argument("--sort", default=True, help='True: sort model, False: reid model')
parser.add_argument("--config_deepsort", type=str, default="./configs/deep_sort.yaml")
parser.add_argument("--display", default=True, help='show resule')
parser.add_argument("--frame_interval", type=int, default=1)
parser.add_argument("--cpu", dest="use_cuda", action="store_false", default=True)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
cfg = get_config()
cfg.merge_from_file(args.config_deepsort)
yolo_reid = yolo_reid(cfg, args, path=args.video_path)
with torch.no_grad():
yolo_reid.deep_sort()