Skip to content

Commit

Permalink
Optimizations to reporters (openmm#4330)
Browse files Browse the repository at this point in the history
* Optimizations to reporters

* Removed unneeded imports
  • Loading branch information
peastman authored Nov 30, 2023
1 parent 127a373 commit f6e6b6e
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 16 deletions.
5 changes: 3 additions & 2 deletions wrappers/python/openmm/app/dcdfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ def writeModel(self, positions, unitCellDimensions=None, periodicBoxVectors=None
raise ValueError('The number of positions must match the number of atoms')
if is_quantity(positions):
positions = positions.value_in_unit(nanometers)
if any(math.isnan(norm(pos)) for pos in positions):
import numpy as np
if np.isnan(positions).any():
raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
if any(math.isinf(norm(pos)) for pos in positions):
if np.isinf(positions).any():
raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
file = self._file

Expand Down
2 changes: 1 addition & 1 deletion wrappers/python/openmm/app/dcdreporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def report(self, simulation, state):
self._out, simulation.topology, simulation.integrator.getStepSize(),
simulation.currentStep, self._reportInterval, self._append
)
self._dcd.writeModel(state.getPositions(), periodicBoxVectors=state.getPeriodicBoxVectors())
self._dcd.writeModel(state.getPositions(asNumpy=True), periodicBoxVectors=state.getPeriodicBoxVectors())

def __del__(self):
self._out.close()
5 changes: 3 additions & 2 deletions wrappers/python/openmm/app/pdbfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,10 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=None, keepIds=Fa
raise ValueError('The number of positions must match the number of atoms')
if is_quantity(positions):
positions = positions.value_in_unit(angstroms)
if any(math.isnan(norm(pos)) for pos in positions):
import numpy as np
if np.isnan(positions).any():
raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
if any(math.isinf(norm(pos)) for pos in positions):
if np.isinf(positions).any():
raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
nonHeterogens = PDBFile._standardResidues[:]
nonHeterogens.remove('HOH')
Expand Down
8 changes: 4 additions & 4 deletions wrappers/python/openmm/app/pdbreporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ def report(self, simulation, state):
topology = self._subsetTopology

#PDBFile will convert to angstroms so do it here first instead
positions = state.getPositions().value_in_unit(angstroms)
positions = state.getPositions(asNumpy=True).value_in_unit(angstroms)
positions = [positions[i] for i in self._atomSubset]

else:
topology = simulation.topology
positions = state.getPositions()
positions = state.getPositions(asNumpy=True)

if self._nextModel == 0:
PDBFile.writeHeader(topology, self._out)
Expand Down Expand Up @@ -202,12 +202,12 @@ def report(self, simulation, state):
topology = self._subsetTopology

#PDBFile will convert to angstroms so do it here first instead
positions = state.getPositions().value_in_unit(angstroms)
positions = state.getPositions(asNumpy=True).value_in_unit(angstroms)
positions = [positions[i] for i in self._atomSubset]

else:
topology = simulation.topology
positions = state.getPositions()
positions = state.getPositions(asNumpy=True)

if self._nextModel == 0:
PDBxFile.writeHeader(topology, self._out)
Expand Down
5 changes: 3 additions & 2 deletions wrappers/python/openmm/app/pdbxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,10 @@ def writeModel(topology, positions, file=sys.stdout, modelIndex=1, keepIds=False
raise ValueError('The number of positions must match the number of atoms')
if is_quantity(positions):
positions = positions.value_in_unit(angstroms)
if any(math.isnan(norm(pos)) for pos in positions):
import numpy as np
if np.isnan(positions).any():
raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
if any(math.isinf(norm(pos)) for pos in positions):
if np.isinf(positions).any():
raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
nonHeterogens = PDBFile._standardResidues[:]
nonHeterogens.remove('HOH')
Expand Down
7 changes: 3 additions & 4 deletions wrappers/python/openmm/app/xtcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
get_xtc_nframes,
get_xtc_natoms,
)
import numpy as np
import os
from openmm import Vec3
from openmm.unit import nanometers, picoseconds, is_quantity, norm
import math
import tempfile
import shutil

Expand Down Expand Up @@ -92,11 +90,12 @@ def writeModel(self, positions, unitCellDimensions=None, periodicBoxVectors=None
raise ValueError("The number of positions must match the number of atoms")
if is_quantity(positions):
positions = positions.value_in_unit(nanometers)
if any(math.isnan(norm(pos)) for pos in positions):
import numpy as np
if np.isnan(positions).any():
raise ValueError(
"Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan"
)
if any(math.isinf(norm(pos)) for pos in positions):
if np.isinf(positions).any():
raise ValueError(
"Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan"
)
Expand Down
2 changes: 1 addition & 1 deletion wrappers/python/openmm/app/xtcreporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ def report(self, simulation, state):
self._append,
)
self._xtc.writeModel(
state.getPositions(), periodicBoxVectors=state.getPeriodicBoxVectors()
state.getPositions(asNumpy=True), periodicBoxVectors=state.getPeriodicBoxVectors()
)

0 comments on commit f6e6b6e

Please sign in to comment.