forked from aqlaboratory/openfold
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_kernels.py
84 lines (67 loc) · 2.49 KB
/
test_kernels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import unittest
from openfold.model.primitives import _attention
from openfold.utils.kernel.attention_core import attention_core
from tests.config import consts
class TestAttentionCore(unittest.TestCase):
def test_attention_core_forward(self):
n_res = consts.n_res
h = consts.n_heads_extra_msa
n_seq = consts.n_extra
c = consts.c_e
dtype = torch.float32
q = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda()
k = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda()
v = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda()
mask = torch.randint(0, 2, [n_seq, n_res]).cuda()
mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype)
out_repro = attention_core(q, k, v, mask_bias, None)
out_gt = _attention(q, k, v, [mask_bias])
self.assertTrue(torch.max(torch.abs(out_repro - out_gt)) < consts.eps)
def test_attention_core_backward(self):
n_res = consts.n_res
h = consts.n_heads_extra_msa
n_seq = consts.n_extra
c = consts.c_e
dtype = torch.float32
q = torch.rand(
[n_seq, h, n_res, c], dtype=dtype, requires_grad=True
).cuda()
k = torch.rand(
[n_seq, h, n_res, c], dtype=dtype, requires_grad=True
).cuda()
v = torch.rand(
[n_seq, h, n_res, c], dtype=dtype, requires_grad=True
).cuda()
mask = torch.randint(0, 2, [n_seq, n_res]).cuda()
mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype)
def clone(t):
t = t.clone()
if(t.requires_grad):
t.retain_grad()
return t
q_repro = clone(q)
k_repro = clone(k)
v_repro = clone(v)
out_repro = attention_core(
q_repro, k_repro, v_repro, mask_bias, None
)
loss_repro = torch.mean(out_repro)
loss_repro.backward()
q_gt = clone(q)
k_gt = clone(k)
v_gt = clone(v)
out_gt = _attention(
q_gt, k_gt, v_gt, [mask_bias]
)
loss_gt = torch.mean(out_gt)
loss_gt.backward()
pairs = zip([q_repro, k_repro, v_repro], [q_gt, k_gt, v_gt])
for t_repro, t_gt in pairs:
self.assertTrue(
torch.max(torch.abs(t_repro.grad - t_gt.grad)) < consts.eps
)
if __name__ == '__main__':
unittest.main()