Skip to content

Commit 51c9988

Browse files
committed
train and test on HEVI is ready
1 parent b6817e0 commit 51c9988

9 files changed

+218
-17
lines changed

README.md

+26-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,30 @@
1-
# Spatio-Temporal Anomaly Detection in First-Person Videos
1+
# Unsupervised Traffic Accident Detection in First-Person Videos
22

3-
Created by Yu Yao and Mingze Xu for IROS 2019
3+
*Yu Yao, Mingze Xu, Yuchen Wang, David Crandall and Ella Atkins*
4+
5+
This repo contains the code for our [paper](https://arxiv.org/pdf/1903.00618.pdf) on unsupervised traffic accident detection.
6+
7+
:boom: The full code will be released upon the acceptance of our paper.
8+
9+
:boom: So far we have released the pytorch implementation of our ICRA paper [*Egocentric Vision-based Future Vehicle Localization for Intelligent Driving Assistance Systems*](https://arxiv.org/pdf/1809.07408.pdf), which is an important building block for the traffic accident detection. The original project repo is https://github.com/MoonBlvd/fvl-ICRA2019
10+
11+
## Future Object Localization
12+
To train the model, run:
13+
14+
python train_fol.py --load_config YOUR_CONFIG_FILE
15+
16+
To test the model, run:
17+
18+
python test_fol.py --load_config YOUR_CONFIG_FILE
19+
20+
An example of the config file can be found in ```config/fol_ego_train.yaml```
21+
22+
#### evaluation result
23+
Note that we have only evaluated the model performance with prediction horizon 0.5 seconds. We are working on proving the 1 second and 2 seconds results.
24+
25+
| Model | pred horizon | FDE | ADE | FIOU |
26+
|:--------------:|--------------|------|-----|------|
27+
| FOL + Ego pred | 0.5 sec | 10.9 | 6.6 | 0.95 |
428

529
## Run detection
630
Go to Mask-RCNN root directory run:

config/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import yaml
22
import argparse
33

4+
# class Config()
5+
6+
47
def visualize_config(args):
58
"""
69
Visualize the configuration on the terminal to check the state

config/fol_ego_test.yaml

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Directories arguments
2+
data_root: "/media/DATA/HEVI_dataset/fol_data"
3+
ego_data_root: "/media/DATA/HEVI_dataset/ego_motion"
4+
checkpoint_dir: "checkpoints/fol_ego_checkpoints"
5+
6+
best_fol_model: 'checkpoints/fol_ego_checkpoints/fol_epoch_078_loss_0.0013.pt'
7+
best_ego_pred_model: 'checkpoints/fol_ego_checkpoints/ego_pred_epoch_078_loss_0.0016.pt'
8+
9+
test_dataset: "taiwan_sa" #"A3D" #"taiwan_sa"
10+
test_root: #"../data/taiwan_sa/testing" #"/media/DATA/A3D" #"/media/DATA/VAD_datasets/taiwan_sa/testing" AnAnAccident_Detection_Dataset
11+
label_file: '../data/A3D/A3D_labels.pkl'
12+
13+
# dataset arguments
14+
seed_max: 5
15+
segment_len: 16
16+
17+
# training parameters
18+
nb_fol_epoch: 100
19+
nb_ego_pred_epoch: 200
20+
lr: 0.0001
21+
22+
lambda_fol: 1
23+
lambda_ego: 1
24+
device: 'cuda'
25+
26+
# fol model parameters
27+
pred_timesteps: 5
28+
input_embed_size: 512
29+
flow_enc_size: 512
30+
box_enc_size: 512
31+
with_ego: True
32+
33+
enc_hidden_size: 512 # no use
34+
enc_concat_type: "average"
35+
predictor_input_size: 512
36+
dec_hidden_size: 512
37+
pred_dim: 4
38+
39+
40+
# ego_pred model parameters
41+
ego_embed_size: 128
42+
ego_enc_size: 128
43+
ego_dec_size: 128
44+
ego_pred_input_size: 128
45+
ego_dim: 3
46+
47+
# dataloader parameters
48+
batch_size: 1
49+
shuffle: False
50+
num_workers: 0
51+
52+
# image parameters
53+
H: 720
54+
W: 1280
55+
channels: 3
56+
57+
flow_roi_size: [5,5,2]
58+
59+
# Anomaly detection parameters
60+
max_age: 10

config/fol_config.yaml config/fol_ego_train.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Directories arguments
22
data_root: "/media/DATA/HEVI_dataset/fol_data"
33
ego_data_root: "/media/DATA/HEVI_dataset/ego_motion"
4-
checkpoint_dir: "/home/brianyao/Documents/stad2019iros-pytorch/checkpoints/fol_ego_checkpoints"
4+
checkpoint_dir: "checkpoints/fol_ego_checkpoints"
55

6-
best_ego_pred_model: "/home/brianyao/Documents/stad2019iros-pytorch/checkpoints/ego_pred_checkpoints/epoch_080_loss_0.001.pt"
6+
best_ego_pred_model: "checkpoints/ego_pred_checkpoints/epoch_080_loss_0.001.pt"
77
test_dataset: "taiwan_sa" #"A3D" #"taiwan_sa"
88
test_root: #"../data/taiwan_sa/testing" #"/media/DATA/A3D" #"/media/DATA/VAD_datasets/taiwan_sa/testing" AnAnAccident_Detection_Dataset
99
label_file: '../data/A3D/A3D_labels.pkl'

lib/utils/data_prep_utils.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@
66
import copy
77
import pickle as pkl
88

9+
def cxcywh_to_x1y1x2y2(boxes):
10+
'''
11+
Params:
12+
boxes:(Cx, Cy, w, h)
13+
Returns:
14+
(x1, y1, x2, y2 or tlbr
15+
'''
16+
new_boxes = np.zeros_like(boxes)
17+
new_boxes[...,0] = boxes[...,0] - boxes[...,2]/2
18+
new_boxes[...,1] = boxes[...,1] - boxes[...,3]/2
19+
new_boxes[...,2] = boxes[...,0] + boxes[...,2]/2
20+
new_boxes[...,3] = boxes[...,1] + boxes[...,3]/2
21+
return new_boxes
922

1023
def bbox_normalize(bbox,W=1280,H=640):
1124
'''
@@ -32,10 +45,10 @@ def bbox_denormalize(bbox,W=1280,H=640):
3245
bbox: [cx, cy, w, h] with size (times, 4), value from 0 to W or H
3346
'''
3447
new_bbox = copy.deepcopy(bbox)
35-
new_bbox[:,0] *= W
36-
new_bbox[:,1] *= H
37-
new_bbox[:,2] *= W
38-
new_bbox[:,3] *= H
48+
new_bbox[..., 0] *= W
49+
new_bbox[..., 1] *= H
50+
new_bbox[..., 2] *= W
51+
new_bbox[..., 3] *= H
3952

4053
return new_bbox
4154

lib/utils/fol_dataloader.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
class HEVIDataset(data.Dataset):
1212
def __init__(self, args, phase):
1313
'''
14+
HEV-I dataset object. Contains bbox, flow and ego motion.
15+
1416
Params:
1517
args: arguments passed from main file
1618
phase: 'train' or 'val'

lib/utils/train_val_utils.py

-4
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ def train_fol(epoch, model, optimizer, train_gen, verbose=True):
3232
optimizer.step()
3333

3434
#write summery for tensorboardX
35-
# writer.add_scalar('data/train_loss', object_pred_loss, n_iters)
36-
# n_iters += 1
3735
if verbose and batch_idx % 100 == 0:
3836
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
3937
epoch, batch_idx * len(data), len(train_gen.dataset),
@@ -228,8 +226,6 @@ def val_fol_ego(epoch, args, fol_model, ego_pred_model, val_gen, verbose=True):
228226
fol_loss += rmse_loss_fol(fol_predictions, target_bbox).item()
229227
ego_pred_loss += rmse_loss_fol(ego_predictions, target_ego_motion).item()
230228

231-
# total_val_loss = fol_loss + ego_pred_loss
232-
# avg_val_loss = total_val_loss/len(val_gen.dataset)
233229
fol_loss /= len(val_gen.dataset)
234230
ego_pred_loss /= len(val_gen.dataset)
235231
avg_val_loss = fol_loss + ego_pred_loss

test_fol.py

+106-3
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,119 @@
22
import os
33
import numpy as np
44
import time
5+
from tqdm import tqdm
56

67
import torch
78
from torch import nn, optim
8-
from torch.nn import functional as F
99
from torch.utils import data
1010
from torchsummaryX import summary
1111

12-
from lib.utils.train_val_utils import train_fol_ego, val_fol_ego
12+
from lib.utils.train_val_utils import val_fol_ego
1313
from lib.models.rnn_ed import FolRNNED, EgoRNNED
1414
from lib.utils.fol_dataloader import HEVIDataset
1515
from config.config import *
1616
from lib.ego_motion_tracker import EgoTracker
17-
from lib.object_tracker import ObjTracker, AllTrackers
17+
from lib.object_tracker import ObjTracker, AllTrackers
18+
from lib.utils.data_prep_utils import bbox_denormalize, cxcywh_to_x1y1x2y2
19+
from lib.utils.eval_utils import compute_IOU
20+
21+
print("Cuda available: ", torch.cuda.is_available())
22+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23+
24+
25+
def test_fol_ego(fol_model, ego_pred_model, test_gen):
26+
'''
27+
Validate future vehicle localization module
28+
Params:
29+
fol_model: The fol model as nn.Module
30+
ego_pred_model: the ego motion prediction model as nn.Module
31+
test_gen: test data generator
32+
Returns:
33+
34+
'''
35+
fol_model.eval() # Sets the module in training mode.
36+
ego_pred_model.eval()
37+
38+
fol_loss = 0
39+
ego_pred_loss = 0
40+
loader = tqdm(test_gen, total=len(test_gen))
41+
42+
FDE = 0
43+
ADE = 0
44+
FIOU = 0
45+
with torch.set_grad_enabled(False):
46+
for batch_idx, data in enumerate(loader):
47+
input_bbox, input_flow, input_ego_motion, target_bbox, target_ego_motion = data
48+
49+
# run forward
50+
ego_predictions = ego_pred_model(input_ego_motion)
51+
fol_predictions = fol_model(input_bbox, input_flow, ego_predictions)
52+
53+
# convert to numpy array, use [0] since batchsize if 1 for test
54+
# the prediction is the box changes
55+
ego_predictions = ego_predictions.to('cpu').numpy()[0]
56+
fol_predictions = fol_predictions.to('cpu').numpy()[0]
57+
input_bbox = input_bbox.to('cpu').numpy()[0]
58+
target_bbox = target_bbox.to('cpu').numpy()[0]
59+
60+
# compute FDE, ADE and FIOU metrics used in FVL2019ICRA paper
61+
input_bbox = np.expand_dims(input_bbox, axis=1)
62+
target_bbox = input_bbox + target_bbox
63+
fol_predictions = input_bbox + fol_predictions
64+
65+
input_bbox = bbox_denormalize(input_bbox, W=1280, H=640)
66+
fol_predictions = bbox_denormalize(fol_predictions, W=1280, H=640)
67+
target_bbox = bbox_denormalize(target_bbox, W=1280, H=640)
68+
69+
fol_predictions_xyxy = cxcywh_to_x1y1x2y2(fol_predictions)
70+
target_bbox_xyxy = cxcywh_to_x1y1x2y2(target_bbox)
71+
72+
# print(fol_predictions_xyxy[15,...])
73+
# print(target_bbox_xyxy[15,...])
74+
75+
ADE += np.mean(np.sqrt(np.sum((target_bbox_xyxy[:,:,:2] - fol_predictions_xyxy[:,:,:2]) ** 2, axis=-1)))
76+
FDE += np.mean(np.sqrt(np.sum((target_bbox_xyxy[:,-1,:2] - fol_predictions_xyxy[:,-1,:2]) ** 2, axis=-1)))
77+
tmp_FIOU = []
78+
for i in range(target_bbox_xyxy.shape[0]):
79+
tmp_FIOU.append(compute_IOU(target_bbox_xyxy[i,-1,:], fol_predictions_xyxy[i,-1,:]))
80+
FIOU += np.mean(tmp_FIOU)
81+
print("FDE: %4f; ADE: %4f; FIOU: %4f" % (FDE, ADE, FIOU))
82+
ADE /= len(test_gen.dataset)
83+
FDE /= len(test_gen.dataset)
84+
FIOU /= len(test_gen.dataset)
85+
print("FDE: %4f; ADE: %4f; FIOU: %4f" % (FDE, ADE, FIOU))
86+
87+
def main(args):
88+
# initialize model
89+
fol_model = FolRNNED(args).to(device)
90+
fol_model.load_state_dict(torch.load(args.best_fol_model))
91+
92+
if args.with_ego:
93+
print("Initializing pre-trained ego motion predictor...")
94+
ego_pred_model = EgoRNNED(args).to(device)
95+
ego_pred_model.load_state_dict(torch.load(args.best_ego_pred_model))
96+
print("Pre-trained ego_motion predictor done!")
97+
98+
# initialize datasets
99+
print("Initializing test dataset...")
100+
dataloader_params ={
101+
"batch_size": args.batch_size,
102+
"shuffle": args.shuffle,
103+
"num_workers": args.num_workers
104+
}
105+
106+
test_set = HEVIDataset(args, 'val')
107+
print("Number of test samples:", test_set.__len__())
108+
test_gen = data.DataLoader(test_set, **dataloader_params)
109+
110+
input_bbox, input_flow, input_ego_motion, target_bbox, target_ego_motion = test_set.__getitem__(1)
111+
print("input shape: ", input_bbox.shape)
112+
print("target shape: ", target_bbox.shape)
113+
114+
test_fol_ego(fol_model, ego_pred_model, test_gen)
115+
116+
117+
if __name__=='__main__':
118+
# load args
119+
args = parse_args()
120+
main(args)

train_fol.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
# print("Number of training samples:", train_set.__len__())
4848

4949
val_set = HEVIDataset(args, 'val')
50-
val_gen = data.DataLoader(val_set, **dataloader_params)
5150
print("Number of validation samples:", val_set.__len__())
51+
val_gen = data.DataLoader(val_set, **dataloader_params)
5252

5353
# print model summary
5454
if args.with_ego:
@@ -120,7 +120,7 @@
120120
print("Saving checkpoints: " + saved_fol_model_name + ' and ' + saved_ego_pred_model_name)
121121
if not os.path.isdir(args.checkpoint_dir):
122122
os.mkdir(args.checkpoint_dir)
123-
123+
124124
torch.save(fol_model.state_dict(), os.path.join(args.checkpoint_dir, saved_fol_model_name))
125125
torch.save(ego_pred_model.state_dict(), os.path.join(args.checkpoint_dir, saved_ego_pred_model_name))
126126

0 commit comments

Comments
 (0)