forked from yxli2123/LoftQ
-
Notifications
You must be signed in to change notification settings - Fork 0
/
entropy_test_1d_dct_nn.py
78 lines (69 loc) · 2.75 KB
/
entropy_test_1d_dct_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import cv2
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from gact.utils import get_dct_matrix
from gact.utils import get_dct_matrix
from gact.memory_efficient_function import per_block_quantization
def naive_adjustment(x, input_shape, quantization_shape = 64):
group_size_1 = input_shape[-2] // quantization_shape
group_size_2 = input_shape[-1] // quantization_shape
x = x.reshape(-1, group_size_1, group_size_2, quantization_shape, quantization_shape)
x = x.permute(0, 1, 3, 2, 4) #! the order is right now, [32, 2, 64, 12, 64]
x = x.reshape(input_shape)
return x
def preprocess_quantization(x, input_shape, quant_shape):
x, quant_state = per_block_quantization(x, input_shape, quant_shape)
return x, quant_state
def zigzag(matrix):
zigzag_matrix = []
for i in range(0, 15):
for j in range(8):
if i - j >= 0 and i - j < 8:
if i % 2 == 0:
zigzag_matrix.append(matrix[i - j][j])
else:
zigzag_matrix.append(matrix[j][i - j])
return np.array(zigzag_matrix)
def shannon_entropy(vector):
# 统计向量中每个值的出现次数
unique, counts = np.unique(vector, return_counts=True)
# 计算概率分布
probabilities = counts / len(vector)
# 计算香农信息熵
entropy = -np.sum(probabilities * np.log2(probabilities))
return entropy
if __name__ == '__main__':
original_data = torch.load('/home/yujin-wa20/projects/LoftQ/output/mistral/base_model.model.model.layers.15.self_attn.k_proj.lora_A.default.pt')[0]
original_data, _ = preprocess_quantization(original_data, original_data.shape, 64)
original_data = naive_adjustment(original_data, original_data.shape, 64)
original_data = original_data.cpu().detach().numpy()
original_data_col, original_data_row = original_data.shape
# construct a table to record every channel's value in the DCT matrix
channel_table = np.zeros((original_data_col // 64, original_data_row, 64))
D = get_dct_matrix(64)
# compute every chunk's DCT
for i in range(0, original_data_col, 64):
for j in range(0, original_data_row):
chunk = original_data[i:i+64, j:j+1]
C = np.dot(D, chunk)
C = np.round(C).flatten()
# no quantization!!!
# return the zigzag order of the DCT matrix
channel_table[i // 64][j] = C
per_channel_entropy = []
for i in range(64):
# get every DCT channel's cross entropy
channel = channel_table[:, :, i]
channel = channel.flatten()
entropy = shannon_entropy(channel)
print(f'Channel {i}\'s entropy: {entropy}')
per_channel_entropy.append(entropy)
# plot the entropy of every channel
x = np.arange(64)
plt.bar(x, per_channel_entropy)
plt.xlabel('Channel')
plt.ylabel('Entropy')
plt.title('Entropy (V layer)')
plt.savefig('entropy.png')