Skip to content

Commit

Permalink
test faster rcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Feb 18, 2024
1 parent ce5684f commit b98cc0d
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 40 deletions.
13 changes: 4 additions & 9 deletions mmdet2trt/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from argparse import ArgumentParser
from pathlib import Path
import logging

import torch

from .mmdet2trt import mmdet2trt
Expand All @@ -18,7 +19,8 @@ def _parse_args():
parser.add_argument('config', help='Path to a mmdet Config file')
parser.add_argument('checkpoint', help='Path to a mmdet Checkpoint file')
parser.add_argument(
'--output', default=None,
'--output',
default=None,
help='Path where tensorrt model will be saved')
parser.add_argument(
'--fp16', action='store_true', help='Enable fp16 inference')
Expand Down Expand Up @@ -81,13 +83,6 @@ def _parse_args():
choices=['VERBOSE', 'INFO', 'WARNING', 'ERROR'],
help='TensorRT logging level.',
)
parser.add_argument(
'--output-names',
nargs=4,
type=str,
default=['num_detections', 'boxes', 'scores', 'classes'],
help='Names for the output nodes of the created TRTModule',
)
args = parser.parse_args()
return args

Expand Down
20 changes: 0 additions & 20 deletions mmdet2trt/converters/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,6 @@ def convert_AnchorGeneratorDynamic(ctx):
'ag_' + str(id(module)), stride=stride)
else:
print('no base_anchors in {}'.format(ag.generator))
# scales = ag.scales.detach().cpu().numpy().astype(np.float32)
# ratios = ag.ratios.detach().cpu().numpy().astype(np.float32)
# scale_major = ag.scale_major
# ctr = ag.ctr
# if ctr is None:
# # center_x = -1
# # center_y = -1
# center_x = 0
# center_y = 0
# else:
# center_x, center_y = ag.ctr

# plugin = create_gridanchordynamic_plugin("ag_" + str(id(module)),
# base_size=base_size,
# stride=stride,
# scales=scales,
# ratios=ratios,
# scale_major=scale_major,
# center_x=center_x,
# center_y=center_y)

custom_layer = ctx.network.add_plugin_v2(
inputs=[input_trt, base_anchors_trt], plugin=plugin)
Expand Down
4 changes: 2 additions & 2 deletions mmdet2trt/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RPNHeadWraper(nn.Module):
def __init__(self, module):
super(RPNHeadWraper, self).__init__()
self.module = module
self.anchor_generator = build_wrapper(self.module.anchor_generator)
self.prior_generator = build_wrapper(self.module.prior_generator)
self.bbox_coder = build_wrapper(self.module.bbox_coder)

self.test_cfg = module.test_cfg
Expand All @@ -29,7 +29,7 @@ def forward(self, feat, x):

cls_scores, bbox_preds = module(feat)

mlvl_anchors = self.anchor_generator(
mlvl_anchors = self.prior_generator(
cls_scores, device=cls_scores[0].device)

mlvl_scores = []
Expand Down
4 changes: 2 additions & 2 deletions mmdet2trt/models/dense_heads/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class YOLOV3HeadWraper(nn.Module):
def __init__(self, module):
super(YOLOV3HeadWraper, self).__init__()
self.module = module
self.anchor_generator = build_wrapper(self.module.anchor_generator)
self.prior_generator = build_wrapper(self.module.prior_generator)
self.bbox_coder = build_wrapper(self.module.bbox_coder)
self.featmap_strides = module.featmap_strides
self.num_attrib = module.num_attrib
Expand All @@ -34,7 +34,7 @@ def forward(self, feats, x):

pred_maps_list = module(feats)[0]

multi_lvl_anchors = self.anchor_generator(
multi_lvl_anchors = self.prior_generator(
pred_maps_list, device=pred_maps_list[0].device)

multi_lvl_bboxes = []
Expand Down
8 changes: 5 additions & 3 deletions mmdet2trt/structures/bbox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .bbox_overlaps import bbox_overlaps_batched
from .transforms import batched_distance2bbox, batched_bbox_cxcywh_to_xyxy
from .transforms import batched_bbox_cxcywh_to_xyxy, batched_distance2bbox

__all__ = ['batched_distance2bbox', 'batched_bbox_cxcywh_to_xyxy',
'bbox_overlaps_batched']
__all__ = [
'batched_distance2bbox', 'batched_bbox_cxcywh_to_xyxy',
'bbox_overlaps_batched'
]
3 changes: 1 addition & 2 deletions mmdet2trt/structures/bbox/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch

import mmdet2trt.ops.util_ops as mm2trt_util
import torch


def batched_distance2bbox(points, distance, max_shape=None):
Expand Down
3 changes: 1 addition & 2 deletions tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,12 @@ def inference_test(trt_model,

image_path = osp.join(test_folder, file_name)
image = mmcv.imread(image_path)
image = mmcv.imconvert(image, 'bgr', 'rgb')

result = inference_detector(wrap_model, image)

visualizer.add_datasample(
'result',
image,
mmcv.imconvert(image, 'bgr', 'rgb'),
data_sample=result,
draw_gt=False,
show=False,
Expand Down

0 comments on commit b98cc0d

Please sign in to comment.