Skip to content

Commit

Permalink
minor improvements and bug fixes.
Browse files Browse the repository at this point in the history
see RELEASE.md for more details.
  • Loading branch information
traveller59 committed Mar 21, 2019
1 parent f892b5c commit 6a0bf67
Show file tree
Hide file tree
Showing 30 changed files with 1,852 additions and 1,124 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ ONLY support python 3.6+, pytorch 1.0.0+. Tested in Ubuntu 16.04/18.04.

## News

2019-3-21: SECOND V1.51 (minor improvement and bug fix) released! See [release notes](RELEASE.md) for more details.

2019-1-20: SECOND V1.5 released! See [release notes](RELEASE.md) for more details.


### Performance in KITTI validation set (50/50 split)

```car.fhd.config``` + 160 epochs (25 fps in 1080Ti):
Expand Down
14 changes: 13 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,16 @@
points([N, 4])->voxels([N, 5, 4])->Features([N, 4])->Sparse Convolution Networks->RPN. See [this](https://github.com/traveller59/second.pytorch/blob/master/second/pytorch/models/middle.py) for more details of sparse conv networks.
2. The [SparseConvNet](https://github.com/facebookresearch/SparseConvNet) is deprecated. New library [spconv](https://github.com/traveller59/spconv) is introduced.
3. Super converge (from fastai) is implemented. Now all network can converge to a good result with only 50~80 epoch. For example. ```car.fhd.config``` only needs 50 epochs to reach 78.3 AP (car mod 3d).
4. Target assigner now works correctly when using multi-class.
4. Target assigner now works correctly when using multi-class.

# Release 1.51

## Minor Improvements and Bug fixes

1. Better support for custom lidar data. You need to check KittiDataset for more details. (no test yet, I don't have custom data)
* Change all box to center format.
* Change kitti info format, now you need to regenerate kitti infos and gt database.
* Eval functions now support custom data evaluation. you need to specify z_center and z_axis in eval function.
2. Better RPN, you can add custom block by inherit RPNBase and implement _make_layer method.
3. Update pretrained model.
4. Add a simple inference notebook. everyone should start this project by that notebook.
13 changes: 5 additions & 8 deletions second/builder/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from second.builder import dbsampler_builder
from functools import partial

from second.utils import config_tool

def build(input_reader_config,
model_config,
Expand All @@ -29,11 +29,7 @@ def build(input_reader_config,
generate_bev = model_config.use_bev
without_reflectivity = model_config.without_reflectivity
num_point_features = model_config.num_point_features
out_size_factor = model_config.rpn.layer_strides[0] / model_config.rpn.upsample_strides[0]
out_size_factor *= model_config.middle_feature_extractor.downsample_factor
out_size_factor = int(out_size_factor)
assert out_size_factor > 0

downsample_factor = config_tool.get_downsample_factor(model_config)
cfg = input_reader_config
db_sampler_cfg = input_reader_config.database_sampler
db_sampler = None
Expand All @@ -45,8 +41,9 @@ def build(input_reader_config,
u_db_sampler = dbsampler_builder.build(u_db_sampler_cfg)
grid_size = voxel_generator.grid_size
# [352, 400]
feature_map_size = grid_size[:2] // out_size_factor
feature_map_size = grid_size[:2] // downsample_factor
feature_map_size = [*feature_map_size, 1][::-1]
print("feature_map_size", feature_map_size)
assert all([n != '' for n in target_assigner.classes]), "you must specify class_name in anchor_generators."
prep_func = partial(
prep_pointcloud,
Expand Down Expand Up @@ -77,7 +74,7 @@ def build(input_reader_config,
remove_points_after_sample=cfg.remove_points_after_sample,
remove_environment=cfg.remove_environment,
use_group_id=cfg.use_group_id,
out_size_factor=out_size_factor)
downsample_factor=downsample_factor)
dataset = KittiDataset(
info_path=cfg.kitti_info_path,
root_path=cfg.kitti_root_path,
Expand Down
10 changes: 9 additions & 1 deletion second/builder/voxel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
from spconv.utils import VoxelGenerator
from second.protos import voxel_generator_pb2

class _VoxelGenerator(VoxelGenerator):
@property
def grid_size(self):
point_cloud_range = np.array(self.point_cloud_range)
voxel_size = np.array(self.voxel_size)
g_size = (point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size
g_size = np.round(g_size).astype(np.int64)
return g_size

def build(voxel_config):
"""Builds a tensor dictionary based on the InputReader config.
Expand All @@ -20,7 +28,7 @@ def build(voxel_config):
if not isinstance(voxel_config, (voxel_generator_pb2.VoxelGenerator)):
raise ValueError('input_reader_config not of type '
'input_reader_pb2.InputReader.')
voxel_generator = VoxelGenerator(
voxel_generator = _VoxelGenerator(
voxel_size=list(voxel_config.voxel_size),
point_cloud_range=list(voxel_config.point_cloud_range),
max_num_points=voxel_config.max_number_of_points_per_voxel,
Expand Down
12 changes: 4 additions & 8 deletions second/configs/all.fhd.config
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ model: {
anchor_generators: {
anchor_generator_range: {
sizes: [1.6, 3.9, 1.56] # wlh
# anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center
anchor_ranges: [0, -32.0, -1.78, 52.8, 32.0, -1.78] # carefully set z center
anchor_ranges: [0, -32.0, -1.0, 52.8, 32.0, -1.0] # carefully set z center
rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code.
matched_threshold : 0.6
unmatched_threshold : 0.45
Expand All @@ -94,8 +93,7 @@ model: {
anchor_generators: {
anchor_generator_range: {
sizes: [0.6, 1.76, 1.73] # wlh
# anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center
anchor_ranges: [0, -32.0, -1.45, 52.8, 32.0, -1.45] # carefully set z center
anchor_ranges: [0, -32.0, -0.6, 52.8, 32.0, -0.6] # carefully set z center
rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code.
matched_threshold : 0.35
unmatched_threshold : 0.2
Expand All @@ -105,8 +103,7 @@ model: {
anchor_generators: {
anchor_generator_range: {
sizes: [0.6, 0.8, 1.73] # wlh
# anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center
anchor_ranges: [0, -32.0, -1.45, 52.8, 32.0, -1.45] # carefully set z center
anchor_ranges: [0, -32.0, -0.6, 52.8, 32.0, -0.6] # carefully set z center
rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code.
matched_threshold : 0.35
unmatched_threshold : 0.2
Expand All @@ -116,8 +113,7 @@ model: {
anchor_generators: {
anchor_generator_range: {
sizes: [1.87103749, 5.02808195, 2.20964255] # wlh
# anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center
anchor_ranges: [0, -32.0, -1.41, 52.8, 32.0, -1.41] # carefully set z center
anchor_ranges: [0, -32.0, -1.0, 52.8, 32.0, -1.0] # carefully set z center
rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code.
matched_threshold : 0.6
unmatched_threshold : 0.45
Expand Down
4 changes: 2 additions & 2 deletions second/configs/car.fhd.config
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ model: {
anchor_generators: {
anchor_generator_range: {
sizes: [1.6, 3.9, 1.56] # wlh
anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center
anchor_ranges: [0, -40.0, -1.0, 70.4, 40.0, -1.0] # carefully set z center
rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code.
matched_threshold : 0.6
unmatched_threshold : 0.45
Expand All @@ -105,7 +105,7 @@ train_input_reader: {
max_num_epochs : 160
batch_size: 6
prefetch_size : 25
max_number_of_voxels: 16000 # to support batchsize=2 in 1080Ti
max_number_of_voxels: 16000
shuffle_points: true
num_workers: 3
groundtruth_localization_noise_std: [1.0, 1.0, 0.5]
Expand Down
4 changes: 2 additions & 2 deletions second/configs/car.fhd.onestage.config
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ model: {
anchor_generators: {
anchor_generator_range: {
sizes: [1.6, 3.9, 1.56] # wlh
anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center
anchor_ranges: [0, -40.0, -1.0, 70.4, 40.0, -1.0] # carefully set z center
rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code.
matched_threshold : 0.6
unmatched_threshold : 0.45
Expand All @@ -105,7 +105,7 @@ train_input_reader: {
max_num_epochs : 160
batch_size: 6
prefetch_size : 25
max_number_of_voxels: 16000 # to support batchsize=2 in 1080Ti
max_number_of_voxels: 16000
shuffle_points: true
num_workers: 3
groundtruth_localization_noise_std: [1.0, 1.0, 0.5]
Expand Down
20 changes: 10 additions & 10 deletions second/configs/car.lite.config
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ model: {
}
rpn: {
module_class_name: "RPNV2"
layer_nums: [3, 5]
layer_strides: [1, 2]
num_filters: [128, 128]
upsample_strides: [1, 2]
num_upsample_filters: [128, 128]
layer_nums: [5]
layer_strides: [1]
num_filters: [128]
upsample_strides: [1]
num_upsample_filters: [128]
use_groupnorm: false
num_groups: 32
num_input_features: 128
Expand Down Expand Up @@ -83,8 +83,8 @@ model: {
anchor_generators: {
anchor_generator_range: {
sizes: [1.6, 3.9, 1.56] # wlh
# anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center
anchor_ranges: [0, -32.0, -1.78, 52.8, 32.0, -1.78]
# anchor_ranges: [0, -40.0, -1.0, 70.4, 40.0, -1.0] # carefully set z center
anchor_ranges: [0, -32.0, -1.0, 52.8, 32.0, -1.0]
rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code.
matched_threshold : 0.6
unmatched_threshold : 0.45
Expand All @@ -104,11 +104,11 @@ model: {

train_input_reader: {
max_num_epochs : 160
batch_size: 12 # sparse conv use 7633MB GPU memory when batch_size=3
batch_size: 12
prefetch_size : 25
max_number_of_voxels: 15000 # to support batchsize=2 in 1080Ti
max_number_of_voxels: 15000
shuffle_points: true
num_workers: 3
num_workers: 0
groundtruth_localization_noise_std: [1.0, 1.0, 0.5]
# groundtruth_rotation_uniform_noise: [-0.3141592654, 0.3141592654]
groundtruth_rotation_uniform_noise: [-1.57, 1.57]
Expand Down
10 changes: 3 additions & 7 deletions second/core/box_np_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ def second_box_encode(boxes, anchors, encode_angle_to_vector=False, smooth_dim=F
# need to convert boxes to z-center format
xa, ya, za, wa, la, ha, ra = np.split(anchors, 7, axis=-1)
xg, yg, zg, wg, lg, hg, rg = np.split(boxes, 7, axis=-1)
zg = zg + hg / 2
za = za + ha / 2
diagonal = np.sqrt(la**2 + wa**2) # 4.3
xt = (xg - xa) / diagonal
yt = (yg - ya) / diagonal
Expand Down Expand Up @@ -71,7 +69,6 @@ def second_box_decode(box_encodings, anchors, encode_angle_to_vector=False, smoo
xt, yt, zt, wt, lt, ht, rtx, rty = np.split(box_encodings, 8, axis=-1)
else:
xt, yt, zt, wt, lt, ht, rt = np.split(box_encodings, 7, axis=-1)
za = za + ha / 2
diagonal = np.sqrt(la**2 + wa**2)
xg = xt * diagonal + xa
yg = yt * diagonal + ya
Expand All @@ -93,7 +90,6 @@ def second_box_decode(box_encodings, anchors, encode_angle_to_vector=False, smoo
rg = np.arctan2(rgy, rgx)
else:
rg = rt + ra
zg = zg - hg / 2
return np.concatenate([xg, yg, zg, wg, lg, hg, rg], axis=-1)

def bev_box_encode(boxes, anchors, encode_angle_to_vector=False, smooth_dim=False):
Expand Down Expand Up @@ -328,7 +324,7 @@ def rotation_box(box_corners, angle):
def center_to_corner_box3d(centers,
dims,
angles=None,
origin=[0.5, 1.0, 0.5],
origin=0.5,
axis=1):
"""convert kitti locations, dimensions and angles to corners
Expand Down Expand Up @@ -399,7 +395,7 @@ def box2d_to_corner_jit(boxes):
return box_corners


def rbbox3d_to_corners(rbboxes, origin=[0.5, 0.5, 0.0], axis=2):
def rbbox3d_to_corners(rbboxes, origin=0.5, axis=2):
return center_to_corner_box3d(
rbboxes[..., :3],
rbboxes[..., 3:6],
Expand Down Expand Up @@ -847,7 +843,7 @@ def assign_label_to_voxel(gt_boxes, coors, voxel_size, coors_range):
gt_boxes[:, :3] - voxel_size * 0.5,
gt_boxes[:, 3:6] + voxel_size,
gt_boxes[:, 6],
origin=[0.5, 0.5, 0],
origin=0.5,
axis=2)
gt_surfaces = corner_to_surfaces_3d(gt_box_corners)
ret = points_in_convex_polygon_3d_jit(voxel_centers, gt_surfaces)
Expand Down
44 changes: 29 additions & 15 deletions second/core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from second.data.preprocess import merge_second_batch, prep_pointcloud
from second.protos import pipeline_pb2

import second.data.kitti_common as kitti

class InferenceContext:
def __init__(self):
Expand All @@ -23,23 +23,36 @@ def get_inference_input_dict(self, info, points):
assert self.voxel_generator is not None
assert self.config is not None
assert self.built is True
rect = info['calib/R0_rect']
P2 = info['calib/P2']
Trv2c = info['calib/Tr_velo_to_cam']
kitti.convert_to_kitti_info_version2(info)
pc_info = info["point_cloud"]
image_info = info["image"]
calib = info["calib"]

rect = calib['R0_rect']
Trv2c = calib['Tr_velo_to_cam']
P2 = calib['P2']

input_cfg = self.config.eval_input_reader
model_cfg = self.config.model.second

input_dict = {
'points': points,
'rect': rect,
'Trv2c': Trv2c,
'P2': P2,
'image_shape': np.array(info["img_shape"], dtype=np.int32),
'image_idx': info['image_idx'],
'image_path': info['img_path'],
# 'pointcloud_num_features': num_point_features,
"calib": {
'rect': rect,
'Trv2c': Trv2c,
'P2': P2,
},
"image": {
'image_shape': np.array(image_info["image_shape"], dtype=np.int32),
'image_idx': image_info['image_idx'],
'image_path': image_info['image_path'],
},
}
out_size_factor = model_cfg.rpn.layer_strides[0] // model_cfg.rpn.upsample_strides[0]
out_size_factor = np.prod(model_cfg.rpn.layer_strides)
if len(model_cfg.rpn.upsample_strides) > 0:
out_size_factor /= model_cfg.rpn.upsample_strides[-1]
out_size_factor *= model_cfg.middle_feature_extractor.downsample_factor
out_size_factor = int(out_size_factor)
example = prep_pointcloud(
input_dict=input_dict,
root_path=str(self.root_path),
Expand All @@ -57,9 +70,10 @@ def get_inference_input_dict(self, info, points):
anchor_cache=self.anchor_cache,
out_size_factor=out_size_factor,
out_dtype=np.float32)
example["image_idx"] = info['image_idx']
example["image_shape"] = input_dict["image_shape"]
example["points"] = points
example["metadata"] = {}
if "image" in info:
example["metadata"]["image"] = input_dict["image"]

if "anchors_mask" in example:
example["anchors_mask"] = example["anchors_mask"].astype(np.uint8)
#############
Expand Down
6 changes: 2 additions & 4 deletions second/core/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,12 +650,11 @@ def noise_per_object_v3_(gt_boxes,
gt_boxes[:, 6], group_centers, valid_mask)
group_nums = np.array(list(group_id_num_dict.values()), dtype=np.int64)

origin = [0.5, 0.5, 0]
gt_box_corners = box_np_ops.center_to_corner_box3d(
gt_boxes[:, :3],
gt_boxes[:, 3:6],
gt_boxes[:, 6],
origin=origin,
origin=0.5,
axis=2)
if group_ids is not None:
if not enable_grot:
Expand Down Expand Up @@ -728,12 +727,11 @@ def noise_per_object_v2_(gt_boxes,
grot_uppers[..., np.newaxis],
size=[num_boxes, num_try])

origin = [0.5, 0.5, 0]
gt_box_corners = box_np_ops.center_to_corner_box3d(
gt_boxes[:, :3],
gt_boxes[:, 3:6],
gt_boxes[:, 6],
origin=origin,
origin=0.5,
axis=2)
if np.abs(global_random_rot_range[0] - global_random_rot_range[1]) < 1e-3:
selected_noise = noise_per_box(gt_boxes[:, [0, 1, 3, 4, 6]],
Expand Down
7 changes: 4 additions & 3 deletions second/core/sample_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def sample_all(self,
num_point_features,
random_crop=False,
gt_group_ids=None,
rect=None,
Trv2c=None,
P2=None):
calib=None):
sampled_num_dict = {}
sample_num_per_class = []
for class_name, max_sample_num in zip(self._sample_classes,
Expand Down Expand Up @@ -180,6 +178,9 @@ def sample_all(self,
# if np.random.choice([False, True], replace=False, p=[0.3, 0.7]):
# do random crop.
if random_crop:
rect = calib["rect"]
Trv2c = calib["Trv2c"]
P2 = calib["P2"]
s_points_list_new = []
gt_bboxes = box_np_ops.box3d_to_bbox(sampled_gt_boxes, rect,
Trv2c, P2)
Expand Down
Loading

0 comments on commit 6a0bf67

Please sign in to comment.