Skip to content

Commit

Permalink
Fix multi head attention
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippThoelke committed May 13, 2021
1 parent a8fb9b5 commit bbcb480
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, hidden_channels=128, num_layers=6, num_rbf=50, rbf_type='expn

self.attention_layers = nn.ModuleList()
for _ in range(num_layers):
layer = MultiHeadAttention(hidden_channels, num_rbf, distance_influence,
layer = MultiHeadAttention(hidden_channels, num_rbf, distance_influence, num_heads,
act_class, attn_act_class, cutoff_lower, cutoff_upper)
self.attention_layers.append(layer)

Expand Down Expand Up @@ -123,13 +123,19 @@ def __repr__(self):


class MultiHeadAttention(MessagePassing):
def __init__(self, hidden_channels, num_rbf, distance_influence,
def __init__(self, hidden_channels, num_rbf, distance_influence, num_heads,
activation, attn_activation, cutoff_lower, cutoff_upper):
super(MultiHeadAttention, self).__init__(aggr='add')
super(MultiHeadAttention, self).__init__(aggr='add', node_dim=0)
assert hidden_channels % num_heads == 0, (f'The number of hidden channels ({hidden_channels}) '
f'must be evenly divisible by the number of '
f'attention heads ({num_heads})')

self.distance_influence = distance_influence
self.num_heads = num_heads
self.head_dim = hidden_channels // num_heads

self.layernorm = nn.LayerNorm(hidden_channels)
self.activation = activation()
self.act = activation()
self.attn_activation = attn_activation()
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper)

Expand Down Expand Up @@ -165,22 +171,24 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.dv_proj.weight)
self.dv_proj.bias.data.fill_(0)

def forward(self, x, edge_index, edge_weight, edge_attr):
def forward(self, x, edge_index, r_ij, f_ij):
head_shape = (-1, self.num_heads, self.head_dim)

x_norm = self.layernorm(x)
q = self.q_proj(x_norm)
k = self.k_proj(x_norm)
v = self.v_proj(x_norm)
q = self.q_proj(x_norm).reshape(head_shape)
k = self.k_proj(x_norm).reshape(head_shape)
v = self.v_proj(x_norm).reshape(head_shape)

dk = self.activation(self.dk_proj(edge_attr)) if self.dk_proj else 1.0
dv = self.activation(self.dv_proj(edge_attr)) if self.dv_proj else 1.0
dk = self.act(self.dk_proj(f_ij)).reshape(head_shape)if self.dk_proj else 1.0
dv = self.act(self.dv_proj(f_ij)).reshape(head_shape) if self.dv_proj else 1.0

out = self.propagate(edge_index, q=q, k=k, v=v, dk=dk, dv=dv, r_ij=edge_weight)
out = self.o_proj(out)
out = self.propagate(edge_index, q=q, k=k, v=v, dk=dk, dv=dv, r_ij=r_ij)
out = self.o_proj(out.reshape(-1, self.num_heads * self.head_dim))
return x + out

def message(self, q_i, k_j, v_j, dk, dv, r_ij):
attn = (q_i * k_j * dk).sum(dim=1)
attn = self.attn_activation(attn) * self.cutoff(r_ij)
attn = (q_i * k_j * dk).sum(dim=-1)
attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1)
v_j = v_j * dv
v_j = v_j * attn.unsqueeze(1)
v_j = v_j * attn.unsqueeze(2)
return v_j

0 comments on commit bbcb480

Please sign in to comment.