forked from skhadem/3D-BoundingBox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Run.py
113 lines (87 loc) · 3.63 KB
/
Run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
This script will use the 2D box from the label rather than from YOLO,
but will still use the neural nets to get the 3D position and plot onto the
image. Press space for next image and escape to quit
"""
from torch_lib.Dataset import *
from library.Math import *
from library.Plotting import *
from torch_lib import Model, ClassAverages
import os
import cv2
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.models import vgg
import numpy as np
# to run car by car
single_car = False
def plot_regressed_3d_bbox(img, truth_img, cam_to_img, box_2d, dimensions, alpha, theta_ray):
# the math! returns X, the corners used for constraint
location, X = calc_location(dimensions, cam_to_img, box_2d, alpha, theta_ray)
orient = alpha + theta_ray
plot_2d_box(truth_img, box_2d)
plot_3d_box(img, cam_to_img, orient, dimensions, location) # 3d boxes
return location
def main():
weights_path = os.path.abspath(os.path.dirname(__file__)) + '/weights'
model_lst = [x for x in sorted(os.listdir(weights_path)) if x.endswith('.pkl')]
if len(model_lst) == 0:
print('No previous model found, please train first!')
exit()
else:
print ('Using previous model %s'%model_lst[-1])
my_vgg = vgg.vgg19_bn(pretrained=True)
model = Model.Model(features=my_vgg.features, bins=2).cuda()
checkpoint = torch.load(weights_path + '/%s'%model_lst[-1])
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# defaults to /eval
dataset = Dataset(os.path.abspath(os.path.dirname(__file__)) + '/eval')
averages = ClassAverages.ClassAverages()
all_images = dataset.all_objects()
for key in sorted(all_images.keys()):
start_time = time.time()
data = all_images[key]
truth_img = data['Image']
img = np.copy(truth_img)
objects = data['Objects']
cam_to_img = data['Calib']
for detectedObject in objects:
label = detectedObject.label
theta_ray = detectedObject.theta_ray
input_img = detectedObject.img
input_tensor = torch.zeros([1,3,224,224]).cuda()
input_tensor[0,:,:,:] = input_img
input_tensor.cuda()
[orient, conf, dim] = model(input_tensor)
orient = orient.cpu().data.numpy()[0, :, :]
conf = conf.cpu().data.numpy()[0, :]
dim = dim.cpu().data.numpy()[0, :]
dim += averages.get_item(label['Class'])
argmax = np.argmax(conf)
orient = orient[argmax, :]
cos = orient[0]
sin = orient[1]
alpha = np.arctan2(sin, cos)
alpha += dataset.angle_bins[argmax]
alpha -= np.pi
location = plot_regressed_3d_bbox(img, truth_img, cam_to_img, label['Box_2D'], dim, alpha, theta_ray)
print('Estimated pose: %s'%location)
print('Truth pose: %s'%label['Location'])
print('-------------')
# plot car by car
if single_car:
numpy_vertical = np.concatenate((truth_img, img), axis=0)
cv2.imshow('2D detection on top, 3D prediction on bottom', numpy_vertical)
cv2.waitKey(0)
print('Got %s poses in %.3f seconds\n'%(len(objects), time.time() - start_time))
# plot image by image
if not single_car:
numpy_vertical = np.concatenate((truth_img, img), axis=0)
cv2.imshow('2D detection on top, 3D prediction on bottom', numpy_vertical)
if cv2.waitKey(0) == 27:
return
if __name__ == '__main__':
main()