Skip to content

Commit

Permalink
fix seg training.
Browse files Browse the repository at this point in the history
  • Loading branch information
donnyyou committed Mar 9, 2019
1 parent a670c06 commit cfddf23
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def str2bool(v):
dest='network:checkpoints_name', help='The name of checkpoint model.')
parser.add_argument('--backbone', default=None, type=str,
dest='network:backbone', help='The base network of model.')
parser.add_argument('--bn_type', default=None, type=str,
dest='network:bn_type', help='The BN type of the network.')
parser.add_argument('--norm_type', default=None, type=str,
dest='network:norm_type', help='The BN type of the network.')
parser.add_argument('--multi_grid', default=None, nargs='+', type=int,
dest='network:multi_grid', help='The multi_grid for resnet backbone.')
parser.add_argument('--pretrained', type=str, default=None,
Expand Down Expand Up @@ -148,8 +148,8 @@ def str2bool(v):
if configer.get('gpu') is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu_id) for gpu_id in configer.get('gpu'))

if configer.get('network', 'bn_type') is None:
configer.update(['network', 'bn_type'], 'torchbn')
if configer.get('network', 'norm_type') is None:
configer.update(['network', 'norm_type'], 'batchnorm')

project_dir = os.path.dirname(os.path.realpath(__file__))
configer.add(['project_dir'], project_dir)
Expand Down
4 changes: 2 additions & 2 deletions methods/tools/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def init(runner):
runner.runner_state['max_performance'] = 0
runner.runner_state['min_val_loss'] = 0
if runner.configer.get('phase') == 'train':
assert len(runner.configer.get('gpu')) > 1 or runner.configer.get('network', 'bn_type') == 'torchbn'
assert len(runner.configer.get('gpu')) > 1 or runner.configer.get('network', 'norm_type') == 'batchnorm'

Log.info('BN Type is {}.'.format(runner.configer.get('network', 'bn_type')))
Log.info('BN Type is {}.'.format(runner.configer.get('network', 'norm_type')))

@staticmethod
def train(runner):
Expand Down
2 changes: 1 addition & 1 deletion utils/helpers/mask_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
try:
import pycocotools.mask as mask_util
except ImportError:
Log.error('pycocotools ImportError.')
print('pycocotools ImportError.')


class MaskHelper(object):
Expand Down

0 comments on commit cfddf23

Please sign in to comment.