Skip to content

Commit

Permalink
[Refactor] Camera keys (open-mmlab#805)
Browse files Browse the repository at this point in the history
* replace all cam_intrinsic

* revert cam_intrinsic in test

* remove rect and fix lint
  • Loading branch information
filaPro authored Aug 6, 2021
1 parent d2fb7ab commit 9f0b01c
Show file tree
Hide file tree
Showing 13 changed files with 38 additions and 43 deletions.
4 changes: 2 additions & 2 deletions mmdet3d/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def show_proj_det_result_meshlab(data,
img_metas=data['img_metas'][0][0],
show=show)
elif box_mode == Box3DMode.CAM:
if 'cam_intrinsic' not in data['img_metas'][0][0]:
if 'cam2img' not in data['img_metas'][0][0]:
raise NotImplementedError(
'camera intrinsic matrix is not provided')

Expand All @@ -434,7 +434,7 @@ def show_proj_det_result_meshlab(data,
img,
None,
show_bboxes,
data['img_metas'][0][0]['cam_intrinsic'],
data['img_metas'][0][0]['cam2img'],
out_dir,
file_name,
box_mode='camera',
Expand Down
14 changes: 7 additions & 7 deletions mmdet3d/core/visualizer/image_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def draw_depth_bbox3d_on_img(bboxes3d,

def draw_camera_bbox3d_on_img(bboxes3d,
raw_img,
cam_intrinsic,
cam2img,
img_metas,
color=(0, 255, 0),
thickness=1):
Expand All @@ -172,7 +172,7 @@ def draw_camera_bbox3d_on_img(bboxes3d,
bboxes3d (:obj:`CameraInstance3DBoxes`, shape=[M, 7]):
3d bbox in camera coordinate system to visualize.
raw_img (numpy.array): The numpy array of image.
cam_intrinsic (dict): Camera intrinsic matrix,
cam2img (dict): Camera intrinsic matrix,
denoted as `K` in depth bbox coordinate system.
img_metas (dict): Useless here.
color (tuple[int]): The color to draw bboxes. Default: (0, 255, 0).
Expand All @@ -181,16 +181,16 @@ def draw_camera_bbox3d_on_img(bboxes3d,
from mmdet3d.core.bbox import points_cam2img

img = raw_img.copy()
cam_intrinsic = copy.deepcopy(cam_intrinsic)
cam2img = copy.deepcopy(cam2img)
corners_3d = bboxes3d.corners
num_bbox = corners_3d.shape[0]
points_3d = corners_3d.reshape(-1, 3)
if not isinstance(cam_intrinsic, torch.Tensor):
cam_intrinsic = torch.from_numpy(np.array(cam_intrinsic))
cam_intrinsic = cam_intrinsic.reshape(3, 3).float().cpu()
if not isinstance(cam2img, torch.Tensor):
cam2img = torch.from_numpy(np.array(cam2img))
cam2img = cam2img.reshape(3, 3).float().cpu()

# project to 2d to get image coords (uv)
uv_origin = points_cam2img(points_3d, cam_intrinsic)
uv_origin = points_cam2img(points_3d, cam2img)
uv_origin = (uv_origin - 1).round()
imgfov_pts_2d = uv_origin[..., :2].reshape(num_bbox, 8, 2).numpy()

Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/datasets/nuscenes_mono_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def show(self, results, out_dir, show=True, pipeline=None):
img,
gt_bboxes,
pred_bboxes,
img_metas['cam_intrinsic'],
img_metas['cam2img'],
out_dir,
file_name,
box_mode='camera',
Expand Down
23 changes: 10 additions & 13 deletions mmdet3d/datasets/pipelines/formating.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,17 @@ class Collect3D(object):
- 'pad_shape': image shape after padding
- 'lidar2img': transform from lidar to image
- 'depth2img': transform from depth to image
- 'cam2img': transform from camera to image
- 'pcd_horizontal_flip': a boolean indicating if point cloud is \
flipped horizontally
- 'pcd_vertical_flip': a boolean indicating if point cloud is \
flipped vertically
- 'box_mode_3d': 3D box mode
- 'box_type_3d': 3D box type
- 'img_norm_cfg': a dict of normalization information:
- mean: per channel mean subtraction
- std: per channel std divisor
- to_rgb: bool indicating if bgr was converted to rgb
- 'rect': rectification matrix
- 'Trv2c': transformation from velodyne to camera coordinate
- 'P2': transformation betweeen cameras
- 'pcd_trans': point cloud transformations
- 'sample_idx': sample index
- 'pcd_scale_factor': point cloud scale factor
Expand All @@ -125,22 +122,22 @@ class Collect3D(object):
keys (Sequence[str]): Keys of results to be collected in ``data``.
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ('filename', 'ori_shape', 'img_shape', 'lidar2img', \
'pad_shape', 'scale_factor', 'flip', 'pcd_horizontal_flip', \
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d', \
'img_norm_cfg', 'rect', 'Trv2c', 'P2', 'pcd_trans', \
Default: ('filename', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'flip',
'pcd_horizontal_flip', 'pcd_vertical_flip', 'box_mode_3d',
'box_type_3d', 'img_norm_cfg', 'pcd_trans',
'sample_idx', 'pcd_scale_factor', 'pcd_rotation', 'pts_filename')
"""

def __init__(self,
keys,
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'pad_shape', 'scale_factor', 'flip',
'cam_intrinsic', 'pcd_horizontal_flip',
'depth2img', 'cam2img', 'pad_shape',
'scale_factor', 'flip', 'pcd_horizontal_flip',
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
'img_norm_cfg', 'rect', 'Trv2c', 'P2', 'pcd_trans',
'sample_idx', 'pcd_scale_factor', 'pcd_rotation',
'pts_filename', 'transformation_3d_flow')):
'img_norm_cfg', 'pcd_trans', 'sample_idx',
'pcd_scale_factor', 'pcd_rotation', 'pts_filename',
'transformation_3d_flow')):
self.keys = keys
self.meta_keys = meta_keys

Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __call__(self, results):
dict: The dict contains loaded image and meta information.
"""
super().__call__(results)
results['cam_intrinsic'] = results['img_info']['cam_intrinsic']
results['cam2img'] = results['img_info']['cam_intrinsic']
return results


Expand Down
5 changes: 2 additions & 3 deletions mmdet3d/datasets/pipelines/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,10 @@ def random_flip_data_3d(self, input_dict, direction='horizontal'):
w - input_dict['centers2d'][..., 0]
# need to modify the horizontal position of camera center
# along u-axis in the image (flip like centers2d)
# ['cam_intrinsic'][0][2] = c_u
# ['cam2img'][0][2] = c_u
# see more details and examples at
# https://github.com/open-mmlab/mmdetection3d/pull/744
input_dict['cam_intrinsic'][0][2] = \
w - input_dict['cam_intrinsic'][0][2]
input_dict['cam2img'][0][2] = w - input_dict['cam2img'][0][2]

def __call__(self, input_dict):
"""Call function to flip points, values in the ``bbox3d_fields`` and \
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/models/dense_heads/fcos_mono3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def _get_bboxes_single(self,
Returns:
tuples[Tensor]: Predicted 3D boxes, scores, labels and attributes.
"""
view = np.array(input_meta['cam_intrinsic'])
view = np.array(input_meta['cam2img'])
scale_factor = input_meta['scale_factor']
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
Expand Down
7 changes: 3 additions & 4 deletions mmdet3d/models/detectors/single_stage_mono3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,10 @@ def show_results(self, data, result, out_dir):
if isinstance(data['img_metas'][0], DC):
img_filename = data['img_metas'][0]._data[0][batch_id][
'filename']
cam_intrinsic = data['img_metas'][0]._data[0][batch_id][
'cam_intrinsic']
cam2img = data['img_metas'][0]._data[0][batch_id]['cam2img']
elif mmcv.is_list_of(data['img_metas'][0], dict):
img_filename = data['img_metas'][0][batch_id]['filename']
cam_intrinsic = data['img_metas'][0][batch_id]['cam_intrinsic']
cam2img = data['img_metas'][0][batch_id]['cam2img']
else:
ValueError(
f"Unsupported data type {type(data['img_metas'][0])} "
Expand All @@ -211,7 +210,7 @@ def show_results(self, data, result, out_dir):
img,
None,
pred_bboxes,
cam_intrinsic,
cam2img,
out_dir,
file_name,
'camera',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def test_load_image_from_file_mono_3d():
img_info=dict(filename=filename, cam_intrinsic=cam_intrinsic.copy()))
results = load_image_from_file_mono_3d(input_dict)
assert results['img'].shape == (900, 1600, 3)
assert np.all(results['cam_intrinsic'] == cam_intrinsic)
assert np.all(results['cam2img'] == cam_intrinsic)

repr_str = repr(load_image_from_file_mono_3d)
expected_repr_str = 'LoadImageFromFileMono3D(to_float32=False, ' \
Expand Down
6 changes: 3 additions & 3 deletions tests/test_models/test_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,9 @@ def test_fcos3d():
attr_labels = [torch.randint(0, 9, [3], device='cuda')]
img_metas = [
dict(
cam_intrinsic=[[1260.8474446004698, 0.0, 807.968244525554],
[0.0, 1260.8474446004698, 495.3344268742088],
[0.0, 0.0, 1.0]],
cam2img=[[1260.8474446004698, 0.0, 807.968244525554],
[0.0, 1260.8474446004698, 495.3344268742088],
[0.0, 0.0, 1.0]],
scale_factor=np.array([1., 1., 1., 1.], dtype=np.float32),
box_type_3d=CameraInstance3DBoxes)
]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_models/test_heads/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,9 +1088,9 @@ def test_fcos_mono3d_head():
attr_labels = [torch.randint(0, 9, [3], device='cuda') for i in range(2)]
img_metas = [
dict(
cam_intrinsic=[[1260.8474446004698, 0.0, 807.968244525554],
[0.0, 1260.8474446004698, 495.3344268742088],
[0.0, 0.0, 1.0]],
cam2img=[[1260.8474446004698, 0.0, 807.968244525554],
[0.0, 1260.8474446004698, 495.3344268742088],
[0.0, 0.0, 1.0]],
scale_factor=np.array([1., 1., 1., 1.], dtype=np.float32),
box_type_3d=CameraInstance3DBoxes) for i in range(2)
]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_runtime/test_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def test_show_result_meshlab():
torch.tensor(
[[6.4495, -3.9097, -1.7409, 1.5063, 3.1819, 1.4716, 1.8782]]))
img = np.random.randn(1, 3, 384, 1280)
cam_intrinsic = np.array([[100.0, 0.0, 50.0], [0.0, 100.0, 50.0],
[0.0, 0.0, 1.0]])
cam2img = np.array([[100.0, 0.0, 50.0], [0.0, 100.0, 50.0],
[0.0, 0.0, 1.0]])
img_meta = dict(
filename=filename,
pcd_horizontal_flip=False,
Expand All @@ -199,7 +199,7 @@ def test_show_result_meshlab():
box_type_3d=CameraInstance3DBoxes,
pcd_trans=np.array([0., 0., 0.]),
pcd_scale_factor=1.0,
cam_intrinsic=cam_intrinsic)
cam2img=cam2img)
data = dict(
points=[[torch.tensor(points)]], img_metas=[[img_meta]], img=[img])
result = [
Expand Down
2 changes: 1 addition & 1 deletion tools/misc/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def show_proj_bbox_img(idx,
img,
gt_bboxes,
None,
img_metas['cam_intrinsic'],
img_metas['cam2img'],
out_dir,
filename,
box_mode='camera',
Expand Down

0 comments on commit 9f0b01c

Please sign in to comment.