forked from facebookresearch/Detectron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_net.py
executable file
·128 lines (112 loc) · 3.84 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python2
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Train a network with Detectron."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import logging
import numpy as np
import pprint
import sys
from caffe2.python import workspace
from core.config import assert_and_infer_cfg
from core.config import cfg
from core.config import merge_cfg_from_file
from core.config import merge_cfg_from_list
from core.test_engine import run_inference
from utils.logging import setup_logging
import utils.c2
import utils.train
utils.c2.import_contrib_ops()
utils.c2.import_detectron_ops()
# OpenCL may be enabled by default in OpenCV3; disable it because it's not
# thread safe and causes unwanted GPU memory allocations.
cv2.ocl.setUseOpenCL(False)
def parse_args():
parser = argparse.ArgumentParser(
description='Train a network with Detectron'
)
parser.add_argument(
'--cfg',
dest='cfg_file',
help='Config file for training (and optionally testing)',
default=None,
type=str
)
parser.add_argument(
'--multi-gpu-testing',
dest='multi_gpu_testing',
help='Use cfg.NUM_GPUS GPUs for inference',
action='store_true'
)
parser.add_argument(
'--skip-test',
dest='skip_test',
help='Do not test the final model',
action='store_true'
)
parser.add_argument(
'opts',
help='See lib/core/config.py for all options',
default=None,
nargs=argparse.REMAINDER
)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def main():
# Initialize C2
workspace.GlobalInit(
['caffe2', '--caffe2_log_level=0', '--caffe2_gpu_memory_tracking=1']
)
# Set up logging and load config options
logger = setup_logging(__name__)
logging.getLogger('roi_data.loader').setLevel(logging.INFO)
args = parse_args()
logger.info('Called with args:')
logger.info(args)
if args.cfg_file is not None:
merge_cfg_from_file(args.cfg_file)
if args.opts is not None:
merge_cfg_from_list(args.opts)
assert_and_infer_cfg()
logger.info('Training with config:')
logger.info(pprint.pformat(cfg))
# Note that while we set the numpy random seed network training will not be
# deterministic in general. There are sources of non-determinism that cannot
# be removed with a reasonble execution-speed tradeoff (such as certain
# non-deterministic cudnn functions).
np.random.seed(cfg.RNG_SEED)
# Execute the training run
checkpoints = utils.train.train_model()
# Test the trained model
if not args.skip_test:
test_model(checkpoints['final'], args.multi_gpu_testing, args.opts)
def test_model(model_file, multi_gpu_testing, opts=None):
"""Test a model."""
# Clear memory before inference
workspace.ResetWorkspace()
# Run inference
run_inference(
model_file, multi_gpu_testing=multi_gpu_testing,
check_expected_results=True,
)
if __name__ == '__main__':
main()