From 6a0bf67112e1e91d7e12cedac3160bcea061c92f Mon Sep 17 00:00:00 2001 From: traveller59 Date: Thu, 21 Mar 2019 12:57:52 +0800 Subject: [PATCH] minor improvements and bug fixes. see RELEASE.md for more details. --- README.md | 3 + RELEASE.md | 14 +- second/builder/dataset_builder.py | 13 +- second/builder/voxel_builder.py | 10 +- second/configs/all.fhd.config | 12 +- second/configs/car.fhd.config | 4 +- second/configs/car.fhd.onestage.config | 4 +- second/configs/car.lite.config | 20 +- second/core/box_np_ops.py | 10 +- second/core/inference.py | 44 +- second/core/preprocess.py | 6 +- second/core/sample_ops.py | 7 +- second/data/dataset.py | 139 +++++- second/data/kitti_common.py | 115 ++++- second/data/preprocess.py | 342 +++++++------- second/kittiviewer/backend.py | 93 ++-- second/kittiviewer/frontend/index.html | 18 +- second/kittiviewer/frontend/js/KittiViewer.js | 61 ++- second/pytorch/builder/second_builder.py | 16 + second/pytorch/core/box_torch_ops.py | 6 +- second/pytorch/inference.py | 14 +- second/pytorch/models/rpn.py | 408 ++++++---------- second/pytorch/models/voxelnet.py | 166 +++++-- second/pytorch/train.py | 442 +++++++----------- second/script.py | 47 ++ second/simple-inference.ipynb | 269 +++++++++++ second/utils/config_tool.py | 59 +++ second/utils/eval.py | 408 +++++++--------- second/utils/log_tool.py | 35 ++ second/utils/simplevis.py | 191 ++++++++ 30 files changed, 1852 insertions(+), 1124 deletions(-) create mode 100644 second/script.py create mode 100644 second/simple-inference.ipynb create mode 100644 second/utils/config_tool.py create mode 100644 second/utils/log_tool.py create mode 100644 second/utils/simplevis.py diff --git a/README.md b/README.md index 52b211af..cb6f8b95 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,11 @@ ONLY support python 3.6+, pytorch 1.0.0+. Tested in Ubuntu 16.04/18.04. ## News +2019-3-21: SECOND V1.51 (minor improvement and bug fix) released! See [release notes](RELEASE.md) for more details. + 2019-1-20: SECOND V1.5 released! See [release notes](RELEASE.md) for more details. + ### Performance in KITTI validation set (50/50 split) ```car.fhd.config``` + 160 epochs (25 fps in 1080Ti): diff --git a/RELEASE.md b/RELEASE.md index 5d04da71..c09796c2 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -6,4 +6,16 @@ points([N, 4])->voxels([N, 5, 4])->Features([N, 4])->Sparse Convolution Networks->RPN. See [this](https://github.com/traveller59/second.pytorch/blob/master/second/pytorch/models/middle.py) for more details of sparse conv networks. 2. The [SparseConvNet](https://github.com/facebookresearch/SparseConvNet) is deprecated. New library [spconv](https://github.com/traveller59/spconv) is introduced. 3. Super converge (from fastai) is implemented. Now all network can converge to a good result with only 50~80 epoch. For example. ```car.fhd.config``` only needs 50 epochs to reach 78.3 AP (car mod 3d). -4. Target assigner now works correctly when using multi-class. \ No newline at end of file +4. Target assigner now works correctly when using multi-class. + +# Release 1.51 + +## Minor Improvements and Bug fixes + +1. Better support for custom lidar data. You need to check KittiDataset for more details. (no test yet, I don't have custom data) +* Change all box to center format. +* Change kitti info format, now you need to regenerate kitti infos and gt database. +* Eval functions now support custom data evaluation. you need to specify z_center and z_axis in eval function. +2. Better RPN, you can add custom block by inherit RPNBase and implement _make_layer method. +3. Update pretrained model. +4. Add a simple inference notebook. everyone should start this project by that notebook. \ No newline at end of file diff --git a/second/builder/dataset_builder.py b/second/builder/dataset_builder.py index 55feef8b..795148dd 100644 --- a/second/builder/dataset_builder.py +++ b/second/builder/dataset_builder.py @@ -4,7 +4,7 @@ import numpy as np from second.builder import dbsampler_builder from functools import partial - +from second.utils import config_tool def build(input_reader_config, model_config, @@ -29,11 +29,7 @@ def build(input_reader_config, generate_bev = model_config.use_bev without_reflectivity = model_config.without_reflectivity num_point_features = model_config.num_point_features - out_size_factor = model_config.rpn.layer_strides[0] / model_config.rpn.upsample_strides[0] - out_size_factor *= model_config.middle_feature_extractor.downsample_factor - out_size_factor = int(out_size_factor) - assert out_size_factor > 0 - + downsample_factor = config_tool.get_downsample_factor(model_config) cfg = input_reader_config db_sampler_cfg = input_reader_config.database_sampler db_sampler = None @@ -45,8 +41,9 @@ def build(input_reader_config, u_db_sampler = dbsampler_builder.build(u_db_sampler_cfg) grid_size = voxel_generator.grid_size # [352, 400] - feature_map_size = grid_size[:2] // out_size_factor + feature_map_size = grid_size[:2] // downsample_factor feature_map_size = [*feature_map_size, 1][::-1] + print("feature_map_size", feature_map_size) assert all([n != '' for n in target_assigner.classes]), "you must specify class_name in anchor_generators." prep_func = partial( prep_pointcloud, @@ -77,7 +74,7 @@ def build(input_reader_config, remove_points_after_sample=cfg.remove_points_after_sample, remove_environment=cfg.remove_environment, use_group_id=cfg.use_group_id, - out_size_factor=out_size_factor) + downsample_factor=downsample_factor) dataset = KittiDataset( info_path=cfg.kitti_info_path, root_path=cfg.kitti_root_path, diff --git a/second/builder/voxel_builder.py b/second/builder/voxel_builder.py index d9e5de19..c671d2a0 100644 --- a/second/builder/voxel_builder.py +++ b/second/builder/voxel_builder.py @@ -3,6 +3,14 @@ from spconv.utils import VoxelGenerator from second.protos import voxel_generator_pb2 +class _VoxelGenerator(VoxelGenerator): + @property + def grid_size(self): + point_cloud_range = np.array(self.point_cloud_range) + voxel_size = np.array(self.voxel_size) + g_size = (point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size + g_size = np.round(g_size).astype(np.int64) + return g_size def build(voxel_config): """Builds a tensor dictionary based on the InputReader config. @@ -20,7 +28,7 @@ def build(voxel_config): if not isinstance(voxel_config, (voxel_generator_pb2.VoxelGenerator)): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') - voxel_generator = VoxelGenerator( + voxel_generator = _VoxelGenerator( voxel_size=list(voxel_config.voxel_size), point_cloud_range=list(voxel_config.point_cloud_range), max_num_points=voxel_config.max_number_of_points_per_voxel, diff --git a/second/configs/all.fhd.config b/second/configs/all.fhd.config index bfe8b47a..52bb22c9 100644 --- a/second/configs/all.fhd.config +++ b/second/configs/all.fhd.config @@ -83,8 +83,7 @@ model: { anchor_generators: { anchor_generator_range: { sizes: [1.6, 3.9, 1.56] # wlh - # anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center - anchor_ranges: [0, -32.0, -1.78, 52.8, 32.0, -1.78] # carefully set z center + anchor_ranges: [0, -32.0, -1.0, 52.8, 32.0, -1.0] # carefully set z center rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code. matched_threshold : 0.6 unmatched_threshold : 0.45 @@ -94,8 +93,7 @@ model: { anchor_generators: { anchor_generator_range: { sizes: [0.6, 1.76, 1.73] # wlh - # anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center - anchor_ranges: [0, -32.0, -1.45, 52.8, 32.0, -1.45] # carefully set z center + anchor_ranges: [0, -32.0, -0.6, 52.8, 32.0, -0.6] # carefully set z center rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code. matched_threshold : 0.35 unmatched_threshold : 0.2 @@ -105,8 +103,7 @@ model: { anchor_generators: { anchor_generator_range: { sizes: [0.6, 0.8, 1.73] # wlh - # anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center - anchor_ranges: [0, -32.0, -1.45, 52.8, 32.0, -1.45] # carefully set z center + anchor_ranges: [0, -32.0, -0.6, 52.8, 32.0, -0.6] # carefully set z center rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code. matched_threshold : 0.35 unmatched_threshold : 0.2 @@ -116,8 +113,7 @@ model: { anchor_generators: { anchor_generator_range: { sizes: [1.87103749, 5.02808195, 2.20964255] # wlh - # anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center - anchor_ranges: [0, -32.0, -1.41, 52.8, 32.0, -1.41] # carefully set z center + anchor_ranges: [0, -32.0, -1.0, 52.8, 32.0, -1.0] # carefully set z center rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code. matched_threshold : 0.6 unmatched_threshold : 0.45 diff --git a/second/configs/car.fhd.config b/second/configs/car.fhd.config index 1f6369bf..1679f189 100644 --- a/second/configs/car.fhd.config +++ b/second/configs/car.fhd.config @@ -83,7 +83,7 @@ model: { anchor_generators: { anchor_generator_range: { sizes: [1.6, 3.9, 1.56] # wlh - anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center + anchor_ranges: [0, -40.0, -1.0, 70.4, 40.0, -1.0] # carefully set z center rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code. matched_threshold : 0.6 unmatched_threshold : 0.45 @@ -105,7 +105,7 @@ train_input_reader: { max_num_epochs : 160 batch_size: 6 prefetch_size : 25 - max_number_of_voxels: 16000 # to support batchsize=2 in 1080Ti + max_number_of_voxels: 16000 shuffle_points: true num_workers: 3 groundtruth_localization_noise_std: [1.0, 1.0, 0.5] diff --git a/second/configs/car.fhd.onestage.config b/second/configs/car.fhd.onestage.config index 5130b8e5..0096823a 100644 --- a/second/configs/car.fhd.onestage.config +++ b/second/configs/car.fhd.onestage.config @@ -83,7 +83,7 @@ model: { anchor_generators: { anchor_generator_range: { sizes: [1.6, 3.9, 1.56] # wlh - anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center + anchor_ranges: [0, -40.0, -1.0, 70.4, 40.0, -1.0] # carefully set z center rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code. matched_threshold : 0.6 unmatched_threshold : 0.45 @@ -105,7 +105,7 @@ train_input_reader: { max_num_epochs : 160 batch_size: 6 prefetch_size : 25 - max_number_of_voxels: 16000 # to support batchsize=2 in 1080Ti + max_number_of_voxels: 16000 shuffle_points: true num_workers: 3 groundtruth_localization_noise_std: [1.0, 1.0, 0.5] diff --git a/second/configs/car.lite.config b/second/configs/car.lite.config index 2c696dab..3a9f50fe 100644 --- a/second/configs/car.lite.config +++ b/second/configs/car.lite.config @@ -22,11 +22,11 @@ model: { } rpn: { module_class_name: "RPNV2" - layer_nums: [3, 5] - layer_strides: [1, 2] - num_filters: [128, 128] - upsample_strides: [1, 2] - num_upsample_filters: [128, 128] + layer_nums: [5] + layer_strides: [1] + num_filters: [128] + upsample_strides: [1] + num_upsample_filters: [128] use_groupnorm: false num_groups: 32 num_input_features: 128 @@ -83,8 +83,8 @@ model: { anchor_generators: { anchor_generator_range: { sizes: [1.6, 3.9, 1.56] # wlh - # anchor_ranges: [0, -40.0, -1.78, 70.4, 40.0, -1.78] # carefully set z center - anchor_ranges: [0, -32.0, -1.78, 52.8, 32.0, -1.78] + # anchor_ranges: [0, -40.0, -1.0, 70.4, 40.0, -1.0] # carefully set z center + anchor_ranges: [0, -32.0, -1.0, 52.8, 32.0, -1.0] rotations: [0, 1.57] # DON'T modify this unless you are very familiar with my code. matched_threshold : 0.6 unmatched_threshold : 0.45 @@ -104,11 +104,11 @@ model: { train_input_reader: { max_num_epochs : 160 - batch_size: 12 # sparse conv use 7633MB GPU memory when batch_size=3 + batch_size: 12 prefetch_size : 25 - max_number_of_voxels: 15000 # to support batchsize=2 in 1080Ti + max_number_of_voxels: 15000 shuffle_points: true - num_workers: 3 + num_workers: 0 groundtruth_localization_noise_std: [1.0, 1.0, 0.5] # groundtruth_rotation_uniform_noise: [-0.3141592654, 0.3141592654] groundtruth_rotation_uniform_noise: [-1.57, 1.57] diff --git a/second/core/box_np_ops.py b/second/core/box_np_ops.py index 21207ae9..b34e2c2c 100644 --- a/second/core/box_np_ops.py +++ b/second/core/box_np_ops.py @@ -31,8 +31,6 @@ def second_box_encode(boxes, anchors, encode_angle_to_vector=False, smooth_dim=F # need to convert boxes to z-center format xa, ya, za, wa, la, ha, ra = np.split(anchors, 7, axis=-1) xg, yg, zg, wg, lg, hg, rg = np.split(boxes, 7, axis=-1) - zg = zg + hg / 2 - za = za + ha / 2 diagonal = np.sqrt(la**2 + wa**2) # 4.3 xt = (xg - xa) / diagonal yt = (yg - ya) / diagonal @@ -71,7 +69,6 @@ def second_box_decode(box_encodings, anchors, encode_angle_to_vector=False, smoo xt, yt, zt, wt, lt, ht, rtx, rty = np.split(box_encodings, 8, axis=-1) else: xt, yt, zt, wt, lt, ht, rt = np.split(box_encodings, 7, axis=-1) - za = za + ha / 2 diagonal = np.sqrt(la**2 + wa**2) xg = xt * diagonal + xa yg = yt * diagonal + ya @@ -93,7 +90,6 @@ def second_box_decode(box_encodings, anchors, encode_angle_to_vector=False, smoo rg = np.arctan2(rgy, rgx) else: rg = rt + ra - zg = zg - hg / 2 return np.concatenate([xg, yg, zg, wg, lg, hg, rg], axis=-1) def bev_box_encode(boxes, anchors, encode_angle_to_vector=False, smooth_dim=False): @@ -328,7 +324,7 @@ def rotation_box(box_corners, angle): def center_to_corner_box3d(centers, dims, angles=None, - origin=[0.5, 1.0, 0.5], + origin=0.5, axis=1): """convert kitti locations, dimensions and angles to corners @@ -399,7 +395,7 @@ def box2d_to_corner_jit(boxes): return box_corners -def rbbox3d_to_corners(rbboxes, origin=[0.5, 0.5, 0.0], axis=2): +def rbbox3d_to_corners(rbboxes, origin=0.5, axis=2): return center_to_corner_box3d( rbboxes[..., :3], rbboxes[..., 3:6], @@ -847,7 +843,7 @@ def assign_label_to_voxel(gt_boxes, coors, voxel_size, coors_range): gt_boxes[:, :3] - voxel_size * 0.5, gt_boxes[:, 3:6] + voxel_size, gt_boxes[:, 6], - origin=[0.5, 0.5, 0], + origin=0.5, axis=2) gt_surfaces = corner_to_surfaces_3d(gt_box_corners) ret = points_in_convex_polygon_3d_jit(voxel_centers, gt_surfaces) diff --git a/second/core/inference.py b/second/core/inference.py index 1993e7cb..160b36f8 100644 --- a/second/core/inference.py +++ b/second/core/inference.py @@ -6,7 +6,7 @@ from second.data.preprocess import merge_second_batch, prep_pointcloud from second.protos import pipeline_pb2 - +import second.data.kitti_common as kitti class InferenceContext: def __init__(self): @@ -23,23 +23,36 @@ def get_inference_input_dict(self, info, points): assert self.voxel_generator is not None assert self.config is not None assert self.built is True - rect = info['calib/R0_rect'] - P2 = info['calib/P2'] - Trv2c = info['calib/Tr_velo_to_cam'] + kitti.convert_to_kitti_info_version2(info) + pc_info = info["point_cloud"] + image_info = info["image"] + calib = info["calib"] + + rect = calib['R0_rect'] + Trv2c = calib['Tr_velo_to_cam'] + P2 = calib['P2'] + input_cfg = self.config.eval_input_reader model_cfg = self.config.model.second input_dict = { 'points': points, - 'rect': rect, - 'Trv2c': Trv2c, - 'P2': P2, - 'image_shape': np.array(info["img_shape"], dtype=np.int32), - 'image_idx': info['image_idx'], - 'image_path': info['img_path'], - # 'pointcloud_num_features': num_point_features, + "calib": { + 'rect': rect, + 'Trv2c': Trv2c, + 'P2': P2, + }, + "image": { + 'image_shape': np.array(image_info["image_shape"], dtype=np.int32), + 'image_idx': image_info['image_idx'], + 'image_path': image_info['image_path'], + }, } - out_size_factor = model_cfg.rpn.layer_strides[0] // model_cfg.rpn.upsample_strides[0] + out_size_factor = np.prod(model_cfg.rpn.layer_strides) + if len(model_cfg.rpn.upsample_strides) > 0: + out_size_factor /= model_cfg.rpn.upsample_strides[-1] + out_size_factor *= model_cfg.middle_feature_extractor.downsample_factor + out_size_factor = int(out_size_factor) example = prep_pointcloud( input_dict=input_dict, root_path=str(self.root_path), @@ -57,9 +70,10 @@ def get_inference_input_dict(self, info, points): anchor_cache=self.anchor_cache, out_size_factor=out_size_factor, out_dtype=np.float32) - example["image_idx"] = info['image_idx'] - example["image_shape"] = input_dict["image_shape"] - example["points"] = points + example["metadata"] = {} + if "image" in info: + example["metadata"]["image"] = input_dict["image"] + if "anchors_mask" in example: example["anchors_mask"] = example["anchors_mask"].astype(np.uint8) ############# diff --git a/second/core/preprocess.py b/second/core/preprocess.py index 3a80e4c1..7adb18c3 100644 --- a/second/core/preprocess.py +++ b/second/core/preprocess.py @@ -650,12 +650,11 @@ def noise_per_object_v3_(gt_boxes, gt_boxes[:, 6], group_centers, valid_mask) group_nums = np.array(list(group_id_num_dict.values()), dtype=np.int64) - origin = [0.5, 0.5, 0] gt_box_corners = box_np_ops.center_to_corner_box3d( gt_boxes[:, :3], gt_boxes[:, 3:6], gt_boxes[:, 6], - origin=origin, + origin=0.5, axis=2) if group_ids is not None: if not enable_grot: @@ -728,12 +727,11 @@ def noise_per_object_v2_(gt_boxes, grot_uppers[..., np.newaxis], size=[num_boxes, num_try]) - origin = [0.5, 0.5, 0] gt_box_corners = box_np_ops.center_to_corner_box3d( gt_boxes[:, :3], gt_boxes[:, 3:6], gt_boxes[:, 6], - origin=origin, + origin=0.5, axis=2) if np.abs(global_random_rot_range[0] - global_random_rot_range[1]) < 1e-3: selected_noise = noise_per_box(gt_boxes[:, [0, 1, 3, 4, 6]], diff --git a/second/core/sample_ops.py b/second/core/sample_ops.py index b540fed3..b00bb1cc 100644 --- a/second/core/sample_ops.py +++ b/second/core/sample_ops.py @@ -99,9 +99,7 @@ def sample_all(self, num_point_features, random_crop=False, gt_group_ids=None, - rect=None, - Trv2c=None, - P2=None): + calib=None): sampled_num_dict = {} sample_num_per_class = [] for class_name, max_sample_num in zip(self._sample_classes, @@ -180,6 +178,9 @@ def sample_all(self, # if np.random.choice([False, True], replace=False, p=[0.3, 0.7]): # do random crop. if random_crop: + rect = calib["rect"] + Trv2c = calib["Trv2c"] + P2 = calib["P2"] s_points_list_new = [] gt_bboxes = box_np_ops.box3d_to_bbox(sampled_gt_boxes, rect, Trv2c, P2) diff --git a/second/data/dataset.py b/second/data/dataset.py index 0130fd1f..7d0810c2 100644 --- a/second/data/dataset.py +++ b/second/data/dataset.py @@ -8,7 +8,7 @@ from second.core import box_np_ops from second.core import preprocess as prep from second.data import kitti_common as kitti -from second.data.preprocess import _read_and_prep_v9 +from second.utils.eval import get_coco_eval_result, get_official_eval_result class Dataset(object): @@ -25,7 +25,6 @@ def __len__(self): raise NotImplementedError - class KittiDataset(Dataset): def __init__(self, info_path, root_path, num_point_features, target_assigner, feature_map_size, prep_func): @@ -37,8 +36,8 @@ def __init__(self, info_path, root_path, num_point_features, self._num_point_features = num_point_features print("remain number of infos:", len(self._kitti_infos)) # generate anchors cache - # [352, 400] ret = target_assigner.generate_anchors(feature_map_size) + self._class_names = target_assigner.classes anchors_dict = target_assigner.generate_anchors_dict(feature_map_size) anchors = ret["anchors"] anchors = anchors.reshape([-1, 7]) @@ -59,12 +58,132 @@ def __len__(self): return len(self._kitti_infos) @property - def kitti_infos(self): - return self._kitti_infos + def ground_truth_annotations(self): + """ + If you want to eval by my eval function, you must + provide this property. + ground_truth_annotations format: + { + bbox: [N, 4], if you fill fake data, MUST HAVE >25 HEIGHT!!!!!! + alpha: [N], you can use zero. + occluded: [N], you can use zero. + truncated: [N], you can use zero. + name: [N] + location: [N, 3] center of 3d box. + dimensions: [N, 3] dim of 3d box. + rotation_y: [N] angle. + } + all fields must be filled, but some fields can fill + zero. + """ + if "annos" not in self._kitti_infos[0]: + return None + gt_annos = [info["annos"] for info in self._kitti_infos] + return gt_annos + + def evaluation(self, dt_annos): + """dt_annos have same format as ground_truth_annotations. + When you want to eval your own dataset, you MUST set correct + the z axis and box z center. + """ + gt_annos = self.ground_truth_annotations + if gt_annos is None: + return None, None + z_axis = 1 # KITTI camera format use y as regular "z" axis. + z_center = 1.0 # KITTI camera box's center is [0.5, 1, 0.5] + # for regular raw lidar data, z_axis = 2, z_center = 0.5. + result_official = get_official_eval_result( + gt_annos, + dt_annos, + self._class_names, + z_axis=z_axis, + z_center=z_center) + result_coco = get_coco_eval_result( + gt_annos, + dt_annos, + self._class_names, + z_axis=z_axis, + z_center=z_center) + return result_official, result_coco def __getitem__(self, idx): - return _read_and_prep_v9( - info=self._kitti_infos[idx], - root_path=self._root_path, - num_point_features=self._num_point_features, - prep_func=self._prep_func) + """ + you need to create a input dict in this function for network inference. + format: { + anchors + voxels + num_points + coordinates + ground_truth: { + gt_boxes + gt_names + [optional]difficulty + [optional]group_ids + } + [optional]anchors_mask, slow in SECOND v1.5, don't use this. + [optional]metadata, in kitti, image index is saved in metadata + } + """ + info = self._kitti_infos[idx] + kitti.convert_to_kitti_info_version2(info) + pc_info = info["point_cloud"] + if "points" not in pc_info: + velo_path = pathlib.Path(pc_info['velodyne_path']) + if not velo_path.is_absolute(): + velo_path = pathlib.Path(self._root_path) / pc_info['velodyne_path'] + velo_reduced_path = velo_path.parent.parent / ( + velo_path.parent.stem + '_reduced') / velo_path.name + if velo_reduced_path.exists(): + velo_path = velo_reduced_path + points = np.fromfile( + str(velo_path), dtype=np.float32, + count=-1).reshape([-1, self._num_point_features]) + input_dict = { + 'points': points, + } + if "image" in info: + input_dict["image"] = info["image"] + if "calib" in info: + calib = info["calib"] + calib_dict = { + 'rect': calib['R0_rect'], + 'Trv2c': calib['Tr_velo_to_cam'], + 'P2': calib['P2'], + } + input_dict["calib"] = calib_dict + if 'annos' in info: + annos = info['annos'] + annos_dict = {} + # we need other objects to avoid collision when sample + annos = kitti.remove_dontcare(annos) + loc = annos["location"] + dims = annos["dimensions"] + rots = annos["rotation_y"] + gt_names = annos["name"] + gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], + axis=1).astype(np.float32) + if "calib" in info: + calib = info["calib"] + gt_boxes = box_np_ops.box_camera_to_lidar( + gt_boxes, calib["R0_rect"], calib["Tr_velo_to_cam"]) + # only center format is allowed. so we need to convert + # kitti [0.5, 0.5, 0] center to [0.5, 0.5, 0.5] + box_np_ops.change_box3d_center_(gt_boxes, [0.5, 0.5, 0], [0.5, 0.5, 0.5]) + gt_dict = { + 'gt_boxes': gt_boxes, + 'gt_names': gt_names, + } + if 'difficulty' in annos: + gt_dict['difficulty'] = annos["difficulty"] + if 'group_ids' in annos: + gt_dict['group_ids'] = annos["group_ids"] + input_dict["ground_truth"] = gt_dict + example = self._prep_func(input_dict=input_dict) + example["metadata"] = {} + if "image" in info: + example["metadata"]["image"] = input_dict["image"] + if "anchors_mask" in example: + example["anchors_mask"] = example["anchors_mask"].astype(np.uint8) + return example + + diff --git a/second/data/kitti_common.py b/second/data/kitti_common.py index eb80db70..eba03083 100644 --- a/second/data/kitti_common.py +++ b/second/data/kitti_common.py @@ -1,3 +1,4 @@ + import concurrent.futures as futures import os import pathlib @@ -121,6 +122,35 @@ def _extend_matrix(mat): return mat +def _check_kitti_directory(root_path): + path = pathlib.Path(root_path) + results = [] + results.append((path / 'training').exists()) + results.append((path / 'testing').exists()) + path_train_image_2 = path / 'training' / 'image_2' + results.append(path_train_image_2.exists()) + results.append(len(path_train_image_2.glob('*.png')) == 7481) + path_train_label_2 = path / 'training' / 'label_2' + results.append(path_train_label_2.exists()) + path_train_lidar = path / 'training' / 'velodyne' + results.append(path_train_lidar.exists()) + path_train_calib = path / 'training' / 'calib' + results.append(path_train_calib.exists()) + results.append(len(path_train_label_2.glob('*.txt')) == 7481) + results.append(len(path_train_lidar.glob('*.bin')) == 7481) + results.append(len(path_train_calib.glob('*.txt')) == 7481) + path_test_image_2 = path / 'testing' / 'image_2' + results.append(path_test_image_2.exists()) + results.append(len(path_test_image_2.glob('*.png')) == 7518) + path_test_lidar = path / 'testing' / 'velodyne' + results.append(path_test_lidar.exists()) + path_test_calib = path / 'testing' / 'calib' + results.append(path_test_calib.exists()) + results.append(len(path_test_lidar.glob('*.bin')) == 7518) + results.append(len(path_test_calib.glob('*.txt')) == 7518) + return np.array(results, dtype=np.bool) + + def get_kitti_image_info(path, training=True, label_info=True, @@ -132,29 +162,63 @@ def get_kitti_image_info(path, relative_path=True, with_imageshape=True): # image_infos = [] + """ + KITTI annotation format version 2: + { + [optional]points: [N, 3+] point cloud + [optional, for kitti]image: { + image_idx: ... + image_path: ... + image_shape: ... + } + point_cloud: { + num_features: 4 + velodyne_path: ... + } + [optional, for kitti]calib: { + R0_rect: ... + Tr_velo_to_cam: ... + P2: ... + } + annos: { + location: [num_gt, 3] array + dimensions: [num_gt, 3] array + rotation_y: [num_gt] angle array + name: [num_gt] ground truth name array + [optional]difficulty: kitti difficulty + [optional]group_ids: used for multi-part object + } + } + """ root_path = pathlib.Path(path) if not isinstance(image_ids, list): image_ids = list(range(image_ids)) def map_func(idx): - image_info = {'image_idx': idx, 'pointcloud_num_features': 4} + info = {} + pc_info = {'num_features': 4} + calib_info = {} + + image_info = {'image_idx': idx} annotations = None if velodyne: - image_info['velodyne_path'] = get_velodyne_path( + pc_info['velodyne_path'] = get_velodyne_path( idx, path, training, relative_path) - image_info['img_path'] = get_image_path(idx, path, training, + image_info['image_path'] = get_image_path(idx, path, training, relative_path) if with_imageshape: - img_path = image_info['img_path'] + img_path = image_info['image_path'] if relative_path: img_path = str(root_path / img_path) - image_info['img_shape'] = np.array( + image_info['image_shape'] = np.array( io.imread(img_path).shape[:2], dtype=np.int32) if label_info: label_path = get_label_path(idx, path, training, relative_path) if relative_path: label_path = str(root_path / label_path) annotations = get_label_anno(label_path) + info["image"] = image_info + info["point_cloud"] = pc_info if calib: calib_path = get_calib_path( idx, path, training, relative_path=False) @@ -177,10 +241,6 @@ def map_func(idx): P1 = _extend_matrix(P1) P2 = _extend_matrix(P2) P3 = _extend_matrix(P3) - image_info['calib/P0'] = P0 - image_info['calib/P1'] = P1 - image_info['calib/P2'] = P2 - image_info['calib/P3'] = P3 R0_rect = np.array([ float(info) for info in lines[4].split(' ')[1:10] ]).reshape([3, 3]) @@ -190,7 +250,7 @@ def map_func(idx): rect_4x4[:3, :3] = R0_rect else: rect_4x4 = R0_rect - image_info['calib/R0_rect'] = rect_4x4 + Tr_velo_to_cam = np.array([ float(info) for info in lines[5].split(' ')[1:13] ]).reshape([3, 4]) @@ -200,17 +260,42 @@ def map_func(idx): if extend_matrix: Tr_velo_to_cam = _extend_matrix(Tr_velo_to_cam) Tr_imu_to_velo = _extend_matrix(Tr_imu_to_velo) - image_info['calib/Tr_velo_to_cam'] = Tr_velo_to_cam - image_info['calib/Tr_imu_to_velo'] = Tr_imu_to_velo + calib_info['P0'] = P0 + calib_info['P1'] = P1 + calib_info['P2'] = P2 + calib_info['P3'] = P3 + calib_info['R0_rect'] = rect_4x4 + calib_info['Tr_velo_to_cam'] = Tr_velo_to_cam + calib_info['Tr_imu_to_velo'] = Tr_imu_to_velo + info["calib"] = calib_info + if annotations is not None: - image_info['annos'] = annotations - add_difficulty_to_annos(image_info) - return image_info + info['annos'] = annotations + add_difficulty_to_annos(info) + return info with futures.ThreadPoolExecutor(num_worker) as executor: image_infos = executor.map(map_func, image_ids) return list(image_infos) +def convert_to_kitti_info_version2(info): + """convert kitti info v1 to v2 if possible. + """ + if "image" not in info or "calib" not in info or "point_cloud" not in info: + info["image"] = { + 'image_shape': info["img_shape"], + 'image_idx': info['image_idx'], + 'image_path': info['img_path'], + } + info["calib"] = { + "R0_rect": info['calib/R0_rect'], + "Tr_velo_to_cam": info['calib/Tr_velo_to_cam'], + "P2": info['calib/P2'], + } + info["point_cloud"] = { + "velodyne_path": info['velodyne_path'], + } + def label_str_to_int(labels, remove_dontcare=True, dtype=np.int32): class_to_label = get_class_to_label_map() diff --git a/second/data/preprocess.py b/second/data/preprocess.py index 65e867f7..39ecd488 100644 --- a/second/data/preprocess.py +++ b/second/data/preprocess.py @@ -2,7 +2,9 @@ import pickle import time from collections import defaultdict +from functools import partial +import cv2 import numpy as np from skimage import io as imgio @@ -10,14 +12,16 @@ from second.core import preprocess as prep from second.core.geometry import points_in_convex_polygon_3d_jit from second.data import kitti_common as kitti +from second.utils import simplevis -def merge_second_batch(batch_list, _unused=False): +def merge_second_batch(batch_list, unlabeled_training=False): example_merged = defaultdict(list) for example in batch_list: for k, v in example.items(): example_merged[k].append(v) ret = {} + voxel_nums_list = example_merged["num_voxels"] example_merged.pop("num_voxels") for key, elems in example_merged.items(): if key in [ @@ -25,22 +29,59 @@ def merge_second_batch(batch_list, _unused=False): 'match_indices' ]: ret[key] = np.concatenate(elems, axis=0) + elif key == 'metadata': + ret[key] = elems + elif key == 'match_indices_num': ret[key] = np.concatenate(elems, axis=0) + elif key == "calib": + ret[key] = {} + for elem in elems: + for k1, v1 in elem.items(): + if k1 not in ret[key]: + ret[key][k1] = [v1] + else: + ret[key][k1].append(v1) + for k1, v1 in ret[key].items(): + ret[key][k1] = np.stack(v1, axis=0) elif key == 'coordinates': coors = [] - for i, coor in enumerate(elems): - coor_pad = np.pad( - coor, ((0, 0), (1, 0)), - mode='constant', - constant_values=i) - coors.append(coor_pad) + if unlabeled_training: + batch_idx = 0 + for i, coor in enumerate(elems): + idx = 0 + for voxel_num in voxel_nums_list[i]: + coor_pad = np.pad( + coor[idx:idx + voxel_num], ((0, 0), (1, 0)), + mode='constant', + constant_values=batch_idx) + coors.append(coor_pad) + idx += voxel_num + batch_idx += 1 + else: + for i, coor in enumerate(elems): + coor_pad = np.pad( + coor, ((0, 0), (1, 0)), + mode='constant', + constant_values=i) + coors.append(coor_pad) ret[key] = np.concatenate(coors, axis=0) else: - ret[key] = np.stack(elems, axis=0) + if unlabeled_training: + ret[key] = np.concatenate(elems, axis=0) + else: + ret[key] = np.stack(elems, axis=0) return ret +def _dict_select(dict_, inds): + for k, v in dict_.items(): + if isinstance(v, dict): + _dict_select(v, inds) + else: + dict_[k] = v[inds] + + def prep_pointcloud(input_dict, root_path, voxel_generator, @@ -71,121 +112,151 @@ def prep_pointcloud(input_dict, random_crop=False, reference_detections=None, add_rgb_to_points=False, - lidar_input=False, unlabeled_db_sampler=None, - out_size_factor=2, + downsample_factor=2, min_gt_point_dict=None, bev_only=False, use_group_id=False, out_dtype=np.float32): """convert point cloud to voxels, create targets if ground truths exists. + + input_dict format: + { + points: [N, 3+] + ground_truth: { + gt_boxes: [num_gt, 7], must in lidar coord, must be center format + gt_names: [num_gt], must be np.ndarray, dtype=np.str + [optional]difficulty: [num_gt] + [optional]group_ids: [num_gt] + } + [optional, for kitti]image: { + image_shape: ... + image_idx: ... + image_path: ... + } + [optional, for kitti]calib: { + rect: ... + Trv2c: ... + P2: ... + } + } + """ + # t = time.time() points = input_dict["points"] if training: - gt_boxes = input_dict["gt_boxes"] - gt_names = input_dict["gt_names"] - difficulty = input_dict["difficulty"] - group_ids = None - if use_group_id and "group_ids" in input_dict: - group_ids = input_dict["group_ids"] - rect = input_dict["rect"] - Trv2c = input_dict["Trv2c"] - P2 = input_dict["P2"] - unlabeled_training = unlabeled_db_sampler is not None - image_idx = input_dict["image_idx"] + ground_truth_dict = input_dict["ground_truth"] + gt_dict = { + "gt_boxes": ground_truth_dict["gt_boxes"], + "gt_names": ground_truth_dict["gt_names"] + } + if "difficulty" not in ground_truth_dict: + difficulty = np.zeros([ground_truth_dict["gt_boxes"].shape[0]], + dtype=np.int32) + gt_dict["difficulty"] = difficulty + if use_group_id and "group_ids" in ground_truth_dict: + group_ids = ground_truth_dict["group_ids"] + gt_dict["group_ids"] = group_ids + calib = None + if "calib" in input_dict: + calib = input_dict["calib"] + if add_rgb_to_points: + assert calib is not None and "image" in input_dict + image_path = input_dict["image"]["image_path"] + image = imgio.imread(str(pathlib.Path(root_path) / image_path)).astype( + np.float32) / 255 + points_rgb = box_np_ops.add_rgb_to_points(points, image, calib["rect"], + calib["Trv2c"], calib["P2"]) + points = np.concatenate([points, points_rgb], axis=1) + num_point_features += 3 if reference_detections is not None: - C, R, T = box_np_ops.projection_matrix_to_CRT_kitti(P2) + assert calib is not None and "image" in input_dict + C, R, T = box_np_ops.projection_matrix_to_CRT_kitti(calib["P2"]) frustums = box_np_ops.get_frustum_v2(reference_detections, C) frustums -= T - # frustums = np.linalg.inv(R) @ frustums.T frustums = np.einsum('ij, akj->aki', np.linalg.inv(R), frustums) - frustums = box_np_ops.camera_to_lidar(frustums, rect, Trv2c) + frustums = box_np_ops.camera_to_lidar(frustums, calib["rect"], + calib["Trv2c"]) surfaces = box_np_ops.corner_to_surfaces_3d_jit(frustums) masks = points_in_convex_polygon_3d_jit(points, surfaces) points = points[masks.any(-1)] - if remove_outside_points and not lidar_input: - image_shape = input_dict["image_shape"] - points = box_np_ops.remove_outside_points(points, rect, Trv2c, P2, - image_shape) + if remove_outside_points: + assert calib is not None + image_shape = input_dict["image"]["image_shape"] + points = box_np_ops.remove_outside_points( + points, calib["rect"], calib["Trv2c"], calib["P2"], image_shape) if remove_environment is True and training: - selected = kitti.keep_arrays_by_name(gt_names, class_names) - gt_boxes = gt_boxes[selected] - gt_names = gt_names[selected] - difficulty = difficulty[selected] - if group_ids is not None: - group_ids = group_ids[selected] - points = prep.remove_points_outside_boxes(points, gt_boxes) - if training: - # print(gt_names) - selected = kitti.drop_arrays_by_name(gt_names, ["DontCare"]) - gt_boxes = gt_boxes[selected] - gt_names = gt_names[selected] - difficulty = difficulty[selected] - if group_ids is not None: - group_ids = group_ids[selected] + selected = kitti.keep_arrays_by_name(gt_dict["gt_names"], class_names) + _dict_select(gt_dict, selected) + masks = box_np_ops.points_in_rbbox(points, gt_dict["gt_boxes"]) + points = points[masks.any(-1)] - gt_boxes = box_np_ops.box_camera_to_lidar(gt_boxes, rect, Trv2c) + if training: + boxes_lidar = gt_dict["gt_boxes"] + bev_map = simplevis.kitti_vis(points, boxes_lidar) + cv2.imshow('pre-noise', bev_map) + selected = kitti.drop_arrays_by_name(gt_dict["gt_names"], ["DontCare"]) + _dict_select(gt_dict, selected) if remove_unknown: - remove_mask = difficulty == -1 + remove_mask = gt_dict["difficulty"] == -1 """ gt_boxes_remove = gt_boxes[remove_mask] gt_boxes_remove[:, 3:6] += 0.25 points = prep.remove_points_in_boxes(points, gt_boxes_remove) """ keep_mask = np.logical_not(remove_mask) - gt_boxes = gt_boxes[keep_mask] - gt_names = gt_names[keep_mask] - difficulty = difficulty[keep_mask] - if group_ids is not None: - group_ids = group_ids[keep_mask] + _dict_select(gt_dict, keep_mask) gt_boxes_mask = np.array( - [n in class_names for n in gt_names], dtype=np.bool_) + [n in class_names for n in gt_dict["gt_names"]], dtype=np.bool_) if db_sampler is not None: + group_ids = None + if "group_ids" in gt_dict: + group_ids = gt_dict["group_ids"] + sampled_dict = db_sampler.sample_all( root_path, - gt_boxes, - gt_names, + gt_dict["gt_boxes"], + gt_dict["gt_names"], num_point_features, random_crop, gt_group_ids=group_ids, - rect=rect, - Trv2c=Trv2c, - P2=P2) + calib=calib) if sampled_dict is not None: sampled_gt_names = sampled_dict["gt_names"] sampled_gt_boxes = sampled_dict["gt_boxes"] sampled_points = sampled_dict["points"] sampled_gt_masks = sampled_dict["gt_masks"] - # gt_names = gt_names[gt_boxes_mask].tolist() - gt_names = np.concatenate([gt_names, sampled_gt_names], axis=0) - # gt_names += [s["name"] for s in sampled] - gt_boxes = np.concatenate([gt_boxes, sampled_gt_boxes]) + gt_dict["gt_names"] = np.concatenate( + [gt_dict["gt_names"], sampled_gt_names], axis=0) + gt_dict["gt_boxes"] = np.concatenate( + [gt_dict["gt_boxes"], sampled_gt_boxes]) gt_boxes_mask = np.concatenate( [gt_boxes_mask, sampled_gt_masks], axis=0) if group_ids is not None: sampled_group_ids = sampled_dict["group_ids"] - group_ids = np.concatenate([group_ids, sampled_group_ids]) + gt_dict["group_ids"] = np.concatenate( + [gt_dict["group_ids"], sampled_group_ids]) if remove_points_after_sample: - points = prep.remove_points_in_boxes( - points, sampled_gt_boxes) + masks = box_np_ops.points_in_rbbox(points, + sampled_gt_boxes) + points = points[np.logical_not(masks.any(-1))] points = np.concatenate([sampled_points, points], axis=0) - # unlabeled_mask = np.zeros((gt_boxes.shape[0], ), dtype=np.bool_) - if without_reflectivity: - used_point_axes = list(range(num_point_features)) - used_point_axes.pop(3) - points = points[:, used_point_axes] pc_range = voxel_generator.point_cloud_range if bev_only: # set z and h to limits - gt_boxes[:, 2] = pc_range[2] - gt_boxes[:, 5] = pc_range[5] - pc_range[2] + gt_dict["gt_boxes"][:, 2] = pc_range[2] + gt_dict["gt_boxes"][:, 5] = pc_range[5] - pc_range[2] + group_ids = None + if "group_ids" in gt_dict: + group_ids = gt_dict["group_ids"] + prep.noise_per_object_v3_( - gt_boxes, + gt_dict["gt_boxes"], points, gt_boxes_mask, rotation_perturb=gt_rotation_noise, @@ -193,31 +264,33 @@ def prep_pointcloud(input_dict, global_random_rot_range=global_random_rot_range, group_ids=group_ids, num_try=100) + # should remove unrelated objects after noise per object - gt_boxes = gt_boxes[gt_boxes_mask] - gt_names = gt_names[gt_boxes_mask] - if group_ids is not None: - group_ids = group_ids[gt_boxes_mask] + _dict_select(gt_dict, gt_boxes_mask) gt_classes = np.array( - [class_names.index(n) + 1 for n in gt_names], dtype=np.int32) + [class_names.index(n) + 1 for n in gt_dict["gt_names"]], + dtype=np.int32) + gt_dict["gt_classes"] = gt_classes - gt_boxes, points = prep.random_flip(gt_boxes, points) - gt_boxes, points = prep.global_rotation( - gt_boxes, points, rotation=global_rotation_noise) - gt_boxes, points = prep.global_scaling_v2(gt_boxes, points, - *global_scaling_noise) + gt_dict["gt_boxes"], points = prep.random_flip(gt_dict["gt_boxes"], + points) + gt_dict["gt_boxes"], points = prep.global_rotation( + gt_dict["gt_boxes"], points, rotation=global_rotation_noise) + gt_dict["gt_boxes"], points = prep.global_scaling_v2( + gt_dict["gt_boxes"], points, *global_scaling_noise) bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]] - mask = prep.filter_gt_box_outside_range(gt_boxes, bv_range) - gt_boxes = gt_boxes[mask] - gt_classes = gt_classes[mask] - gt_names = gt_names[mask] - if group_ids is not None: - group_ids = group_ids[mask] + mask = prep.filter_gt_box_outside_range(gt_dict["gt_boxes"], bv_range) + _dict_select(gt_dict, mask) # limit rad to [-pi, pi] - gt_boxes[:, 6] = box_np_ops.limit_period( - gt_boxes[:, 6], offset=0.5, period=2 * np.pi) + gt_dict["gt_boxes"][:, 6] = box_np_ops.limit_period( + gt_dict["gt_boxes"][:, 6], offset=0.5, period=2 * np.pi) + + boxes_lidar = gt_dict["gt_boxes"] + bev_map = simplevis.kitti_vis(points, boxes_lidar) + cv2.imshow('post-noise', bev_map) + cv2.waitKey(0) if shuffle_points: # shuffle is a little slow. @@ -231,41 +304,31 @@ def prep_pointcloud(input_dict, voxels, coordinates, num_points = voxel_generator.generate( points, max_voxels) - example = { 'voxels': voxels, 'num_points': num_points, 'coordinates': coordinates, "num_voxels": np.array([voxels.shape[0]], dtype=np.int64) } - example.update({ - 'rect': rect, - 'Trv2c': Trv2c, - 'P2': P2, - }) - # if not lidar_input: - feature_map_size = grid_size[:2] // out_size_factor + if calib is not None: + example["calib"] = calib + feature_map_size = grid_size[:2] // downsample_factor feature_map_size = [*feature_map_size, 1][::-1] if anchor_cache is not None: anchors = anchor_cache["anchors"] anchors_bv = anchor_cache["anchors_bv"] - matched_thresholds = anchor_cache["matched_thresholds"] - unmatched_thresholds = anchor_cache["unmatched_thresholds"] anchors_dict = anchor_cache["anchors_dict"] else: ret = target_assigner.generate_anchors(feature_map_size) anchors = ret["anchors"] anchors = anchors.reshape([-1, 7]) - matched_thresholds = ret["matched_thresholds"] - unmatched_thresholds = ret["unmatched_thresholds"] anchors_dict = target_assigner.generate_anchors_dict(feature_map_size) anchors_bv = box_np_ops.rbbox2d_to_near_bbox( anchors[:, [0, 1, 3, 4, 6]]) example["anchors"] = anchors - # print("debug", anchors.shape, matched_thresholds.shape) - # anchors_bv = anchors_bv.reshape([-1, 4]) anchors_mask = None if anchor_area_threshold >= 0: + # slow with high resolution. recommend disable this forever. coors = coordinates dense_voxel_map = box_np_ops.sparse_sum_for_anchors_mask( coors, tuple(grid_size[::-1][1:])) @@ -278,73 +341,26 @@ def prep_pointcloud(input_dict, example['anchors_mask'] = anchors_mask if not training: return example + # voxel_labels = box_np_ops.assign_label_to_voxel(gt_boxes, coordinates, + # voxel_size, coors_range) + """ + example.update({ + 'gt_boxes': gt_boxes.astype(out_dtype), + 'num_gt': np.array([gt_boxes.shape[0]]), + # 'voxel_labels': voxel_labels, + }) + """ if create_targets: targets_dict = target_assigner.assign_v2( anchors_dict, - gt_boxes, + gt_dict["gt_boxes"], anchors_mask, - gt_classes=gt_classes, - gt_names=gt_names) + gt_classes=gt_dict["gt_classes"], + gt_names=gt_dict["gt_names"]) + example.update({ 'labels': targets_dict['labels'], 'reg_targets': targets_dict['bbox_targets'], 'reg_weights': targets_dict['bbox_outside_weights'], }) return example - - -def _read_and_prep_v9(info, root_path, num_point_features, prep_func): - """read data from KITTI-format infos, then call prep function. - """ - # velodyne_path = str(pathlib.Path(root_path) / info['velodyne_path']) - # velodyne_path += '_reduced' - v_path = pathlib.Path(root_path) / info['velodyne_path'] - v_path = v_path.parent.parent / ( - v_path.parent.stem + "_reduced") / v_path.name - - points = np.fromfile( - str(v_path), dtype=np.float32, - count=-1).reshape([-1, num_point_features]) - image_idx = info['image_idx'] - rect = info['calib/R0_rect'].astype(np.float32) - Trv2c = info['calib/Tr_velo_to_cam'].astype(np.float32) - P2 = info['calib/P2'].astype(np.float32) - - input_dict = { - 'points': points, - 'rect': rect, - 'Trv2c': Trv2c, - 'P2': P2, - 'image_shape': np.array(info["img_shape"], dtype=np.int32), - 'image_idx': image_idx, - 'image_path': info['img_path'], - # 'pointcloud_num_features': num_point_features, - } - - if 'annos' in info: - annos = info['annos'] - # we need other objects to avoid collision when sample - annos = kitti.remove_dontcare(annos) - loc = annos["location"] - dims = annos["dimensions"] - rots = annos["rotation_y"] - gt_names = annos["name"] - # print(gt_names, len(loc)) - gt_boxes = np.concatenate( - [loc, dims, rots[..., np.newaxis]], axis=1).astype(np.float32) - # gt_boxes = box_np_ops.box_camera_to_lidar(gt_boxes, rect, Trv2c) - difficulty = annos["difficulty"] - input_dict.update({ - 'gt_boxes': gt_boxes, - 'gt_names': gt_names, - 'difficulty': difficulty, - }) - if 'group_ids' in annos: - input_dict['group_ids'] = annos["group_ids"] - example = prep_func(input_dict=input_dict) - example["image_idx"] = image_idx - example["image_shape"] = input_dict["image_shape"] - if "anchors_mask" in example: - example["anchors_mask"] = example["anchors_mask"].astype(np.uint8) - return example - diff --git a/second/kittiviewer/backend.py b/second/kittiviewer/backend.py index f02256e7..4b3a1e91 100644 --- a/second/kittiviewer/backend.py +++ b/second/kittiviewer/backend.py @@ -7,17 +7,28 @@ from flask import Flask, jsonify, request from flask_cors import CORS +import fire + +import io as sysio +import json import pickle import sys +import time from functools import partial from pathlib import Path +import datetime +import fire +import matplotlib.pyplot as plt +import numba +import skimage +from shapely.geometry import Polygon +from skimage import io import second.core.box_np_ops as box_np_ops import second.core.preprocess as prep from second.core.box_coders import GroundBox3dCoder from second.core.region_similarity import ( DistanceSimilarity, NearestIouSimilarity, RotateIouSimilarity) -from second.core.sample_ops import DataBaseSamplerV2 from second.core.target_assigner import TargetAssigner from second.data import kitti_common as kitti from second.protos import pipeline_pb2 @@ -73,7 +84,7 @@ def readinfo(): with open(info_path, 'rb') as f: kitti_infos = pickle.load(f) BACKEND.kitti_infos = kitti_infos - BACKEND.image_idxes = [info["image_idx"] for info in kitti_infos] + BACKEND.image_idxes = [info["image"]["image_idx"] for info in kitti_infos] response["image_indexes"] = BACKEND.image_idxes response = jsonify(results=[response]) @@ -112,12 +123,19 @@ def get_pointcloud(): if BACKEND.kitti_infos is None: return error_response("kitti info is not loaded") image_idx = instance["image_idx"] + enable_int16 = instance["enable_int16"] + idx = BACKEND.image_idxes.index(image_idx) kitti_info = BACKEND.kitti_infos[idx] - rect = kitti_info['calib/R0_rect'] - P2 = kitti_info['calib/P2'] - Trv2c = kitti_info['calib/Tr_velo_to_cam'] - img_shape = kitti_info["img_shape"] # hw + pc_info = kitti_info["point_cloud"] + image_info = kitti_info["image"] + calib = kitti_info["calib"] + + rect = calib['R0_rect'] + Trv2c = calib['Tr_velo_to_cam'] + P2 = calib['P2'] + + img_shape = image_info["image_shape"] # hw wh = np.array(img_shape[::-1]) whwh = np.tile(wh, 2) if 'annos' in kitti_info: @@ -143,10 +161,22 @@ def get_pointcloud(): response["bbox"] = bbox.tolist() response["labels"] = labels[:num_obj].tolist() - - v_path = str(Path(BACKEND.root_path) / kitti_info['velodyne_path']) - with open(v_path, 'rb') as f: - pc_str = base64.encodestring(f.read()) + response["num_features"] = pc_info["num_features"] + v_path = str(Path(BACKEND.root_path) / pc_info['velodyne_path']) + points = np.fromfile( + v_path, dtype=np.float32, count=-1).reshape([-1, pc_info["num_features"]]) + if instance['remove_outside']: + if 'image_shape' in image_info: + image_shape = image_info['image_shape'] + points = box_np_ops.remove_outside_points( + points, rect, Trv2c, P2, image_shape) + else: + points = points[points[..., 2] > 0] + if enable_int16: + int16_factor = instance["int16_factor"] + points *= int16_factor + points = points.astype(np.int16) + pc_str = base64.b64encode(points.tobytes()) response["pointcloud"] = pc_str.decode("utf-8") if "with_det" in instance and instance["with_det"]: if BACKEND.dt_annos is None: @@ -178,7 +208,7 @@ def get_pointcloud(): # response["score"] = score.tolist() response = jsonify(results=[response]) response.headers['Access-Control-Allow-Headers'] = '*' - print("send response!") + print("send response with size {}!".format(len(pc_str))) return response @app.route('/api/get_image', methods=['POST']) @@ -193,25 +223,16 @@ def get_image(): image_idx = instance["image_idx"] idx = BACKEND.image_idxes.index(image_idx) kitti_info = BACKEND.kitti_infos[idx] - rect = kitti_info['calib/R0_rect'] - P2 = kitti_info['calib/P2'] - Trv2c = kitti_info['calib/Tr_velo_to_cam'] - if 'img_path' in kitti_info: - img_path = kitti_info['img_path'] - if img_path != "": - image_path = BACKEND.root_path / img_path + image_info = kitti_info["image"] + if 'image_path' in image_info: + image_path = image_info['image_path'] + if image_path != "": + image_path = BACKEND.root_path / image_path print(image_path) with open(str(image_path), 'rb') as f: image_str = f.read() response["image_b64"] = base64.b64encode(image_str).decode("utf-8") response["image_b64"] = 'data:image/{};base64,'.format(image_path.suffix[1:]) + response["image_b64"] - '''# - response["rect"] = rect.tolist() - response["P2"] = P2.tolist() - response["Trv2c"] = Trv2c.tolist() - response["L2CMat"] = ((rect @ Trv2c).T).tolist() - response["C2LMat"] = np.linalg.inv((rect @ Trv2c).T).tolist() - ''' print("send an image with size {}!".format(len(response["image_b64"]))) response = jsonify(results=[response]) response.headers['Access-Control-Allow-Headers'] = '*' @@ -253,24 +274,30 @@ def inference_by_idx(): if BACKEND.inference_ctx is None: return error_response("inference_ctx is not loaded") image_idx = instance["image_idx"] + remove_outside = instance["remove_outside"] idx = BACKEND.image_idxes.index(image_idx) kitti_info = BACKEND.kitti_infos[idx] + pc_info = kitti_info["point_cloud"] + image_info = kitti_info["image"] + calib = kitti_info["calib"] - v_path = str(Path(BACKEND.root_path) / kitti_info['velodyne_path']) + v_path = str(Path(BACKEND.root_path) / pc_info['velodyne_path']) num_features = 4 points = np.fromfile( str(v_path), dtype=np.float32, count=-1).reshape([-1, num_features]) - rect = kitti_info['calib/R0_rect'] - P2 = kitti_info['calib/P2'] - Trv2c = kitti_info['calib/Tr_velo_to_cam'] - if 'img_shape' in kitti_info: - image_shape = kitti_info['img_shape'] + + + rect = calib['R0_rect'] + Trv2c = calib['Tr_velo_to_cam'] + P2 = calib['P2'] + if remove_outside and 'image_shape' in image_info: + image_shape = image_info['image_shape'] points = box_np_ops.remove_outside_points( points, rect, Trv2c, P2, image_shape) print(points.shape[0]) - img_shape = kitti_info["img_shape"] # hw - wh = np.array(img_shape[::-1]) + image_shape = image_info["image_shape"] # hw + wh = np.array(image_shape[::-1]) whwh = np.tile(wh, 2) t = time.time() diff --git a/second/kittiviewer/frontend/index.html b/second/kittiviewer/frontend/index.html index ac1c31ab..37dcfc3a 100644 --- a/second/kittiviewer/frontend/index.html +++ b/second/kittiviewer/frontend/index.html @@ -267,7 +267,8 @@ // pointParicle.position.needsUpdate = true; // required after the first render var gui = new dat.GUI(); var coreParams = { - backgroundcolor: "#000000" + backgroundcolor: "#000000", + useCameraHelper: true, }; var cameraGui = gui.addFolder("core"); cameraGui.add(camera, "fov"); @@ -276,6 +277,16 @@ .onChange(function (value) { renderer.setClearColor(value, 1); }); + /* + cameraGui.add(coreParams, "useCameraHelper") + .onChange(function (value) { + if (value){ + scene.remove(camhelper); + }else{ + scene.add(camhelper); + } + }); + */ cameraGui.open(); var kittiGui = gui.addFolder("kitti controllers"); kittiGui.add(viewer, "backend").onChange(function (value) { @@ -320,6 +331,10 @@ kittiGui.add(viewer, "prev"); */ kittiGui.add(viewer, "inference"); + kittiGui.add(viewer, "enableInt16"); + kittiGui.add(viewer, "int16Factor", 1, 200); + kittiGui.add(viewer, "removeOutside"); + viewer.screenshot = function(){ viewer.saveAsImage(renderer); }; @@ -386,6 +401,7 @@ camera.aspect = window.innerWidth / window.innerHeight; camera.updateProjectionMatrix(); camhelper.update(); + renderer.setSize(window.innerWidth, window.innerHeight); labelRenderer.setSize(window.innerWidth, window.innerHeight); } diff --git a/second/kittiviewer/frontend/js/KittiViewer.js b/second/kittiviewer/frontend/js/KittiViewer.js index b8f3207c..5b2fa09e 100644 --- a/second/kittiviewer/frontend/js/KittiViewer.js +++ b/second/kittiviewer/frontend/js/KittiViewer.js @@ -22,6 +22,9 @@ var KittiViewer = function (pointCloud, logger, imageCanvas) { this.logger = logger; this.imageCanvas = imageCanvas; this.image = ''; + this.enableInt16 = true; + this.int16Factor = 100; + this.removeOutside = true; }; KittiViewer.prototype = { @@ -114,7 +117,7 @@ KittiViewer.prototype = { }, inference: function( ){ let self = this; - let data = {"image_idx": self.imageIndex}; + let data = {"image_idx": self.imageIndex, "remove_outside": self.removeOutside}; return $.ajax({ url: this.addhttp(this.backend) + '/api/inference_by_idx', method: 'POST', @@ -205,6 +208,9 @@ KittiViewer.prototype = { let data = {}; data["image_idx"] = image_idx; data["with_det"] = this.drawDet; + data["enable_int16"] = this.enableInt16; + data["int16_factor"] = this.int16Factor; + data["remove_outside"] = this.removeOutside; let self = this; var ajax1 = $.ajax({ url: this.addhttp(this.backend) + '/api/get_pointcloud', @@ -219,18 +225,23 @@ KittiViewer.prototype = { self.clear(); response = response["results"][0]; var points_buf = str2buffer(atob(response["pointcloud"])); - var points = new Float32Array(points_buf); - if (response.hasOwnProperty("dims")){ - var locs = response["locs"]; - var dims = response["dims"]; - - var rots = response["rots"]; - var labels = response["labels"]; - self.gtBboxes = response["bbox"]; - self.gtBoxes = boxEdgeWithLabel(dims, locs, rots, 2, - self.gtBoxColor, labels, - self.gtLabelColor); + var points; + if (self.enableInt16){ + var points = new Int16Array(points_buf); } + else{ + var points = new Float32Array(points_buf); + } + var numFeatures = response["num_features"]; + var locs = response["locs"]; + var dims = response["dims"]; + + var rots = response["rots"]; + var labels = response["labels"]; + self.gtBboxes = response["bbox"]; + self.gtBoxes = boxEdgeWithLabel(dims, locs, rots, 2, + self.gtBoxColor, labels, + self.gtLabelColor); // var boxes = boxEdge(dims, locs, rots, 2, "rgb(0, 255, 0)"); for (var i = 0; i < self.gtBoxes.length; ++i) { scene.add(self.gtBoxes[i]); @@ -242,6 +253,7 @@ KittiViewer.prototype = { var rots = response["dt_rots"]; var scores = response["dt_scores"]; self.dtBboxes = response["dt_bbox"]; + console.log("draw det", dims.length); let label_with_score = []; for (var i = 0; i < locs.length; ++i) { label_with_score.push("score=" + scores[i].toFixed(2).toString()); @@ -253,17 +265,18 @@ KittiViewer.prototype = { scene.add(self.dtBoxes[i]); } } - for (var i = 0; i < Math.min(points.length / 4, self.maxPoints); i++) { - self.pointCloud.geometry.attributes.position.array[i * 3] = points[ - i * 4]; - self.pointCloud.geometry.attributes.position.array[i * 3 + 1] = - points[i * 4 + - 1]; - self.pointCloud.geometry.attributes.position.array[i * 3 + 2] = - points[i * 4 + - 2]; + for (var i = 0; i < Math.min(points.length / numFeatures, self.maxPoints); i++) { + for (var j = 0; j < numFeatures; ++j){ + self.pointCloud.geometry.attributes.position.array[i * 3 + j] = points[ + i * numFeatures + j]; + } } - self.pointCloud.geometry.setDrawRange(0, Math.min(points.length / 4, + if (self.enableInt16){ + for (var i = 0; i < self.pointCloud.geometry.attributes.position.array.length; i++) { + self.pointCloud.geometry.attributes.position.array[i] /=self.int16Factor; + } + } + self.pointCloud.geometry.setDrawRange(0, Math.min(points.length / numFeatures, self.maxPoints)); self.pointCloud.geometry.attributes.position.needsUpdate = true; self.pointCloud.geometry.computeBoundingSphere(); @@ -296,7 +309,7 @@ KittiViewer.prototype = { console.log("out of range!"); } } - }, + }, drawImage : function(){ if (this.image === ''){ console.log("??????"); @@ -366,4 +379,4 @@ KittiViewer.prototype = { } } -} \ No newline at end of file +} diff --git a/second/pytorch/builder/second_builder.py b/second/pytorch/builder/second_builder.py index 32a40947..afeeb3ef 100644 --- a/second/pytorch/builder/second_builder.py +++ b/second/pytorch/builder/second_builder.py @@ -1,3 +1,17 @@ +# Copyright 2017 yanyan. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== """VoxelNet builder. """ @@ -25,6 +39,7 @@ def build(model_cfg: second_pb2.VoxelNet, voxel_generator, 0: LossNormType.NormByNumExamples, 1: LossNormType.NormByNumPositives, 2: LossNormType.NormByNumPosNeg, + 3: LossNormType.DontNorm, } loss_norm_type = loss_norm_type_dict[model_cfg.loss_norm_type] @@ -81,5 +96,6 @@ def build(model_cfg: second_pb2.VoxelNet, voxel_generator, cls_loss_ftor=cls_loss_ftor, target_assigner=target_assigner, measure_time=measure_time, + voxel_generator=voxel_generator, ) return net diff --git a/second/pytorch/core/box_torch_ops.py b/second/pytorch/core/box_torch_ops.py index 5c6d5b6c..97e4319a 100644 --- a/second/pytorch/core/box_torch_ops.py +++ b/second/pytorch/core/box_torch_ops.py @@ -20,8 +20,6 @@ def second_box_encode(boxes, anchors, encode_angle_to_vector=False, smooth_dim=F """ xa, ya, za, wa, la, ha, ra = torch.split(anchors, 1, dim=-1) xg, yg, zg, wg, lg, hg, rg = torch.split(boxes, 1, dim=-1) - za = za + ha / 2 - zg = zg + hg / 2 diagonal = torch.sqrt(la**2 + wa**2) xt = (xg - xa) / diagonal yt = (yg - ya) / diagonal @@ -65,7 +63,6 @@ def second_box_decode(box_encodings, anchors, encode_angle_to_vector=False, smoo xt, yt, zt, wt, lt, ht, rt = torch.split(box_encodings, 1, dim=-1) # xt, yt, zt, wt, lt, ht, rt = torch.split(box_encodings, 1, dim=-1) - za = za + ha / 2 diagonal = torch.sqrt(la**2 + wa**2) xg = xt * diagonal + xa yg = yt * diagonal + ya @@ -87,7 +84,6 @@ def second_box_decode(box_encodings, anchors, encode_angle_to_vector=False, smoo rg = torch.atan2(rgy, rgx) else: rg = rt + ra - zg = zg - hg / 2 return torch.cat([xg, yg, zg, wg, lg, hg, rg], dim=-1) def bev_box_encode(boxes, anchors, encode_angle_to_vector=False, smooth_dim=False): @@ -301,7 +297,7 @@ def rotation_2d(points, angles): def center_to_corner_box3d(centers, dims, angles, - origin=[0.5, 1.0, 0.5], + origin=0.5, axis=1): """convert kitti locations, dimensions and angles to corners diff --git a/second/pytorch/inference.py b/second/pytorch/inference.py index 1fb51b65..2d6774a3 100644 --- a/second/pytorch/inference.py +++ b/second/pytorch/inference.py @@ -9,7 +9,7 @@ from second.builder import target_assigner_builder, voxel_builder from second.pytorch.builder import box_coder_builder, second_builder from second.pytorch.models.voxelnet import VoxelNet -from second.pytorch.train import predict_kitti_to_anno, example_convert_to_torch +from second.pytorch.train import predict_to_kitti_label, example_convert_to_torch class TorchInferenceContext(InferenceContext): @@ -38,7 +38,6 @@ def _build(self): out_size_factor = model_cfg.rpn.layer_strides[0] / model_cfg.rpn.upsample_strides[0] out_size_factor *= model_cfg.middle_feature_extractor.downsample_factor out_size_factor = int(out_size_factor) - assert out_size_factor > 0 self.net = second_builder.build(model_cfg, voxel_generator, target_assigner) self.net.cuda().eval() @@ -56,13 +55,14 @@ def _build(self): unmatched_thresholds = ret["unmatched_thresholds"] anchors_bv = box_np_ops.rbbox2d_to_near_bbox( anchors[:, [0, 1, 3, 4, 6]]) - self.anchor_cache = { + anchor_cache = { "anchors": anchors, "anchors_bv": anchors_bv, "matched_thresholds": matched_thresholds, "unmatched_thresholds": unmatched_thresholds, "anchors_dict": anchors_dict, } + self.anchor_cache = anchor_cache def _restore(self, ckpt_path): ckpt_path = Path(ckpt_path) @@ -73,12 +73,8 @@ def _inference(self, example): train_cfg = self.config.train_config input_cfg = self.config.eval_input_reader model_cfg = self.config.model.second - if train_cfg.enable_mixed_precision: - float_dtype = torch.half - else: - float_dtype = torch.float32 - example_torch = example_convert_to_torch(example, float_dtype) - result_annos = predict_kitti_to_anno( + example_torch = example_convert_to_torch(example) + result_annos = predict_to_kitti_label( self.net, example_torch, list( self.target_assigner.classes), model_cfg.post_center_limit_range, model_cfg.lidar_input) diff --git a/second/pytorch/models/rpn.py b/second/pytorch/models/rpn.py index f4190cf4..2cceab27 100644 --- a/second/pytorch/models/rpn.py +++ b/second/pytorch/models/rpn.py @@ -2,6 +2,7 @@ from enum import Enum import numpy as np +import sparseconvnet as scn import torch from torch import nn from torch.nn import functional as F @@ -9,17 +10,18 @@ import torchplus from torchplus.nn import Empty, GroupNorm, Sequential from torchplus.tools import change_default_args +from torchvision.models import resnet class RPN(nn.Module): def __init__(self, use_norm=True, num_class=2, - layer_nums=[3, 5, 5], - layer_strides=[2, 2, 2], - num_filters=[128, 128, 256], - upsample_strides=[1, 2, 4], - num_upsample_filters=[256, 256, 256], + layer_nums=(3, 5, 5), + layer_strides=(2, 2, 2), + num_filters=(128, 128, 256), + upsample_strides=(1, 2, 4), + num_upsample_filters=(256, 256, 256), num_input_features=128, num_anchor_per_loc=2, encode_background_as_zeros=True, @@ -42,8 +44,10 @@ def __init__(self, assert len(num_upsample_filters) == len(layer_nums) factors = [] for i in range(len(layer_nums)): - assert int(np.prod(layer_strides[:i + 1])) % upsample_strides[i] == 0 - factors.append(np.prod(layer_strides[:i + 1]) // upsample_strides[i]) + assert int(np.prod( + layer_strides[:i + 1])) % upsample_strides[i] == 0 + factors.append( + np.prod(layer_strides[:i + 1]) // upsample_strides[i]) assert all([x == factors[0] for x in factors]) if use_norm: if use_groupnorm: @@ -80,7 +84,8 @@ def __init__(self, self.block1 = Sequential( nn.ZeroPad2d(1), Conv2d( - num_input_features, num_filters[0], 3, stride=layer_strides[0]), + num_input_features, num_filters[0], 3, + stride=layer_strides[0]), BatchNorm2d(num_filters[0]), nn.ReLU(), ) @@ -196,17 +201,24 @@ def forward(self, x, bev=None): return ret_dict -class RPNV2(nn.Module): - """Compare with RPN, RPNV2 support arbitrary number of stage. +class RPNBase(nn.Module): + """Base RPN class. you need to subclass this class and implement + _make_layer function to generate block for RPN. + + Notes: + 1. upsample_strides and num_upsample_filters + if len(upsample_strides) == 2 and len(layer_nums) == 3 (3 downsample stage), + then upsample process start with stage 2. + if len(upsample_strides) == 0, no upsample. """ def __init__(self, use_norm=True, num_class=2, - layer_nums=[3, 5, 5], - layer_strides=[2, 2, 2], - num_filters=[128, 128, 256], - upsample_strides=[1, 2, 4], - num_upsample_filters=[256, 256, 256], + layer_nums=(3, 5, 5), + layer_strides=(2, 2, 2), + num_filters=(128, 128, 256), + upsample_strides=(1, 2, 4), + num_upsample_filters=(256, 256, 256), num_input_features=128, num_anchor_per_loc=2, encode_background_as_zeros=True, @@ -217,23 +229,32 @@ def __init__(self, box_code_size=7, use_rc_net=False, name='rpn'): - super(RPNV2, self).__init__() + super(RPNBase, self).__init__() self._num_anchor_per_loc = num_anchor_per_loc self._use_direction_classifier = use_direction_classifier self._use_bev = use_bev self._use_rc_net = use_rc_net - # assert len(layer_nums) == 3 + + self._layer_strides = layer_strides + self._num_filters = num_filters + self._layer_nums = layer_nums + self._upsample_strides = upsample_strides + self._num_upsample_filters = num_upsample_filters + self._num_input_features = num_input_features + self._use_norm = use_norm + self._use_groupnorm = use_groupnorm + self._num_groups = num_groups assert len(layer_strides) == len(layer_nums) assert len(num_filters) == len(layer_nums) - assert len(upsample_strides) == len(layer_nums) - assert len(num_upsample_filters) == len(layer_nums) - """ - factors = [] - for i in range(len(layer_nums)): - assert int(np.prod(layer_strides[:i + 1])) % upsample_strides[i] == 0 - factors.append(np.prod(layer_strides[:i + 1]) // upsample_strides[i]) - assert all([x == factors[0] for x in factors]) - """ + assert len(num_upsample_filters) == len(upsample_strides) + self._upsample_start_idx = len(layer_nums) - len(upsample_strides) + must_equal_list = [] + for i in range(len(upsample_strides)): + must_equal_list.append(upsample_strides[i] / + layer_strides[i + self._upsample_start_idx]) + for val in must_equal_list: + assert val == must_equal_list[0] + if use_norm: if use_groupnorm: BatchNorm2d = change_default_args( @@ -251,64 +272,64 @@ def __init__(self, nn.ConvTranspose2d) in_filters = [num_input_features, *num_filters[:-1]] - # note that when stride > 1, conv2d with same padding isn't - # equal to pad-conv2d. we should use pad-conv2d. blocks = [] deblocks = [] - + for i, layer_num in enumerate(layer_nums): - block = Sequential( - nn.ZeroPad2d(1), - Conv2d( - in_filters[i], num_filters[i], 3, stride=layer_strides[i]), - BatchNorm2d(num_filters[i]), - nn.ReLU(), - ) - for j in range(layer_num): - block.add( - Conv2d(num_filters[i], num_filters[i], 3, padding=1)) - block.add(BatchNorm2d(num_filters[i])) - block.add(nn.ReLU()) + block, num_out_filters = self._make_layer( + in_filters[i], + num_filters[i], + layer_num, + stride=layer_strides[i]) blocks.append(block) - deblock = Sequential( - ConvTranspose2d( - num_filters[i], - num_upsample_filters[i], - upsample_strides[i], - stride=upsample_strides[i]), - BatchNorm2d(num_upsample_filters[i]), - nn.ReLU(), - ) - deblocks.append(deblock) + if i - self._upsample_start_idx >= 0: + deblock = nn.Sequential( + ConvTranspose2d( + num_out_filters, + num_upsample_filters[i - self._upsample_start_idx], + upsample_strides[i - self._upsample_start_idx], + stride=upsample_strides[i - self._upsample_start_idx]), + BatchNorm2d( + num_upsample_filters[i - self._upsample_start_idx]), + nn.ReLU(), + ) + deblocks.append(deblock) self.blocks = nn.ModuleList(blocks) self.deblocks = nn.ModuleList(deblocks) + if encode_background_as_zeros: num_cls = num_anchor_per_loc * num_class else: num_cls = num_anchor_per_loc * (num_class + 1) - self.conv_cls = nn.Conv2d(sum(num_upsample_filters), num_cls, 1) - self.conv_box = nn.Conv2d( - sum(num_upsample_filters), num_anchor_per_loc * box_code_size, 1) + if len(num_upsample_filters) == 0: + final_num_filters = num_out_filters + else: + final_num_filters = sum(num_upsample_filters) + self.conv_cls = nn.Conv2d(final_num_filters, num_cls, 1) + self.conv_box = nn.Conv2d(final_num_filters, + num_anchor_per_loc * box_code_size, 1) if use_direction_classifier: - self.conv_dir_cls = nn.Conv2d( - sum(num_upsample_filters), num_anchor_per_loc * 2, 1) + self.conv_dir_cls = nn.Conv2d(final_num_filters, + num_anchor_per_loc * 2, 1) - if self._use_rc_net: - self.conv_rc = nn.Conv2d( - sum(num_upsample_filters), num_anchor_per_loc * box_code_size, - 1) + @property + def downsample_factor(self): + factor = np.prod(self._layer_strides) + if len(self._upsample_strides) > 0: + factor /= self._upsample_strides[-1] + return factor + + def _make_layer(self, inplanes, planes, num_blocks, stride=1): + raise NotImplementedError def forward(self, x, bev=None): - # t = time.time() - # torch.cuda.synchronize() ups = [] for i in range(len(self.blocks)): x = self.blocks[i](x) - ups.append(self.deblocks[i](x)) - if len(ups) > 1: + if i - self._upsample_start_idx >= 0: + ups.append(self.deblocks[i - self._upsample_start_idx](x)) + if len(ups) > 0: x = torch.cat(ups, dim=1) - else: - x = ups[0] box_preds = self.conv_box(x) cls_preds = self.conv_cls(x) # [N, C, y(H), x(W)] @@ -322,210 +343,81 @@ def forward(self, x, bev=None): dir_cls_preds = self.conv_dir_cls(x) dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous() ret_dict["dir_cls_preds"] = dir_cls_preds - if self._use_rc_net: - rc_preds = self.conv_rc(x) - rc_preds = rc_preds.permute(0, 2, 3, 1).contiguous() - ret_dict["rc_preds"] = rc_preds - # torch.cuda.synchronize() - # print("rpn forward time", time.time() - t) return ret_dict +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) -class Squeeze(nn.Module): - def forward(self, x): - return x.squeeze(2) -class SparseRPN(nn.Module): - """Don't use this. - """ - def __init__(self, - output_shape, - num_input_features=128, - num_filters_down1=[64], - num_filters_down2=[64, 64], - use_norm=True, - num_class=2, - layer_nums=[3, 5, 5], - layer_strides=[2, 2, 2], - num_filters=[128, 128, 256], - upsample_strides=[1, 2, 4], - num_upsample_filters=[256, 256, 256], - num_anchor_per_loc=2, - encode_background_as_zeros=True, - use_direction_classifier=True, - use_groupnorm=False, - num_groups=32, - use_bev=False, - box_code_size=7, - use_rc_net=False, - name='sparse_rpn'): - super(SparseRPN, self).__init__() - self._num_anchor_per_loc = num_anchor_per_loc - self._use_direction_classifier = use_direction_classifier - self.name = name - if use_norm: - BatchNorm2d = change_default_args( - eps=1e-3, momentum=0.01)(nn.BatchNorm2d) - BatchNorm1d = change_default_args( - eps=1e-3, momentum=0.01)(nn.BatchNorm1d) +class ResNetRPN(RPNBase): + def __init__(self, *args, **kw): + self.inplanes = -1 + super(ResNetRPN, self).__init__(*args, **kw) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # if zero_init_residual: + for m in self.modules(): + if isinstance(m, resnet.Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, resnet.BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + + def _make_layer(self, inplanes, planes, num_blocks, stride=1): + if self.inplanes == -1: + self.inplanes = self._num_input_features + block = resnet.BasicBlock # Bottleneck is bad for this project, need tune? + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, + stride), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers), self.inplanes + + +class RPNV2(RPNBase): + def _make_layer(self, inplanes, planes, num_blocks, stride=1): + if self._use_norm: + if self._use_groupnorm: + BatchNorm2d = change_default_args( + num_groups=self._num_groups, eps=1e-3)(GroupNorm) + else: + BatchNorm2d = change_default_args( + eps=1e-3, momentum=0.01)(nn.BatchNorm2d) Conv2d = change_default_args(bias=False)(nn.Conv2d) - SpConv3d = change_default_args(bias=False)(spconv.SparseConv3d) - SubMConv3d = change_default_args(bias=False)(spconv.SubMConv3d) ConvTranspose2d = change_default_args(bias=False)( nn.ConvTranspose2d) else: BatchNorm2d = Empty - BatchNorm1d = Empty Conv2d = change_default_args(bias=True)(nn.Conv2d) - SpConv3d = change_default_args(bias=True)(spconv.SparseConv3d) - SubMConv3d = change_default_args(bias=True)(spconv.SubMConv3d) ConvTranspose2d = change_default_args(bias=True)( nn.ConvTranspose2d) - sparse_shape = np.array(output_shape[1:4]) + [1, 0, 0] - # sparse_shape[0] = 11 - print(sparse_shape) - self.sparse_shape = sparse_shape - self.voxel_output_shape = output_shape - # [11, 400, 352] - self.block1 = spconv.SparseSequential( - SpConv3d( - num_input_features, num_filters[0], 3, stride=[2, layer_strides[0], layer_strides[0]], padding=[0, 1, 1]), - BatchNorm1d(num_filters[0]), - nn.ReLU()) - # [5, 200, 176] - for i in range(layer_nums[0]): - self.block1.add(SubMConv3d( - num_filters[0], num_filters[0], 3, padding=1, indice_key="subm0")) - self.block1.add(BatchNorm1d(num_filters[0])) - self.block1.add(nn.ReLU()) - - self.deconv1 = spconv.SparseSequential( - SpConv3d( - num_filters[0], num_filters[0], (3, 1, 1), stride=(2, 1, 1)), - BatchNorm1d(num_filters[0]), - nn.ReLU(), - SpConv3d( - num_filters[0], num_upsample_filters[0], (2, 1, 1), stride=1), - BatchNorm1d(num_upsample_filters[0]), - nn.ReLU(), - spconv.ToDense(), - Squeeze() - ) # [1, 200, 176] - - # [5, 200, 176] - self.block2 = spconv.SparseSequential( - SpConv3d( - num_filters[0], num_filters[1], 3, stride=[2, layer_strides[1], layer_strides[1]], padding=[0, 1, 1]), - BatchNorm1d(num_filters[1]), - nn.ReLU()) - - for i in range(layer_nums[1]): - self.block2.add(SubMConv3d( - num_filters[1], num_filters[1], 3, padding=1, indice_key="subm1")) - self.block2.add(BatchNorm1d(num_filters[1])) - self.block2.add(nn.ReLU()) - # [2, 100, 88] - self.deconv2 = spconv.SparseSequential( - SpConv3d( - num_filters[1], num_filters[1], (2, 1, 1), stride=1), - BatchNorm1d(num_filters[1]), - nn.ReLU(), - spconv.ToDense(), - Squeeze(), - ConvTranspose2d( - num_filters[1], - num_upsample_filters[1], - upsample_strides[1], - stride=upsample_strides[1]), - BatchNorm2d(num_upsample_filters[1]), - nn.ReLU() - ) # [1, 200, 176] - - self.block3 = spconv.SparseSequential( - SpConv3d( - num_filters[1], num_filters[2], [2, 3, 3], stride=[1, layer_strides[2], layer_strides[2]], padding=[0, 1, 1]), - BatchNorm1d(num_filters[2]), - nn.ReLU()) - for i in range(layer_nums[2]): - self.block3.add(SubMConv3d( - num_filters[2], num_filters[2], 3, padding=1, indice_key="subm2")) - self.block3.add(BatchNorm1d(num_filters[2])) - self.block3.add(nn.ReLU()) - - - self.deconv3 = Sequential( - spconv.ToDense(), - Squeeze(), - ConvTranspose2d( - num_filters[2], - num_upsample_filters[2], - upsample_strides[2], - stride=upsample_strides[2]), - BatchNorm2d(num_upsample_filters[2]), - nn.ReLU(), - ) # [1, 200, 176] - self.post = Sequential( - Conv2d( - sum(num_upsample_filters), - 128, - 3, - stride=1,padding=1), - BatchNorm2d(128), - nn.ReLU(), - Conv2d( - 128, - 64, - 3, - stride=1,padding=1), - BatchNorm2d(64), + block = Sequential( + nn.ZeroPad2d(1), + Conv2d(inplanes, planes, 3, stride=stride), + BatchNorm2d(planes), nn.ReLU(), + ) + for j in range(num_blocks): + block.add(Conv2d(planes, planes, 3, padding=1)) + block.add(BatchNorm2d(planes)) + block.add(nn.ReLU()) - ) # [1, 200, 176] - if encode_background_as_zeros: - num_cls = num_anchor_per_loc * num_class - else: - num_cls = num_anchor_per_loc * (num_class + 1) - '''self.conv_cls = nn.Conv2d(sum(num_upsample_filters), num_cls, 1) - self.conv_box = nn.Conv2d( - sum(num_upsample_filters), num_anchor_per_loc * box_code_size, 1) - if use_direction_classifier: - self.conv_dir_cls = nn.Conv2d( - sum(num_upsample_filters), num_anchor_per_loc * 2, 1) - ''' - self.conv_cls = nn.Conv2d(64, num_cls, 1) - self.conv_box = nn.Conv2d( - 64, num_anchor_per_loc * box_code_size, 1) - if use_direction_classifier: - self.conv_dir_cls = nn.Conv2d( - 64, num_anchor_per_loc * 2, 1) - - - def forward(self, voxel_features, coors, batch_size): - coors = coors.int() - sx = spconv.SparseConvTensor(voxel_features, coors, self.sparse_shape, batch_size) - b1 = self.block1(sx) - b2 = self.block2(b1) - b3 = self.block3(b2) - # print(b1.sparity, b2.sparity, b3.sparity) - up1 = self.deconv1(b1) - up2 = self.deconv2(b2) - up3 = self.deconv3(b3) - x = torch.cat([up1, up2, up3], dim=1) - x = self.post(x) - # out = self.to_dense(out).squeeze(2) - # print("debug1") - box_preds = self.conv_box(x) - cls_preds = self.conv_cls(x) - box_preds = box_preds.permute(0, 2, 3, 1).contiguous() - cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous() - ret_dict = { - "box_preds": box_preds, - "cls_preds": cls_preds, - } - if self._use_direction_classifier: - dir_cls_preds = self.conv_dir_cls(x) - dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous() - ret_dict["dir_cls_preds"] = dir_cls_preds - - return ret_dict + return block, planes diff --git a/second/pytorch/models/voxelnet.py b/second/pytorch/models/voxelnet.py index b4ec7f58..470d1348 100644 --- a/second/pytorch/models/voxelnet.py +++ b/second/pytorch/models/voxelnet.py @@ -3,6 +3,7 @@ from functools import reduce import numpy as np +import sparseconvnet as scn import torch from torch import nn from torch.nn import functional as F @@ -40,7 +41,7 @@ class LossNormType(Enum): NormByNumPositives = "norm_by_num_positives" NormByNumExamples = "norm_by_num_examples" NormByNumPosNeg = "norm_by_num_pos_neg" - + DontNorm = "dont_norm" class VoxelNet(nn.Module): def __init__(self, @@ -89,6 +90,7 @@ def __init__(self, loc_loss_ftor=None, cls_loss_ftor=None, measure_time=False, + voxel_generator=None, name='voxelnet'): super().__init__() self.name = name @@ -106,13 +108,18 @@ def __init__(self, self._use_bev = use_bev self._num_input_features = num_input_features self._box_coder = target_assigner.box_coder + self._use_rc_net = use_rc_net self._lidar_only = lidar_only self.target_assigner = target_assigner + self.voxel_generator = voxel_generator self._pos_cls_weight = pos_cls_weight self._neg_cls_weight = neg_cls_weight self._encode_rad_error_by_sin = encode_rad_error_by_sin self._loss_norm_type = loss_norm_type self._dir_loss_ftor = WeightedSoftmaxClassificationLoss() + self._vox_loss_ftor = WeightedSoftmaxClassificationLoss() + self._rc_loss_ftor = WeightedSigmoidClassificationLoss() + self._diff_loc_loss_ftor = WeightedSmoothL1LocalizationLoss() self._loc_loss_ftor = loc_loss_ftor self._cls_loss_ftor = cls_loss_ftor @@ -140,7 +147,7 @@ def __init__(self, else: num_rpn_input_filters = middle_num_filters_d2[-1] - if use_sparse_rpn: # don't use this. just for fun. + if use_sparse_rpn: self.sparse_rpn = rpn.SparseRPN( output_shape, # num_input_features=vfe_num_filters[-1], @@ -184,6 +191,7 @@ def __init__(self, rpn_class_dict = { "RPN": rpn.RPN, "RPNV2": rpn.RPNV2, + "ResNetRPN": rpn.ResNetRPN, } rpn_class = rpn_class_dict[rpn_class_name] self.rpn = rpn_class( @@ -201,7 +209,15 @@ def __init__(self, use_bev=use_bev, use_groupnorm=use_groupnorm, num_groups=num_groups, - box_code_size=target_assigner.box_coder.code_size) + box_code_size=target_assigner.box_coder.code_size, + use_rc_net=use_rc_net) + if use_voxel_classifier: + self.voxel_classifier = SegmentCNN(output_shape, use_norm) + self.vox_acc = metrics.Accuracy(dim=-1) + self.vox_precision = metrics.Precision(dim=-1) + self.vox_recall = metrics.Recall(dim=-1) + else: + self.voxel_classifier = None self.rpn_acc = metrics.Accuracy( dim=-1, encode_background_as_zeros=encode_background_as_zeros) @@ -272,12 +288,9 @@ def forward(self, example): # features: [num_voxels, max_num_points_per_voxel, 7] # num_points: [num_voxels] # coors: [num_voxels, 4] - # t = time.time() self.start_timer("voxel_feature_extractor") voxel_features = self.voxel_feature_extractor(voxels, num_points) self.end_timer("voxel_feature_extractor") - # torch.cuda.synchronize() - # print("vfe time", time.time() - t) if self._use_sparse_rpn: preds_dict = self.sparse_rpn(voxel_features, coors, batch_size_dev) @@ -292,6 +305,9 @@ def forward(self, example): else: preds_dict = self.rpn(spatial_features) self.end_timer("rpn forward") + if self.voxel_classifier is not None: + preds_dict["voxel_logits"] = self.voxel_classifier( + voxel_features, coors, batch_size_dev) # preds_dict["voxel_features"] = voxel_features # preds_dict["spatial_features"] = spatial_features box_preds = preds_dict["box_preds"] @@ -342,6 +358,37 @@ def forward(self, example): dir_logits, dir_targets, weights=weights) dir_loss = dir_loss.sum() / batch_size_dev loss += dir_loss * self._direction_loss_weight + unlabeled_training = False + if unlabeled_training: + diff_loc_loss_weight = 6.0 + match_indices_num = example['match_indices_num'] + if match_indices_num[0] != 0: + match_indices = example['match_indices'] + actual_batch_size = batch_size_dev // 2 + diff_loc_loss = 0 + idx = 0 + box_preds = box_preds.view(batch_size_dev, -1, + self._box_coder.code_size) + for i in range(actual_batch_size): + match_indices_batch = match_indices[ + idx:idx + match_indices_num[i]] + match_indices_batch = match_indices_batch.view(2, -1) + # raise ValueError("test") + lfs = box_preds[2 * i, match_indices_batch[0]] + rfs = box_preds[2 * i + 1, match_indices_batch[1]] + lfs_t = reg_targets[2 * i, match_indices_batch[0]] + rfs_t = reg_targets[2 * i + 1, match_indices_batch[1]] + err = lfs - rfs + err_t = lfs_t - rfs_t + err, err_t = add_sin_difference(err, err_t) + diff_loc_loss += self._diff_loc_loss_ftor( + err.view(1, -1, 7), err_t.view(1, -1, 7)) + idx += match_indices_num[i] + positives = (labels > 0).type_as(diff_loc_loss) + num_pos = torch.clamp(positives.sum() / 2, min=1.0) + diff_loc_loss_reduced = diff_loc_loss.sum( + ) * diff_loc_loss_weight / num_pos + loss += diff_loc_loss_reduced return { "loss": loss, "cls_loss": cls_loss, @@ -356,22 +403,22 @@ def forward(self, example): } else: self.start_timer("predict") - res = self.predict_v2(example, preds_dict) + with torch.no_grad(): + res = self.predict(example, preds_dict) self.end_timer("predict") return res - def predict_v2(self, example, preds_dict): - t = time.time() + def predict(self, example, preds_dict): batch_size = example['anchors'].shape[0] + if "metadata" in example: + meta_list = example["metadata"] + else: + meta_list = [None] * batch_size batch_anchors = example["anchors"].view(batch_size, -1, 7) - batch_rect = example["rect"] - batch_Trv2c = example["Trv2c"] - batch_P2 = example["P2"] if "anchors_mask" not in example: batch_anchors_mask = [None] * batch_size else: batch_anchors_mask = example["anchors_mask"].view(batch_size, -1) - batch_imgidx = example['image_idx'] t = time.time() batch_box_preds = preds_dict["box_preds"] @@ -393,21 +440,17 @@ def predict_v2(self, example, preds_dict): batch_dir_preds = [None] * batch_size predictions_dicts = [] - for box_preds, cls_preds, dir_preds, rect, Trv2c, P2, img_idx, a_mask in zip( - batch_box_preds, batch_cls_preds, batch_dir_preds, batch_rect, - batch_Trv2c, batch_P2, batch_imgidx, batch_anchors_mask): + for box_preds, cls_preds, dir_preds, a_mask, meta in zip( + batch_box_preds, batch_cls_preds, batch_dir_preds, + batch_anchors_mask, meta_list): if a_mask is not None: box_preds = box_preds[a_mask] cls_preds = cls_preds[a_mask] box_preds = box_preds.float() cls_preds = cls_preds.float() - rect = rect.float() - Trv2c = Trv2c.float() - P2 = P2.float() if self._use_direction_classifier: if a_mask is not None: dir_preds = dir_preds[a_mask] - # print(dir_preds.shape) dir_labels = torch.max(dir_preds, dim=-1)[1] if self._encode_background_as_zeros: # this don't support softmax @@ -471,6 +514,7 @@ def predict_v2(self, example, preds_dict): total_scores.shape[0], device=total_scores.device, dtype=torch.long) + else: top_scores, top_labels = torch.max(total_scores, dim=-1) @@ -502,7 +546,6 @@ def predict_v2(self, example, preds_dict): post_max_size=self._nms_post_max_size, iou_threshold=self._nms_iou_threshold, ) - else: selected = [] # if selected is not None: @@ -526,41 +569,53 @@ def predict_v2(self, example, preds_dict): final_box_preds = box_preds final_scores = scores final_labels = label_preds - final_box_preds_camera = box_torch_ops.box_lidar_to_camera( - final_box_preds, rect, Trv2c) - locs = final_box_preds_camera[:, :3] - dims = final_box_preds_camera[:, 3:6] - angles = final_box_preds_camera[:, 6] - camera_box_origin = [0.5, 1.0, 0.5] - box_corners = box_torch_ops.center_to_corner_box3d( - locs, dims, angles, camera_box_origin, axis=1) - box_corners_in_image = box_torch_ops.project_to_image( - box_corners, P2) - # box_corners_in_image: [N, 8, 2] - minxy = torch.min(box_corners_in_image, dim=1)[0] - maxxy = torch.max(box_corners_in_image, dim=1)[0] - box_2d_preds = torch.cat([minxy, maxxy], dim=1) # predictions predictions_dict = { - "bbox": box_2d_preds, - "box3d_camera": final_box_preds_camera, "box3d_lidar": final_box_preds, "scores": final_scores, "label_preds": label_preds, - "image_idx": img_idx, + "metadata": meta, } else: dtype = batch_box_preds.dtype device = batch_box_preds.device predictions_dict = { - "bbox": torch.zeros([0, 4], dtype=dtype, device=device), - "box3d_camera": torch.zeros([0, 7], dtype=dtype, device=device), "box3d_lidar": torch.zeros([0, 7], dtype=dtype, device=device), "scores": torch.zeros([0], dtype=dtype, device=device), "label_preds": torch.zeros([0, 4], dtype=top_labels.dtype, device=device), - "image_idx": img_idx, + "metadata": meta, } predictions_dicts.append(predictions_dict) + if "calib" in example: + batch_rect = example["calib"]["rect"] + batch_Trv2c = example["calib"]["Trv2c"] + batch_P2 = example["calib"]["P2"] + for pred_dict, rect, Trv2c, P2 in zip(predictions_dicts, + batch_rect, batch_Trv2c, batch_P2): + final_box_preds = pred_dict["box3d_lidar"] + if final_box_preds.shape[0] != 0: + final_box_preds[:, 2] -= final_box_preds[:, 5] / 2 + final_box_preds_camera = box_torch_ops.box_lidar_to_camera( + final_box_preds, rect, Trv2c) + + locs = final_box_preds_camera[:, :3] + dims = final_box_preds_camera[:, 3:6] + angles = final_box_preds_camera[:, 6] + camera_box_origin = [0.5, 1.0, 0.5] + box_corners = box_torch_ops.center_to_corner_box3d( + locs, dims, angles, camera_box_origin, axis=1) + box_corners_in_image = box_torch_ops.project_to_image( + box_corners, P2) + # box_corners_in_image: [N, 8, 2] + minxy = torch.min(box_corners_in_image, dim=1)[0] + maxxy = torch.max(box_corners_in_image, dim=1)[0] + box_2d_preds = torch.cat([minxy, maxxy], dim=1) + pred_dict["bbox"] = box_2d_preds + pred_dict["box3d_camera"] = final_box_preds_camera + else: + pred_dict["bbox"] = torch.zeros([0, 4], dtype=dtype, device=device) + pred_dict["box3d_camera"] = torch.zeros([0, 7], dtype=dtype, device=device) + return predictions_dicts @@ -570,6 +625,10 @@ def metrics_to_float(self): self.rpn_cls_loss.float() self.rpn_loc_loss.float() self.rpn_total_loss.float() + if self.voxel_classifier is not None: + self.vox_acc.float() + self.vox_precision.float() + self.vox_recall.float() def update_metrics(self, cls_loss, @@ -601,6 +660,22 @@ def update_metrics(self, for i, thresh in enumerate(self.rpn_metrics.thresholds): ret[f"prec@{int(thresh*100)}"] = float(prec[i]) ret[f"rec@{int(thresh*100)}"] = float(recall[i]) + if self.voxel_classifier is not None: + vox_acc = self.vox_acc( + vox_labels, vox_preds.view(-1, 2), + weights=vox_weights > 0).numpy()[0] + vox_prec = self.vox_precision( + vox_labels, vox_preds.view(-1, 2), + weights=vox_weights > 0).numpy()[0] + vox_recall = self.vox_recall( + vox_labels, vox_preds.view(-1, 2), + weights=vox_weights > 0).numpy()[0] + + ret.update({ + "vox_acc": float(vox_acc), + "vox_prec": float(vox_prec), + "vox_recall": float(vox_recall), + }) return ret def clear_metrics(self): @@ -609,6 +684,11 @@ def clear_metrics(self): self.rpn_cls_loss.clear() self.rpn_loc_loss.clear() self.rpn_total_loss.clear() + if self.voxel_classifier is not None: + self.vox_acc.clear() + self.vox_precision.clear() + self.vox_recall.clear() + pass @staticmethod def convert_norm_to_float(net): @@ -701,6 +781,10 @@ def prepare_loss_weights(labels, normalizer = torch.clamp(normalizer, min=1.0) reg_weights /= normalizer[:, 0:1, 0] cls_weights /= cls_normalizer + elif loss_norm_type == LossNormType.DontNorm: # support ghm loss + pos_normalizer = positives.sum(1, keepdim=True).type(dtype) + reg_weights /= torch.clamp(pos_normalizer, min=1.0) + # pass else: raise ValueError( f"unknown loss norm type. available: {list(LossNormType)}") diff --git a/second/pytorch/train.py b/second/pytorch/train.py index 313d5224..6a3f5324 100644 --- a/second/pytorch/train.py +++ b/second/pytorch/train.py @@ -10,8 +10,9 @@ import torch from google.protobuf import text_format from tensorboardX import SummaryWriter - +import copy import torchplus +from second.core import box_np_ops import second.data.kitti_common as kitti from second.builder import target_assigner_builder, voxel_builder from second.data.preprocess import merge_second_batch @@ -21,7 +22,7 @@ second_builder) from second.utils.eval import get_coco_eval_result, get_official_eval_result from second.utils.progress_bar import ProgressBar - +from second.utils.log_tool import metric_to_str, flat_nested_json_dict def _get_pos_neg_loss(cls_loss, labels): # cls_loss: [N, num_anchors, num_class] @@ -40,37 +41,16 @@ def _get_pos_neg_loss(cls_loss, labels): return cls_pos_loss, cls_neg_loss -def _flat_nested_json_dict(json_dict, flatted, sep=".", start=""): - for k, v in json_dict.items(): - if isinstance(v, dict): - _flat_nested_json_dict(v, flatted, sep, start + sep + k) - else: - flatted[start + sep + k] = v - - -def flat_nested_json_dict(json_dict, sep=".") -> dict: - """flat a nested json-like dict. this function make shadow copy. - """ - flatted = {} - for k, v in json_dict.items(): - if isinstance(v, dict): - _flat_nested_json_dict(v, flatted, sep, k) - else: - flatted[k] = v - return flatted - - def example_convert_to_torch(example, dtype=torch.float32, device=None) -> dict: device = device or torch.device("cuda:0") example_torch = {} float_names = [ - "voxels", "anchors", "reg_targets", "reg_weights", "bev_map", "rect", - "Trv2c", "P2" + "voxels", "anchors", "reg_targets", "reg_weights", "bev_map" ] - for k, v in example.items(): if k in float_names: + # slow when directly provide fp32 data with dtype=torch.half example_torch[k] = torch.tensor(v, dtype=torch.float32, device=device).to(dtype) elif k in ["coordinates", "labels", "num_points"]: example_torch[k] = torch.tensor( @@ -78,10 +58,25 @@ def example_convert_to_torch(example, dtype=torch.float32, elif k in ["anchors_mask"]: example_torch[k] = torch.tensor( v, dtype=torch.uint8, device=device) + elif k == "calib": + calib = {} + for k1, v1 in v.items(): + calib[k1] = torch.tensor(v1, dtype=dtype, device=device).to(dtype) + example_torch[k] = calib else: example_torch[k] = v return example_torch +def build_network(model_cfg, measure_time=False): + voxel_generator = voxel_builder.build(model_cfg.voxel_generator) + bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]] + box_coder = box_coder_builder.build(model_cfg.box_coder) + target_assigner_cfg = model_cfg.target_assigner + target_assigner = target_assigner_builder.build(target_assigner_cfg, + bv_range, box_coder) + class_names = target_assigner.classes + net = second_builder.build(model_cfg, voxel_generator, target_assigner, measure_time=measure_time) + return net def train(config_path, model_dir, @@ -90,51 +85,47 @@ def train(config_path, display_step=50, summary_step=5, pickle_result=True, - patchs=None): + resume=False): """train a VoxelNet model specified by a config file. """ if create_folder: if pathlib.Path(model_dir).exists(): model_dir = torchplus.train.create_folder(model_dir) - patchs = patchs or [] model_dir = pathlib.Path(model_dir) + if not resume and model_dir.exists(): + raise ValueError("model dir exists and you don't specify resume.") model_dir.mkdir(parents=True, exist_ok=True) if result_path is None: result_path = model_dir / 'results' config_file_bkp = "pipeline.config" - config = pipeline_pb2.TrainEvalPipelineConfig() - with open(config_path, "r") as f: - proto_str = f.read() - text_format.Merge(proto_str, config) - for patch in patchs: - patch = "config." + patch - exec(patch) - shutil.copyfile(config_path, str(model_dir / config_file_bkp)) + if isinstance(config_path, str): + # directly provide a config object. this usually used + # when you want to train with several different parameters in + # one script. + config = pipeline_pb2.TrainEvalPipelineConfig() + with open(config_path, "r") as f: + proto_str = f.read() + text_format.Merge(proto_str, config) + else: + config = config_path + proto_str = text_format.MessageToString(config, indent=2) + with (model_dir / config_file_bkp).open("w") as f: + f.write(proto_str) + input_cfg = config.train_input_reader eval_input_cfg = config.eval_input_reader model_cfg = config.model.second train_cfg = config.train_config - - ###################### - # BUILD VOXEL GENERATOR - ###################### - voxel_generator = voxel_builder.build(model_cfg.voxel_generator) - ###################### - # BUILD TARGET ASSIGNER - ###################### - bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]] - box_coder = box_coder_builder.build(model_cfg.box_coder) - target_assigner_cfg = model_cfg.target_assigner - target_assigner = target_assigner_builder.build(target_assigner_cfg, - bv_range, box_coder) + net = build_network(model_cfg).cuda() + if train_cfg.enable_mixed_precision: + net.half() + net.metrics_to_float() + net.convert_norm_to_float(net) + target_assigner = net.target_assigner + voxel_generator = net.voxel_generator class_names = target_assigner.classes - ###################### - # BUILD NET - ###################### - center_limit_range = model_cfg.post_center_limit_range - net = second_builder.build(model_cfg, voxel_generator, target_assigner) - net.cuda() + # net_train = torch.nn.DataParallel(net).cuda() print("num_trainable parameters:", len(list(net.parameters()))) # for n, p in net.named_parameters(): @@ -146,13 +137,10 @@ def train(config_path, torchplus.train.try_restore_latest_checkpoints(model_dir, [net]) gstep = net.get_global_step() - 1 optimizer_cfg = train_cfg.optimizer - if train_cfg.enable_mixed_precision: - net.half() - net.metrics_to_float() - net.convert_norm_to_float(net) loss_scale = train_cfg.loss_scale_factor mixed_optimizer = optimizer_builder.build(optimizer_cfg, net, mixed=train_cfg.enable_mixed_precision, loss_scale=loss_scale) optimizer = mixed_optimizer + center_limit_range = model_cfg.post_center_limit_range """ if train_cfg.enable_mixed_precision: mixed_optimizer = torchplus.train.MixedPrecisionWrapper( @@ -171,7 +159,6 @@ def train(config_path, ###################### # PREPARE INPUT ###################### - dataset = input_reader_builder.build( input_cfg, model_cfg, @@ -184,7 +171,6 @@ def train(config_path, training=False, voxel_generator=voxel_generator, target_assigner=target_assigner) - def _worker_init_fn(worker_id): time_seed = np.array(time.time(), dtype=np.int32) np.random.seed(time_seed + worker_id) @@ -317,10 +303,9 @@ def _worker_init_fn(worker_id): # mixed_optimizer.param_groups[0]['lr']) metrics["lr"] = float( optimizer.lr) - - metrics["image_idx"] = example['image_idx'][0] + if "image_info" in example['metadata'][0]: + metrics["image_idx"] = example['metadata'][0]["image_info"]['image_idx'] training_detail.append(metrics) - flatted_metrics = flat_nested_json_dict(metrics) flatted_summarys = flat_nested_json_dict(metrics, "/") """ for k, v in flatted_summarys.items(): @@ -330,19 +315,7 @@ def _worker_init_fn(worker_id): else: writer.add_scalar(k, v, global_step) """ - metrics_str_list = [] - for k, v in flatted_metrics.items(): - if isinstance(v, float): - metrics_str_list.append(f"{k}={v:.3}") - elif isinstance(v, (list, tuple)): - if v and isinstance(v[0], float): - v_str = ', '.join([f"{e:.3}" for e in v]) - metrics_str_list.append(f"{k}=[{v_str}]") - else: - metrics_str_list.append(f"{k}={v}") - else: - metrics_str_list.append(f"{k}={v}") - log_str = ', '.join(metrics_str_list) + log_str = metric_to_str(metrics) print(log_str, file=logf) print(log_str) ckpt_elasped_time = time.time() - ckpt_start_time @@ -371,15 +344,9 @@ def _worker_init_fn(worker_id): prog_bar.start((len(eval_dataset) + eval_input_cfg.batch_size - 1) // eval_input_cfg.batch_size) for example in iter(eval_dataloader): example = example_convert_to_torch(example, float_dtype) - if pickle_result: - dt_annos += predict_kitti_to_anno( - net, example, class_names, center_limit_range, - model_cfg.lidar_input) - else: - _predict_kitti_to_file(net, example, result_path_step, - class_names, center_limit_range, - model_cfg.lidar_input) - + dt_annos += predict_to_kitti_label( + net, example, class_names, center_limit_range, + model_cfg.lidar_input) prog_bar.print_bar() sec_per_ex = len(eval_dataset) / (time.time() - t) @@ -388,24 +355,18 @@ def _worker_init_fn(worker_id): print( f'generate label finished({sec_per_ex:.2f}/s). start eval:', file=logf) - gt_annos = [ - info["annos"] for info in eval_dataset.dataset.kitti_infos - ] - if not pickle_result: - dt_annos = kitti.get_label_annos(result_path_step) - # result = get_official_eval_result_v2(gt_annos, dt_annos, class_names) - # print(json.dumps(result, indent=2), file=logf) - result = get_official_eval_result(gt_annos, dt_annos, class_names) - print(result, file=logf) - print(result) - writer.add_text('eval_result', json.dumps(result, indent=2), global_step) - result = get_coco_eval_result(gt_annos, dt_annos, class_names) - print(result, file=logf) - print(result) + result_official, result_coco = eval_dataset.dataset.evaluation(dt_annos) + print(result_official) + print(result_official, file=logf) + print(result_coco) + print(result_coco, file=logf) if pickle_result: with open(result_path_step / "result.pkl", 'wb') as f: pickle.dump(dt_annos, f) - writer.add_text('eval_result', result, global_step) + else: + kitti_anno_to_label_file(dt_annos, result_path_step) + writer.add_text('eval_result', result_official, global_step) + writer.add_text('eval_result coco', result_coco, global_step) net.train() except Exception as e: torchplus.train.save_models(model_dir, [net, optimizer], @@ -418,159 +379,126 @@ def _worker_init_fn(worker_id): logf.close() -def _predict_kitti_to_file(net, - example, - result_save_path, - class_names, - center_limit_range=None, - lidar_input=False): - batch_image_shape = example['image_shape'] - batch_imgidx = example['image_idx'] - predictions_dicts = net(example) - # t = time.time() - for i, preds_dict in enumerate(predictions_dicts): - image_shape = batch_image_shape[i] - img_idx = preds_dict["image_idx"] - if preds_dict["bbox"] is not None or preds_dict["bbox"].size.numel(): - box_2d_preds = preds_dict["bbox"].data.cpu().numpy() - box_preds = preds_dict["box3d_camera"].data.cpu().numpy() - scores = preds_dict["scores"].data.cpu().numpy() - box_preds_lidar = preds_dict["box3d_lidar"].data.cpu().numpy() - # write pred to file - box_preds = box_preds[:, [0, 1, 2, 4, 5, 3, - 6]] # lhw->hwl(label file format) - label_preds = preds_dict["label_preds"].data.cpu().numpy() - # label_preds = np.zeros([box_2d_preds.shape[0]], dtype=np.int32) - result_lines = [] - for box, box_lidar, bbox, score, label in zip( - box_preds, box_preds_lidar, box_2d_preds, scores, - label_preds): - if not lidar_input: - if bbox[0] > image_shape[1] or bbox[1] > image_shape[0]: - continue - if bbox[2] < 0 or bbox[3] < 0: - continue - # print(img_shape) - if center_limit_range is not None: - limit_range = np.array(center_limit_range) - if (np.any(box_lidar[:3] < limit_range[:3]) - or np.any(box_lidar[:3] > limit_range[3:])): - continue - bbox[2:] = np.minimum(bbox[2:], image_shape[::-1]) - bbox[:2] = np.maximum(bbox[:2], [0, 0]) - result_dict = { - 'name': class_names[int(label)], - 'alpha': -np.arctan2(-box_lidar[1], box_lidar[0]) + box[6], - 'bbox': bbox, - 'location': box[:3], - 'dimensions': box[3:6], - 'rotation_y': box[6], - 'score': score, - } - result_line = kitti.kitti_result_line(result_dict) - result_lines.append(result_line) - else: - result_lines = [] - result_file = f"{result_save_path}/{kitti.get_image_index_str(img_idx)}.txt" - result_str = '\n'.join(result_lines) - with open(result_file, 'w') as f: - f.write(result_str) - - -def predict_kitti_to_anno(net, +def predict_to_kitti_label(net, example, class_names, center_limit_range=None, - lidar_input=False, - global_set=None): - batch_image_shape = example['image_shape'] - batch_imgidx = example['image_idx'] + lidar_input=False): predictions_dicts = net(example) - # t = time.time() + limit_range = None + if center_limit_range is not None: + limit_range = np.array(center_limit_range) annos = [] for i, preds_dict in enumerate(predictions_dicts): - image_shape = batch_image_shape[i] - img_idx = preds_dict["image_idx"] - if preds_dict["bbox"] is not None or preds_dict["bbox"].size.numel() != 0: - box_2d_preds = preds_dict["bbox"].detach().cpu().numpy() - box_preds = preds_dict["box3d_camera"].detach().cpu().numpy() - scores = preds_dict["scores"].detach().cpu().numpy() - box_preds_lidar = preds_dict["box3d_lidar"].detach().cpu().numpy() - # write pred to file - label_preds = preds_dict["label_preds"].detach().cpu().numpy() - # label_preds = np.zeros([box_2d_preds.shape[0]], dtype=np.int32) - anno = kitti.get_start_result_anno() - num_example = 0 - for box, box_lidar, bbox, score, label in zip( - box_preds, box_preds_lidar, box_2d_preds, scores, - label_preds): - if not lidar_input: - if bbox[0] > image_shape[1] or bbox[1] > image_shape[0]: - continue - if bbox[2] < 0 or bbox[3] < 0: - continue - # print(img_shape) - if center_limit_range is not None: - limit_range = np.array(center_limit_range) - if (np.any(box_lidar[:3] < limit_range[:3]) - or np.any(box_lidar[:3] > limit_range[3:])): - continue - bbox[2:] = np.minimum(bbox[2:], image_shape[::-1]) - bbox[:2] = np.maximum(bbox[:2], [0, 0]) - anno["name"].append(class_names[int(label)]) - anno["truncated"].append(0.0) - anno["occluded"].append(0) - anno["alpha"].append(-np.arctan2(-box_lidar[1], box_lidar[0]) + - box[6]) - anno["bbox"].append(bbox) - anno["dimensions"].append(box[3:6]) - anno["location"].append(box[:3]) - anno["rotation_y"].append(box[6]) - if global_set is not None: - for i in range(100000): - if score in global_set: - score -= 1 / 100000 - else: - global_set.add(score) - break - anno["score"].append(score) - - num_example += 1 - if num_example != 0: - anno = {n: np.stack(v) for n, v in anno.items()} - annos.append(anno) + box3d_lidar = preds_dict["box3d_lidar"].detach().cpu().numpy() + box3d_camera = None + scores = preds_dict["scores"].detach().cpu().numpy() + label_preds = preds_dict["label_preds"].detach().cpu().numpy() + if "box3d_camera" in preds_dict: + box3d_camera = preds_dict["box3d_camera"].detach().cpu().numpy() + bbox = None + if "bbox" in preds_dict: + bbox = preds_dict["bbox"].detach().cpu().numpy() + anno = kitti.get_start_result_anno() + num_example = 0 + for j in range(box3d_lidar.shape[0]): + if limit_range is not None: + if (np.any(box3d_lidar[j, :3] < limit_range[:3]) + or np.any(box3d_lidar[j, :3] > limit_range[3:])): + continue + if "bbox" in preds_dict: + assert "image_shape" in preds_dict["metadata"]["image"] + image_shape = preds_dict["metadata"]["image"]["image_shape"] + if bbox[j, 0] > image_shape[1] or bbox[j, 1] > image_shape[0]: + continue + if bbox[j, 2] < 0 or bbox[j, 3] < 0: + continue + bbox[j, 2:] = np.minimum(bbox[j, 2:], image_shape[::-1]) + bbox[j, :2] = np.maximum(bbox[j, :2], [0, 0]) + anno["bbox"].append(bbox[j]) + # convert center format to kitti format + # box3d_lidar[j, 2] -= box3d_lidar[j, 5] / 2 + anno["alpha"].append(-np.arctan2(-box3d_lidar[j, 1], box3d_lidar[j, 0]) + + box3d_camera[j, 6]) + anno["dimensions"].append(box3d_camera[j, 3:6]) + anno["location"].append(box3d_camera[j, :3]) + anno["rotation_y"].append(box3d_camera[j, 6]) else: - annos.append(kitti.empty_result_anno()) + # bbox's height must higher than 25, otherwise filtered during eval + anno["bbox"].append(np.array([0, 0, 50, 50])) + # note that if you use raw lidar data to eval, + # you will get strange performance because + # in standard KITTI eval, instance with small bbox height + # will be filtered. but it is impossible to filter + # boxes when using raw data. + anno["alpha"].append(0.0) + anno["dimensions"].append(box3d_lidar[j, 3:6]) + anno["location"].append(box3d_lidar[j, :3]) + anno["rotation_y"].append(box3d_lidar[j, 6]) + + anno["name"].append(class_names[int(label_preds[j])]) + anno["truncated"].append(0.0) + anno["occluded"].append(0) + anno["score"].append(scores[j]) + + num_example += 1 + if num_example != 0: + anno = {n: np.stack(v) for n, v in anno.items()} + annos.append(anno) else: annos.append(kitti.empty_result_anno()) num_example = annos[-1]["name"].shape[0] - annos[-1]["image_idx"] = np.array( - [img_idx] * num_example, dtype=np.int64) + annos[-1]["metadata"] = preds_dict["metadata"] return annos +def kitti_anno_to_label_file(annos, folder): + folder = pathlib.Path(folder) + for anno in annos: + image_idx = anno["metadata"]["image"]["image_idx"] + label_lines = [] + for j in range(anno["bbox"].shape[0]): + label_dict = { + 'name': anno["name"][j], + 'alpha': anno["alpha"][j], + 'bbox': anno["bbox"][j], + 'location': anno["location"][j], + 'dimensions': anno["dimensions"][j], + 'rotation_y': anno["rotation_y"][j], + 'score': anno["score"][j], + } + label_line = kitti.kitti_result_line(label_dict) + label_lines.append(label_line) + label_file = folder / f"{kitti.get_image_index_str(image_idx)}.txt" + label_str = '\n'.join(label_lines) + with open(label_file, 'w') as f: + f.write(label_str) + + def evaluate(config_path, - model_dir, + model_dir=None, result_path=None, - predict_test=False, ckpt_path=None, ref_detfile=None, pickle_result=True, measure_time=False, batch_size=None): - model_dir = pathlib.Path(model_dir) - if predict_test: - result_name = 'predict_test' - else: - result_name = 'eval_results' + result_name = 'eval_results' if result_path is None: + model_dir = pathlib.Path(model_dir) result_path = model_dir / result_name else: result_path = pathlib.Path(result_path) - config = pipeline_pb2.TrainEvalPipelineConfig() - with open(config_path, "r") as f: - proto_str = f.read() - text_format.Merge(proto_str, config) + if isinstance(config_path, str): + # directly provide a config object. this usually used + # when you want to eval with several different parameters in + # one script. + config = pipeline_pb2.TrainEvalPipelineConfig() + with open(config_path, "r") as f: + proto_str = f.read() + text_format.Merge(proto_str, config) + else: + config = config_path input_cfg = config.eval_input_reader model_cfg = config.model.second @@ -580,26 +508,22 @@ def evaluate(config_path, ###################### # BUILD VOXEL GENERATOR ###################### - voxel_generator = voxel_builder.build(model_cfg.voxel_generator) - bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]] - box_coder = box_coder_builder.build(model_cfg.box_coder) - target_assigner_cfg = model_cfg.target_assigner - target_assigner = target_assigner_builder.build(target_assigner_cfg, - bv_range, box_coder) + net = build_network(model_cfg, measure_time=measure_time).cuda() + if train_cfg.enable_mixed_precision: + net.half() + print("half inference!") + net.metrics_to_float() + net.convert_norm_to_float(net) + target_assigner = net.target_assigner + voxel_generator = net.voxel_generator class_names = target_assigner.classes - net = second_builder.build(model_cfg, voxel_generator, target_assigner, measure_time=measure_time) - net.cuda() - if ckpt_path is None: + assert model_dir is not None torchplus.train.try_restore_latest_checkpoints(model_dir, [net]) else: torchplus.train.restore(ckpt_path, net) - if train_cfg.enable_mixed_precision: - net.half() - print("half inference!") - net.metrics_to_float() - net.convert_norm_to_float(net) + batch_size = batch_size or input_cfg.batch_size eval_dataset = input_reader_builder.build( input_cfg, @@ -625,7 +549,6 @@ def evaluate(config_path, result_path_step.mkdir(parents=True, exist_ok=True) t = time.time() dt_annos = [] - global_set = None print("Generate output labels...") bar = ProgressBar() bar.start((len(eval_dataset) + batch_size - 1) // batch_size) @@ -641,14 +564,9 @@ def evaluate(config_path, if measure_time: torch.cuda.synchronize() prep_example_times.append(time.time() - t1) - - if pickle_result: - dt_annos += predict_kitti_to_anno( + dt_annos += predict_to_kitti_label( net, example, class_names, center_limit_range, - model_cfg.lidar_input, global_set) - else: - _predict_kitti_to_file(net, example, result_path_step, class_names, - center_limit_range, model_cfg.lidar_input) + model_cfg.lidar_input) # print(json.dumps(net.middle_feature_extractor.middle_conv.sparity_dict)) bar.print_bar() if measure_time: @@ -661,18 +579,16 @@ def evaluate(config_path, print(f"avg prep time: {np.mean(prep_times) * 1000:.3f} ms") for name, val in net.get_avg_time_dict().items(): print(f"avg {name} time = {val * 1000:.3f} ms") - if not predict_test: - gt_annos = [info["annos"] for info in eval_dataset.dataset.kitti_infos] - if not pickle_result: - dt_annos = kitti.get_label_annos(result_path_step) - result = get_official_eval_result(gt_annos, dt_annos, class_names) - # print(json.dumps(result, indent=2)) - print(result) - result = get_coco_eval_result(gt_annos, dt_annos, class_names) - print(result) - if pickle_result: - with open(result_path_step / "result.pkl", 'wb') as f: - pickle.dump(dt_annos, f) + if pickle_result: + with open(result_path_step / "result.pkl", 'wb') as f: + pickle.dump(dt_annos, f) + else: + kitti_anno_to_label_file(dt_annos, result_path_step) + + result_official, result_coco = eval_dataset.dataset.evaluation(dt_annos) + if result_official is not None: + print(result_official) + print(result_coco) def save_config(config_path, save_path): diff --git a/second/script.py b/second/script.py new file mode 100644 index 00000000..433aec42 --- /dev/null +++ b/second/script.py @@ -0,0 +1,47 @@ +from second.pytorch.train import train, evaluate +from google.protobuf import text_format +from second.protos import pipeline_pb2 +from pathlib import Path +from second.utils import config_tool + + +def train_multi_rpn_layer_num(): + config_path = "./configs/car.lite.config" + model_root = Path.home() / "second_test" # don't forget to change this. + config = pipeline_pb2.TrainEvalPipelineConfig() + with open(config_path, "r") as f: + proto_str = f.read() + text_format.Merge(proto_str, config) + input_cfg = config.eval_input_reader + model_cfg = config.model.second + layer_nums = [2, 4, 7, 9] + for l in layer_nums: + model_dir = str(model_root / f"car_lite_L{l}") + model_cfg.rpn.layer_nums[:] = [l] + train(config, model_dir) + + +def eval_multi_threshold(): + config_path = "./configs/car.fhd.config" + ckpt_name = "/path/to/your/model_ckpt" # don't forget to change this. + assert "/path/to/your" not in ckpt_name + config = pipeline_pb2.TrainEvalPipelineConfig() + with open(config_path, "r") as f: + proto_str = f.read() + text_format.Merge(proto_str, config) + model_cfg = config.model.second + threshs = [0.3] + for thresh in threshs: + model_cfg.nms_score_threshold = thresh + # don't forget to change this. + result_path = Path.home() / f"second_test_eval_{thresh:.2f}" + evaluate( + config, + result_path=result_path, + ckpt_path=str(ckpt_name), + batch_size=1, + measure_time=True) + + +if __name__ == "__main__": + eval_multi_threshold() \ No newline at end of file diff --git a/second/simple-inference.ipynb b/second/simple-inference.ipynb new file mode 100644 index 00000000..87ca4a30 --- /dev/null +++ b/second/simple-inference.ipynb @@ -0,0 +1,269 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pickle\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from google.protobuf import text_format\n", + "from second.utils import simplevis, config_tool\n", + "from second.pytorch.train import build_network\n", + "from second.protos import pipeline_pb2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read Config file" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "config_path = \"/home/yy/deeplearning/second.pytorch/second/configs/car.lite.config\"\n", + "config = pipeline_pb2.TrainEvalPipelineConfig()\n", + "with open(config_path, \"r\") as f:\n", + " proto_str = f.read()\n", + " text_format.Merge(proto_str, config)\n", + "input_cfg = config.eval_input_reader\n", + "model_cfg = config.model.second\n", + "\n", + "# set detection range\n", + "config_tool.change_detection_range(model_cfg, [-50, -50, 50, 50])\n", + "\n", + "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "device = torch.device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build Network, Target Assigner and Voxel Generator" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 41 2000 2000]\n" + ] + } + ], + "source": [ + "ckpt_path = \"/home/yy/pretrained_models_v1.5/car_lite/voxelnet-15500.tckpt\"\n", + "net = build_network(model_cfg).to(device).eval()\n", + "net.load_state_dict(torch.load(ckpt_path))\n", + "target_assigner = net.target_assigner\n", + "voxel_generator = net.voxel_generator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate Anchors" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "downsample_factor = config_tool.get_downsample_factor(model_cfg)\n", + "grid_size = voxel_generator.grid_size\n", + "feature_map_size = grid_size[:2] // downsample_factor\n", + "feature_map_size = [*feature_map_size, 1][::-1]\n", + "\n", + "anchors = target_assigner.generate_anchors(feature_map_size)[\"anchors\"]\n", + "anchors = torch.tensor(anchors, dtype=torch.float32, device=device)\n", + "anchors = anchors.view(1, -1, 7)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read KITTI infos\n", + "you can load your custom point cloud." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "info_path = input_cfg.kitti_info_path\n", + "root_path = Path(input_cfg.kitti_root_path)\n", + "with open(info_path, 'rb') as f:\n", + " infos = pickle.load(f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Point Cloud, Generate Voxels" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(84129, 5, 4)\n" + ] + } + ], + "source": [ + "info = infos[564]\n", + "v_path = info[\"point_cloud\"]['velodyne_path']\n", + "v_path = str(root_path / v_path)\n", + "points = np.fromfile(\n", + " v_path, dtype=np.float32, count=-1).reshape([-1, 4])\n", + "\n", + "voxels, coords, num_points = voxel_generator.generate(points, max_voxels=90000)\n", + "print(voxels.shape)\n", + "# add batch idx to coords\n", + "coords = np.pad(coords, ((0, 0), (1, 0)), mode='constant', constant_values=0)\n", + "voxels = torch.tensor(voxels, dtype=torch.float32, device=device)\n", + "coords = torch.tensor(coords, dtype=torch.int32, device=device)\n", + "num_points = torch.tensor(num_points, dtype=torch.int32, device=device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Detection" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "example = {\n", + " \"anchors\": anchors,\n", + " \"voxels\": voxels,\n", + " \"num_points\": num_points,\n", + " \"coordinates\": coords,\n", + "}\n", + "pred = net(example)[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple Vis" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "boxes_lidar = pred[\"box3d_lidar\"].detach().cpu().numpy()\n", + "vis_voxel_size = [0.1, 0.1, 0.1]\n", + "vis_point_range = [-50, -30, -3, 50, 30, 1]\n", + "bev_map = simplevis.point_to_vis_bev(points, vis_voxel_size, vis_point_range)\n", + "bev_map = simplevis.draw_box_in_bev(bev_map, vis_point_range, boxes_lidar, [0, 255, 0], 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAADsCAYAAACWscopAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJztfXecFEX6/lO7SxIzd9whIkEJkjPIiXASlg0sYZclswmWnNPCguSMgOSg5CDxCxygiJ4e8PMIShZFQRAQRI6c2d15f3/MVNPd0zPTM9MzPbtTz36ez1RXV1e/09Pbb9f71vsWIyIICAgICAQ3QswWQEBAQEDAfAhlICAgICAglIGAgICAgFAGAgICAgIQykBAQEBAAEIZCAgICAjAR8qAMdaEMXaGMXaWMZbmi3MICAgICBgHZnScAWMsFMDPABoBuAzgMIC2RHTa0BMJCAgICBgGX4wMagI4S0S/EtFTAJ8CaOaD8wgICAgIGARfKIPCAC7Jti/b6gQEBAQEAhRhPuiTadTZ2aIYY6kAUm2b1Xwgh4CAgB/x+uuv4/Lly9b/5u81GjiqF/AG/yOivxrRkS9GBpcBFJFtvw7giroRES0moupEVN0HMggIBDyGDBliSD8jR440pB9vcfnyZcyYMQP4zkEDR/UC3uA3w3oiIkMJ62jjVwDFAeQGcBxAORfHkKCgr3jt2jXTZXDGIUOGmC4DAJo4caIxfRFow4YNmvVmf8ccyO+MenYbPjIgokwAvQDsBvAjgA1E9IPR5xEQ0IsLFy6YLYJTTJ06FUOHDvW6nw8++MCr44cPH+61DByMaVmLBQIZvvAZgIh2Adjli74FBNzFtWvXNOv79esnlWfNmuX1efr16+dxP1OmTPH6/GPHjvW6D4EghtFmIg9NS2YPtQRzMA8ePGi6DP7kqFGjzJXBZg6aMWOGZr2goTTMTGR40JknYIyZL4SAgIAxIGDjpo1o1aqVXb3mXEMBb/C9UZNwRG4igRyPrKws0849YMAAr45PSxPZXAT8BLNNRMJMJOhrXr161XQZBgwYYMp5R48e7f/zOjIHCTORL2iYmch0RSCUgaCvmZmZaboMnAMHDvT42GHDhnl0nLcKYfLkyaZfN0GHFMpAUFAvTXeoquiNQjCLQiEELIUDWUBAL44fP45KlSqZLYYunDhxAhUrVjRbDIHsA+FAFhDQixIlSqBXr14SAxkVK1bUnV7C3SCxMWPGeCKSQJBAjAwEcjyuXLmC1157zZC++vTpo1k/e/ZsQ/r3BMOHD8fEiRN1tR0zZgxGjRrlY4kE/AjDRgam+wuEz0DQ13z69KnpMuihNzmK0tPTdbcdO3asV3JOmTLF9GslKFE4kAUF9fLrr7825bx9+/b16DhPlYI/FYJgwFAoA0FBvTx+/LjpMnhCT5TCiBEj/CafGCEEBIUyEBTUy8ePH5sug6ccOnSoT/s3w2T08ccfu2xjl9fIxmnTppn+mwQYhTIQFNRLs8xEWuzfv79Hx/lSKYwbN86r48UIwVSKOAMBAb24fPkyJk2ahHnz5qFnz566cu3zNnPmzPGJTAMGDLCuCuYG0tLSMHnyZJ/IM27cOI9XTPvyyy/xwgsvoFatWgZLpY3p06dj0KBBfjlXNoCYTSQoqJcbN270y3l69+7tVnuz8hU5orcjBEc8efIkXbx4kS5evEg3b95U8NKlS6Z/72xOYSYSFNTLhw8f+vV87iqFnMKpU6eaLkMQUigDQUG9/OWXX0w5b58+fdxqnx1zFqkpFILfKZSBoKBe5mRThN5Mph988IHuPsePH++VTEIh+JVCGQgK6mWgxBn069dPd9tBgwbpbusLhSCYbSiUgaCgXt6/f990GeTUqxTcUQi+oDcjBDE68BvF1FIBAb24d+8eXnjhBbPFsEP//v0xc+ZMs8UQyN4QKawFBPRi5cqVhvQjT4OtpieYOXMm+vfvr6vt4MGDdbXTm9bancyl48eP191WjqlTp3p0nIBJMNtEJMxEgu5y7NixmnP0e/ToQT169LCr//nnn/0qX69evdw+JtBiDtT01qks6DMKn4Fg8HDOnDmKbUeKwNHx/lYGnL6IN9CbvM6dDKaC2Zr+UwYAlgL4E8ApWd2rAPYA+MX2+YqtngGYDeAsgBMAqgplIGgEN2/eTHPnziXA8QjA0Ru5WcqA012l4M94g9GjR7vVXowQAo5+VQbvAagKpTKYCiDNVk4DMMVWjgTwGaxKoTaAg0IZCHrDzz//nKZPn66o69mzp+7j09LS6MyZM4bI0rNnT7fOraa7QWj+olAI2Zr+NRMBKAalMjgDoJCtXAjAGVt5EYC2Wu2EMhA0giNHjnS6v1u3btS9e3dF3bFjx/wqoyuF4emiN4KCGjRdGdxW7b9l+9wB4F1Z/VcAqgtlIOhLduvWzel+o0YG7tITR7IWncUbpKWluTzenwveCPqdAasMdsJeGVRz0GcqgO9sNPuCCmYjcnOLKyXAuWHDBtNk9YdCMIpjxoxxq/2ECRNMvxcEzVcGwkwk6BeqV7xKS0uzMwO54rZt20z/Hs6cyHojkgcPHuxwn54Rgh66qxAETad/I5AZY8UA7CCi8rbtaQBuENFkxlgagFeJaAhjLApAL1gdybUAzCaimjr6dy2EQFBiwYIF6N69OwBgypQpGDp0KHr06IH58+fr7mPjxo348ssvAQCLFi1Ct27dsHDhQnTr1k2x0A1jjL+c2C2A4875nKF3796aC+YYEY08bNgwTJo0yas+fI1p06bpDqAT0AX/LW4DYB2AqwAyAFwGkAKgAKwmoF9sn6/a2jIA8wCcA3ASOvwFYmQg6A7HjBmja0ZP165dpbKeNXe9paPpro7ozcwiZ7EGepPWCeYYiqAzweClI5OLIx/CpEmT/C6jN1NQvY1GFgohqCiUgWDw8dixY3ZTS+UKYM+ePbRp0yYCQGvWrKG9e/fS3r17afPmzabJ7I1SEBTUQaEMBIOPx44do/Hjx1PXrl3tRgEPHjyw4/3792nWrFmKOAO5+cgZ3XVSe0t31jrwJ4VDOeApUlgLBB/Onz+Pjz76CE+ePMGCBQsU+3799Vfky5cPAFCoUCFcvXpV2leoUCGfydSjRw8AxjmYjUR6ejomTJhgthgCvoX/HMhiZCBoJvfu3SuVjx07RhMnTnTLUWsGvZWvf//+dnU5YX1kQZ/QsJGBWM9AIKDx+PFjxXZGRobdtM+uXbuia9eu/hTLKfgooWfPnh4dr7XOwYcffohBgwZ5LZuAgCMIM5FAQGPv3r147733pO3hw4fj3r17mDNnDrp27YpFixbp6ocrC3mcgR7wGAc11GYqZ+jVqxfmzp2ru72AgBsQZiLB4OCePXvs6vQ6gfW2M4rdu3f3u+NZMOgpZhMJBgdnzZollV0tbJ+ammq6vJye+A1ENlNBDyiUgWBw8KOPPpLKRiZG8/eoQVDQRxRTSwUE9MId34K74D4Fd3wIAgIGQvgMBIOH06ZNk8rDhw/XtYykmSYjd/0GgboCmmC2oBgZCAQfzp07h6VLl/oskKpbt26Kbb0zjgQETIQYGQgGBz/88EOpzFMjuOucTU1N9dlIQcwgEjSZwoEsGHw8evQoAdbkb4E0c0hOoRgE/UxhJhIIPowbNw4PHjzArVu3PHYId+nSBQCwZMkSr2SRB7EJBCDkTxTmsFVOgDATCQYXx48fL5X1rn0MgLp06eIX+bp166aQy9s1CQQ9oPxPXW+2bL6jGBkIBC+6d++OjIwMfPzxxx73wUcI6mUv9cLZiKBv37746KOPPJZNwAsQ7EcCWnU5B4aNDIQyEMjx6Ny5MwB4pTycQctkNHjwYBQsWFCs9+tvCGXgOcw2EQkzkaA7TEtL8+r4lJQUSklJ8aoPZ7OTuKlo0qRJNHHiRNOvV9BRyyQkzES6KFJYC2QrZGRkuH1MSkqKVP7kk0/wySefeCXD4sWLsXjxYkVdamoqAGtsQo8ePfDo0SOEhIh/r4AAg/WxKeAU4m4VCGhs2rRJsf3hhx+63Qd/+KekpCgUgzdITU2VFAAAhXJgjOHu3btu+SAEchb++OMPs0VwG8JnIBDQuHHjBgoUKAAAGD16NEaPHm1Iv8nJyQCUTuNPPvkEnTt3xscff4zOnTtrPsydTUnVyoE0depUDBkyxBCZBXSCEEx+A+EzEAwOXrlyhaZOnWq6HI6oNXW1W7duNGjQIAKUeZUE/cTg8hsY5jMIg4BAACN37txgjGHy5MlIS0szrN+kpCQwxrB06VKHbfj0UznUIwP5dmpqKhYvXoyQkBC88sormDJlCp48eWKYzAICvoRLnwFjrAhj7GvG2I+MsR8YY31t9a8yxvYwxn6xfb5iq2eMsdmMsbOMsROMsaq+/hICORuDBw/Gw4cPvepj5MiRAJTJ6FyZSLXenvg0VQ65wuB+g5CQEPz5558YOnQoRowY4ZXcAgZBOJFdQ4cJpxCAqrbyCwB+BlAWwFQAabb6NABTbOVIAJ/BevlrAzgozESCnvL69et2dSNHjnQ6xXTIkCF2dcOHDyfAmtfISPk6d+4smYoCNV9S0NGRSShnmor8ZyYioqsArtrK9xhjPwIoDKAZgPq2ZisAfANgqK1+JVmf8gcYYy8zxgrZ+gl4rFu3Dm3btjVbDAEnGDduHABgyJAhmDp1qlSfnp4OwD6SeMqUKXj8+DH69++PR48eITExEcuXL/dKBmeBbF27dsVrr72GvHnzYujQoV6dR8AD8FGAnxzGc+fORa9evfxzMh/CramljLFiAKoAOAjgb/wBb/ssaGtWGMAl2WGXbXXZAi+//LLZIgjIoGXK4Q99+Tz+IUOGYMKECZgwYQLGjx8v1U+cOBFZWVkICQlBrly5AMBrRQBYlYBcEXTu3FkyExUoUABPnjxBVlaWQlmZhXXr1pktQmDAR8ohJygCAHBnxs/zAL4H0NK2fVu1/5btcyeAd2X1XwGoptFfKoDvbDR7qCUYoPzzzz89Om7EiBEEgMaNG0ejRo3SFbmcnJwsRSh7GqnMTUYTJkygyZMnS/UzZsygTZs2KdouWLDA59dv3759pv+GpjBnmoS06N/1DADkArAbwABZ3RkAhWR+hTO28iIAbbXaCZ+BoLu8du0aAaDBgwcr6rkPgD/01eT709LSaNiwYTR8+HAaOHAgAaDExESv5ercuTN17tzZ4X5+fs6TJ08S8EwBLFq0yK3zLVu2TCqvXLlSsW/58uWm/06CptF/WUuZ1QC7AsBNIuonq58G4AYRTWaMpQF4lYiGMMaiAPSC1ZFcC8BsIqrp4hzOhRAIWly7dg0TJ05Enjx57EwuAwcORO7cuTFp0iRFfXp6umSmmTlzJtLS0pCZmQnGGJ4+fSplFE1MTATwzMeg/nQErXQWPFiNY/z48ciTJw/CwsIQGhqKPn36YNasWejXT/oXwtKlS5GRkYGwsDCkpKRgzZo1ePr0KZKSklxel5UrV6JTp0529V999RVeeOEF1KxZExs2bEB8fLzm8ceOHUPlypVdnkcg4OG/oDMA78KqgU4AOGZjJIACsJqAfrF9vmprzwDMA3AOwEkA1cVsIkFP+ccffyi2+/Xrp9geMGCAFOAFgIYOHUoTJkyg7du30+HDh+n06dNSO97GiJEBp3p0wM1EY8eOJQA0ffp0mjVrFgGgefPmEQCaO3eu4m2e1wPP3vKXLl1KAOiTTz6hjz/+2OH5He3bsGGDZv25c+do3bp1fv0N169fb/p9ZDQ/++wz02WwUSx7KRgcVCuDAQMGUN++fe3qACiUwr///W+6ePEiHTlyRLPfhIQEj2Vy5E+QK4bx48fTtGnTaMaMGfTRRx8p2skf/osXL9Z8OMvNQq7IFcLVq1dp+/btdPToUbp16xbt2rWLDh06pOk3uHr1Kt24cYP27Nlj+m+cHZkTlYFIVCcQ0CAiRUK4GTNm2C0cU6lSJQDA9OnTcfLkSSxbtgx//vkniAi//PKLZr8rVqxAQkKCITLy5HfcTNS5c2fcv38fFosFAwYMkMxOCxcuBAD07NlTksFisUhRyosWLZJMUHpMRRxPnz4FAOzYsQNVqlTBxYsX8eTJE0RGRuLkyZO4e/euov3y5ctx7do1nDt3Dm+++SYOHTqE8+fPe/r1gxIRERFmi2A4hDIQCHgsXrxYmtc/bNgwKfHbihUrsGLFCuTLlw8AcPr0aYSGhqJWrVqoUaMG7ty5gxIlSuCbb77BhQsX7PpdsWIFEhMTkZiYiKSkJF0P4OTkZPmI1g7cd/D888/DYrEAAPr06YMFCxZI23PnzgUAJCQkICwsDJmZmQCs8QmeZDrlx3Tp0gVFihRB2bJlcfDgQRw6dAiFChVCVFSU1Hbu3LlITEzE66+/jhdffBE//PADAKB48eKKPletWuW2HBzffPONx8dq4cSJEzh48KChfQpowGwTkTATCTrj1atXac2aNU7b9OzZk3799VcCrDbx3377jS5cuCCZiAYNGiRFJaempnplIlIzOTmZkpOTCVCaj8aPH0+zZs2iOXPm0OzZswmwziRavHgxAVa/wfz586VtzmXLlkn+Ark/QD77aMmSJZqy7N+/X7pmR48epUOHDhEA2rp1q51vYdGiRbRu3Tpau3YtHTp0iPbu3SuZjFatWuXz3/XTTz81/d7KIRRrIAsEB65cuYLhw4drBorxCOTevXsjJiYGRYsWxf379wEAu3fvRtWqVfHVV19h6tSp6NevH0JCQnD//n0pOIzPJpL3rZ5hJMeyZcukMk+BLU90l5KSgk8++QQpKSkoWrQonn/+eRARcufODSJCnz59sGzZMjx69AgZGRno27cvlixZIpnClixZgr/+9a9o3ry57uuzaNEiEJEi55JeLFu2TBoNcbnVcDRryRssX75cus4CXsOw2UTCTCSgGzdv3jRl0Q4tRdC/f388fvwYAJArVy5s374dpUqVwvXr17Fx40bcvHkT4eHh0kP9yZMnePLkiWIRGiLSVAR837JlyxTmoKSkJKnN0qVLpX3JyclITk6W7P158+ZF7ty5MWDAAAwcOBC9e/dGnz59pD6ePn2Kvn37YtGiRWCMKcw8WVlZ0vm4j2HOnDma12XWrFno2rUrunXrJrWdP3++84tpw+LFiyVFsHTpUkmRqdNrGK0IAOt15oo1J0ZHr1271mwRPIJQBj7ChAkTADx7c5w7dy62bNmCyZMnY968ebh+/Tr+7//+T2p/4sQJAECPHj2c9svbeYrjx4/j+PHjirpjx465PK558+Z49dVX8fe//12qu3r1Ku7du4fr16/j7t27uHPnDu7cuYPbt2/j9u3buHXrFm7evImbN2/ixo0buHHjBv73v//hypUruH79On7//XdcvnwZly9fxqVLl3Dp0iUcOXIER44ckfJDjRgxQtOWP3PmTMlxCjx7YIaHhyMzMxPTpk0DAGRmZqJv377Imzev9PDmDzg9o2L1CIExplAaiYmJdoph3rx5ePLkCWbOnCnJJl/0pl+/fpg1axYYY1I8BGB9kMfGxmLhwoVYsGCB4m1/9uzZ+OijjzBr1izMmDEDs2fPVsQsdOvWDbNnz7a7f2bPnq35veROeT7KSUlJscvK6iskJSVhx44dOTIPWLt27cwWwSMEpJkoMzMTYWHPcug9ffoUuXPn9rtcAlY0a9YM27ZtU9Q1b94cW7du9ai/+Ph4bNiwwa6+devWWL9+Pdq0aYNPP/1U+nQE/jYLAN27d0dISAhCQ0ORmZmJkJAQ5MmTR8pLlJWVhY8++ghdunTBkydPsHLlSmk2EZ9ZtGLFCt3fISEhQcqNtGzZMin5XVJSEkJDQ1GiRAk899xzICL0798fgPXBTETo27cvpk6dipCQEOTPnx8Wi+XZ9L6QEBCRIt/NzJkzQUQYMGAApk2bhsGDB0v71IFs8mP69++PGTNmYMCAAYp9vG7OnDno3bu3w++4atUqdOzYUfc1CRZMmTIlkBIQGmYmCkhlIBA4aNq0Kf71r38p6rSUgzPExsZi8+bNAIBWrVph48aNdm24guAKgX9ydO/eHQsWLJC2k5OTHS5M07t3b+TOnRtZWVl4+vQpQkJCYLFY8PjxYyxdulR68Hfq1AkrV66UjktKSrIzHWlBrkT4NmNMsoUvX74cY8eORVhYGHLnzo1BgwZJD+AZM2YgNDQUjx49gsVikfwJuXPnRkZGhqTEMjIyYLFYkJGRgdDQUAwaNAgffvghBg4ciGnTpsFisWDo0KGYNGmSFOVMRNL+wYMHY9y4cQgLC0P+/PklMxUH70uN+fPnuxydyrF06VJpZCFgCoQyEPA9oqOjsWPHDkVdTEwMtm/frruPFi1aSOYwuVKQQ60gHI0cACje4BMTE5E3b15kZWVJnw8ePECePHlgsVikB+vjx4+RkZEhPawtFos0MpA/xAGrf0Ief8DNRPL/E3mMAj+ey8YYw3PPPYeXXnoJkyZNwqRJk5AnTx6EhITgyZMn0ihg+PDhmDRpEkJDQ5GRkSFlYp00aZLUJjQ0FNeuXcNf/vIXWCwWXLx4EcWKFYPFYsHvv/+OAgUKIG/evAgLC5OOGTlyJNLT0xEWFoZ79+5hxowZmDhxIoYPHy7JP3jwYBQoUABEhFy5cmHQoEHSPvkoxBncVRo5AQG6nrVYA1nQezZp0sSt9k2bNrWri4mJcdi+efPmUrlly5aabeLi4hTbrVq10iVLu3btCAC1b9+eAFCHDh2kfZ06dZK2O3XqRIA14lhe1voErKkqONXnTExMpISEBKl9p06dqGPHjpr9JSYm0sSJE2nkyJE0YcIEGjduHI0bN47S09Np5MiR1LdvX0pPTyfgWeS0syypfJ+jNlprMcsX27l58ybt37/f8MV9BJXUWljJxxTpKAR9x8jISF3ttJQDZ7NmzaRyixYtNMuxsbGKY9SKwV22b9/eTkl07NjRTjHwhzgA6eHOlYX6YS/vX96O98HbyMv8WLWSSUpKUvSXlJQkxSioP72hXDHIFYJYiS1HUigDwcCiXDHIRwvy0YG8LFcKgL1i0MM2bdoottu2bSuV27VrJykGrgw6duxI7du3V2xzZcE/5SMMPeRKRa0UHCkSrRGHWknIFYJcSahHBTwXkifrLgiaw/79+xvdpwg6Ewh8yP0L8jKfiST3J+hB69atQUQKR7N8H2MMn376qTRdcd26dWjfvj0AgIiwdu1aaXvNmjVo3769VN+hQwe78xGR3dTS1atXa8qmnnXDj5M7qPmUVsaYlAqDMaYIZhMIHPTp08fh1FxvwKcWGwThQM5RIPhtvVYzoDUjST41tWXLltiyZYu0Ly4uDps2bVK0lzuZtWYkOXM6A0Dbtm0VzmC+1rW6Tq48APs1Dhwpg3bt2mHt2rWSUuDTMhljWLlyJTp27IhVq1bZKQTA6ggPCQmRZkep10YQEHACoQwCClx6Tx7oXBHIr0BOUAwOrgmflupsVOBo+qmjtgAU7Vu3bm0VQTaKYIzZPeABKEYScsjbS1+JyC7uoW3btlK8gXoksWbNGoVCkCMtLQ23b9/GwoULnU6TFTAW6inKetGrVy8pwWCAQaSj8Av0Wu0YvH+AMxl5v97CEwukUX06uB48PsGZeUhLEcTFxQF49vCPj49HXFwcNm7cCMaYtKJXfHw8iEjKEBofH4/169fDYrFIZiWLxSI91Nu0aaMI+uLIysqS6vh+RxlF16xZI52Pl4kI7du3x6pVq2CxWOzMSHfu3JFSSCxduhTDhw/HBx98YNd3eno6xowZ4/BaKeCp1TmI4IkiAKwZBHjq8ZwKMTLQgjdv+p6cy9l5DJZl7ty5iuhW3bI4kGP06NEYPXo0GjVqhD179mDMmDHYv38/9uzZg4EDB+LD6R+6lP23337DqFGjpPn63B+gNv3ExsYiLCwMWVlZsFgskmmpVatW0oN806ZNiI2NlR7cFotFyv8jN/ls2LBBGoFwXwQAp6Ymfi7eh1o5rF+/Hm3btpX64iMG7qtgjGH16tXo0KEDXnrpJWRmZiJ37tzS8phhYWEICQnB2LFjAQAffPABGGOwWCwIDQ2VrvXo0aPt5Bo1ahRKlCiBfPnyOVzqUhM53ETpDKmpqYpcVdkUYmTgcxj0D+LUOejkH1FyMKlHC16qTYeKwBUcjH5Gjx6Njz/+GC1btpTq9uzZYz1E4026adOmdnVFixaVImRbtWolvcEzxtCqVSvExsYCADZv3oyMjAxkZGRI+XxiY2NhsVgkcjksFov0dg9YTTh822KxSKOKuLg4ZGZmYsOGDS4VAWAdtWzcuBEbNmyQzsHPHR8fL20DUJxfPlpYvXo1bt++jQcPHuDx48d49OgRnjx5AiKSFMH06dOl/rOyshRrEmiBX2uetqVNmzYA4LPcP8eOHbNbNAfQXh/an3Ant5KnikAebHfq1CmP+ghEBJwy4At9BDJmzJih2B43bpzDtg0aNDDmpGozUgChc+fO6N69u129PJEch9qRHBsbi/j4eFStWhXAs4caVwrcDMQVAmMM27Ztw7Zt29C0aVNs3rwZmzdvRlZWFrZs2SI97HkZgPSw5mVu/mnZsiU2bdqkGRWtBxs3bsSmTZsUSkaufPj5+IOZK4q2bdtKD/olS5YgKysLhQoVwoMHDzBs2DAMGzYMgwYNQmZmJiwWCyZOnIgtW7YgPT1dkdVUjgIFCtj5QnyJypUr48UXX5S2R4wYAQCaabCNgN6UF5463t1RIvPnz5fu9/Lly3t0vkBEwCmDatWqmS2CS7gK1ZfjjTfecLzTwRt/WlqaZvIxxXEeQCsVsjcJt5YsWYIpU6YAeGaLHTVqFEaMGIG0tDRNZaDG5s2bFW/kGzZsQGxsLDZu3IjY2Fjpgc0fdJs3b0azZs0QExOjePjJfRFEJM1WkvsP5D6AzMxMxQwmb7BlyxbNcxCR5K+QK6N169Zh3bp1Ut3ChQvx/fffY9asWbh9+zbu3buHfv364eHDh3jw4AEGDRqE+/fvIzMzE5mZmejSpYvi/PHx8VJq8YcPHyoSCBqZIlqey2jYsGGKfTyduB5oTeN1BU8c7O4sHequEvHU9xDQMDvgLCCDzsi+Tr0Iu8M/b8/roB8erFKjRg2KiorS3afDlBOy+QMdAAAgAElEQVQO/t59911d/UZERCjlcPanOk5ejo6OJuBZEFrz5s0pJiaGWrRoQc2aNaPY2Fhq0aIFNW3alFq2bElRUVHUvHlzKcLZWTqMmJgYiomJoebNmyvoy3snLi5OYqtWrahVq1ZSZHV8fDwBoNatW1Pr1q2lMmANoGvTpg21b9+e2rRpIwXM8cA6dXtOnr6jVatWdtd8/fr10opiPLCNy6C+19X9OqK7QXmOyCPEtagO2POG6oA+X9KkCG8RgexT6vnT0U+FChWMkcFW5yi/j9vfw7a/Tp06xl0XWbt//vOfTrcbNGig2b9ccamVWEREhKRI+D6ujPinOj2Gs3QZvmLz5s2pZcuWFBsbKzEuLo5atmypSLchVxJqBcL3833ysnyfnHZ17vxp9KFWDq1bt1YoEnm0t6D/2bVrV14WysATlipVyqvjS5cu7RO5Klas6LyNm0pID2vXrm1X98477zg9Rj5qeO+996Ry3bp17err16/vljyNGjVyuK9x48Z2dVojHndGTL4kH4Hw0Q5X4vJPeTqOli1bKshTc8hTdKjTdci3eb+uUnp4m/tJParQGk2oU4Rw+kJ58FQgQU7DlEHA+QzUKFasmGF9/fzzz14df+bMGYMkUcLl6mVy57EHqF27tl3dgQMH7Or++9//2tW9++67Unn//v1See/evc/Ek9nuef0333zjkayO0LhxY8U2YwwRERGIiIhAZGQkoqKisHPnTqnsavaNr8GnnvLZTTzKOjY2Flu2bJGc4nzb23P5A+rZVlrndeS49sR34cq3sGrVKrf7zAno2rWrT/oNeGVgNEqXLu3V8W+//bbuthUqVNDdtnLlyi7bOHOu16hRw+E+rQc/ANSpU8e1YDogVwxa+Oc//6nYbtiwoVv9a83p16rj9Tt37sTOnTvdOocc3uYKkjtw5bEN8um3agXAv8+WLVuk2U3OZjl5OgPKSPh6xpKj1B/BDvkSqkbCpTJgjOVljB1ijB1njP3AGBtjqy/OGDvIGPuFMbaeMZbbVp/Htn3Wtr+YNwJeuHDBm8Pt4O3b/Y8//oiyZcvqanvy5End/epZh/j77793uO/w4cO6z8Wh541SPhrwFHrO4+nbLX+I7ty5UxodeAt3ZqE4Ao+wtplBsXnzZhCR4iGuVmjOHvDq68Mjsp21cQQeOBcskK9Z7Q/4ax1po6FnZPAEwPtEVAlAZQBNGGO1AUwBMJOISgK4BYBPME4BcIuI3gIw09YuR+H06dO621asWNGHkuiH1ijg//2//2dXJzcLuYP69etL5ffff99pW/lDi5t/nD3InI0CtBAdHe30/P6C/O2/ZcuWihXf1DDqTV9LSTiDXDG4FbnsJXhCQAH3YZqZiKy4b9vMZSMBeB8ATy25AkBzW7mZbRu2/Q2Yl0bN4sWLe3O4qXDpD3ATzsxBtWrVcrjv22+/1dW/q5FAvXr1pLJcAWj9xDzg7t///reinkcoewNnysFfNnQ9cCaTOjOrO3B0rLM+3TkfT/ZnBNq1a2dYX/6CEaNDX8E0MxEAMMZCGWPHAPwJYA+AcwBuExEPF74MoLCtXBjAJQCw7b8DoICRQnuLMmXKGNKPOz4Bo+DMHHTw4EG3+6tbt67LNnIF8J///Ecqyx9wX3/9tVRWP/ydgfexe/duAEB4eLhif2RkJBhj+Pzzzx0eH0gPfzW4mUgOo+WV9+fuyMAsrFmzxul+dWI/NXgqcD3gOa/cQSCvMWGqA5mIsoioMoDXAdQEoOVF5Xe91p1u9x/BGEtljH3HGPtOr7CBBnd8AlWqVNHVzt8R2Pv27bOre++993QdK1cAnoIrAQ6tB+Vnn33m8Hith62A8eApNZzBV3mQzIA3IwNfr0Vh6siAg4huA/gGQG0ALzPGwmy7XgdwxVa+DKAIANj2vwTgpkZfi4moOhmUcc8d/PTTT/4+JY4ePaqrnTMnMYczc5AzeOoPkI8G3MFXX33l0XEAEBERoasdVx5NmzbFjh07ANjnPwoEEJE0m4grsEB4i5enC9eTqM8TrF271if9+hKBPDLwFfTMJvorY+xlWzkfgIYAfgTwNQB+NycA2GYrb7dtw7b/3+Tl69v58+e9OdynMMNB7Mwc9M4773jdvy/MLo0aNVJsq81BehEZGSmV5TOH+C1mdnyBHI6mknJnsfzfQsupHAjKQiB4oGdkUAjA14yxEwAOA9hDRDsADAUwgDF2FlafAM9d+wmAArb6AQDSjBfbOxjlMwCMdxD7EnqniboaCajjBrTgbiyBI+zatcuuLiIiQlIEO3fulEYEHNHR0V4pBSPeClu2bAkikpQAVwzqh76WQuBKQO87FG+n10GsdxU5M5HdA8qy4/TSMFcNiOgEADuDNxH9Cqv/QF3/GEBwTWQOIGhFERuB+vXrGx5VrEaTJk0kR7GWn4CbjrRGLjt27EB0dLT06Snmz5/v9UySFi1aAIAUZcwf1lxueR0AhflIa4QAeDfzyBPwNSUcoW3btoZmRBUwH0EXgQyY4zNwB9Wr+8aNotdnoOVAlisCTxzH6rdcV2+9TZo0cdiPq2N37NjhcfCZfOESb8BHBdwkxNdXkMcT8LUYeJkf541VVc+x8tgCRwFo6qml6n61FIEj5aA1tdRVnIGr2UTqNaVdISEhwXUjFbwJVvO1E9knMDtJnT8T1ZnFSpUq6W5brVo1j89Tq1Ytn8gvT0rnL4aHh9vV8eR0DtNyazAyMtKvcvP02zxRnTxhXYsWLaRkdPJj5PtbtGghJZzTm6VWnQhPTa3Ed5zqTKrufFetRHV6U2EbQaPSafuKKSkp/jhP8CSqywk4fvy47rZ6ZhM5grtxBlpRyf/4xz/s6uS5h+SjBnn8gbzM4SoS2Rm03nB5naOYAw6evC4iIgKMMURGRiIyMhLR0dG6Zyl5ipCQEMVCN/yzWbNm0sI76rxEvB3fLz+Og5uetKDVXn6MPCWGGnLzk9qXoB41qLe1TEmOzEtaU1O9nYrqbu4iV6MNLXgyouAwewlQt2H2qCAnjwy8Ws9Ag85GDTVq1HC4z5sRwz/+8Q+3j3EnfXXDhg0V2zyVtVZKa61U1nLKRwEREREUGRmpWAchMjJSWkzHF4yJiaFmzZpJnwAUC/HoWVhHPpKQt5envNba1qKzNvJRhLxsRBpsd0cYzsgX+RF0yJw7MihcuLDrRtkA5cqVcysozRWqVKnicaI6rRGDVqyCVqpref4i+ajBmf9B7WhWjxrks5H4WzSHo7dcR3Vy7Nq1CxEREWjSpAk+++wz7Nq1SzomIiICRISsrCyH/ghv0LRpU4WMRISYmBhs3boV27ZtQ7NmzbB161Y0a9ZMYvPmze364Ut1yvtSl7W2tcBzIWnBUd+uciTpcWQbOVtJT4xCdkx3EYgIOGXw+++/my2CIfjhhx8M7U9v0JpeaCkIdaprdcyCXDHIp6lqmZbkaS7UU1XlDx916gpHyqBhw4Z2OY3UaxwA1llIRITw8HA0adIEjDFpPWJOV2Ymd8GnsfK1jokI27dvx/bt2xETE4OmTZti27ZtiImJUcjBFWFMTIyiP54C21Nl0KxZM5f75IrCmdKQx0o4qtOaAeVoVpSWw9pVHiRX5iR3g9pEkjwHMNtElJPNRGbTG2e0mloro2mRr4amdy1lTr40Jv98//33FfvVS2VqmZcaN26s6Xjm9ep94eHhLk1PrshNT1FRURQVFUXR0dESeZumTZs6XIIzOjpa1/KcfK1n9aeztq7qHO339TrRntCfjulsxpxrJvIWRYoUMayvEiVKGNaXL+Aq+lnLrKSV+0hrKqs6O6qjBXLUpiU+YnB3HQQtp6sc8u0GDRooths1agQiwhdffIHdu3ejcePGaNSoERo1aoTGjRuDMWbXX+PGjaW8SFojDFdo0qSJ5KDmowEuJ/8u0dHRiIyMxL/+9S8QkRQMJ4+DICL861//chkkp74+27dv12wXHR2t2MfPpdWem7bU++WL8wDQNGepoeXgduT01ho1uFpjwVXcg7+Qk0cVOU4ZXLp0yZB+ihUrhl9//VVX2zfffNOtvr1dbY3Dk+hnLQXx3XfKXIHVq1dX+CA8WUWNQ+9qamoloH54y30LRCTlPWrQoAH27NmDL7/8UtF2z549koLg5huLxSI9+L/44gvFWgpccajTZqghNz/xB7/c7GOxWLBr1y5YLBYpMjoqKkra5m35w5/HQxCRItWGo+uj9rFw8P70mJN4W705nNTKAbA3RWmZmhyZn7T8EtkhKhpwnW01OyPHKQOj4M4Ka+fOndPdtmTJkl6vtqZ3pTVPoVYOeldR405p+WjB0cNLDfXDTn6cPPq5fv36Cj+DPBne+++/jwYNGkiKgX9yZcEf2HzUwB/kahnlowo5w8PD7ZQA91HI64BnaTS4E1ter7Usp9zZ7ez6OGrDv4O6X63gO0d9yEcnrqK4t23bZlen5atwNKpwNlU2OyAnZWiVYLa/QPgM/EOtaa4VK1Z0eVyVKlWkctWqVRX7PPFJvPPOO1K5Tp06Ulk9hbVu3bpSWR70Vq9ePc0y9zVwvv/++9SgQQNq0KCBwv/A6xo2bCgRgKLcqFEju6mtjRs3VpDXcV+E2l/By/IAufDwcKfbzsinx+ptr9WO9+GqTotaU3L1+Dp8QVfTXz1hfHy8R8e1adPGlGsgo2E+A+bsbcRfYIx5LMTf//53/PHHH17LULhwYd0zmYoUKWKYOcodlCpVCj///LMhfZUrV85uxlOFChUU02ErVqyoMEVVqVJFmtVUuXJlxbrN1apVQ0hIiPWmYgyHDx9G9erVERoaqpi5VLt2bRz4r3PTkiaYNeCNiKTzEBFCQ0OlT8D6hvzNN99Iq6xxmz5gnbnEzUDy+57nDOKy80815N9P/cnbf/7554iIiEBWVhZCQ0OlHEvctKTezsrKko6Xz3KS52mKiIiAxWLB7t27pfqIiAjN/E3y4/ixztaDgLP/PE+T1/qiTxliY2PtTU0+PidHeno6JkyYYFyHsPohvDA/fU8GLQMQEMrglVdeodu3b3t07N/+9jdcu3bNMFmkB703l8WHC2+VLl3aazMTh/rh7wmqV68umZVq1Kihz6REQIeOHbB69WrUqVMHUVFRSE9Px7vvvov9+/ejbt26aN68OR4/fowHDx7gwoUL0vTBevXqSVNVtZLnNWjQQDIdNWzYEF9++aXCD0BE+PLLLyV/gfr+lz/YuYNZHpfw+eefS+YiuRL5/PPPERkZiZCQEGlfZmYmcuXKJSkGXpeVlQXAPhmf+kHu0T2oa2mpZ5gzdw7u3r2L9PR0TJkyBUOHDtV1mpiYGMnp3KJFC6V/gIDJUybjwIED9v4GJ7K0jG1pF52t91j1946LizMsuV+nTp2wcuVKdOhgvWeHDx+OiRMn6jp2/PjxGDFihCFyOIBhyiAgfAZyRSBfV1cPjFQEgMoBzTykG1Db/0uWLOmZ4DaUK1dOs15riU4tRVCpUiWn21WrVlVsy/0LXBGo10nWCmbjPplvv/0W6enpAJ7NQNq3bx+ef/55vPLKK8jIyEChQoWkfuUxC998842udNrAs4c8VwTcuczrQ0JCEBJi/XfgCiI8PBzh4eHSNmMMERER2L17t6QI5G/63E/w8ssvI1++fChYsCBeeuklFClSBK+//joKFSqE1157DUWLFkWePHnsZOSKQK58Bg0ehNFjRtvdY6tWr5LK4U3CNe89ReoNBowaPUrRx4qVK9C7d2/p+jtSBPJZRxzy2UdajuK0tDRs3boVixcvVu5w8n/jVBG4Ot6H4EnxePoLuSKYPHmy02N9rAiMhdn+goD1GZB77QcNGkQAqH379rqPcZS2oVSpUuZ/fyfU8hXUq1dPstvfunXLzoavvrZyf4G8rMVhw4Y5vX7qc/GYhEaNGkl+AG7nl8cbyG3//Fhux+fk6SzktnW+zdNdAKCoqChq164dJSYmUnJyMqWkpFBSUhJ17tyZUlNTCQD16NFD+uzWrZvmd1Ik4yPr57hx46T6ZcuW0erVqxXHTJ8+Xfe9rNdHoJf8O8k5ZcoUXeky3Kab/5NBQsN8BqYrAq4Mli1bJv3TVKpUyS7oyO/0oTL47bffCHAdCOSIZcqUcdlGyzmst07NypUr65KLP1S7dOni1bXlzuOFCxcSAOrcubPDto4UgtwZLP/kikDt4HXk3JXXRUZGSvmPoqKipO3IyEhq1qwZxcfHU+fOnalr164EgPr160e9e/emfv360fHjx+mnn35y+x5MT0+nsWPHSnXLli1z/54x+CF67Ngxj8+1YsUKz89tkjK4e/euKefVyZynDDIzM82+qIbceHqUQd++fal169ZOZ2OULl3aJ99LbzpttQJQzyTyx7WdO3cuAdYHolYaba4I+Cd/gdB6kVCPAuQPfz4CUI8E5A97/vDnkca836ioKGrevDkNGzZMkbJ44MCB9MEHH9DgwYPtZOnZs6fm91UrKpD1fhozZoyi3apVqxTbEyZMMOR6+/q33blzJwGg9evX+/S+CTLmPGXA6U7GS7Nuam+oN6Ojr5QBoP9N36fX1oM/ZwpBPjLgZiJ5ygr1g1Y9FVRtCtIiVwbqunbt2kmjoZ49e1JqaqrCDNSrVy/q27cv9e7d2+W1UaTM8OTPwOvt79/XZ/16810Cn4YpA5fLXvobvl5aUTcYrJfa02Md4IUXXtDVBZ8xVLZsWZw+fdpDQayoWrUqjhw5Im3Lp4RyqKeKAsqppIbCyfXhM4oAa7K7ffv2oW7duggJCcF/9v7HzonMnb589bX3339fmk3EGJNmFPGZQRxffPGFwlFrsVg0p5MCkCKDbS8uEqKiokBEyJs3L5YsWQIAmDdvHgCgS5cuUru5c+c6/sIqhIaGokmTJggLC0NWZBYsFgvCwsI8XrkNgM8drH47n7+/R7DB7FGBemQg6Jh67Ps5mY5GjVrmIXUiO8DeUazerzYRAVCYhuQ+A05nK6kNGTJEKo8ZM4Y+/PBDXd9THuDly/UXBHMEc66ZSNAxfa0MTDcfyag2CTlyFGtRnuFU7yI5amexmnL/gdY+f1wToRgENSiUQU5j2bJlve7D24e5O2s1+4rO1ltWp7lWlwGlIlDPKOJUjxC0RgnOyEcPXHn4QhmoJxcIRSDogEIZBDIDPU4gO1NLIahnD2mZiDjVisGRMnCWN0itDHidUd9RrQjU255OSRbMkRTKIJBZsmRJ02VQU2vUEEhmIXeox1QE2CsF+bY84ZyaekcKrkxLRlMoAUEN+l8ZAAgFcBTADtt2cQAHAfwCYD2A3Lb6PLbts7b9xYJNGfiSWtlH5XT2gJdnIA0kuloVrWPHjnZz7R2NDLRiDLRWRVO3UWcl5SMDreAzOdVBab66Rs4UQSCuTCboN/o/ayljbACA6gBeJKJoxtgGAFuI6FPG2EIAx4loAWOsB4CKRNSNMdYGQAsicrrIKWOMJk+ejLS0NKmua9euWLRokS7ZvEVqair279+P06dPo3z58vj0009Rvnx5XLlyBfv27cOtW7dQrFgx3LlzB0WLFsXzzz+Pb7/9FhUrVkT+/Pnx008/AbAu9BIeHo67d+8CsOYw4VM6y5Urh6VLl6JQoUIgIuzYsQOtW7fmyhApKSlITk5GnTp18Pnnn+PFF1+Upjr++uuvaNeundT2xRdfxIEDB/DHH3/gxo0byJ8/PwoWLIgZM2aAMYb+/fsjOjoa4eHhCAkJcZ610gN88cUXAICCBQuiZMmSuHbtGgoWLAgAkowyRe/R58OHD3H9+nXpnH/5y19AROjRo4eUwyc8PFwxZVSdobNRo0YICQlRtGnUqJHTtZT5d1MjPDxcKsv7a9GiBXLlyoXQ0FCEhYXZ/YOtW7cO8fHxyJUrFywWC9atW6fot23btiCyZl3lmSsHDBiAF198Eb/99hssFgtWrFiBLl264MmTJyCyrpvA2yYlJUnrKISFhcFisWD58uWa30EgR8KwRHV6RwWvA/gKwPsAdsA64/d/AMJs+98BsNtW3g3gHVs5zNaOGTUy+O9//+u0bubMmQSAZs2a5bXWvXfvntla3ye8fPky3bp1i27evEmXLl2iixcv0o0bN+j+/ft07949iXfv3qW7d+/S9evX6cKFC6bLzX9brbd0/lbepEkTunz5MgGg8+fP0+nTpx3217hxY81RgjPKRwkxMTHUpk0b6tixIyUlJREASkxMpISEBLf67NSpE3Xs2JH69Okj1Y0fP54GDhxIwLP0Hl27dpVStgCQzsnJ93fu3Jm6d+9u+u8l6Bf610wEYBOAagDqw6oM/gLgrGx/EQCnbOVTAF6X7TsH4C/CTCToLadMmaLYvnr1Kv3222905coVAp7FBBw+fJiOHDnisdlGvYiNJ5SnptBix44dpXKHDh0IsMYi8PQV48aNU7Tn+Y7klCsG+TZPHpeYmOjX34fn58qJDOAXQ/8pAwDRAObbyvVhVQZ/hb0yOGkr/wB7ZVBAo99UAN/ZaPYFFczGvHjxIgGgc+fOEWBVCidPnqRvv/1WeovnD3ajEiByJ3OnTp0IsCbTS0lJoeTkZEpKSlIk6+vYsaPiwczLbdu2leo6depEEyZMoMmTJ9udS60I1BlPHWVABXQkDQwSbtq0yavj+QtHANKvymASgMsALgD4A8BDAGtgkplIMHjpapaPOm/QsGHD6Pfff3do4pIvjemsX74MptZIoV27dpScnCxt84cvHxmo384TEhKk7KXt2rWT6pOSkmjy5Mn04YcfOjRxctNP7969HSa7A6AwNwFQyOdLbty4kSZMmEAzZsww/V4BXCfwu337tu6+PFEGf/zxhz++pzlTS2EbGdjKGwG0sZUXAuhhK/cEsNBWbgNgg45+Tb9xBAOX0dHRFBERQc2aNSMA1LJlS+mzadOm9NNPP9HFixeluf5xcXHSsb///judO3dO+jRKJrkZyVNzjHz93HHjxtHGjRtp/vz5uo/v27evVHamHNRMTk423KewceNG0+8Tf/D+/fuKzwBgQCiDEgAOwTqFdCOAPLb6vLbts7b9JYQyEPSUMTEx1Lx5c83FUmJjYykqKoqio6Pp1KlT0pvY1atXJX8CYH2r01IE//znP92KWXDmbFYrBG4+kjuTO3TooBgNANbRy8iRI2nixIm6FMHFixcln4B6BOCIWg8uLR/E1q1bdfV3/Phxu7qvv/6avv76awJAGzZsIMCqiPX09/DhQ13tTpw4Ydh95c6ogPPAgQOGnd9AiqAzweBhXFwcxcfHSwohLi5OmlvPlYXaRMR56tQpySxTr149Q+TRMhc5MsUkJydTYmKiQgn069dP0YYrgqlTp3osEx8l9OrVi3r27Km5ApmcXbt2pT59+kgzlrTkcoeLFi0iwDoLypPj9So2MxigSoDT/3EGvgRjzHwhBAIerVq1QkhICNavXw8AaNasGRhjmDt3LjIzM1GsWDGp7ZEjR+zWa9YCX0OZp8DWQoMGDQBASo0thzrWQQudO3cGEUlxAmFhYQgLC8Mbb7yBUaNGSe3cWZDeW6SmpiJPnjyYM2cO+vfvj5kzZ/rlvM7Qu3dvzJkzx9A+b968iVdffdXQPgMM/o0zECMDQTO4adMmSkhIoHnz5rl13JkzZxTbzpLfeUotkxGfItquXTtpppB8xMBjEA4fPqyYFjp16lTd6a2NYiDOMtKz+I+gHYWZSDDnc9myZTRixAjatGkTffDBBwQ8c7rGx8dTfHy82316qxgczTxq3bq1YpsrBk75tFOzryunHnNSdiY3WV26dMl0WXxIYSYSyHmYNWsW/vvf/0pmIJ6O5ObNm3jhhRfQq1cvAEDr1q3tTEVbt27V7LNevXrYsGED/va3v2nuf++996SyfKUzednZ6nuNGzd2mMIiISEBK1as0NzXoUMHrF692mG/vkZKSgo++eQT084vYBiEmUgw59GROWjZsmU0Z84caZuvI92yZUvFNFItvvvuuy4T4XlKV/EJgoJ+oDATCeY8yiNy5eQzceLj4+nmzZu6+oqKirLLYWSEUtA7FVVQ0E8UykAwZ/Hu3buKICxHbN++PQFW34FcefBANDn15CaqW7euRF7H/Qrvvfce1atXz7ApqYKCPqBQBoLBS64QXNGfZhyjch4JCrpJoQwEcza1RgmHDh1SfDpSCv5aoB4A1a9f3/RrFSw8ePAg/fbbb5SZmUkAaM+ePWSxWMhisdDNmzcVJsQDBw5IBKzpz2/fvk23b9+mO3fuKDht2jS6e/eudBzvg0eE379/nyZNmkTr1q2jAwcO0Nq1a+mzzz6T2v/xxx+0a9cuu0yzaq5atcoX10UoA8GcQx4h7Ijnzp2THvxHjx7V1WeTJk2oVq1apn83QePIlYCggoYpgzAICJiMMmXKONzXunVrvPnmm9J2lSpV0KFDBzDGEBcXh3Xr1iE6OhodOnSQ2kRFRSEjIwMHDx6U6mrVqgXAOmX0wIEDbsv47rvvAgD279/vsE39+vWdTkMVEAhkiDgDgYDGlStX8Nprr0nb+/btw40bN7B7927kyZMHs2bNsjsmLi4OmzZt0tV/7dq1pbJWnMG3337rqegCBiMzMxNhYeL9VQURZyCYM6g3Y6WcWhk3ORMSEighIYFq1KhBNWrU8Jnc8tlHnL5Ie5GT+PDhQ3r06BE9evRIqnv8+LFUfvLkCQGgp0+fah4vzESaFBHIAgJ6UaNGDQDWt/1Dhw551EedOnUAaI8U6tati3379nkuYA5FVlYWQkNDDetPjAw0IUYGgjmP58+f9+v5atasSTVr1qRatWpJzubatWtT7dq1dR3/j3/8w/Rr5k/yxYPMokhkp0kxm0gw+/P27dt05swZOnPmjK6AMy3yaaaOWK1aNUNkfeedd6hOnTq62/sqBUYw8+DBg6bLEIAUZiKBnIczZ87g0aNHqFy5slf9vPfee3jttdfw6aefau6vXt06qmaM4fDhw2F+O6oAAA5dSURBVJptatWqpZiNpBd16tQRTmcBf0KYiQRzNrdv30579+51+7iYmBiKiYnxu7zvvPOOyzbZ2awUHh7ucJ9w7JpKYSYSzN68fv26rnYbN250WzFoLUvpC+r1LbhjXjKdjv6cHCOUgakUZiKB4MOmTZtQsGBBAMp1CByhcuXKOHbsmGHnr1mzJgB4PCNJDtPNSY7+45iDenK8z+hZQwJuQZiJBIObo0ePthstuEprwVmlShWqWrWqRPX+atWqUfXq1RV0Vz5nowZHJiV/mpGGDBmi2G7YsKG0qL0mnYwOsrKyTL8fgphiZCAgwHHo0CHkzp0blStXRpUqVXD06FHTZKlZs6YhI4eAA8HhyMBisSAkJMSv4ghIECMDQcFAoZ5I5+yaNK9hw4bWspORgcViMV3OIKZhIwOhzgX8ivv37/vlPBUrVvRp/3x6KgCH01OBZ9HPnkxT9RdmzJjhcN+XX37pR0kEzIQwEwkEDSpVqgTA8cL3avNS1apVFdtHjhzRfa7q1avju+++80RM/8Dd/zhHjmUADx8+xHPPPeeVOAIew79mIgAXAJwEcAy2YQmAVwHsAfCL7fMVWz0DMBvAWQAnAFQVZiLBYKGeiGdXZiVnJqVAnKbqSbJBQcPo3zgDWJXBX1R1UwGk2cppAKbYypEAPoNVKdQGcFAoA0EA9ODBA9Nl8BW1ZiUFC4UyMJUB4TNoBmCFrbwCQHNZ/Uqy4gCAlxljhbw4j0AOQf78+f1+zgoVKqBChQqG91u5cmVF2gxXJqRq1aoZLoNXkD9OvITc1CaQjaFzZHAewBEA3wNItdXdVrW5ZfvcAeBdWf1XAKpr9JkK4DsbzdaugtmIu3fvNl0GI+kqjsEXM5GGDRtGY8eOtW7r+XPSl7+zzQoq6Hcz0Wu2z4IAjgN4D46VwU7YK4NqwkwkKBg4HDZsmP72LpSBfLEaQb/Tv2YiIrpi+/wTwP8BqAngGjf/2D7/tDW/DKCI7PDXAVzRcx4BgZwAPmvJFdSzlfyJSZMmGdaXMBPlDLhUBoyx/IyxF3gZQGMApwBsB5Bga5YAYJutvB1AJ2ZFbQB3iOiq4ZILCMjw9ttvSwSAcuXK+V0GHttw/PhxXe3dmaoqIOBruIwzYIyVgHU0AABhANYS0QTGWAEAGwC8AeAigFZEdJNZXxPmAmgC4CGAJCJyOuFaxBnkbNy5cwcvvfSSoX12794dCxYs8OjY8uXLS2Wtt1rGGIgIJ0+edNgHd0o7a+MIRifQ8zkITuMMHj9+jLx58/pNHAEFRDoKweDmuHHjTJfBLOpJf+GS7v456Uu+qL2g3ykS1QkIZCdUqFDBo1GEGoEY2fznn39KqcUF/A7DRgYiN5FAtsGQIUMQGxtrthhugZukjFAEVatWDThFAMBwE6CASTDbRCTMRIKuuHr1asX2Z599Rm+99ZbpcjljuXLldLetVKmS6fJ6wydPnpguQxBTLHspGJz85ptv7OreeustKlmyJJUqVUqqK126NJUpU4bKlCnjF7nKli1LZcuWdfu4ihUrumxTpUoVn8qemppq+u8q6DGFz0BAwBfgU1N//PFHXe3LlCmDn376yZciBTyePn2K3Llzmy1GsEL4DASCAytXrvTr+X788UeniqB06dKKbV8qAr3BawICRkAoA4GARqdOnfxynooVK+paEOfMmTMen6Ns2bJuyaMneK1KlSoey2MURARyzoAwEwkEPcqXL49Tp04p6oyaCsrx9ttv6zY9ZTdkZGQgV65cZosRrBBmIoHggKdRxu5ArQgAY6aCcpQpU8YUReCv3EdXrojUYzkBYmQgkC1RvHhxnD9/3vB+Bw4ciLNnz+LRo0f44osvvOqrVKlS+Pnnn3W1LVeuHH744QevzicQlDBsZCCUgUC2R/HixaUyt1/zz3PnzuGtt95S1P3yyy8u+xw5ciTGjRvnsUwlS5bUdR5PYLQJC/L/PmH+z24QykBAwEhERUUBAHbu3GmyJM6h5d/wFj179pQWtJ82dZq+g4TSCBQIn4GAgFGIi4tDvnz58Morr5gtiksYrQgA4LnnnsPf//53TJs2zfqQd0WBnAmzo49FBLJgILB9+/ZeHZ+amkrt2rVz6xh5xLQr6olu1hPNbAhJuX3y5EnTf78gpn9XOhMQyIkoU6aMVI6JiZHKv//+u9t9lSpVCmvXrsX8+fORlpaG3r17O23/1ltv6XYuv/322zh9+rTTNhUqVMCJEyd0yyvHoEGDMH36dI+OBUSiupyCMLMFEBAwCzt37kS+fPlw5MgRXLx4EefOnUPevHlx9OhRt/viNvcePXpIdaNHj8bo0aM12589e1ZXv3qnpepxKDtaVIeIMGjQIF3yCORgmG0iEmYiQbN57tw5r45PSUnR1a548eISjZTfnQypdvT0T9bHjRs3TP8Ng5jCTCQQfChcuDAKFy7sdT8lSpSQygcPHkS+fPm86k8ruKt8+fIoXry4Ytrr+fPnJRoFvfEJDvMc6XEYu3AiZwfHu4AOmD0qECMDQU9ZuHBhqVykSBHFvjfeeIOKFi1KRYsW9akMiYmJUrlYsWKG9BnoazWoefPmTdNlCGIaNjIQPgOBbAu5o/fSpUuKfRcvXvS6/2LFiim2bS8uCsid0BcuXPD6nIDSn6AVV5CT8xwJmAehDAQEHEDPwz1//vyGnOvNN9/EuXPn7Oq14grcUQS+CFJTQ++sKIEAh9kmImEmEsyuNGIFshIlSmjWe7JqmppeOZbd4KVLl0z/LYKYwoEsIGA2WrVq5fGx3LH866+/au53FVegB/5KfPf888/75TwCvoVQBgICHsITExH3Q7iaUdSkSRNPRDIFX331ldkiCBgAXcqAMfYyY2wTY+wnxtiPjLF3GGOvMsb2MMZ+sX2+YmvLGGOzGWNnGWMnGGP+SaouIOBHpKSkoG/fvm4fp8cPUapUKTz//POIi4tDbGysB9L5F7Vq1TJbBAEDoNeB/BGAz4kojjGWG8BzAIYD+IqIJjPG0gCkARgKIAJASRtrAVhg+xQQyDH461//6rO+f/75Z786ZQ1PiS2QPaHDufsigPOwpbuW1Z8BUMhWLgTgjK28CEBbrXbCgSyYU/jGG2/4rO9GjRp5dFzp0qXdPqZ8+fJey3v79m3Tf48gpl8dyCUAXAewjDF2lDH2MWMsP4C/EdFVALB9FrS1LwxAPun7sq1OQEA3ChQoYLYIDpGYmCjFMbz++uuG9l2iRAlcvXrVo2PPnDnj9jGnTp3y/DFkw/379z2SVyCwoMdMFAagKoDeRHSQMfYRrCYhR9DKeE52jRhLBZBq23wCwLeTobMP/gLgf2YLYTZu3LgBBOi1WL58uVS+fPmyoX07ml0EX14LL9coMFoh6kBA3hcmobRRHelRBpcBXCaig7btTbAqg2uMsUJEdJUxVgjAn7L2RWTHvw7AbsVsIloMYDEAMMa+I4NW68nuENfiGcS1eAZxLZ5BXItnYIx9Z1RfLs1ERPQHgEuMMa6BGgA4DWA7gARbXQKAbbbydgCdbLOKagO4w81JAgICAgKBCb2ziXoDWGObSfQrgCRYFckGxlgKgIsAeATOLgCRAM4CeGhrKyAgICAQwNClDIjoGACtYVkDjbYEoKebcix2s31OhrgWzyCuxTOIa/EM4lo8g2HXgmllYhQQEBAQCC6IdBQCAgICAuYrA8ZYE8bYGVv6CmdTVnMEGGNFGGNf29J6/MAY62urD8r0HoyxUFv8yg7bdnHG2EHbdVhv81OBMZbHtn3Wtr+YmXIbDZHy5RkYY/1t/xunGGPrGGN5g+m+YIwtZYz9yRg7Jatz+15gjCXY2v/CGEvQOpccpioDxlgogHmwprAoC6AtY6ysmTL5AZkABhLR2wBqA+hp+85psKb3KAngKzyL5ZCn90iFNb1HTkJfAPIE/VMAzLRdh1sAUmz1KQBuEdFbAGba2uUk8JQvZQBUgvWaBN09wRgrDKAPgOpEVB5AKIA2CK77YjkAdaZCt+4FxtirAEbBmgqoJoBRXIE4hMnrGLwDYLdsexiAYWbKZMI12AagEQxM75FdCGsMylcA3gewA9bwp/8BCFPfHwB2A3jHVg6ztWNmyO2D6+DzlC/ZhXiWweBV2++8A0B4sN0XAIoBOOXpvQCgLYBFsnpFOy2abSYK6tQVtiFtFQAHEZzpPWYBGALAYtsuAOA2EWXatuXfVboOtv13bO1zAkTKFxuI6HcA02Gdrn4V1t/5ewTnfSGHu/eC2/eI2cpAV+qKnAjG2PMANgPoR0R3nTXVqMv214gxFg3gTyL6Xl6t0ZR07Mvu4ClfFhBRFQAPYEDKl+wImymjGYDiAF4DkB9WU4gawXBf6IGj7+/2dTFbGehKXZHTwBjLBasiWENEW2zV12xpPeBJeo9siH8AiGGMXQDwKaymolkAXmaM8fgX+XeVroNt/0sAbvpTYB9CK+VLVQTfPQEADQGcJ6LrRJQBYAuAOgjO+0IOd+8Ft+8Rs5XBYQAlbTMFcsPqKNpuskw+BWOMAfgEwI9ENEO2K6jSexDRMCJ6nYiKwfq7/5uI2gP4GkCcrZn6OvDrE2drnyPeAEmkfJHjIoDajLHnbP8r/FoE3X2hgrv3wm4AjRljr9hGW41tdY4RAI6SSAA/AzgHIN1sefzwfd+Fdbh2AsAxGyNhtXN+BeAX2+ertvYM1hlX5wCchHWWhenfw+BrUh/ADlu5BIBDsKYz2Qggj60+r237rG1/CbPlNvgaVAbwne2+2ArglWC9JwCMAfATrJmMVwHIE0z3BYB1sPpLMmB9w0/x5F4AkGy7LmcBJLk6r4hAFhAQEBAw3UwkICAgIBAAEMpAQEBAQEAoAwEBAQEBoQwEBAQEBCCUgYCAgIAAhDIQEBAQEIBQBgICAgICEMpAQEBAQADA/wdG3uMHNjPJYQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.imshow(bev_map)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/second/utils/config_tool.py b/second/utils/config_tool.py new file mode 100644 index 00000000..962e66b1 --- /dev/null +++ b/second/utils/config_tool.py @@ -0,0 +1,59 @@ +# This file contains some config modification function. +# some functions should be only used for KITTI dataset. + +from google.protobuf import text_format +from second.protos import pipeline_pb2, second_pb2 +from pathlib import Path +import numpy as np + + +def change_detection_range(model_config, new_range): + assert len(new_range) == 4, "you must provide a list such as [-50, -50, 50, 50]" + old_pc_range = list(model_config.voxel_generator.point_cloud_range) + old_pc_range[:2] = new_range[:2] + old_pc_range[3:5] = new_range[2:] + model_config.voxel_generator.point_cloud_range[:] = old_pc_range + for anchor_generator in model_config.target_assigner.anchor_generators: + a_type = anchor_generator.WhichOneof('anchor_generator') + if a_type == "anchor_generator_range": + a_cfg = anchor_generator.anchor_generator_range + old_a_range = list(a_cfg.anchor_ranges) + old_a_range[:2] = new_range[:2] + old_a_range[3:5] = new_range[2:] + a_cfg.anchor_ranges[:] = old_a_range + elif a_type == "anchor_generator_stride": + a_cfg = anchor_generator.anchor_generator_stride + old_offset = list(a_cfg.offsets) + stride = list(a_cfg.strides) + old_offset[0] = new_range[0] + stride[0] / 2 + old_offset[1] = new_range[1] + stride[1] / 2 + a_cfg.offsets[:] = old_offset + else: + raise ValueError("unknown") + old_post_range = list(model_config.post_center_limit_range) + old_post_range[:2] = new_range[:2] + old_post_range[3:5] = new_range[2:] + model_config.post_center_limit_range[:] = old_post_range + +def get_downsample_factor(model_config): + downsample_factor = np.prod(model_config.rpn.layer_strides) + if len(model_config.rpn.upsample_strides) > 0: + downsample_factor /= model_config.rpn.upsample_strides[-1] + downsample_factor *= model_config.middle_feature_extractor.downsample_factor + downsample_factor = int(downsample_factor) + assert downsample_factor > 0 + return downsample_factor + + +if __name__ == "__main__": + config_path = "/home/yy/deeplearning/deeplearning/mypackages/second/configs/car.lite.1.config" + config = pipeline_pb2.TrainEvalPipelineConfig() + + with open(config_path, "r") as f: + proto_str = f.read() + text_format.Merge(proto_str, config) + + change_detection_range(config, [-50, -50, 50, 50]) + proto_str = text_format.MessageToString(config, indent=2) + print(proto_str) + diff --git a/second/utils/eval.py b/second/utils/eval.py index bdf6265a..5867c6d0 100644 --- a/second/utils/eval.py +++ b/second/utils/eval.py @@ -37,7 +37,10 @@ def get_thresholds(scores: np.ndarray, num_gt, num_sample_pts=41): def clean_data(gt_anno, dt_anno, current_class, difficulty): - CLASS_NAMES = ['car', 'pedestrian', 'cyclist', 'van', 'person_sitting', 'car', 'tractor', 'trailer'] + CLASS_NAMES = [ + 'car', 'pedestrian', 'cyclist', 'van', 'person_sitting', 'car', + 'tractor', 'trailer' + ] MIN_HEIGHT = [40, 25, 25] MAX_OCCLUSION = [0, 1, 2] MAX_TRUNCATION = [0.15, 0.3, 0.5] @@ -101,11 +104,11 @@ def image_box_overlap(boxes, query_boxes, criterion=-1): qbox_area = ((query_boxes[k, 2] - query_boxes[k, 0]) * (query_boxes[k, 3] - query_boxes[k, 1])) for n in range(N): - iw = (min(boxes[n, 2], query_boxes[k, 2]) - - max(boxes[n, 0], query_boxes[k, 0])) + iw = (min(boxes[n, 2], query_boxes[k, 2]) - max( + boxes[n, 0], query_boxes[k, 0])) if iw > 0: - ih = (min(boxes[n, 3], query_boxes[k, 3]) - - max(boxes[n, 1], query_boxes[k, 1])) + ih = (min(boxes[n, 3], query_boxes[k, 3]) - max( + boxes[n, 1], query_boxes[k, 1])) if ih > 0: if criterion == -1: ua = ( @@ -128,15 +131,27 @@ def bev_box_overlap(boxes, qboxes, criterion=-1): @numba.jit(nopython=True, parallel=True) -def d3_box_overlap_kernel(boxes, qboxes, rinc, criterion=-1): - # ONLY support overlap in CAMERA, not lider. +def d3_box_overlap_kernel(boxes, + qboxes, + rinc, + criterion=-1, + z_axis=1, + z_center=1.0): + """ + z_axis: the z (height) axis. + z_center: unified z (height) center of box. + """ N, K = boxes.shape[0], qboxes.shape[0] for i in range(N): for j in range(K): if rinc[i, j] > 0: - iw = (min(boxes[i, 1], qboxes[j, 1]) - max( - boxes[i, 1] - boxes[i, 4], qboxes[j, 1] - qboxes[j, 4])) - + min_z = min( + boxes[i, z_axis] + boxes[i, z_axis + 3] * (1 - z_center), + qboxes[j, z_axis] + qboxes[j, z_axis + 3] * (1 - z_center)) + max_z = max( + boxes[i, z_axis] - boxes[i, z_axis + 3] * z_center, + qboxes[j, z_axis] - qboxes[j, z_axis + 3] * z_center) + iw = min_z - max_z if iw > 0: area1 = boxes[i, 3] * boxes[i, 4] * boxes[i, 5] area2 = qboxes[j, 3] * qboxes[j, 4] * qboxes[j, 5] @@ -154,10 +169,14 @@ def d3_box_overlap_kernel(boxes, qboxes, rinc, criterion=-1): rinc[i, j] = 0.0 -def d3_box_overlap(boxes, qboxes, criterion=-1): - rinc = rotate_iou_gpu_eval(boxes[:, [0, 2, 3, 5, 6]], - qboxes[:, [0, 2, 3, 5, 6]], 2) - d3_box_overlap_kernel(boxes, qboxes, rinc, criterion) +def d3_box_overlap(boxes, qboxes, criterion=-1, z_axis=1, z_center=1.0): + """kitti camera format z_axis=1. + """ + bev_axes = list(range(7)) + bev_axes.pop(z_axis + 3) + bev_axes.pop(z_axis) + rinc = rotate_iou_gpu_eval(boxes[:, bev_axes], qboxes[:, bev_axes], 2) + d3_box_overlap_kernel(boxes, qboxes, rinc, criterion, z_axis, z_center) return rinc @@ -312,8 +331,8 @@ def fused_compute_statistics(overlaps, dc_num = 0 for i in range(gt_nums.shape[0]): for t, thresh in enumerate(thresholds): - overlap = overlaps[dt_num:dt_num + dt_nums[i], gt_num: - gt_num + gt_nums[i]] + overlap = overlaps[dt_num:dt_num + dt_nums[i], gt_num:gt_num + + gt_nums[i]] gt_data = gt_datas[gt_num:gt_num + gt_nums[i]] dt_data = dt_datas[dt_num:dt_num + dt_nums[i]] @@ -342,14 +361,20 @@ def fused_compute_statistics(overlaps, dc_num += dc_nums[i] -def calculate_iou_partly(gt_annos, dt_annos, metric, num_parts=50): +def calculate_iou_partly(gt_annos, + dt_annos, + metric, + num_parts=50, + z_axis=1, + z_center=1.0): """fast iou algorithm. this function can be used independently to - do result analysis. Must be used in CAMERA coordinate system. + do result analysis. Args: gt_annos: dict, must from get_label_annos() in kitti_common.py dt_annos: dict, must from get_label_annos() in kitti_common.py metric: eval type. 0: bbox, 1: bev, 2: 3d num_parts: int. a parameter for fast calculate algorithm + z_axis: height axis. kitti camera use 1, lidar use 2. """ assert len(gt_annos) == len(dt_annos) total_dt_num = np.stack([len(a["name"]) for a in dt_annos], 0) @@ -358,7 +383,8 @@ def calculate_iou_partly(gt_annos, dt_annos, metric, num_parts=50): split_parts = get_split_parts(num_examples, num_parts) parted_overlaps = [] example_idx = 0 - + bev_axes = list(range(3)) + bev_axes.pop(z_axis) for num_part in split_parts: gt_annos_part = gt_annos[example_idx:example_idx + num_part] dt_annos_part = dt_annos[example_idx:example_idx + num_part] @@ -368,34 +394,35 @@ def calculate_iou_partly(gt_annos, dt_annos, metric, num_parts=50): overlap_part = image_box_overlap(gt_boxes, dt_boxes) elif metric == 1: loc = np.concatenate( - [a["location"][:, [0, 2]] for a in gt_annos_part], 0) + [a["location"][:, bev_axes] for a in gt_annos_part], 0) dims = np.concatenate( - [a["dimensions"][:, [0, 2]] for a in gt_annos_part], 0) + [a["dimensions"][:, bev_axes] for a in gt_annos_part], 0) rots = np.concatenate([a["rotation_y"] for a in gt_annos_part], 0) - gt_boxes = np.concatenate( - [loc, dims, rots[..., np.newaxis]], axis=1) + gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], + axis=1) loc = np.concatenate( - [a["location"][:, [0, 2]] for a in dt_annos_part], 0) + [a["location"][:, bev_axes] for a in dt_annos_part], 0) dims = np.concatenate( - [a["dimensions"][:, [0, 2]] for a in dt_annos_part], 0) + [a["dimensions"][:, bev_axes] for a in dt_annos_part], 0) rots = np.concatenate([a["rotation_y"] for a in dt_annos_part], 0) - dt_boxes = np.concatenate( - [loc, dims, rots[..., np.newaxis]], axis=1) - overlap_part = bev_box_overlap(gt_boxes, dt_boxes).astype( - np.float64) + dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], + axis=1) + overlap_part = bev_box_overlap(gt_boxes, + dt_boxes).astype(np.float64) elif metric == 2: loc = np.concatenate([a["location"] for a in gt_annos_part], 0) dims = np.concatenate([a["dimensions"] for a in gt_annos_part], 0) rots = np.concatenate([a["rotation_y"] for a in gt_annos_part], 0) - gt_boxes = np.concatenate( - [loc, dims, rots[..., np.newaxis]], axis=1) + gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], + axis=1) loc = np.concatenate([a["location"] for a in dt_annos_part], 0) dims = np.concatenate([a["dimensions"] for a in dt_annos_part], 0) rots = np.concatenate([a["rotation_y"] for a in dt_annos_part], 0) - dt_boxes = np.concatenate( - [loc, dims, rots[..., np.newaxis]], axis=1) - overlap_part = d3_box_overlap(gt_boxes, dt_boxes).astype( - np.float64) + dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], + axis=1) + overlap_part = d3_box_overlap( + gt_boxes, dt_boxes, z_axis=z_axis, + z_center=z_center).astype(np.float64) else: raise ValueError("unknown metric") parted_overlaps.append(overlap_part) @@ -410,8 +437,9 @@ def calculate_iou_partly(gt_annos, dt_annos, metric, num_parts=50): gt_box_num = total_gt_num[example_idx + i] dt_box_num = total_dt_num[example_idx + i] overlaps.append( - parted_overlaps[j][gt_num_idx:gt_num_idx + gt_box_num, - dt_num_idx:dt_num_idx + dt_box_num]) + parted_overlaps[j][gt_num_idx:gt_num_idx + + gt_box_num, dt_num_idx:dt_num_idx + + dt_box_num]) gt_num_idx += gt_box_num dt_num_idx += dt_box_num example_idx += num_part @@ -450,102 +478,6 @@ def _prepare_data(gt_annos, dt_annos, current_class, difficulty): total_dc_num, total_num_valid_gt) -def eval_class(gt_annos, - dt_annos, - current_class, - difficulty, - metric, - min_overlap, - compute_aos=False, - num_parts=50): - """Kitti eval. Only support 2d/bev/3d/aos eval for now. - Args: - gt_annos: dict, must from get_label_annos() in kitti_common.py - dt_annos: dict, must from get_label_annos() in kitti_common.py - current_class: int, 0: car, 1: pedestrian, 2: cyclist - difficulty: int. eval difficulty, 0: easy, 1: normal, 2: hard - metric: eval type. 0: bbox, 1: bev, 2: 3d - min_overlap: float, min overlap. official: - [[0.7, 0.5, 0.5], [0.7, 0.5, 0.5], [0.7, 0.5, 0.5]] - format: [metric, class]. choose one from matrix above. - num_parts: int. a parameter for fast calculate algorithm - - Returns: - dict of recall, precision and aos - """ - assert len(gt_annos) == len(dt_annos) - num_examples = len(gt_annos) - split_parts = get_split_parts(num_examples, num_parts) - thresholdss = [] - rets = calculate_iou_partly(dt_annos, gt_annos, metric, num_parts) - overlaps, parted_overlaps, total_dt_num, total_gt_num = rets - rets = _prepare_data(gt_annos, dt_annos, current_class, difficulty) - (gt_datas_list, dt_datas_list, ignored_gts, ignored_dets, dontcares, - total_dc_num, total_num_valid_gt) = rets - - for i in range(len(gt_annos)): - rets = compute_statistics_jit( - overlaps[i], - gt_datas_list[i], - dt_datas_list[i], - ignored_gts[i], - ignored_dets[i], - dontcares[i], - metric, - min_overlap=min_overlap, - thresh=0.0, - compute_fp=False) - tp, fp, fn, similarity, thresholds = rets - thresholdss += thresholds.tolist() - thresholdss = np.array(thresholdss) - thresholds = get_thresholds(thresholdss, total_num_valid_gt) - thresholds = np.array(thresholds) - pr = np.zeros([len(thresholds), 4]) - idx = 0 - for j, num_part in enumerate(split_parts): - gt_datas_part = np.concatenate(gt_datas_list[idx:idx + num_part], 0) - dt_datas_part = np.concatenate(dt_datas_list[idx:idx + num_part], 0) - dc_datas_part = np.concatenate(dontcares[idx:idx + num_part], 0) - ignored_dets_part = np.concatenate(ignored_dets[idx:idx + num_part], 0) - ignored_gts_part = np.concatenate(ignored_gts[idx:idx + num_part], 0) - fused_compute_statistics( - parted_overlaps[j], - pr, - total_gt_num[idx:idx + num_part], - total_dt_num[idx:idx + num_part], - total_dc_num[idx:idx + num_part], - gt_datas_part, - dt_datas_part, - dc_datas_part, - ignored_gts_part, - ignored_dets_part, - metric, - min_overlap=min_overlap, - thresholds=thresholds, - compute_aos=compute_aos) - idx += num_part - N_SAMPLE_PTS = 41 - precision = np.zeros([N_SAMPLE_PTS]) - recall = np.zeros([N_SAMPLE_PTS]) - aos = np.zeros([N_SAMPLE_PTS]) - for i in range(len(thresholds)): - recall[i] = pr[i, 0] / (pr[i, 0] + pr[i, 2]) - precision[i] = pr[i, 0] / (pr[i, 0] + pr[i, 1]) - if compute_aos: - aos[i] = pr[i, 3] / (pr[i, 0] + pr[i, 1]) - for i in range(len(thresholds)): - precision[i] = np.max(precision[i:]) - recall[i] = np.max(recall[i:]) - if compute_aos: - aos[i] = np.max(aos[i:]) - ret_dict = { - "recall": recall, - "precision": precision, - "orientation": aos, - } - return ret_dict - - def eval_class_v3(gt_annos, dt_annos, current_classes, @@ -553,6 +485,8 @@ def eval_class_v3(gt_annos, metric, min_overlaps, compute_aos=False, + z_axis=1, + z_center=1.0, num_parts=50): """Kitti eval. support 2d/bev/3d/aos eval. support 0.5:0.05:0.95 coco AP. Args: @@ -573,7 +507,13 @@ def eval_class_v3(gt_annos, num_examples = len(gt_annos) split_parts = get_split_parts(num_examples, num_parts) - rets = calculate_iou_partly(dt_annos, gt_annos, metric, num_parts) + rets = calculate_iou_partly( + dt_annos, + gt_annos, + metric, + num_parts, + z_axis=z_axis, + z_center=z_center) overlaps, parted_overlaps, total_dt_num, total_gt_num = rets N_SAMPLE_PTS = 41 num_minoverlap = len(min_overlaps) @@ -656,30 +596,6 @@ def eval_class_v3(gt_annos, return ret_dict -def do_eval(gt_annos, dt_annos, current_class, min_overlaps, - compute_aos=False): - - mAP_bbox = [] - mAP_aos = [] - for i in range(3): # i=difficulty - ret = eval_class(gt_annos, dt_annos, current_class, i, 0, - min_overlaps[0], compute_aos) - mAP_bbox.append(get_mAP(ret["precision"])) - if compute_aos: - mAP_aos.append(get_mAP(ret["orientation"])) - mAP_bev = [] - for i in range(3): - ret = eval_class(gt_annos, dt_annos, current_class, i, 1, - min_overlaps[1]) - mAP_bev.append(get_mAP(ret["precision"])) - mAP_3d = [] - for i in range(3): - ret = eval_class(gt_annos, dt_annos, current_class, i, 2, - min_overlaps[2]) - mAP_3d.append(get_mAP(ret["precision"])) - return mAP_bbox, mAP_bev, mAP_3d, mAP_aos - - def get_mAP_v2(prec): sums = 0 for i in range(0, prec.shape[-1], 4): @@ -692,33 +608,68 @@ def do_eval_v2(gt_annos, current_classes, min_overlaps, compute_aos=False, - difficultys = [0, 1, 2]): + difficultys=(0, 1, 2), + z_axis=1, + z_center=1.0): # min_overlaps: [num_minoverlap, metric, num_class] - ret = eval_class_v3(gt_annos, dt_annos, current_classes, difficultys, 0, - min_overlaps, compute_aos) + ret = eval_class_v3( + gt_annos, + dt_annos, + current_classes, + difficultys, + 0, + min_overlaps, + compute_aos, + z_axis=z_axis, + z_center=z_center) # ret: [num_class, num_diff, num_minoverlap, num_sample_points] mAP_bbox = get_mAP_v2(ret["precision"]) mAP_aos = None if compute_aos: mAP_aos = get_mAP_v2(ret["orientation"]) - ret = eval_class_v3(gt_annos, dt_annos, current_classes, difficultys, 1, - min_overlaps) + ret = eval_class_v3( + gt_annos, + dt_annos, + current_classes, + difficultys, + 1, + min_overlaps, + z_axis=z_axis, + z_center=z_center) mAP_bev = get_mAP_v2(ret["precision"]) - ret = eval_class_v3(gt_annos, dt_annos, current_classes, difficultys, 2, - min_overlaps) + ret = eval_class_v3( + gt_annos, + dt_annos, + current_classes, + difficultys, + 2, + min_overlaps, + z_axis=z_axis, + z_center=z_center) mAP_3d = get_mAP_v2(ret["precision"]) return mAP_bbox, mAP_bev, mAP_3d, mAP_aos -def do_coco_style_eval(gt_annos, dt_annos, current_classes, overlap_ranges, - compute_aos): +def do_coco_style_eval(gt_annos, + dt_annos, + current_classes, + overlap_ranges, + compute_aos, + z_axis=1, + z_center=1.0): # overlap_ranges: [range, metric, num_class] min_overlaps = np.zeros([10, *overlap_ranges.shape[1:]]) for i in range(overlap_ranges.shape[1]): for j in range(overlap_ranges.shape[2]): min_overlaps[:, i, j] = np.linspace(*overlap_ranges[:, i, j]) mAP_bbox, mAP_bev, mAP_3d, mAP_aos = do_eval_v2( - gt_annos, dt_annos, current_classes, min_overlaps, compute_aos) + gt_annos, + dt_annos, + current_classes, + min_overlaps, + compute_aos, + z_axis=z_axis, + z_center=z_center) # ret: [num_class, num_diff, num_minoverlap] mAP_bbox = mAP_bbox.mean(-1) mAP_bev = mAP_bev.mean(-1) @@ -736,64 +687,21 @@ def print_str(value, *arg, sstream=None): print(value, *arg, file=sstream) return sstream.getvalue() - -def get_official_eval_result_v1(gt_annos, dt_annos, current_class): - mAP_0_7 = np.array([[0.7, 0.5, 0.5, 0.7, 0.5], [0.7, 0.5, 0.5, 0.7, 0.5], - [0.7, 0.5, 0.5, 0.7, 0.5]]) - mAP_0_5 = np.array([[0.7, 0.5, 0.5, 0.7, - 0.5], [0.5, 0.25, 0.25, 0.5, 0.25], - [0.5, 0.25, 0.25, 0.5, 0.25]]) - mAP_list = [mAP_0_7, mAP_0_5] - class_to_name = { - 0: 'Car', - 1: 'Pedestrian', - 2: 'Cyclist', - 3: 'Van', - 4: 'Person_sitting', - } - name_to_class = {v: n for n, v in class_to_name.items()} - if isinstance(current_class, str): - current_class = name_to_class[current_class] - result = '' - # check whether alpha is valid - compute_aos = False - for anno in dt_annos: - if anno['alpha'].shape[0] != 0: - if anno['alpha'][0] != -10: - compute_aos = True - break - for mAP in mAP_list: - # mAP threshold matrix: [num_minoverlap, metric, class] - mAPbbox, mAPbev, mAP3d, mAPaos = do_eval( - gt_annos, dt_annos, current_class, mAP[:, current_class], - compute_aos) - # mAP: [num_class, num_diff, num_minoverlap] - result += print_str( - (f"{class_to_name[current_class]} " - "AP@{:.2f}, {:.2f}, {:.2f}:".format(*mAP[:, current_class]))) - result += print_str((f"bbox AP:{mAPbbox[0]:.2f}, " - f"{mAPbbox[1]:.2f}, " - f"{mAPbbox[2]:.2f}")) - result += print_str((f"bev AP:{mAPbev[0]:.2f}, " - f"{mAPbev[1]:.2f}, " - f"{mAPbev[2]:.2f}")) - result += print_str((f"3d AP:{mAP3d[0]:.2f}, " - f"{mAP3d[1]:.2f}, " - f"{mAP3d[2]:.2f}")) - if compute_aos: - result += print_str((f"aos AP:{mAPaos[0]:.2f}, " - f"{mAPaos[1]:.2f}, " - f"{mAPaos[2]:.2f}")) - - return result - - -def get_official_eval_result(gt_annos, dt_annos, current_classes, difficultys=[0, 1, 2]): - overlap_0_7 = np.array([[0.7, 0.5, 0.5, 0.7, - 0.5, 0.7, 0.7, 0.7], [0.7, 0.5, 0.5, 0.7, 0.5, 0.7, 0.7, 0.7], +def get_official_eval_result(gt_annos, + dt_annos, + current_classes, + difficultys=[0, 1, 2], + z_axis=1, + z_center=1.0): + """ + gt_annos and dt_annos must contains following keys: + [bbox, location, dimensions, rotation_y, score] + """ + overlap_0_7 = np.array([[0.7, 0.5, 0.5, 0.7, 0.5, 0.7, 0.7, 0.7], + [0.7, 0.5, 0.5, 0.7, 0.5, 0.7, 0.7, 0.7], [0.7, 0.5, 0.5, 0.7, 0.5, 0.7, 0.7, 0.7]]) - overlap_0_5 = np.array([[0.7, 0.5, 0.5, 0.7, - 0.5, 0.5, 0.5, 0.5], [0.5, 0.25, 0.25, 0.5, 0.25, 0.5, 0.5, 0.5], + overlap_0_5 = np.array([[0.7, 0.5, 0.5, 0.7, 0.5, 0.5, 0.5, 0.5], + [0.5, 0.25, 0.25, 0.5, 0.25, 0.5, 0.5, 0.5], [0.5, 0.25, 0.25, 0.5, 0.25, 0.5, 0.5, 0.5]]) min_overlaps = np.stack([overlap_0_7, overlap_0_5], axis=0) # [2, 3, 5] class_to_name = { @@ -826,7 +734,14 @@ def get_official_eval_result(gt_annos, dt_annos, current_classes, difficultys=[0 compute_aos = True break mAPbbox, mAPbev, mAP3d, mAPaos = do_eval_v2( - gt_annos, dt_annos, current_classes, min_overlaps, compute_aos, difficultys) + gt_annos, + dt_annos, + current_classes, + min_overlaps, + compute_aos, + difficultys, + z_axis=z_axis, + z_center=z_center) for j, curcls in enumerate(current_classes): # mAP threshold array: [num_minoverlap, metric, class] # mAP result: [num_class, num_diff, num_minoverlap] @@ -850,7 +765,12 @@ def get_official_eval_result(gt_annos, dt_annos, current_classes, difficultys=[0 return result -def get_coco_eval_result(gt_annos, dt_annos, current_classes): + +def get_coco_eval_result(gt_annos, + dt_annos, + current_classes, + z_axis=1, + z_center=1.0): class_to_name = { 0: 'Car', 1: 'Pedestrian', @@ -880,7 +800,6 @@ def get_coco_eval_result(gt_annos, dt_annos, current_classes): 5: [0.5, 0.95, 10], 6: [0.5, 0.95, 10], 7: [0.5, 0.95, 10], - } name_to_class = {v: n for n, v in class_to_name.items()} @@ -895,7 +814,8 @@ def get_coco_eval_result(gt_annos, dt_annos, current_classes): current_classes = current_classes_int overlap_ranges = np.zeros([3, 3, len(current_classes)]) for i, curcls in enumerate(current_classes): - overlap_ranges[:, :, i] = np.array(class_to_range[curcls])[:, np.newaxis] + overlap_ranges[:, :, i] = np.array( + class_to_range[curcls])[:, np.newaxis] result = '' # check whether alpha is valid compute_aos = False @@ -905,26 +825,32 @@ def get_coco_eval_result(gt_annos, dt_annos, current_classes): compute_aos = True break mAPbbox, mAPbev, mAP3d, mAPaos = do_coco_style_eval( - gt_annos, dt_annos, current_classes, overlap_ranges, compute_aos) + gt_annos, + dt_annos, + current_classes, + overlap_ranges, + compute_aos, + z_axis=z_axis, + z_center=z_center) for j, curcls in enumerate(current_classes): # mAP threshold array: [num_minoverlap, metric, class] # mAP result: [num_class, num_diff, num_minoverlap] o_range = np.array(class_to_range[curcls])[[0, 2, 1]] o_range[1] = (o_range[2] - o_range[0]) / (o_range[1] - 1) - result += print_str( - (f"{class_to_name[curcls]} " - "coco AP@{:.2f}:{:.2f}:{:.2f}:".format(*o_range))) + result += print_str((f"{class_to_name[curcls]} " + "coco AP@{:.2f}:{:.2f}:{:.2f}:".format(*o_range))) result += print_str((f"bbox AP:{mAPbbox[j, 0]:.2f}, " - f"{mAPbbox[j, 1]:.2f}, " - f"{mAPbbox[j, 2]:.2f}")) + f"{mAPbbox[j, 1]:.2f}, " + f"{mAPbbox[j, 2]:.2f}")) result += print_str((f"bev AP:{mAPbev[j, 0]:.2f}, " - f"{mAPbev[j, 1]:.2f}, " - f"{mAPbev[j, 2]:.2f}")) + f"{mAPbev[j, 1]:.2f}, " + f"{mAPbev[j, 2]:.2f}")) result += print_str((f"3d AP:{mAP3d[j, 0]:.2f}, " - f"{mAP3d[j, 1]:.2f}, " - f"{mAP3d[j, 2]:.2f}")) + f"{mAP3d[j, 1]:.2f}, " + f"{mAP3d[j, 2]:.2f}")) if compute_aos: result += print_str((f"aos AP:{mAPaos[j, 0]:.2f}, " - f"{mAPaos[j, 1]:.2f}, " - f"{mAPaos[j, 2]:.2f}")) + f"{mAPaos[j, 1]:.2f}, " + f"{mAPaos[j, 2]:.2f}")) return result + diff --git a/second/utils/log_tool.py b/second/utils/log_tool.py new file mode 100644 index 00000000..8352ae67 --- /dev/null +++ b/second/utils/log_tool.py @@ -0,0 +1,35 @@ +def _flat_nested_json_dict(json_dict, flatted, sep=".", start=""): + for k, v in json_dict.items(): + if isinstance(v, dict): + _flat_nested_json_dict(v, flatted, sep, start + sep + k) + else: + flatted[start + sep + k] = v + + +def flat_nested_json_dict(json_dict, sep=".") -> dict: + """flat a nested json-like dict. this function make shadow copy. + """ + flatted = {} + for k, v in json_dict.items(): + if isinstance(v, dict): + _flat_nested_json_dict(v, flatted, sep, k) + else: + flatted[k] = v + return flatted + +def metric_to_str(metrics, sep='.'): + flatted_metrics = flat_nested_json_dict(metrics, sep) + metrics_str_list = [] + for k, v in flatted_metrics.items(): + if isinstance(v, float): + metrics_str_list.append(f"{k}={v:.3}") + elif isinstance(v, (list, tuple)): + if v and isinstance(v[0], float): + v_str = ', '.join([f"{e:.3}" for e in v]) + metrics_str_list.append(f"{k}=[{v_str}]") + else: + metrics_str_list.append(f"{k}={v}") + else: + metrics_str_list.append(f"{k}={v}") + return ', '.join(metrics_str_list) + diff --git a/second/utils/simplevis.py b/second/utils/simplevis.py new file mode 100644 index 00000000..9f2cef70 --- /dev/null +++ b/second/utils/simplevis.py @@ -0,0 +1,191 @@ +import cv2 +import numba +import numpy as np + +from second.core import box_np_ops + + +@numba.jit(nopython=True) +def _points_to_bevmap_reverse_kernel( + points, + voxel_size, + coors_range, + coor_to_voxelidx, + # coors_2d, + bev_map, + height_lowers, + # density_norm_num=16, + with_reflectivity=False, + max_voxels=40000): + # put all computations to one loop. + # we shouldn't create large array in main jit code, otherwise + # reduce performance + N = points.shape[0] + ndim = 3 + ndim_minus_1 = ndim - 1 + grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size + # np.round(grid_size) + # grid_size = np.round(grid_size).astype(np.int64)(np.int32) + grid_size = np.round(grid_size, 0, grid_size).astype(np.int32) + height_slice_size = voxel_size[-1] + coor = np.zeros(shape=(3, ), dtype=np.int32) # DHW + voxel_num = 0 + failed = False + for i in range(N): + failed = False + for j in range(ndim): + c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j]) + if c < 0 or c >= grid_size[j]: + failed = True + break + coor[ndim_minus_1 - j] = c + if failed: + continue + voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]] + if voxelidx == -1: + voxelidx = voxel_num + if voxel_num >= max_voxels: + break + voxel_num += 1 + coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx + # coors_2d[voxelidx] = coor[1:] + bev_map[-1, coor[1], coor[2]] += 1 + height_norm = bev_map[coor[0], coor[1], coor[2]] + incomimg_height_norm = ( + points[i, 2] - height_lowers[coor[0]]) / height_slice_size + if incomimg_height_norm > height_norm: + bev_map[coor[0], coor[1], coor[2]] = incomimg_height_norm + if with_reflectivity: + bev_map[-2, coor[1], coor[2]] = points[i, 3] + # return voxel_num + + +def points_to_bev(points, + voxel_size, + coors_range, + with_reflectivity=False, + density_norm_num=16, + max_voxels=40000): + """convert kitti points(N, 4) to a bev map. return [C, H, W] map. + this function based on algorithm in points_to_voxel. + takes 5ms in a reduced pointcloud with voxel_size=[0.1, 0.1, 0.8] + + Args: + points: [N, ndim] float tensor. points[:, :3] contain xyz points and + points[:, 3] contain reflectivity. + voxel_size: [3] list/tuple or array, float. xyz, indicate voxel size + coors_range: [6] list/tuple or array, float. indicate voxel range. + format: xyzxyz, minmax + with_reflectivity: bool. if True, will add a intensity map to bev map. + Returns: + bev_map: [num_height_maps + 1(2), H, W] float tensor. + `WARNING`: bev_map[-1] is num_points map, NOT density map, + because calculate density map need more time in cpu rather than gpu. + if with_reflectivity is True, bev_map[-2] is intensity map. + """ + if not isinstance(voxel_size, np.ndarray): + voxel_size = np.array(voxel_size, dtype=points.dtype) + if not isinstance(coors_range, np.ndarray): + coors_range = np.array(coors_range, dtype=points.dtype) + voxelmap_shape = (coors_range[3:] - coors_range[:3]) / voxel_size + voxelmap_shape = tuple(np.round(voxelmap_shape).astype(np.int32).tolist()) + voxelmap_shape = voxelmap_shape[::-1] # DHW format + coor_to_voxelidx = -np.ones(shape=voxelmap_shape, dtype=np.int32) + # coors_2d = np.zeros(shape=(max_voxels, 2), dtype=np.int32) + bev_map_shape = list(voxelmap_shape) + bev_map_shape[0] += 1 + height_lowers = np.linspace( + coors_range[2], coors_range[5], voxelmap_shape[0], endpoint=False) + if with_reflectivity: + bev_map_shape[0] += 1 + bev_map = np.zeros(shape=bev_map_shape, dtype=points.dtype) + _points_to_bevmap_reverse_kernel(points, voxel_size, coors_range, + coor_to_voxelidx, bev_map, height_lowers, + with_reflectivity, max_voxels) + # print(voxel_num) + return bev_map + + +def point_to_vis_bev(points, + voxel_size=None, + coors_range=None, + max_voxels=80000): + if voxel_size is None: + voxel_size = (0.1, 0.1, 0.1) + if coors_range is None: + coors_range = (-50, -50, -3, 50, 50, 1) + voxel_size[2] = coors_range[5] - coors_range[2] + bev_map = points_to_bev( + points, voxel_size, coors_range, max_voxels=max_voxels) + height_map = (bev_map[0] * 255).astype(np.uint8) + return cv2.cvtColor(height_map, cv2.COLOR_GRAY2RGB) + + +def cv2_draw_lines(img, lines, colors, thickness, line_type=cv2.LINE_8): + lines = lines.astype(np.int32) + for line, color in zip(lines, colors): + color = list(int(c) for c in color) + cv2.line(img, (line[0], line[1]), (line[2], line[3]), color, thickness) + return img + + +def cv2_draw_text(img, locs, labels, colors, thickness, line_type=cv2.LINE_8): + locs = locs.astype(np.int32) + font_line_type = cv2.LINE_8 + font = cv2.FONT_ITALIC + font = cv2.FONT_HERSHEY_DUPLEX + font = cv2.FONT_HERSHEY_PLAIN + font = cv2.FONT_HERSHEY_SIMPLEX + for loc, label, color in zip(locs, labels, colors): + color = list(int(c) for c in color) + cv2.putText(img, label, tuple(loc), font, 0.7, color, thickness, + font_line_type, False) + return img + + +def draw_box_in_bev(img, + coors_range, + boxes, + color, + thickness=1, + labels=None, + label_color=None): + """ + Args: + boxes: center format. + """ + coors_range = np.array(coors_range) + bev_corners = box_np_ops.center_to_corner_box2d( + boxes[:, [0, 1]], boxes[:, [3, 4]], boxes[:, 6]) + bev_corners -= coors_range[:2] + bev_corners *= np.array( + img.shape[:2])[::-1] / (coors_range[3:5] - coors_range[:2]) + standup = box_np_ops.corner_to_standup_nd(bev_corners) + text_center = standup[:, 2:] + text_center[:, 1] -= (standup[:, 3] - standup[:, 1]) / 2 + + bev_lines = np.concatenate( + [bev_corners[:, [0, 2, 3]], bev_corners[:, [1, 3, 0]]], axis=2) + bev_lines = bev_lines.reshape(-1, 4) + colors = np.tile(np.array(color).reshape(1, 3), [bev_lines.shape[0], 1]) + colors = colors.astype(np.int32) + img = cv2_draw_lines(img, bev_lines, colors, thickness) + if labels is not None: + if label_color is None: + label_color = colors + else: + label_color = np.tile( + np.array(label_color).reshape(1, 3), [bev_lines.shape[0], 1]) + label_color = label_color.astype(np.int32) + + img = cv2_draw_text(img, text_center, labels, label_color, + thickness * 2) + return img + +def kitti_vis(points, boxes, labels=None): + vis_voxel_size = [0.1, 0.1, 0.1] + vis_point_range = [0, -30, -3, 64, 30, 1] + bev_map = point_to_vis_bev(points, vis_voxel_size, vis_point_range) + bev_map = draw_box_in_bev(bev_map, vis_point_range, boxes, [0, 255, 0], 2, labels) + + return bev_map \ No newline at end of file