forked from zengwb-lx/Yolov5-Deepsort-Fastreid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
person_search_reid.py
122 lines (99 loc) · 4.94 KB
/
person_search_reid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time : 2021/1/18
# @Author : zengwb
import os
import cv2
import numpy as np
import torch
import warnings
import argparse
import onnxruntime as ort
from utils.datasets import LoadStreams, LoadImages
from utils.draw import draw_boxes, draw_person
from utils.general import check_img_size
from utils.torch_utils import time_synchronized
from person_detect_yolov5 import Person_detect
from deep_sort import build_tracker
from utils.parser import get_config
from utils.log import get_logger
from utils.torch_utils import select_device, load_classifier, time_synchronized
from sklearn.metrics.pairwise import cosine_similarity
class yolo_reid():
def __init__(self, cfg, args, path):
self.args = args
self.video_path = path
use_cuda = args.use_cuda and torch.cuda.is_available()
if not use_cuda:
warnings.warn("Running in cpu mode which maybe very slow!", UserWarning)
# Person_detect行人检测类
self.person_detect = Person_detect(self.args, self.video_path)
# deepsort 类
self.deepsort = build_tracker(cfg, args.sort, use_cuda=use_cuda)
imgsz = check_img_size(args.img_size, s=32) # self.model.stride.max()) # check img_size
self.dataset = LoadImages(self.video_path, img_size=imgsz)
self.query_feat = np.load(args.query)
self.names = np.load(args.names)
def deep_sort(self):
idx_frame = 0
results = []
for video_path, img, ori_img, vid_cap in self.dataset:
idx_frame += 1
# print('aaaaaaaa', video_path, img.shape, im0s.shape, vid_cap)
t1 = time_synchronized()
# yolo detection
bbox_xywh, cls_conf, cls_ids, xy = self.person_detect.detect(video_path, img, ori_img, vid_cap)
# do tracking # features:reid模型输出512dim特征
outputs, features = self.deepsort.update(bbox_xywh, cls_conf, ori_img)
print(len(outputs), len(bbox_xywh), features.shape)
person_cossim = cosine_similarity(features, self.query_feat)
max_idx = np.argmax(person_cossim, axis=1)
maximum = np.max(person_cossim, axis=1)
max_idx[maximum < 0.6] = -1
score = maximum
reid_results = max_idx
draw_person(ori_img, xy, reid_results, self.names) # draw_person name
# print(features.shape, self.query_feat.shape, person_cossim.shape, features[1].shape)
if len(outputs) > 0:
bbox_tlwh = []
bbox_xyxy = outputs[:, :4]
identities = outputs[:, -1]
ori_im = draw_boxes(ori_img, bbox_xyxy, identities)
for bb_xyxy in bbox_xyxy:
bbox_tlwh.append(self.deepsort._xyxy_to_tlwh(bb_xyxy))
# results.append((idx_frame - 1, bbox_tlwh, identities))
# print("yolo+deepsort:", time_synchronized() - t1)
if self.args.display:
cv2.imshow("test", ori_img)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--video_path", default='./MOT16-03.mp4', type=str)
parser.add_argument("--camera", action="store", dest="cam", type=int, default="-1")
parser.add_argument('--device', default='cuda:0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
# yolov5
parser.add_argument('--weights', nargs='+', type=str, default='./weights/yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--img-size', type=int, default=1080, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
parser.add_argument('--classes', default=[0], type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
# deep_sort
parser.add_argument("--sort", default=False, help='True: sort model or False: reid model')
parser.add_argument("--config_deepsort", type=str, default="./configs/deep_sort.yaml")
parser.add_argument("--display", default=True, help='show resule')
parser.add_argument("--frame_interval", type=int, default=1)
parser.add_argument("--cpu", dest="use_cuda", action="store_false", default=True)
# reid
parser.add_argument("--query", type=str, default="./fast_reid/query/query_features.npy")
parser.add_argument("--names", type=str, default="./fast_reid/query/names.npy")
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
cfg = get_config()
cfg.merge_from_file(args.config_deepsort)
yolo_reid = yolo_reid(cfg, args, path=args.video_path)
with torch.no_grad():
yolo_reid.deep_sort()