Skip to content

Commit

Permalink
Fixed issues in applying patches (openmm#4279)
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman authored Oct 23, 2023
1 parent 10c909d commit 32a5f30
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion wrappers/python/openmm/app/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def registerResidueTemplate(self, template):

def registerPatch(self, patch):
"""Register a new patch that can be applied to templates."""
patch.index = len(self._patches)
self._patches[patch.name] = patch

def registerTemplatePatch(self, residue, patch, patchResidueIndex):
Expand Down Expand Up @@ -792,6 +793,17 @@ def __init__(self, node, atomIndices):
else:
self.excludeWith = self.atoms[0]

def __eq__(self, other):
if not isinstance(other, ForceField._VirtualSiteData):
return False
if self.type != other.type or self.index != other.index or self.atoms != other.atoms or self.excludeWith != other.excludeWith:
return False
if self.type in ('average2', 'average3', 'outOfPlane'):
return self.weights == other.weights
elif self.type == 'localCoords':
return self.originWeights == other.originWeights and self.xWeights == other.xWeights and self.yWeights == other.yWeights and self.localPos == other.localPos
return False

class _PatchData(object):
"""Inner class used to encapsulate data about a patch definition."""
def __init__(self, name, numResidues):
Expand All @@ -807,6 +819,10 @@ def __init__(self, name, numResidues):
self.allAtomNames = set()
self.virtualSites = [[] for i in range(numResidues)]
self.attributes = {}
self.index = None

def __lt__(self, other):
return self.index < other.index

def createPatchedTemplates(self, templates):
"""Apply this patch to a set of templates, creating new modified ones."""
Expand Down Expand Up @@ -1582,7 +1598,7 @@ def _applyPatchesToMatchResidues(forcefield, data, residues, templateForResidue,
patchedTemplates = {}
for name, template in forcefield._templates.items():
if name in forcefield._templatePatches:
patches = [forcefield._patches[patchName] for patchName, patchResidueIndex in forcefield._templatePatches[name] if forcefield._patches[patchName].numResidues == 1]
patches = sorted([forcefield._patches[patchName] for patchName, patchResidueIndex in forcefield._templatePatches[name] if forcefield._patches[patchName].numResidues == 1])
if len(patches) > 0:
newTemplates = []
patchedTemplates[name] = newTemplates
Expand Down

0 comments on commit 32a5f30

Please sign in to comment.