Skip to content

Commit

Permalink
Fixes to Python API for ATMForce (openmm#4319)
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman authored Nov 16, 2023
1 parent 2ab3b77 commit 43f571d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
1 change: 1 addition & 0 deletions wrappers/python/src/swig_doxygen/swigInputConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,4 +521,5 @@
("ATMForce", "getDefaultUmax") : ('unit.kilojoule_per_mole', ()),
("ATMForce", "getDefaultUbcore") : ('unit.kilojoule_per_mole', ()),
("ATMForce", "getDefaultAcore") : (None, ()),
("ATMForce", "getParticleParameters") : (None, ("unit.nanometer", "unit.nanometer")),
}
11 changes: 9 additions & 2 deletions wrappers/python/src/swig_doxygen/swig_lib/python/typemaps.i
Original file line number Diff line number Diff line change
Expand Up @@ -537,16 +537,23 @@ int Py_SequenceToVecVecVecDouble(PyObject* obj, std::vector<std::vector<std::vec
%typemap(argout) const std::vector<Vec3>& {
}


/* The following typemaps handle the ways a Vec3 can be returned from a function. */
%typemap(out, fragment="Vec3_to_PyVec3") Vec3 {
$result = Vec3_to_PyVec3(*$1);
}


%typemap(out, fragment="Vec3_to_PyVec3") const Vec3& {
$result = Vec3_to_PyVec3(*$1);
}

%typemap(in, numinputs=0) Vec3& OUTPUT (Vec3 temp) {
$1 = &temp;
}

%typemap(argout, fragment="Vec3_to_PyVec3") Vec3& OUTPUT {
%append_output(Vec3_to_PyVec3(*$1));
}

/* Convert C++ (Vec3&, Vec3&, Vec3&) object to python tuple or tuples */
%typemap(argout, fragment="Vec3_to_PyVec3") (Vec3& a, Vec3& b, Vec3& c) {
PyObject* pyVec1 = Vec3_to_PyVec3(*$1);
Expand Down
17 changes: 17 additions & 0 deletions wrappers/python/tests/TestAPIUnits.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,23 @@ def testCustomManyParticleForce(self):
self.assertEqual(force.getParticleParameters(1)[0][0], 20)
self.assertEqual(force.getParticleParameters(2)[0][0], 30*4.184)

def testATMForce(self):
"""Tests the ATMForce API features"""
force = ATMForce(0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.6, 0.8, -1.0);
force.addParticle(Vec3(1, 2, 3), Vec3(4, 5, 6))
self.assertEqual(0.1, force.getGlobalParameterDefaultValue(0))
self.assertEqual(0.2, force.getGlobalParameterDefaultValue(1))
self.assertEqual(0.3, force.getGlobalParameterDefaultValue(2))
self.assertEqual(0.4, force.getGlobalParameterDefaultValue(3))
self.assertEqual(0.5, force.getGlobalParameterDefaultValue(4))
self.assertEqual(0.7, force.getGlobalParameterDefaultValue(5))
self.assertEqual(0.6, force.getGlobalParameterDefaultValue(6))
self.assertEqual(0.8, force.getGlobalParameterDefaultValue(7))
self.assertEqual(-1.0, force.getGlobalParameterDefaultValue(8))
d1, d0 = force.getParticleParameters(0)
self.assertEqual(Vec3(1, 2, 3)*nanometers, d1)
self.assertEqual(Vec3(4, 5, 6)*nanometers, d0)

def testDrudeForce(self):
""" Tests the DrudeForce API features """
force = DrudeForce()
Expand Down

0 comments on commit 43f571d

Please sign in to comment.