forked from ppwwyyxx/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_registry.py
45 lines (33 loc) · 1.21 KB
/
test_registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright (c) Facebook, Inc. and its affiliates.
import unittest
import torch
from detectron2.modeling.meta_arch import GeneralizedRCNN
from detectron2.utils.registry import _convert_target_to_string, locate
class A:
class B:
pass
class TestLocate(unittest.TestCase):
def _test_obj(self, obj):
name = _convert_target_to_string(obj)
newobj = locate(name)
self.assertIs(obj, newobj)
def test_basic(self):
self._test_obj(GeneralizedRCNN)
def test_inside_class(self):
# requires using __qualname__ instead of __name__
self._test_obj(A.B)
def test_builtin(self):
self._test_obj(len)
self._test_obj(dict)
def test_pytorch_optim(self):
# pydoc.locate does not work for it
self._test_obj(torch.optim.SGD)
def test_failure(self):
with self.assertRaises(ImportError):
locate("asdf")
def test_compress_target(self):
from detectron2.data.transforms import RandomCrop
name = _convert_target_to_string(RandomCrop)
# name shouldn't contain 'augmentation_impl'
self.assertEqual(name, "detectron2.data.transforms.RandomCrop")
self.assertIs(RandomCrop, locate(name))