-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtest_binarize.py
129 lines (97 loc) · 3.8 KB
/
test_binarize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import copy
import unittest
import torch
import torch.nn as nn
from bnn import BConfig, prepare_binary_model
from bnn.layers import Conv2d, Linear
from bnn.ops import (
BasicInputBinarizer,
BasicScaleBinarizer,
XNORWeightBinarizer
)
class Flatten(nn.Module):
def __init__(self) -> None:
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class BinarizerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.linear_layer = nn.Linear(10, 3)
self.conv_layer = nn.Conv2d(3, 16, 1, 1)
self.net = nn.Sequential(
nn.Conv2d(3, 16, 1, 1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.Conv2d(16, 16, 1, 1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
Flatten(),
nn.Linear(16, 3)
)
self.input1 = torch.rand(1, 10)
self.input2 = torch.rand(1, 3, 8, 8)
self.random_bconfig = BConfig(
activation_pre_process=BasicInputBinarizer,
activation_post_process=BasicScaleBinarizer,
weight_pre_process=XNORWeightBinarizer
)
def tearDown(self) -> None:
pass
def weight_reset(self, m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
m.reset_parameters()
def test_single_linear_layer(self):
model = copy.copy(self.linear_layer)
model = prepare_binary_model(model, bconfig=self.random_bconfig)
self.assertEqual(type(model), Linear)
def test_single_conv2d_layer(self):
model = copy.copy(self.conv_layer)
model = prepare_binary_model(model, bconfig=self.random_bconfig)
self.assertEqual(type(model), Conv2d)
def test_many_layers(self):
model = copy.copy(self.linear_layer)
model = prepare_binary_model(model, bconfig=self.random_bconfig)
self.assertEqual(type(model), Linear)
def test_skip_binarization(self):
model = copy.copy(self.net)
fp32_config = BConfig(
activation_pre_process=nn.Identity,
activation_post_process=nn.Identity,
weight_pre_process=nn.Identity
)
model = prepare_binary_model(model, bconfig=self.random_bconfig, custom_config_layers_name={'8': fp32_config})
cnt_conv, cnt_linear = 0, 0
for module in model.modules():
if isinstance(module, Conv2d):
cnt_conv += 1
elif isinstance(module, Linear):
if isinstance(module.activation_pre_process, nn.Identity):
cnt_linear += 1
self.assertEqual(cnt_conv, 2)
self.assertEqual(cnt_linear, 1)
def test_save_load_state_dict(self):
model = copy.deepcopy(self.net)
x = self.input2.clone()
model = prepare_binary_model(model, bconfig=self.random_bconfig)
out1 = model(x)
binary_state_dict = model.state_dict()
model = copy.deepcopy(self.net)
model.apply(self.weight_reset)
model = prepare_binary_model(model, bconfig=self.random_bconfig)
model.load_state_dict(binary_state_dict)
out2 = model(x)
self.assertTrue(torch.equal(out1, out2))
class OpsTestCase(unittest.TestCase):
def setUp(self) -> None:
self.input = torch.tensor([0.3, 0.1, -2, -0.001, 0.01])
self.conv_layer = nn.Conv2d(3, 16, 1, 1)
def test_basic_input_binarizer(self):
funct = BasicInputBinarizer()
self.assertTrue(torch.equal(funct(self.input.clone()), torch.sign(self.input.clone())))
def test_BasicScaleBinarizer(self):
funct = BasicScaleBinarizer(copy.copy(self.conv_layer))
def test_XNORWeightBinarizer(self):
funct = XNORWeightBinarizer()
if __name__ == '__main__':
unittest.main()