-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheinsum_test.py
66 lines (47 loc) · 1.71 KB
/
einsum_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
def test0():
q = torch.randn(1, 3, 4, 5)
k = torch.randn(1, 6, 4, 5)
aw = torch.einsum("bnmc,bkmc->bmnk", q, k)
print(aw.shape) # 1, 4, 3, 6
q_transposed = q.permute(0, 2, 3, 1)
k_transposed = k.permute(0, 2, 3, 1)
result = torch.matmul(q_transposed.transpose(-2, -1), k_transposed)
print(result.shape)
print(torch.sum(torch.abs(aw - result)))
def test1():
aw = torch.randn(1, 4, 3, 6)
v = torch.randn(1, 6, 4, 5)
x = torch.einsum("bmnk,bkmc->bnmc", aw, v)
print(x.shape) # 1, 3, 4, 5
v_permuted = v.permute(0, 2, 1, 3)
x_matmul = torch.matmul(aw, v_permuted).permute(0, 2, 1, 3)
print(x_matmul.shape)
print(torch.sum(torch.abs(x - x_matmul)))
def test2():
embed = torch.randn(1, 4, 5, 7, 8)
guide = torch.randn(1, 6, 4, 5)
aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide)
print(aw.shape)
guide_permuted = guide.permute(0, 2, 3, 1)
embed_permuted = embed.permute(0, 1, 3, 4, 2)
shape = embed_permuted.shape
embed_permuted = embed_permuted.reshape(shape[0], shape[1], -1, shape[-1])
x = torch.matmul(embed_permuted, guide_permuted).reshape(*shape[:4], -1)
print(x.shape)
print(torch.sum(torch.abs(aw - x)))
def test3():
x = torch.randn(1, 3, 7, 8)
w = torch.randn(1, 9, 3)
res0 = torch.einsum("bchw,bkc->bkhw",x , w)
print(res0.shape)
shape = x.shape
x_permuted = x.permute(0, 2, 3, 1).reshape(shape[0], -1, shape[1])
w_permuted = w.permute(0, 2, 1)
res = torch.matmul(x_permuted, w_permuted).reshape(shape[0], *shape[-2:], -1).permute(0, 3, 1, 2)
print(res.shape)
print(torch.sum(torch.abs(res0 - res)))
# test0()
# test1()
# test2()
test3()