diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index fbd4fb960..3a0b2ead7 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -2,6 +2,7 @@ import math import torch from torch import nn +import torch.nn.functional as F from torch_scatter import scatter from torch_geometric.nn import MessagePassing