Skip to content

Commit

Permalink
minor fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
donnyyou committed Jun 17, 2019
1 parent 8ef3ebb commit ed1e922
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 12 deletions.
1 change: 0 additions & 1 deletion hypes/det/voc/ssd512_vgg16_voc_det.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
"log_format": "%(asctime)s %(levelname)-7s %(message)s",
"rewrite": true
},

"solver": {
"lr": {
"metric": "epoch",
Expand Down
2 changes: 0 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def str2bool(v):
dest='logging:log_to_file', help='Whether to write logging into files.')

# *********** Params for test or submission. **********
parser.add_argument('--test_img', default=None, type=str,
dest='test:test_img', help='The test path of image.')
parser.add_argument('--test_dir', default=None, type=str,
dest='test:test_dir', help='The test directory of images.')
parser.add_argument('--root_dir', default=None, type=str,
Expand Down
11 changes: 3 additions & 8 deletions methods/tools/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,11 @@ def test(runner):
runner.configer.get('network', 'checkpoints_name'),
runner.configer.get('test', 'out_dir'))

test_img = runner.configer.get('test', 'test_img')
test_dir = runner.configer.get('test', 'test_dir')
if test_img is None and test_dir is None:
Log.error('test_img & test_dir not exists.')
if test_dir is None:
Log.error('test_dir not given!!!')
exit(1)

if test_img is not None:
runner.test_img(test_img, out_dir)

if test_dir is not None:
runner.test(test_dir, out_dir)
runner.test(test_dir, out_dir)

Log.info('Testing end...')
2 changes: 1 addition & 1 deletion methods/tools/runner_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.nn as nn
from torch.nn.parallel.scatter_gather import gather as torch_gather

from extensions.tools.parallel.data_parallel import DataParallelModel
from utils.tools.logger import Logger as Log


Expand All @@ -33,6 +32,7 @@ def _make_parallel(runner, net):
if len(runner.configer.get('gpu')) == 1 or len(range(torch.cuda.device_count())) == 1:
runner.configer.update(['network', 'gathered'], True)

from extensions.tools.parallel.data_parallel import DataParallelModel
return DataParallelModel(net, gather_=runner.configer.get('network', 'gathered'))

@staticmethod
Expand Down

0 comments on commit ed1e922

Please sign in to comment.