Skip to content

Commit

Permalink
Speed up check_compatibility for relevant writer method (mosdef-hub#700)
Browse files Browse the repository at this point in the history
* init commit, standard simplify_check for relevant write methods

* adapt new speed up for lammps

* revert changes in lammpsdata write in favor of new/separate PR

* remove simplify_check from topology save function

* better if condition for _check_single_potential

* implement better criteria to for first check in _check_single_potential
  • Loading branch information
daico007 authored Dec 13, 2022
1 parent fe321e2 commit 5085640
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 31 deletions.
14 changes: 2 additions & 12 deletions gmso/core/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,18 +1158,6 @@ def get_index(self, member):

return index

def _reindex_connection_types(self, ref):
"""Re-generate the indices of the connection types in the topology."""
if ref not in self._index_refs:
raise GMSOError(
f"cannot reindex {ref}. It should be one of "
f"{ANGLE_TYPE_DICT}, {BOND_TYPE_DICT}, "
f"{ANGLE_TYPE_DICT}, {DIHEDRAL_TYPE_DICT}, {IMPROPER_TYPE_DICT},"
f"{PAIRPOTENTIAL_TYPE_DICT}"
)
for i, ref_member in enumerate(self._set_refs[ref].keys()):
self._index_refs[ref][ref_member] = i

def get_forcefield(self):
"""Get an instance of gmso.ForceField out of this topology
Expand Down Expand Up @@ -1399,6 +1387,8 @@ def save(self, filename, overwrite=False, **kwargs):
**kwargs:
The arguments to specific file savers listed below(as extensions):
* json: types, update, indent
* gro: precision
* lammps/lammpsdata: atom_style
"""
if not isinstance(filename, Path):
filename = Path(filename).resolve()
Expand Down
4 changes: 2 additions & 2 deletions gmso/formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from .formats_registry import LoadersRegistry, SaversRegistry
from .gro import read_gro, write_gro
from .gsd import write_gsd
from .json import save_json
from .json import write_json
from .lammpsdata import write_lammpsdata
from .mcf import write_mcf
from .mol2 import from_mol2
from .mol2 import read_mol2
from .top import write_top
from .xyz import read_xyz, write_xyz

Expand Down
2 changes: 1 addition & 1 deletion gmso/formats/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _from_json(json_dict):


@saves_as(".json")
def save_json(top, filename, **kwargs):
def write_json(top, filename, **kwargs):
"""Save the topology as a JSON file.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion gmso/formats/mol2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@loads_as(".mol2")
def from_mol2(filename, site_type="atom"):
def read_mol2(filename, site_type="atom"):
"""Read in a TRIPOS mol2 file format into a gmso topology object.
Creates a Topology from a mol2 file structure. This will read in the
Expand Down
24 changes: 19 additions & 5 deletions gmso/formats/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,23 @@


@saves_as(".top")
def write_top(top, filename, top_vars=None, simplify_check=False):
"""Write a gmso.core.Topology object to a GROMACS topology (.TOP) file."""
pot_types = _validate_compatibility(top, simplify_check)
def write_top(top, filename, top_vars=None):
"""Write a gmso.core.Topology object to a GROMACS topology (.TOP) file.
Parameters
----------
top : gmso.Topology
A typed Topology Object
filename : str
Path of the output file
Notes
-----
See https://manual.gromacs.org/current/reference-manual/topologies/topology-file-formats.html for
a full description of the top file format. This method is a work in progress and do not currently
support the full GROMACS specs.
"""
pot_types = _validate_compatibility(top)
top_vars = _get_top_vars(top, top_vars)

# Sanity checks
Expand Down Expand Up @@ -262,9 +276,9 @@ def _accepted_potentials():
return accepted_potentials


def _validate_compatibility(top, simplify_check):
def _validate_compatibility(top):
"""Check compatability of topology object with GROMACS TOP format."""
pot_types = check_compatibility(top, _accepted_potentials(), simplify_check)
pot_types = check_compatibility(top, _accepted_potentials())
return pot_types


Expand Down
25 changes: 15 additions & 10 deletions gmso/utils/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from gmso.exceptions import EngineIncompatibilityError


def check_compatibility(topology, accepted_potentials, simplify_check=False):
def check_compatibility(topology, accepted_potentials):
"""
Compare the potentials in a topology against a list of accepted potential templates.
Expand All @@ -15,8 +15,7 @@ def check_compatibility(topology, accepted_potentials, simplify_check=False):
The topology whose potentials to check.
accepted_potentials: list
A list of gmso.Potential objects to check against
simplify_check : bool, optional, default=False
Simplify the sympy expression check, aka, only compare the expression string
Returns
-------
potential_forms_dict: dict
Expand All @@ -30,7 +29,8 @@ def check_compatibility(topology, accepted_potentials, simplify_check=False):
# filter_by=PotentialFilters.UNIQUE_NAME_CLASS
):
potential_form = _check_single_potential(
atom_type, accepted_potentials, simplify_check
atom_type,
accepted_potentials,
)
if not potential_form:
raise EngineIncompatibilityError(
Expand All @@ -43,7 +43,8 @@ def check_compatibility(topology, accepted_potentials, simplify_check=False):
# filter_by=PotentialFilters.UNIQUE_NAME_CLASS
):
potential_form = _check_single_potential(
connection_type, accepted_potentials, simplify_check
connection_type,
accepted_potentials,
)
if not potential_form:
raise EngineIncompatibilityError(
Expand All @@ -55,14 +56,18 @@ def check_compatibility(topology, accepted_potentials, simplify_check=False):
return potential_forms_dict


def _check_single_potential(potential, accepted_potentials, simplify_check):
def _check_single_potential(potential, accepted_potentials):
"""Check to see if a single given potential is in the list of accepted potentials."""
ind_var = potential.independent_variables
u_dims = {para.units.dimensions for para in potential.parameters.values()}
for ref in accepted_potentials:
if ref.independent_variables == potential.independent_variables:
if simplify_check:
if str(ref.expression) == str(potential.expression):
return {potential: ref.name}
ref_ind_var = ref.independent_variables
ref_u_dims = set(ref.expected_parameters_dimensions.values())
if len(ind_var) == len(ref_ind_var) and u_dims == ref_u_dims:
if str(ref.expression) == str(potential.expression):
return {potential: ref.name}
else:
print("Simpify", ref, potential)
if sympy.simplify(ref.expression - potential.expression) == 0:
return {potential: ref.name}
return False

0 comments on commit 5085640

Please sign in to comment.