forked from Qidian213/deep_sort_yolov3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
131 lines (99 loc) · 3.87 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
from timeit import time
import warnings
import cv2
import numpy as np
from PIL import Image
from yolo import YOLO
from deep_sort import preprocessing
from deep_sort import nn_matching
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from tools import generate_detections as gdet
import imutils.video
from videocaptureasync import VideoCaptureAsync
warnings.filterwarnings('ignore')
def main(yolo):
# Definition of the parameters
max_cosine_distance = 0.3
nn_budget = None
nms_max_overlap = 1.0
# Deep SORT
model_filename = 'model_data/mars-small128.pb'
encoder = gdet.create_box_encoder(model_filename,batch_size=1)
metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
tracker = Tracker(metric)
writeVideo_flag = True
asyncVideo_flag = False
file_path = 'video.webm'
if asyncVideo_flag :
video_capture = VideoCaptureAsync(file_path)
else:
video_capture = cv2.VideoCapture(file_path)
if asyncVideo_flag:
video_capture.start()
if writeVideo_flag:
if asyncVideo_flag:
w = int(video_capture.cap.get(3))
h = int(video_capture.cap.get(4))
else:
w = int(video_capture.get(3))
h = int(video_capture.get(4))
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('output_yolov3.avi', fourcc, 30, (w, h))
frame_index = -1
fps = 0.0
fps_imutils = imutils.video.FPS().start()
while True:
ret, frame = video_capture.read() # frame shape 640*480*3
if ret != True:
break
t1 = time.time()
image = Image.fromarray(frame[...,::-1]) # bgr to rgb
boxs = yolo.detect_image(image)[0]
confidence = yolo.detect_image(image)[1]
features = encoder(frame,boxs)
detections = [Detection(bbox, confidence, feature) for bbox, confidence, feature in zip(boxs, confidence, features)]
# Run non-maxima suppression.
boxes = np.array([d.tlwh for d in detections])
scores = np.array([d.confidence for d in detections])
indices = preprocessing.non_max_suppression(boxes, nms_max_overlap, scores)
detections = [detections[i] for i in indices]
# Call the tracker
tracker.predict()
tracker.update(detections)
for track in tracker.tracks:
if not track.is_confirmed() or track.time_since_update > 1:
continue
bbox = track.to_tlbr()
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])),(255,255,255), 2)
cv2.putText(frame, str(track.track_id),(int(bbox[0]), int(bbox[1])),0, 5e-3 * 200, (0,255,0),2)
for det in detections:
bbox = det.to_tlbr()
score = "%.2f" % round(det.confidence * 100, 2)
cv2.rectangle(frame,(int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])),(255,0,0), 2)
cv2.putText(frame, score + '%', (int(bbox[0]), int(bbox[3])), 0, 5e-3 * 130, (0,255,0),2)
cv2.imshow('', frame)
if writeVideo_flag: # and not asyncVideo_flag:
# save a frame
out.write(frame)
frame_index = frame_index + 1
fps_imutils.update()
fps = (fps + (1./(time.time()-t1))) / 2
print("FPS = %f"%(fps))
# Press Q to stop!
if cv2.waitKey(1) & 0xFF == ord('q'):
break
fps_imutils.stop()
print('imutils FPS: {}'.format(fps_imutils.fps()))
if asyncVideo_flag:
video_capture.stop()
else:
video_capture.release()
if writeVideo_flag:
out.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main(YOLO())