This repository was archived by the owner on Nov 21, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
Copy pathtest_loader.py
121 lines (102 loc) · 4.07 KB
/
test_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import logging
import unittest
import unittest.mock as mock
from caffe2.proto import caffe2_pb2
from caffe2.python import core
from caffe2.python import muji
from caffe2.python import workspace
from detectron.core.config import assert_and_infer_cfg
from detectron.core.config import cfg
from detectron.roi_data.loader import RoIDataLoader
import detectron.utils.logging as logging_utils
def get_roidb_blobs(roidb):
blobs = {}
blobs['data'] = np.stack([entry['data'] for entry in roidb])
return blobs, True
def get_net(data_loader, name):
logger = logging.getLogger(__name__)
blob_names = data_loader.get_output_names()
net = core.Net(name)
net.type = 'dag'
for gpu_id in range(cfg.NUM_GPUS):
with core.NameScope('gpu_{}'.format(gpu_id)):
with core.DeviceScope(muji.OnGPU(gpu_id)):
for blob_name in blob_names:
blob = core.ScopedName(blob_name)
workspace.CreateBlob(blob)
net.DequeueBlobs(
data_loader._blobs_queue_name, blob_names)
logger.info("Protobuf:\n" + str(net.Proto()))
return net
def get_roidb_sample_data(sample_data):
roidb = []
for _ in range(np.random.randint(4, 10)):
roidb.append({'data': sample_data})
return roidb
def create_loader_and_network(sample_data, name):
roidb = get_roidb_sample_data(sample_data)
loader = RoIDataLoader(roidb)
net = get_net(loader, 'dequeue_net_train')
loader.register_sigint_handler()
loader.start(prefill=False)
return loader, net
def run_net(net):
workspace.RunNetOnce(net)
gpu_dev = core.DeviceOption(caffe2_pb2.CUDA, 0)
name_scope = 'gpu_{}'.format(0)
with core.NameScope(name_scope):
with core.DeviceScope(gpu_dev):
data = workspace.FetchBlob(core.ScopedName('data'))
return data
class TestRoIDataLoader(unittest.TestCase):
@mock.patch(
'detectron.roi_data.loader.get_minibatch_blob_names',
return_value=[u'data']
)
@mock.patch(
'detectron.roi_data.loader.get_minibatch',
side_effect=get_roidb_blobs
)
def test_two_parallel_loaders(self, _1, _2):
train_data = np.random.rand(2, 3, 3).astype(np.float32)
train_loader, train_net = create_loader_and_network(train_data,
'dequeue_net_train')
test_data = np.random.rand(2, 4, 4).astype(np.float32)
test_loader, test_net = create_loader_and_network(test_data,
'dequeue_net_test')
for _ in range(5):
data = run_net(train_net)
self.assertEqual(data[0].tolist(), train_data.tolist())
data = run_net(test_net)
self.assertEqual(data[0].tolist(), test_data.tolist())
test_loader.shutdown()
train_loader.shutdown()
if __name__ == '__main__':
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
logger = logging_utils.setup_logging(__name__)
logger.setLevel(logging.DEBUG)
logging.getLogger('detectron.roi_data.loader').setLevel(logging.INFO)
np.random.seed(cfg.RNG_SEED)
cfg.TRAIN.ASPECT_GROUPING = False
cfg.NUM_GPUS = 2
assert_and_infer_cfg()
unittest.main()