Skip to content

Commit

Permalink
[Example][Bugfix] Fix dimenet example (dmlc#4219)
Browse files Browse the repository at this point in the history
Co-authored-by: Xin Yao <[email protected]>
  • Loading branch information
chang-l and yaox12 authored Jul 7, 2022
1 parent 9ee7ced commit 28b0904
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/pytorch/dimenet/modules/bessel_basis_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def __init__(self,
self.reset_params()

def reset_params(self):
torch.arange(1, self.frequencies.numel() + 1, out=self.frequencies).mul_(np.pi)
with torch.no_grad():
torch.arange(1, self.frequencies.numel() + 1, out=self.frequencies).mul_(np.pi)
self.frequencies.requires_grad_()

def forward(self, g):
d_scaled = g.edata['d'] / self.cutoff
# Necessary for proper broadcasting behaviour
d_scaled = torch.unsqueeze(d_scaled, -1)
d_cutoff = self.envelope(d_scaled)
g.edata['rbf'] = d_cutoff * torch.sin(self.frequencies * d_scaled)
return g
return g

0 comments on commit 28b0904

Please sign in to comment.