-
Notifications
You must be signed in to change notification settings - Fork 681
/
Copy pathtest_linear4bit.py
223 lines (187 loc) · 7.51 KB
/
test_linear4bit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import copy
import os
import pickle
from tempfile import TemporaryDirectory
import pytest
import torch
import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer
storage = {
"uint8": torch.uint8,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE)
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward):
original_dtype = torch.float16
compute_dtype = None
device = "cuda"
layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device="meta",
)
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
linear_q.weight = new_weight
if bias:
linear_q.bias = torch.nn.Parameter(linear.bias)
linear_q = linear_q.to(device)
# saving to state_dict:
sd = linear_q.state_dict()
# restoring from state_dict:
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
# creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device="meta",
)
# loading weights from state_dict:
linear_q2.weight = weight2
if bias:
linear_q2.bias = torch.nn.Parameter(bias_data2)
linear_q2 = linear_q2.to(device)
# MATCHING
a, b = linear_q.weight, linear_q2.weight
# Quantizing original layer with specified quant_storage type
linear_qs = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
quant_storage=storage[quant_storage],
device="meta",
)
linear_qs.weight = bnb.nn.Params4bit(
data=linear.weight,
requires_grad=False,
quant_type=quant_type,
quant_storage=storage[quant_storage],
)
if bias:
linear_qs.bias = torch.nn.Parameter(linear.bias)
linear_qs = linear_qs.to(device)
assert a.device == b.device
assert a.dtype == b.dtype
assert torch.equal(a, b)
q0 = a.quant_state
q1 = b.quant_state
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0, attr), getattr(q1, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
else:
assert c == d, f"{c} != {d}"
if q0.state2 is not None:
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
else:
assert c == d, f"{c} != {d}"
if bias:
a, b = linear_q.bias, linear_q2.bias
assert a.device == b.device
assert a.dtype == b.dtype
assert torch.equal(a, b)
if save_before_forward:
bytes_4bit = torch_save_to_buffer(linear_q)
# Forward test
x = torch.rand(42, layer_shape[0], device=device)
a = linear_q(x)
b = linear_q2(x)
c = linear_qs(x)
assert a.device == b.device
assert a.dtype == b.dtype
assert a.device == c.device
assert a.dtype == c.dtype
assert torch.equal(a, b)
assert torch.equal(a, c)
if not save_before_forward:
bytes_4bit = torch_save_to_buffer(linear_q)
linear_q3 = torch_load_from_buffer(bytes_4bit)
# Test moving to CPU and back to GPU
linear_q2.to("cpu")
linear_q2.to(device)
d = linear_qs(x)
assert c.dtype == d.dtype
assert c.device == d.device
assert torch.equal(c, d)
d = linear_q3(x)
assert c.dtype == d.dtype
assert c.device == d.device
assert torch.equal(c, d)
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with TemporaryDirectory() as tmpdir:
state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
state_path = os.path.join(tmpdir, "state.pth")
torch.save(linear.state_dict(), state_path)
torch.save(linear_q.state_dict(), state_path_4bit)
size_orig, size_4 = (
os.path.getsize(state_path),
os.path.getsize(state_path_4bit),
)
size_ratio = size_4 / size_orig
target_compression = (
0.143 if original_dtype == torch.float32 else 0.29
) # these numbers get lower as weight shape increases
ratio_error_msg = (
f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
)
assert size_ratio < target_compression, ratio_error_msg
def test_copy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
def test_deepcopy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
dict_keys_before = set(param.__dict__.keys())
copy_param = copy.deepcopy(param)
dict_keys_after = set(param.__dict__.keys())
dict_keys_copy = set(copy_param.__dict__.keys())
assert param.quant_state is not copy_param.quant_state
assert param.data.data_ptr() != copy_param.data.data_ptr()
# there was a bug where deepcopy would modify the original object
assert dict_keys_before == dict_keys_after
assert dict_keys_before == dict_keys_copy
def test_params4bit_real_serialization():
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")
dict_keys_before = set(original_param.__dict__.keys())
original_param.cuda(0) # move to CUDA to trigger quantization
serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param)
dict_keys_after = set(original_param.__dict__.keys())
dict_keys_deserialized = set(deserialized_param.__dict__.keys())
assert torch.equal(original_param.data, deserialized_param.data)
assert original_param.requires_grad == deserialized_param.requires_grad == False
assert original_param.quant_type == deserialized_param.quant_type
assert original_param.blocksize == deserialized_param.blocksize
assert original_param.compress_statistics == deserialized_param.compress_statistics
assert original_param.quant_state == deserialized_param.quant_state
# there was a bug where deepcopy would modify the original object
assert dict_keys_before == dict_keys_after
assert dict_keys_before == dict_keys_deserialized