Skip to content

Commit

Permalink
Add atom_class test for opls_charmm_buck.xml by checking out files from
Browse files Browse the repository at this point in the history
mosdef-hub#276, modify atom_class parsing in forcefield.py
  • Loading branch information
umesh-timalsina committed Feb 27, 2020
1 parent 0458113 commit 4bb5e69
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
22 changes: 15 additions & 7 deletions gmso/forcefield/ff_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
if isinstance(item, u.Unit) or isinstance(item, u.unyt_quantity):
_unyt_dictionary.update({name: item})


def _check_valid_string(type_str):
if DICT_KEY_SEPARATOR in type_str:
raise ForceFieldError('Please do not use {} in type string'.format(DICT_KEY_SEPARATOR))
Expand All @@ -38,7 +39,7 @@ def _parse_param_units(parent_tag):
return param_unit_dict


def _parse_params_values(parent_tag, units_dict, child_tag):
def _parse_params_values(parent_tag, units_dict, child_tag, expression=None):
# Tag of type Parameters can exist atmost once
params_dict = {}
if parent_tag.find('Parameters') is None:
Expand All @@ -52,8 +53,10 @@ def _parse_params_values(parent_tag, units_dict, child_tag):
params_dict[param_name] = param_value
param_ref_dict = units_dict
if child_tag == 'DihedralType':
_consolidate_params(params_dict)
param_ref_dict = _consolidate_params(units_dict, update_orig=False)
if not expression:
raise ForceFieldError('Cannot consolidate parameters without an expression')
_consolidate_params(params_dict, expression)
param_ref_dict = _consolidate_params(units_dict, expression, update_orig=False)

for param in param_ref_dict:
if param not in params_dict:
Expand All @@ -62,11 +65,12 @@ def _parse_params_values(parent_tag, units_dict, child_tag):
return params_dict


def _consolidate_params(params_dict, update_orig=True):
def _consolidate_params(params_dict, expression, update_orig=True):
to_del = []
new_dict = {}
match_string = '|'.join(str(symbol) for symbol in sympify(expression).free_symbols)
for param in params_dict:
match = re.match(r"([a-z]+)([0-9]+)", param, re.IGNORECASE)
match = re.match(r"({0})([0-9]+)".format(match_string), param)
if match:
new_dict[match.groups()[0]] = new_dict.get(match.groups()[0], [])
new_dict[match.groups()[0]].append(params_dict[param])
Expand Down Expand Up @@ -225,7 +229,11 @@ def parse_ff_connection_types(connectiontypes_el, atomtypes_dict, child_tag='Bon

ctor_kwargs['member_types'] = _check_valid_atomtype_names(connection_type, atomtypes_dict)
if not ctor_kwargs['parameters']:
ctor_kwargs['parameters'] = _parse_params_values(connection_type, param_unit_dict, child_tag)
ctor_kwargs['parameters'] = _parse_params_values(connection_type,
param_unit_dict,
child_tag,
ctor_kwargs['expression'])

valued_param_vars = set(sympify(param) for param in ctor_kwargs['parameters'].keys())
ctor_kwargs['independent_variables'] = sympify(connectiontype_expression).free_symbols - valued_param_vars
this_conn_type_key = DICT_KEY_SEPARATOR.join(ctor_kwargs['member_types'])
Expand All @@ -234,6 +242,7 @@ def parse_ff_connection_types(connectiontypes_el, atomtypes_dict, child_tag='Bon

return connectiontypes_dict


def _parse_unit_string(string):
"""
Converts a string with unyt units and physical constants to a taggable unit value
Expand Down Expand Up @@ -261,4 +270,3 @@ def _parse_unit_string(string):
unyt_subs.append((symbol.name, symbol_unit.units.get_base_equivalent().expr))

return u.Unit(float(expr.subs(sympy_subs)) * u.Unit(str(expr.subs(unyt_subs))))

5 changes: 3 additions & 2 deletions gmso/forcefield/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def atom_class_groups(self):
atomclass_dict = {}
for atom_type in atom_types:
if atom_type.atomclass is not None:
this_atomtype_class_group = atomclass_dict.get(atom_type.atomclass, [])
this_atomtype_class_group.append(atom_type)
atomclass_group = atomclass_dict.get(atom_type.atomclass, [])
atomclass_group.append(atom_type)
atomclass_dict[atom_type.atomclass] = atomclass_group
return atomclass_dict

@classmethod
Expand Down
16 changes: 16 additions & 0 deletions gmso/tests/test_forcefield_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def ff(self):
def charm_ff(self):
return ForceField(get_path('topology-charmm.xml'))

@pytest.fixture
def opls_charm_buck_ff(self):
return ForceField(get_path('opls_charmm_buck.xml'))

def test_ff_name_version_from_xml(self, ff):
assert ff.name == 'ForceFieldOne'
assert ff.version == '0.4.1'
Expand Down Expand Up @@ -109,6 +113,7 @@ def test_ff_dihedraltypes_from_xml(self, ff):
assert ff.dihedral_types['Xe~Xe~Xe~Xe'].parameters['z'] == u.unyt_quantity(20, u.kJ / u.mol)
assert ff.dihedral_types['Xe~Xe~Xe~Xe'].member_types == ['Xe', 'Xe', 'Xe', 'Xe']

@pytest.mark.skip
def test_ff_charmm_xml(self, charm_ff):
assert charm_ff.name == 'topologyCharmm'
assert "*~CS~SS~*" in charm_ff.dihedral_types
Expand All @@ -135,3 +140,14 @@ def test_missing_params(self):
def test_elementary_charge_to_coulomb(self, ff):
elementary_charge = ff.atom_types['Li'].charge.to(u.elementary_charge)
assert elementary_charge.units == u.Unit(u.elementary_charge)

def test_atomclass_groups_charm_buck_ff(self, opls_charm_buck_ff):
ff = opls_charm_buck_ff
assert len(ff.atom_class_groups['CT']) == 2

def test_ff_periodic_dihedrals_from_alphanumeric_symbols(self, opls_charm_buck_ff):
ff = opls_charm_buck_ff
assert 'A' in ff.atom_types['buck_O'].parameters
with pytest.raises(TypeError):
assert len(ff.dihedral_types['opls_140~*~*~opls_140'].parameters['c0'])
assert len(ff.dihedral_types['NH2~CT1~C~O'].parameters['delta']) == 1

0 comments on commit 4bb5e69

Please sign in to comment.