forked from torchmd/torchmd-net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimize.py
94 lines (75 loc) · 3.04 KB
/
optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
from typing import Optional, Tuple
import torch as pt
from NNPOps.CFConv import CFConv
from NNPOps.CFConvNeighbors import CFConvNeighbors
from .models.model import TorchMD_Net
from .models.torchmd_gn import TorchMD_GN
class TorchMD_GN_optimized(pt.nn.Module):
"""This Module is equivalent to TorchMD_GN, but replaces some of
its internal operation by optimized ones from NNPops.
"""
def __init__(self, model):
if model.rbf_type != "gauss":
raise ValueError('Only rbf_type="gauss" is supproted')
if model.trainable_rbf:
raise ValueError("trainalbe_rbf=True is not supported")
if model.activation != "ssp":
raise ValueError('Only activation="ssp" is supported')
if model.neighbor_embedding:
raise ValueError("neighbor_embedding=True is not supported")
if model.cutoff_lower != 0.0:
raise ValueError("Only lower_cutoff=0.0 is supported")
if model.aggr != "add":
raise ValueError('Only aggr="add" is supported')
super().__init__()
self.model = model
self.neighbors = CFConvNeighbors(self.model.cutoff_upper)
offset = self.model.distance_expansion.offset
width = offset[1] - offset[0]
self.convs = [
CFConv(
gaussianWidth=width,
activation="ssp",
weights1=inter.mlp[0].weight.T,
biases1=inter.mlp[0].bias,
weights2=inter.mlp[2].weight.T,
biases2=inter.mlp[2].bias,
)
for inter in self.model.interactions
]
def forward(
self,
z: pt.Tensor,
pos: pt.Tensor,
batch: pt.Tensor,
box: Optional[pt.Tensor] = None,
q: Optional[pt.Tensor] = None,
s: Optional[pt.Tensor] = None,
) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]:
assert pt.all(batch == 0)
assert box is None, "Box is not supported"
x = self.model.embedding(z)
self.neighbors.build(pos)
for inter, conv in zip(self.model.interactions, self.convs):
y = inter.conv.lin1(x)
y = conv(self.neighbors, pos, y)
y = inter.conv.lin2(y)
y = inter.act(y)
x = x + inter.lin(y)
return x, None, z, pos, batch
def __repr__(self):
return "Optimized: " + repr(self.model)
def optimize(model):
"""
Returns an optimized version for a given TorchMD_Net model.
If the model is not supported, a ValueError is raised.
"""
assert isinstance(model, TorchMD_Net)
if isinstance(model.representation_model, TorchMD_GN):
model.representation_model = TorchMD_GN_optimized(model.representation_model)
else:
raise ValueError("Unsupported model! Only TorchMD_GN is suppored.")
return model