Skip to content

Commit

Permalink
1. add sweeps support for nuscenes
Browse files Browse the repository at this point in the history
2. add exp decay and manual step lr scheduler
3. fix some bug.
  • Loading branch information
traveller59 committed Apr 2, 2019
1 parent 95c99a1 commit 4fd6d5f
Show file tree
Hide file tree
Showing 10 changed files with 386 additions and 126 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ Since the dataset is really large, you can download parts of the dataset.

Then run
```bash
python create_data.py nuscenes_data_prep --data_path=NUSCENES_TRAINVAL_DATASET_ROOT --version="v1.0-trainval"
python create_data.py nuscenes_data_prep --data_path=NUSCENES_TEST_DATASET_ROOT --version="v1.0-test"
python create_data.py nuscenes_data_prep --data_path=NUSCENES_TRAINVAL_DATASET_ROOT --version="v1.0-trainval" --max_sweeps=10
python create_data.py nuscenes_data_prep --data_path=NUSCENES_TEST_DATASET_ROOT --version="v1.0-test" --max_sweeps=10
```

* Modify config file
Expand Down
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

3. remove much unused and deprecated code.

4. add two learning rate scheduler: exp decay and manual step

# Release 1.5.1

## Minor Improvements and Bug fixes
Expand Down
184 changes: 108 additions & 76 deletions second/data/nuscenes_dataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from pathlib import Path
import json
import pickle
import time
from functools import partial
from copy import deepcopy
import numpy as np
from functools import partial
from pathlib import Path

import fire
import json
import numpy as np

from second.core import box_np_ops
from second.core import preprocess as prep
from second.data import kitti_common as kitti
from second.utils.eval import get_coco_eval_result, get_official_eval_result
from second.data.dataset import Dataset
from second.utils.eval import get_coco_eval_result, get_official_eval_result
from second.utils.progress_bar import progress_bar_iter as prog_bar


Expand Down Expand Up @@ -120,7 +121,7 @@ def get_sensor_data(self, query):
read_test_image = False
if isinstance(query, dict):
assert "lidar" in query
idx = query["lidar"]["idx"] # currently only for visualization
idx = query["lidar"]["idx"]
read_test_image = "cam" in query

info = self._nusc_infos[idx]
Expand All @@ -133,12 +134,24 @@ def get_sensor_data(self, query):
"token": info["token"]
},
}

lidar_path = Path(info['lidar_path'])
points = np.fromfile(
str(lidar_path), dtype=np.float32, count=-1).reshape([-1,
5])[:, :4]
points[:, -1] /= 255
sweep_points_list = [points]

for sweep in info["sweeps"]:
points_sweep = np.fromfile(
str(sweep["lidar_path"]), dtype=np.float32,
count=-1).reshape([-1, 5])[:, :4]
points_sweep[:, -1] /= 255
points_sweep[:, :3] = points_sweep[:, :3] @ sweep[
"sweep2lidar_rotation"]
points_sweep[:, :3] += sweep["sweep2lidar_translation"]
sweep_points_list.append(points_sweep)
points = np.concatenate(sweep_points_list, axis=0)

if read_test_image:
if Path(info["cam_front_path"]).exists():
with open(str(info["cam_front_path"]), 'rb') as f:
Expand All @@ -148,7 +161,7 @@ def get_sensor_data(self, query):
res["cam"] = {
"type": "camera",
"data": image_str,
"datatype": "jpg",
"datatype": Path(info["cam_front_path"]).suffix[1:],
}

# mask = box_np_ops.points_in_rbbox(points, info["gt_boxes"]).any(-1)
Expand Down Expand Up @@ -290,7 +303,7 @@ def evaluation_nusc(self, detections, output_dir):
res_path = str(Path(output_dir) / "results_nusc.json")
with open(res_path, "w") as f:
json.dump(nusc_annos, f)
del nusc # release memory
del nusc # release memory
from nuscenes.eval.detection.evaluate import main as eval_main
eval_main(
res_path,
Expand Down Expand Up @@ -380,6 +393,7 @@ def _lidar_nusc_box_to_global(nusc, boxes, sample_token):
for box in boxes:
# Move box to ego vehicle coord system
box.rotate(pyquaternion.Quaternion(cs_record['rotation']))

box.translate(np.array(cs_record['translation']))
# Move box to global coord system
box.rotate(pyquaternion.Quaternion(pose_record['rotation']))
Expand Down Expand Up @@ -416,39 +430,100 @@ def _get_available_scenes(nusc):
return available_scenes


def _fill_trainval_infos(nusc, train_scenes, val_scenes, test=False):
def _fill_trainval_infos(nusc,
train_scenes,
val_scenes,
test=False,
max_sweeps=10):
train_nusc_infos = []
val_nusc_infos = []
from pyquaternion import Quaternion
for sample in prog_bar(nusc.sample):
lidar_token = sample["data"]["LIDAR_TOP"]
cam_front_token = sample["data"]["CAM_FRONT"]
sd_rec = nusc.get('sample_data', sample['data']["LIDAR_TOP"])
cs_record = nusc.get('calibrated_sensor',
sd_rec['calibrated_sensor_token'])
pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])

lidar_path, boxes, _ = nusc.get_sample_data(lidar_token)
cam_path, _, cam_intrinsic = nusc.get_sample_data(cam_front_token)
if Path(lidar_path).exists():
info = {
"lidar_path": lidar_path,
"cam_front_path": cam_path,
"token": sample["token"],
}
if not test:
locs = np.array([b.center for b in boxes]).reshape(-1, 3)
dims = np.array([b.wlh for b in boxes]).reshape(-1, 3)
rots = np.array(
[b.orientation.yaw_pitch_roll[0] for b in boxes]).reshape(
-1, 1)
names = np.array([b.name for b in boxes])
gt_boxes = np.concatenate([locs, dims, -rots - np.pi / 2],
axis=1)
info["gt_boxes"] = gt_boxes
info["gt_names"] = names
if sample["scene_token"] in train_scenes:
train_nusc_infos.append(info)
# if Path(lidar_path).exists():
info = {
"lidar_path": lidar_path,
"cam_front_path": cam_path,
"token": sample["token"],
"sweeps": [],
"lidar2ego_translation": cs_record['translation'],
"lidar2ego_rotation": cs_record['rotation'],
"ego2global_translation": pose_record['translation'],
"ego2global_rotation": pose_record['rotation'],
"timestamp": sample["timestamp"],
}

l2e_r = info["lidar2ego_rotation"]
l2e_t = info["lidar2ego_translation"]
e2g_r = info["ego2global_rotation"]
e2g_t = info["ego2global_translation"]
l2e_r_mat = Quaternion(l2e_r).rotation_matrix
e2g_r_mat = Quaternion(e2g_r).rotation_matrix

sd_rec = nusc.get('sample_data', sample['data']["LIDAR_TOP"])
sweeps = []
while len(sweeps) < max_sweeps:
if not sd_rec['prev'] == "":
sd_rec = nusc.get('sample_data', sd_rec['prev'])
cs_record = nusc.get('calibrated_sensor',
sd_rec['calibrated_sensor_token'])
pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])
lidar_path = nusc.get_sample_data_path(sd_rec['token'])
sweep = {
"lidar_path": lidar_path,
"sample_data_token": sd_rec['token'],
"lidar2ego_translation": cs_record['translation'],
"lidar2ego_rotation": cs_record['rotation'],
"ego2global_translation": pose_record['translation'],
"ego2global_rotation": pose_record['rotation'],
"timestamp": sd_rec["timestamp"]
}
l2e_r_s = sweep["lidar2ego_rotation"]
l2e_t_s = sweep["lidar2ego_translation"]
e2g_r_s = sweep["ego2global_rotation"]
e2g_t_s = sweep["ego2global_translation"]
# sweep->ego->global->ego'->lidar
l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix
e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix

R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ (
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ (
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
T -= e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(
l2e_r_mat).T) + l2e_t @ np.linalg.inv(l2e_r_mat).T
sweep["sweep2lidar_rotation"] = R
sweep["sweep2lidar_translation"] = T
sweeps.append(sweep)
else:
val_nusc_infos.append(info)
break
info["sweeps"] = sweeps

if not test:
locs = np.array([b.center for b in boxes]).reshape(-1, 3)
dims = np.array([b.wlh for b in boxes]).reshape(-1, 3)
rots = np.array([b.orientation.yaw_pitch_roll[0]
for b in boxes]).reshape(-1, 1)
names = np.array([b.name for b in boxes])
gt_boxes = np.concatenate([locs, dims, -rots - np.pi / 2], axis=1)
info["gt_boxes"] = gt_boxes
info["gt_names"] = names
if sample["scene_token"] in train_scenes:
train_nusc_infos.append(info)
else:
val_nusc_infos.append(info)
return train_nusc_infos, val_nusc_infos


def create_nuscenes_infos(root_path, version="v1.0-trainval"):
def create_nuscenes_infos(root_path, version="v1.0-trainval", max_sweeps=10):
from nuscenes.nuscenes import NuScenes
nusc = NuScenes(version=version, dataroot=root_path, verbose=True)
from nuscenes.utils import splits
Expand Down Expand Up @@ -487,48 +562,7 @@ def create_nuscenes_infos(root_path, version="v1.0-trainval"):
print(
f"train scene: {len(train_scenes)}, val scene: {len(val_scenes)}")
train_nusc_infos, val_nusc_infos = _fill_trainval_infos(
nusc, train_scenes, val_scenes, test)
if test:
print(f"test sample: {len(train_nusc_infos)}")
with open(root_path / "infos_test.pkl", 'wb') as f:
pickle.dump(train_nusc_infos, f)
else:
print(
f"train sample: {len(train_nusc_infos)}, val sample: {len(val_nusc_infos)}"
)
with open(root_path / "infos_train.pkl", 'wb') as f:
pickle.dump(train_nusc_infos, f)
with open(root_path / "infos_val.pkl", 'wb') as f:
pickle.dump(val_nusc_infos, f)


def create_nuscenes_infos_custom(
root_path,
version="v1.0-trainval",
split_rate=0.82353, # 700 / 850
test=False):
"""Don't use this because official evaluation tool don't support custom
"""
from nuscenes.nuscenes import NuScenes
nusc = NuScenes(version=version, dataroot=root_path, verbose=True)
from nuscenes.utils import splits
available_vers = ["v1.0-trainval", "v1.0-test", "v1.0-mini"]
root_path = Path(root_path)
# filter exist scenes. you may only download part of dataset.
available_scenes = _get_available_scenes(nusc)
num_train_scene = np.round(split_rate * len(available_scenes)).astype(
np.int64)
train_scenes = set(
[s["token"] for s in available_scenes[:num_train_scene]])
val_scenes = set([s["token"] for s in available_scenes[num_train_scene:]])
if test:
train_scenes = set([s["token"] for s in nusc.scene])
print(f"test scene: {len(train_scenes)}")
else:
print(
f"train scene: {len(train_scenes)}, val scene: {len(val_scenes)}")
train_nusc_infos, val_nusc_infos = _fill_trainval_infos(
nusc, train_scenes, val_scenes)
nusc, train_scenes, val_scenes, test, max_sweeps=max_sweeps)
if test:
print(f"test sample: {len(train_nusc_infos)}")
with open(root_path / "infos_test.pkl", 'wb') as f:
Expand Down Expand Up @@ -557,6 +591,4 @@ def get_box_mean(info_path, class_name="vehicle.car"):


if __name__ == "__main__":
# create_nuscenes_infos("/media/yy/My Passport/datasets/nuscene/v1.0-mini",
# "v1.0-mini")
fire.Fire()
fire.Fire()
28 changes: 15 additions & 13 deletions second/kittiviewer/frontend/js/KittiViewer.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var KittiViewer = function (pointCloud, logger, imageCanvas) {
this.gtBboxes = [];
this.dtBboxes = [];
this.pointCloud = pointCloud;
this.maxPoints = 150000;
this.maxPoints = 500000;
this.pointVertices = new Float32Array(this.maxPoints * 3);
this.gtBoxColor = "#00ff00";
this.dtBoxColor = "#ff0000";
Expand Down Expand Up @@ -238,18 +238,20 @@ KittiViewer.prototype = {
var points = new Float32Array(points_buf);
}
var numFeatures = response["num_features"];
var locs = response["locs"];
var dims = response["dims"];

var rots = response["rots"];
var labels = response["labels"];
self.gtBboxes = response["bbox"];
self.gtBoxes = boxEdgeWithLabel(dims, locs, rots, 2,
self.gtBoxColor, labels,
self.gtLabelColor);
// var boxes = boxEdge(dims, locs, rots, 2, "rgb(0, 255, 0)");
for (var i = 0; i < self.gtBoxes.length; ++i) {
scene.add(self.gtBoxes[i]);
if ("locs" in response){
var locs = response["locs"];
var dims = response["dims"];

var rots = response["rots"];
var labels = response["labels"];
self.gtBboxes = response["bbox"];
self.gtBoxes = boxEdgeWithLabel(dims, locs, rots, 2,
self.gtBoxColor, labels,
self.gtLabelColor);
// var boxes = boxEdge(dims, locs, rots, 2, "rgb(0, 255, 0)");
for (var i = 0; i < self.gtBoxes.length; ++i) {
scene.add(self.gtBoxes[i]);
}
}
if (self.drawDet && response.hasOwnProperty("dt_dims")) {

Expand Down
38 changes: 38 additions & 0 deletions second/protos/optimizer.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ message LearningRate {
oneof learning_rate {
MultiPhase multi_phase = 1;
OneCycle one_cycle = 2;
ExponentialDecay exponential_decay = 3;
ManualStepping manual_stepping = 4;
}
}

Expand All @@ -65,4 +67,40 @@ message OneCycle {
repeated float moms = 2;
float div_factor = 3;
float pct_start = 4;
}

/*
ManualStepping example:
initial_learning_rate = 0.001
decay_length = 0.1
decay_factor = 0.8
staircase = True
detail:
progress 0%~10%, lr=0.001
progress 10%~20%, lr=0.001 * 0.8
progress 20%~30%, lr=0.001 * 0.8 * 0.8
......
*/


message ExponentialDecay {
float initial_learning_rate = 1;
float decay_length = 2; // must in range (0, 1)
float decay_factor = 3;
bool staircase = 4;
}

/*
ManualStepping example:
boundaries = [0.8, 0.9]
rates = [0.001, 0.002, 0.003]
detail:
progress 0%~80%, lr=0.001
progress 80%~90%, lr=0.002
progress 90%~100%, lr=0.003
*/

message ManualStepping {
repeated float boundaries = 1; // must in range (0, 1)
repeated float rates = 2;
}
Loading

0 comments on commit 4fd6d5f

Please sign in to comment.