From 27bcb65729dcfa4284d2170999739b3ab4d29405 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Tue, 2 Nov 2021 13:26:09 -0700 Subject: [PATCH] updateParametersInContext() can change tabulated functions (#3307) * updateParametersInContext() can change tabulated functions * Fixed error in building C wrappers * updateParametersInContext() can change tabulated functions for CustomCentroidBondForce * CustomNonbondedForce can update tabulated functions * CustomGBForce can update tabulated functions * CustomManyParticleForce can update tabulated functions * CustomHbondForce can update tabulated functions --- .../include/openmm/CustomCentroidBondForce.h | 9 +- .../include/openmm/CustomCompoundBondForce.h | 9 +- openmmapi/include/openmm/CustomGBForce.h | 11 +- openmmapi/include/openmm/CustomHbondForce.h | 11 +- .../include/openmm/CustomManyParticleForce.h | 9 +- .../include/openmm/CustomNonbondedForce.h | 11 +- openmmapi/include/openmm/TabulatedFunction.h | 11 +- openmmapi/src/TabulatedFunction.cpp | 63 ++++- .../include/openmm/common/CommonKernels.h | 18 +- platforms/common/src/CommonKernels.cpp | 195 +++++++++---- platforms/cpu/include/CpuKernels.h | 7 +- platforms/cpu/src/CpuKernels.cpp | 142 ++++++++-- .../reference/include/ReferenceKernels.h | 16 ++ platforms/reference/src/ReferenceKernels.cpp | 260 ++++++++++++++---- tests/TestCustomCentroidBondForce.h | 12 + tests/TestCustomCompoundBondForce.h | 25 ++ tests/TestCustomGBForce.h | 31 ++- tests/TestCustomHbondForce.h | 9 + tests/TestCustomManyParticleForce.h | 9 + tests/TestCustomNonbondedForce.h | 15 + wrappers/generateWrappers.py | 16 +- 21 files changed, 709 insertions(+), 180 deletions(-) diff --git a/openmmapi/include/openmm/CustomCentroidBondForce.h b/openmmapi/include/openmm/CustomCentroidBondForce.h index cfe500046b..ffb2125c1e 100644 --- a/openmmapi/include/openmm/CustomCentroidBondForce.h +++ b/openmmapi/include/openmm/CustomCentroidBondForce.h @@ -358,15 +358,16 @@ class OPENMM_EXPORT CustomCentroidBondForce : public Force { */ const std::string& getTabulatedFunctionName(int index) const; /** - * Update the per-bond parameters in a Context to match those stored in this Force object. This method provides + * Update the per-bond parameters and tabulated functions in a Context to match those stored in this Force object. This method provides * an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext() * to copy them over to the Context. * - * This method has several limitations. The only information it updates is the values of per-bond parameters. - * All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing + * This method has several limitations. The only information it updates is the values of per-bond parameters and tabulated + * functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing * the Context. Neither the definitions of groups nor the set of groups involved in a bond can be changed, nor can new - * bonds be added. + * bonds be added. Also, while the tabulated values of a function can change, everything else about it (its dimensions, + * the data range) must not be changed. */ void updateParametersInContext(Context& context); /** diff --git a/openmmapi/include/openmm/CustomCompoundBondForce.h b/openmmapi/include/openmm/CustomCompoundBondForce.h index 956ba475a3..740491d63b 100644 --- a/openmmapi/include/openmm/CustomCompoundBondForce.h +++ b/openmmapi/include/openmm/CustomCompoundBondForce.h @@ -338,14 +338,15 @@ class OPENMM_EXPORT CustomCompoundBondForce : public Force { */ void setFunctionParameters(int index, const std::string& name, const std::vector& values, double min, double max); /** - * Update the per-bond parameters in a Context to match those stored in this Force object. This method provides + * Update the per-bond parameters and tabulated functions in a Context to match those stored in this Force object. This method provides * an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext() * to copy them over to the Context. * - * This method has several limitations. The only information it updates is the values of per-bond parameters. - * All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing - * the Context. The set of particles involved in a bond cannot be changed, nor can new bonds be added. + * This method has several limitations. The only information it updates is the values of per-bond parameters and tabulated + * functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing + * the Context. The set of particles involved in a bond cannot be changed, nor can new bonds be added. Also, while the + * tabulated values of a function can change, everything else about it (its dimensions, the data range) must not be changed. */ void updateParametersInContext(Context& context); /** diff --git a/openmmapi/include/openmm/CustomGBForce.h b/openmmapi/include/openmm/CustomGBForce.h index 7109c775c8..42dedaed7a 100644 --- a/openmmapi/include/openmm/CustomGBForce.h +++ b/openmmapi/include/openmm/CustomGBForce.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2008-2016 Stanford University and the Authors. * + * Portions copyright (c) 2008-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -544,14 +544,15 @@ class OPENMM_EXPORT CustomGBForce : public Force { */ void setFunctionParameters(int index, const std::string& name, const std::vector& values, double min, double max); /** - * Update the per-particle parameters in a Context to match those stored in this Force object. This method provides + * Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides * an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext() * to copy them over to the Context. * - * This method has several limitations. The only information it updates is the values of per-particle parameters. - * All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing - * the Context. Also, this method cannot be used to add new particles, only to change the parameters of existing ones. + * This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated + * functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing + * the Context. Also, this method cannot be used to add new particles, only to change the parameters of existing ones. While + * the tabulated values of a function can change, everything else about it (its dimensions, the data range) must not be changed. */ void updateParametersInContext(Context& context); /** diff --git a/openmmapi/include/openmm/CustomHbondForce.h b/openmmapi/include/openmm/CustomHbondForce.h index 67e4c54a43..10b20fe5ea 100644 --- a/openmmapi/include/openmm/CustomHbondForce.h +++ b/openmmapi/include/openmm/CustomHbondForce.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2008-2014 Stanford University and the Authors. * + * Portions copyright (c) 2008-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -443,15 +443,16 @@ class OPENMM_EXPORT CustomHbondForce : public Force { */ void setFunctionParameters(int index, const std::string& name, const std::vector& values, double min, double max); /** - * Update the per-donor and per-acceptor parameters in a Context to match those stored in this Force object. This method + * Update the per-donor and per-acceptor parameters and tabulated functions in a Context to match those stored in this Force object. This method * provides an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * Simply call setDonorParameters() and setAcceptorParameters() to modify this object's parameters, then call * updateParametersInContext() to copy them over to the Context. * - * This method has several limitations. The only information it updates is the values of per-donor and per-acceptor parameters. - * All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can only + * This method has several limitations. The only information it updates is the values of per-donor and per-acceptor parameters and tabulated + * functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can only * be changed by reinitializing the Context. The set of particles involved in a donor or acceptor cannot be changed, nor can - * new donors or acceptors be added. + * new donors or acceptors be added. While the tabulated values of a function can change, everything else about it (its dimensions, + * the data range) must not be changed. */ void updateParametersInContext(Context& context); /** diff --git a/openmmapi/include/openmm/CustomManyParticleForce.h b/openmmapi/include/openmm/CustomManyParticleForce.h index 2d19079f2d..7af2d43066 100644 --- a/openmmapi/include/openmm/CustomManyParticleForce.h +++ b/openmmapi/include/openmm/CustomManyParticleForce.h @@ -480,15 +480,16 @@ class OPENMM_EXPORT CustomManyParticleForce : public Force { */ const std::string& getTabulatedFunctionName(int index) const; /** - * Update the per-particle parameters in a Context to match those stored in this Force object. This method provides + * Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides * an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext() * to copy them over to the Context. * - * This method has several limitations. The only information it updates is the values of per-particle parameters. - * All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can + * This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated + * functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can * only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change - * the parameters of existing ones. + * the parameters of existing ones. While the tabulated values of a function can change, everything else about it (its dimensions, + * the data range) must not be changed. */ void updateParametersInContext(Context& context); /** diff --git a/openmmapi/include/openmm/CustomNonbondedForce.h b/openmmapi/include/openmm/CustomNonbondedForce.h index 004a5e29be..a4fed87a4a 100644 --- a/openmmapi/include/openmm/CustomNonbondedForce.h +++ b/openmmapi/include/openmm/CustomNonbondedForce.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2008-2016 Stanford University and the Authors. * + * Portions copyright (c) 2008-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -490,15 +490,16 @@ class OPENMM_EXPORT CustomNonbondedForce : public Force { */ void setInteractionGroupParameters(int index, const std::set& set1, const std::set& set2); /** - * Update the per-particle parameters in a Context to match those stored in this Force object. This method provides + * Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides * an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext() * to copy them over to the Context. * - * This method has several limitations. The only information it updates is the values of per-particle parameters. - * All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can + * This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated + * functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can * only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change - * the parameters of existing ones. + * the parameters of existing ones. While the tabulated values of a function can change, everything else about it (its dimensions, + * the data range) must not be changed. */ void updateParametersInContext(Context& context); /** diff --git a/openmmapi/include/openmm/TabulatedFunction.h b/openmmapi/include/openmm/TabulatedFunction.h index 31d7980757..2805ec4da8 100644 --- a/openmmapi/include/openmm/TabulatedFunction.h +++ b/openmmapi/include/openmm/TabulatedFunction.h @@ -65,9 +65,12 @@ class OPENMM_EXPORT TabulatedFunction { virtual TabulatedFunction* Copy() const = 0; /** * Get the periodicity status of the tabulated function. - * */ bool getPeriodic() const; + virtual bool operator==(const TabulatedFunction& other) const = 0; + virtual bool operator!=(const TabulatedFunction& other) const { + return !(*this == other); + } protected: bool periodic; }; @@ -114,6 +117,7 @@ class OPENMM_EXPORT Continuous1DFunction : public TabulatedFunction { * @deprecated This will be removed in a future release. */ Continuous1DFunction* Copy() const; + bool operator==(const TabulatedFunction& other) const; private: std::vector values; double min, max; @@ -176,6 +180,7 @@ class OPENMM_EXPORT Continuous2DFunction : public TabulatedFunction { * @deprecated This will be removed in a future release. */ Continuous2DFunction* Copy() const; + bool operator==(const TabulatedFunction& other) const; private: std::vector values; int xsize, ysize; @@ -254,6 +259,7 @@ class OPENMM_EXPORT Continuous3DFunction : public TabulatedFunction { * @deprecated This will be removed in a future release. */ Continuous3DFunction* Copy() const; + bool operator==(const TabulatedFunction& other) const; private: std::vector values; int xsize, ysize, zsize; @@ -291,6 +297,7 @@ class OPENMM_EXPORT Discrete1DFunction : public TabulatedFunction { * @deprecated This will be removed in a future release. */ Discrete1DFunction* Copy() const; + bool operator==(const TabulatedFunction& other) const; private: std::vector values; }; @@ -335,6 +342,7 @@ class OPENMM_EXPORT Discrete2DFunction : public TabulatedFunction { * @deprecated This will be removed in a future release. */ Discrete2DFunction* Copy() const; + bool operator==(const TabulatedFunction& other) const; private: int xsize, ysize; std::vector values; @@ -383,6 +391,7 @@ class OPENMM_EXPORT Discrete3DFunction : public TabulatedFunction { * @deprecated This will be removed in a future release. */ Discrete3DFunction* Copy() const; + bool operator==(const TabulatedFunction& other) const; private: int xsize, ysize, zsize; std::vector values; diff --git a/openmmapi/src/TabulatedFunction.cpp b/openmmapi/src/TabulatedFunction.cpp index 0233b0ace5..1f32250af1 100644 --- a/openmmapi/src/TabulatedFunction.cpp +++ b/openmmapi/src/TabulatedFunction.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2014 Stanford University and the Authors. * + * Portions copyright (c) 2014-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -75,6 +75,15 @@ Continuous1DFunction* Continuous1DFunction::Copy() const { return new Continuous1DFunction(new_vec, min, max); } +bool Continuous1DFunction::operator==(const TabulatedFunction& other) const { + const Continuous1DFunction* fn = dynamic_cast(&other); + if (fn == NULL) + return false; + if (fn->min != min || fn->max != max) + return false; + return (fn->values == values); +} + Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector& values, double xmin, double xmax, double ymin, double ymax, bool periodic) { this->periodic = periodic; setFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax); @@ -120,6 +129,19 @@ Continuous2DFunction* Continuous2DFunction::Copy() const { return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax); } +bool Continuous2DFunction::operator==(const TabulatedFunction& other) const { + const Continuous2DFunction* fn = dynamic_cast(&other); + if (fn == NULL) + return false; + if (fn->xsize != xsize || fn->ysize != ysize) + return false; + if (fn->xmin != xmin || fn->xmax != xmax) + return false; + if (fn->ymin != ymin || fn->ymax != ymax) + return false; + return (fn->values == values); +} + Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic) { this->periodic = periodic; setFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax); @@ -173,6 +195,20 @@ Continuous3DFunction* Continuous3DFunction::Copy() const { return new Continuous3DFunction(xsize, ysize, zsize, new_vec, xmin, xmax, ymin, ymax, zmin, zmax); } +bool Continuous3DFunction::operator==(const TabulatedFunction& other) const { + const Continuous3DFunction* fn = dynamic_cast(&other); + if (fn == NULL) + return false; + if (fn->xsize != xsize || fn->ysize != ysize || fn->zsize != zsize) + return false; + if (fn->xmin != xmin || fn->xmax != xmax) + return false; + if (fn->ymin != ymin || fn->ymax != ymax) + return false; + if (fn->zmin != zmin || fn->zmax != zmax) + return false; + return (fn->values == values); +} Discrete1DFunction::Discrete1DFunction(const vector& values) { this->values = values; @@ -193,6 +229,13 @@ Discrete1DFunction* Discrete1DFunction::Copy() const { return new Discrete1DFunction(new_vec); } +bool Discrete1DFunction::operator==(const TabulatedFunction& other) const { + const Discrete1DFunction* fn = dynamic_cast(&other); + if (fn == NULL) + return false; + return (fn->values == values); +} + Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector& values) { if (values.size() != xsize*ysize) throw OpenMMException("Discrete2DFunction: incorrect number of values"); @@ -222,6 +265,15 @@ Discrete2DFunction* Discrete2DFunction::Copy() const { return new Discrete2DFunction(xsize, ysize, new_vec); } +bool Discrete2DFunction::operator==(const TabulatedFunction& other) const { + const Discrete2DFunction* fn = dynamic_cast(&other); + if (fn == NULL) + return false; + if (fn->xsize != xsize || fn->ysize != ysize) + return false; + return (fn->values == values); +} + Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector& values) { if (values.size() != xsize*ysize*zsize) throw OpenMMException("Discrete3DFunction: incorrect number of values"); @@ -253,3 +305,12 @@ Discrete3DFunction* Discrete3DFunction::Copy() const { new_vec[i] = values[i]; return new Discrete3DFunction(xsize, ysize, zsize, new_vec); } + +bool Discrete3DFunction::operator==(const TabulatedFunction& other) const { + const Discrete3DFunction* fn = dynamic_cast(&other); + if (fn == NULL) + return false; + if (fn->xsize != xsize || fn->ysize != ysize || fn->zsize != zsize) + return false; + return (fn->values == values); +} diff --git a/platforms/common/include/openmm/common/CommonKernels.h b/platforms/common/include/openmm/common/CommonKernels.h index 8664625f64..07ec7bc305 100644 --- a/platforms/common/include/openmm/common/CommonKernels.h +++ b/platforms/common/include/openmm/common/CommonKernels.h @@ -529,7 +529,8 @@ class CommonCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondFor ComputeArray globals; std::vector globalParamNames; std::vector globalParamValues; - std::vector tabulatedFunctions; + std::vector tabulatedFunctionArrays; + std::map tabulatedFunctions; const System& system; }; @@ -577,7 +578,8 @@ class CommonCalcCustomCentroidBondForceKernel : public CalcCustomCentroidBondFor ComputeArray groupForces, bondGroups, centerPositions; std::vector globalParamNames; std::vector globalParamValues; - std::vector tabulatedFunctions; + std::vector tabulatedFunctionArrays; + std::map tabulatedFunctions; std::vector groupForcesArgs; ComputeKernel computeCentersKernel, groupForcesKernel, applyForcesKernel; const System& system; @@ -628,7 +630,8 @@ class CommonCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKern std::vector interactionGroupArgs; std::vector globalParamNames; std::vector globalParamValues; - std::vector tabulatedFunctions; + std::vector tabulatedFunctionArrays; + std::map tabulatedFunctions; double longRangeCoefficient; std::vector longRangeCoefficientDerivs; bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList; @@ -728,7 +731,8 @@ class CommonCalcCustomGBForceKernel : public CalcCustomGBForceKernel { ComputeArray longEnergyDerivs, globals, valueBuffers; std::vector globalParamNames; std::vector globalParamValues; - std::vector tabulatedFunctions; + std::vector tabulatedFunctionArrays; + std::map tabulatedFunctions; std::vector pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue; const System& system; ComputeKernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel; @@ -785,7 +789,8 @@ class CommonCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel { ComputeArray acceptorExclusions; std::vector globalParamNames; std::vector globalParamValues; - std::vector tabulatedFunctions; + std::vector tabulatedFunctionArrays; + std::map tabulatedFunctions; const System& system; ComputeKernel donorKernel, acceptorKernel; }; @@ -836,7 +841,8 @@ class CommonCalcCustomManyParticleForceKernel : public CalcCustomManyParticleFor ComputeArray neighborPairs, numNeighborPairs, neighborStartIndex, numNeighborsForAtom, neighbors; std::vector globalParamNames; std::vector globalParamValues; - std::vector tabulatedFunctions; + std::vector tabulatedFunctionArrays; + std::map tabulatedFunctions; const System& system; ComputeKernel forceKernel, blockBoundsKernel, neighborsKernel, startIndicesKernel, copyPairsKernel; ComputeEvent event; diff --git a/platforms/common/src/CommonKernels.cpp b/platforms/common/src/CommonKernels.cpp index eb91115bd4..78525bda3c 100644 --- a/platforms/common/src/CommonKernels.cpp +++ b/platforms/common/src/CommonKernels.cpp @@ -35,6 +35,7 @@ #include "openmm/internal/CustomCompoundBondForceImpl.h" #include "openmm/internal/CustomHbondForceImpl.h" #include "openmm/internal/CustomManyParticleForceImpl.h" +#include "openmm/serialization/XmlSerializer.h" #include "CommonKernelSources.h" #include "lepton/CustomFunction.h" #include "lepton/ExpressionTreeNode.h" @@ -1289,16 +1290,17 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c map functions; vector > functionDefinitions; vector functionList; - tabulatedFunctions.resize(force.getNumTabulatedFunctions()); + tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions()); for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { functionList.push_back(&force.getTabulatedFunction(i)); string name = force.getTabulatedFunctionName(i); + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); int width; vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); - tabulatedFunctions[i].initialize(cc, f.size(), "TabulatedFunction"); - tabulatedFunctions[i].upload(f); - string arrayName = cc.getBondedUtilities().addArgument(tabulatedFunctions[i], width == 1 ? "float" : "float"+cc.intToString(width)); + tabulatedFunctionArrays[i].initialize(cc, f.size(), "TabulatedFunction"); + tabulatedFunctionArrays[i].upload(f); + string arrayName = cc.getBondedUtilities().addArgument(tabulatedFunctionArrays[i], width == 1 ? "float" : "float"+cc.intToString(width)); functionDefinitions.push_back(make_pair(name, arrayName)); } @@ -1397,9 +1399,9 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp throw OpenMMException("updateParametersInContext: The number of bonds has changed"); if (numBonds == 0) return; - + // Record the per-bond parameters. - + vector > paramVector(numBonds); vector particles; vector parameters; @@ -1410,9 +1412,21 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp paramVector[i][j] = (float) parameters[j]; } params->setParameterValues(paramVector); - + + // See if any tabulated functions have changed. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + int width; + vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); + tabulatedFunctionArrays[i].upload(f); + } + } + // Mark that the current reordering may be invalid. - + cc.invalidateMolecules(info); } @@ -1535,17 +1549,18 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c vector > functionDefinitions; vector functionList; stringstream extraArgs; - tabulatedFunctions.resize(force.getNumTabulatedFunctions()); + tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions()); for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { functionList.push_back(&force.getTabulatedFunction(i)); string name = force.getTabulatedFunctionName(i); + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); string arrayName = "table"+cc.intToString(i); functionDefinitions.push_back(make_pair(name, arrayName)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); int width; vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); - tabulatedFunctions[i].initialize(cc, f.size(), "TabulatedFunction"); - tabulatedFunctions[i].upload(f); + tabulatedFunctionArrays[i].initialize(cc, f.size(), "TabulatedFunction"); + tabulatedFunctionArrays[i].upload(f); extraArgs << ", GLOBAL const float"; if (width > 1) extraArgs << width; @@ -1667,7 +1682,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c groupForcesKernel->addArg(); // Periodic box information will be set just before it is executed. if (needEnergyParamDerivs) groupForcesKernel->addArg(); // Deriv buffer hasn't been created yet. - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) groupForcesKernel->addArg(function); if (globals.isInitialized()) groupForcesKernel->addArg(globals); @@ -1714,9 +1729,9 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp throw OpenMMException("updateParametersInContext: The number of bonds has changed"); if (numBonds == 0) return; - + // Record the per-bond parameters. - + vector > paramVector(numBonds); vector particles; vector parameters; @@ -1727,9 +1742,21 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp paramVector[i][j] = (float) parameters[j]; } params->setParameterValues(paramVector); - + + // See if any tabulated functions have changed. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + int width; + vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); + tabulatedFunctionArrays[i].upload(f); + } + } + // Mark that the current reordering may be invalid. - + cc.invalidateMolecules(info); } @@ -1868,18 +1895,19 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons vector > functionDefinitions; vector functionList; vector tableTypes; - tabulatedFunctions.resize(force.getNumTabulatedFunctions()); + tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions()); for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { functionList.push_back(&force.getTabulatedFunction(i)); string name = force.getTabulatedFunctionName(i); + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); string arrayName = prefix+"table"+cc.intToString(i); functionDefinitions.push_back(make_pair(name, arrayName)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); int width; vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); - tabulatedFunctions[i].initialize(cc, f.size(), "TabulatedFunction"); - tabulatedFunctions[i].upload(f); - cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(tabulatedFunctions[i], arrayName, "float", width)); + tabulatedFunctionArrays[i].initialize(cc, f.size(), "TabulatedFunction"); + tabulatedFunctionArrays[i].upload(f); + cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(tabulatedFunctionArrays[i], arrayName, "float", width)); if (width == 1) tableTypes.push_back("float"); else @@ -2166,7 +2194,7 @@ void CommonCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNon stringstream args; for (int i = 0; i < (int) buffers.size(); i++) args<<", GLOBAL const "<addArg(); // Periodic box information will be set just before it is executed. for (auto& parameter : params->getParameterInfos()) interactionGroupKernel->addArg(parameter.getArray()); - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) interactionGroupKernel->addArg(function); if (globals.isInitialized()) interactionGroupKernel->addArg(globals); @@ -2342,18 +2370,30 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& paramVector[i][j] = (float) parameters[j]; } params->setParameterValues(paramVector); - + // If necessary, recompute the long range correction. - + if (forceCopy != NULL) { longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force); CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool()); hasInitializedLongRangeCorrection = false; *forceCopy = force; } - + + // See if any tabulated functions have changed. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + int width; + vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); + tabulatedFunctionArrays[i].upload(f); + } + } + // Mark that the current reordering may be invalid. - + cc.invalidateMolecules(info); } @@ -2679,18 +2719,19 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo vector > functionDefinitions; vector functionList; stringstream tableArgs; - tabulatedFunctions.resize(force.getNumTabulatedFunctions()); + tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions()); for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { functionList.push_back(&force.getTabulatedFunction(i)); string name = force.getTabulatedFunctionName(i); + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); string arrayName = prefix+"table"+cc.intToString(i); functionDefinitions.push_back(make_pair(name, arrayName)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); int width; vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); - tabulatedFunctions[i].initialize(cc, f.size(), "TabulatedFunction"); - tabulatedFunctions[i].upload(f); - nb.addArgument(ComputeParameterInfo(tabulatedFunctions[i], arrayName, "float", width)); + tabulatedFunctionArrays[i].initialize(cc, f.size(), "TabulatedFunction"); + tabulatedFunctionArrays[i].upload(f); + nb.addArgument(ComputeParameterInfo(tabulatedFunctionArrays[i], arrayName, "float", width)); tableArgs << ", GLOBAL const float"; if (width > 1) tableArgs << width; @@ -3510,7 +3551,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include } for (auto& d : dValue0dParam) pairValueKernel->addArg(d); - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) pairValueKernel->addArg(function); perParticleValueKernel->addArg(cc.getPosq()); perParticleValueKernel->addArg(valueBuffers); @@ -3529,7 +3570,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include for (int j = 0; j < dValuedParam[i]->getParameterInfos().size(); j++) perParticleValueKernel->addArg(dValuedParam[i]->getParameterInfos()[j].getArray()); } - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) perParticleValueKernel->addArg(function); pairEnergyKernel->addArg(useLong ? cc.getLongForceBuffer() : cc.getForceBuffers()); pairEnergyKernel->addArg(cc.getEnergyBuffer()); @@ -3570,7 +3611,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include pairEnergyKernel->addArg(buffer.getArray()); if (needEnergyParamDerivs) pairEnergyKernel->addArg(cc.getEnergyParamDerivBuffer()); - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) pairEnergyKernel->addArg(function); perParticleEnergyKernel->addArg(cc.getEnergyBuffer()); perParticleEnergyKernel->addArg(cc.getPosq()); @@ -3595,7 +3636,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include perParticleEnergyKernel->addArg(longEnergyDerivs); if (needEnergyParamDerivs) perParticleEnergyKernel->addArg(cc.getEnergyParamDerivBuffer()); - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) perParticleEnergyKernel->addArg(function); if (needParameterGradient || needEnergyParamDerivs) { gradientChainRuleKernel->addArg(cc.getPosq()); @@ -3614,7 +3655,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include for (auto& buffer : d->getParameterInfos()) gradientChainRuleKernel->addArg(buffer.getArray()); } - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) gradientChainRuleKernel->addArg(function); } } @@ -3653,9 +3694,9 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context int numParticles = force.getNumParticles(); if (numParticles != cc.getNumAtoms()) throw OpenMMException("updateParametersInContext: The number of particles has changed"); - + // Record the per-particle parameters. - + vector > paramVector(cc.getPaddedNumAtoms(), vector(force.getNumPerParticleParameters(), 0)); vector parameters; for (int i = 0; i < numParticles; i++) { @@ -3664,9 +3705,21 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context paramVector[i][j] = (float) parameters[j]; } params->setParameterValues(paramVector); - + + // See if any tabulated functions have changed. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + int width; + vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); + tabulatedFunctionArrays[i].upload(f); + } + } + // Mark that the current reordering may be invalid. - + cc.invalidateMolecules(info); } @@ -3880,17 +3933,18 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu vector > functionDefinitions; vector functionList; stringstream tableArgs; - tabulatedFunctions.resize(force.getNumTabulatedFunctions()); + tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions()); for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { functionList.push_back(&force.getTabulatedFunction(i)); string name = force.getTabulatedFunctionName(i); + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); string arrayName = "table"+cc.intToString(i); functionDefinitions.push_back(make_pair(name, arrayName)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); int width; vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); - tabulatedFunctions[i].initialize(cc, f.size(), "TabulatedFunction"); - tabulatedFunctions[i].upload(f); + tabulatedFunctionArrays[i].initialize(cc, f.size(), "TabulatedFunction"); + tabulatedFunctionArrays[i].upload(f); tableArgs << ", GLOBAL const float"; if (width > 1) tableArgs << width; @@ -4132,7 +4186,7 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl donorKernel->addArg(parameter.getArray()); for (auto& parameter : acceptorParams->getParameterInfos()) donorKernel->addArg(parameter.getArray()); - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) donorKernel->addArg(function); if (cc.getSupports64BitGlobalAtomics()) acceptorKernel->addArg(cc.getLongForceBuffer()); @@ -4153,7 +4207,7 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl acceptorKernel->addArg(parameter.getArray()); for (auto& parameter : acceptorParams->getParameterInfos()) acceptorKernel->addArg(parameter.getArray()); - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) acceptorKernel->addArg(function); } setPeriodicBoxArgs(cc, donorKernel, cc.getSupports64BitGlobalAtomics() ? 6 : 7); @@ -4172,9 +4226,9 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont throw OpenMMException("updateParametersInContext: The number of donors has changed"); if (numAcceptors != force.getNumAcceptors()) throw OpenMMException("updateParametersInContext: The number of acceptors has changed"); - + // Record the per-donor parameters. - + if (numDonors > 0) { vector > donorParamVector(numDonors); vector parameters; @@ -4187,9 +4241,9 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont } donorParams->setParameterValues(donorParamVector); } - + // Record the per-acceptor parameters. - + if (numAcceptors > 0) { vector > acceptorParamVector(numAcceptors); vector parameters; @@ -4202,9 +4256,21 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont } acceptorParams->setParameterValues(acceptorParamVector); } - + + // See if any tabulated functions have changed. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + int width; + vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); + tabulatedFunctionArrays[i].upload(f); + } + } + // Mark that the current reordering may be invalid. - + cc.invalidateMolecules(info); } @@ -4280,17 +4346,18 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c vector > functionDefinitions; vector functionList; stringstream tableArgs; - tabulatedFunctions.resize(force.getNumTabulatedFunctions()); + tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions()); for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { functionList.push_back(&force.getTabulatedFunction(i)); string name = force.getTabulatedFunctionName(i); + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); string arrayName = "table"+cc.intToString(i); functionDefinitions.push_back(make_pair(name, arrayName)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); int width; vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); - tabulatedFunctions[i].initialize(cc, f.size(), "TabulatedFunction"); - tabulatedFunctions[i].upload(f); + tabulatedFunctionArrays[i].initialize(cc, f.size(), "TabulatedFunction"); + tabulatedFunctionArrays[i].upload(f); tableArgs << ", GLOBAL const float"; if (width > 1) tableArgs << width; @@ -4593,7 +4660,7 @@ double CommonCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bo forceKernel->addArg(globals); for (auto& parameter : params->getParameterInfos()) forceKernel->addArg(parameter.getArray()); - for (auto& function : tabulatedFunctions) + for (auto& function : tabulatedFunctionArrays) forceKernel->addArg(function); if (nonbondedMethod != NoCutoff) { @@ -4695,9 +4762,9 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp int numParticles = force.getNumParticles(); if (numParticles != cc.getNumAtoms()) throw OpenMMException("updateParametersInContext: The number of particles has changed"); - + // Record the per-particle parameters. - + vector > paramVector(numParticles); vector parameters; int type; @@ -4708,9 +4775,21 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp paramVector[i][j] = (float) parameters[j]; } params->setParameterValues(paramVector); - + + // See if any tabulated functions have changed. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + int width; + vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); + tabulatedFunctionArrays[i].upload(f); + } + } + // Mark that the current reordering may be invalid. - + cc.invalidateMolecules(info); } diff --git a/platforms/cpu/include/CpuKernels.h b/platforms/cpu/include/CpuKernels.h index e87990bb7c..3c28fb6658 100644 --- a/platforms/cpu/include/CpuKernels.h +++ b/platforms/cpu/include/CpuKernels.h @@ -320,7 +320,8 @@ class CpuCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel * @param force the CustomNonbondedForce to copy the parameters from */ void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force); -private: +private: + void createInteraction(const CustomNonbondedForce& force); CpuPlatform::PlatformData& data; int numParticles; std::vector > particleParamArray; @@ -333,6 +334,7 @@ class CpuCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel std::vector parameterNames, globalParameterNames, energyParamDerivNames; std::vector, std::set > > interactionGroups; std::vector longRangeCoefficientDerivs; + std::map tabulatedFunctions; NonbondedMethod nonbondedMethod; CpuCustomNonbondedForce* nonbonded; }; @@ -410,6 +412,7 @@ class CpuCalcCustomGBForceKernel : public CalcCustomGBForceKernel { */ void copyParametersToContext(ContextImpl& context, const CustomGBForce& force); private: + void createInteraction(const CustomGBForce& force); CpuPlatform::PlatformData& data; int numParticles; bool isPeriodic; @@ -421,6 +424,7 @@ class CpuCalcCustomGBForceKernel : public CalcCustomGBForceKernel { std::vector particleParameterNames, globalParameterNames, energyParamDerivNames, valueNames; std::vector valueTypes; std::vector energyTypes; + std::map tabulatedFunctions; NonbondedMethod nonbondedMethod; }; @@ -463,6 +467,7 @@ class CpuCalcCustomManyParticleForceKernel : public CalcCustomManyParticleForceK std::vector > particleParamArray; CpuCustomManyParticleForce* ixn; std::vector globalParameterNames; + std::map tabulatedFunctions; NonbondedMethod nonbondedMethod; }; diff --git a/platforms/cpu/src/CpuKernels.cpp b/platforms/cpu/src/CpuKernels.cpp index ccd5dbd218..5d8468a56a 100644 --- a/platforms/cpu/src/CpuKernels.cpp +++ b/platforms/cpu/src/CpuKernels.cpp @@ -45,6 +45,7 @@ #include "openmm/internal/ContextImpl.h" #include "openmm/internal/NonbondedForceImpl.h" #include "openmm/internal/vectorize.h" +#include "openmm/serialization/XmlSerializer.h" #include "lepton/CompiledExpression.h" #include "lepton/CustomFunction.h" #include "lepton/Operation.h" @@ -868,7 +869,6 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C // Build the arrays. - int numParameters = force.getNumPerParticleParameters(); particleParamArray.resize(numParticles); for (int i = 0; i < numParticles; ++i) force.getParticleParameters(i, particleParamArray[i]); @@ -882,10 +882,41 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C switchingDistance = force.getSwitchingDistance(); } + // Record the tabulated functions for future reference. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) + tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); + + // Record information for the long range correction. + + if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) { + forceCopy = new CustomNonbondedForce(force); + hasInitializedLongRangeCorrection = false; + } + else { + longRangeCoefficient = 0.0; + hasInitializedLongRangeCorrection = true; + } + + // Record the interaction groups. + + for (int i = 0; i < force.getNumInteractionGroups(); i++) { + set set1, set2; + force.getInteractionGroupParameters(i, set1, set2); + interactionGroups.push_back(make_pair(set1, set2)); + } + data.isPeriodic |= (nonbondedMethod == CutoffPeriodic); + + // Create the interaction. + + createInteraction(force); +} + +void CpuCalcCustomNonbondedForceKernel::createInteraction(const CustomNonbondedForce& force) { // Create custom functions for the tabulated functions. map functions; - for (int i = 0; i < force.getNumFunctions(); i++) + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); // Parse the various expressions used to calculate the force. @@ -893,7 +924,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize(); Lepton::CompiledExpression energyExpression = expression.createCompiledExpression(); Lepton::CompiledExpression forceExpression = expression.differentiate("r").createCompiledExpression(); - for (int i = 0; i < numParameters; i++) + for (int i = 0; i < force.getNumPerParticleParameters(); i++) parameterNames.push_back(force.getPerParticleParameterName(i)); for (int i = 0; i < force.getNumGlobalParameters(); i++) { globalParameterNames.push_back(force.getGlobalParameterName(i)); @@ -907,7 +938,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C } set variables; variables.insert("r"); - for (int i = 0; i < numParameters; i++) { + for (int i = 0; i < force.getNumPerParticleParameters(); i++) { variables.insert(parameterNames[i]+"1"); variables.insert(parameterNames[i]+"2"); } @@ -918,26 +949,9 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C for (auto& function : functions) delete function.second; - - // Record information for the long range correction. - - if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) { - forceCopy = new CustomNonbondedForce(force); - hasInitializedLongRangeCorrection = false; - } - else { - longRangeCoefficient = 0.0; - hasInitializedLongRangeCorrection = true; - } - - // Record the interaction groups. - - for (int i = 0; i < force.getNumInteractionGroups(); i++) { - set set1, set2; - force.getInteractionGroupParameters(i, set1, set2); - interactionGroups.push_back(make_pair(set1, set2)); - } - data.isPeriodic |= (nonbondedMethod == CutoffPeriodic); + + // Create the object that computes the interaction. + nonbonded = new CpuCustomNonbondedForce(energyExpression, forceExpression, parameterNames, exclusions, energyParamDerivExpressions, data.threads); if (interactionGroups.size() > 0) nonbonded->setInteractionGroups(interactionGroups); @@ -1011,6 +1025,22 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con hasInitializedLongRangeCorrection = true; *forceCopy = force; } + + // See if any tabulated functions have changed. + + bool changed = false; + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + changed = true; + } + } + if (changed) { + delete nonbonded; + nonbonded = NULL; + createInteraction(force); + } } CpuCalcGBSAOBCForceKernel::~CpuCalcGBSAOBCForceKernel() { @@ -1101,11 +1131,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB // Build the arrays. - int numPerParticleParameters = force.getNumPerParticleParameters(); particleParamArray.resize(numParticles); for (int i = 0; i < numParticles; ++i) force.getParticleParameters(i, particleParamArray[i]); - for (int i = 0; i < numPerParticleParameters; i++) + for (int i = 0; i < force.getNumPerParticleParameters(); i++) particleParameterNames.push_back(force.getPerParticleParameterName(i)); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalParameterNames.push_back(force.getGlobalParameterName(i)); @@ -1113,15 +1142,30 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB nonbondedCutoff = force.getCutoffDistance(); if (nonbondedMethod != NoCutoff) neighborList = new CpuNeighborList(4); + data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic); + + // Record the tabulated functions for future reference. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) + tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); + + // Create the interaction. + createInteraction(force); +} + +void CpuCalcCustomGBForceKernel::createInteraction(const CustomGBForce& force) { // Create custom functions for the tabulated functions. map functions; - for (int i = 0; i < force.getNumFunctions(); i++) + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); // Parse the expressions for computed values. + valueTypes.clear(); + valueNames.clear(); + energyParamDerivNames.clear(); vector > valueDerivExpressions(force.getNumComputedValues()); vector > valueGradientExpressions(force.getNumComputedValues()); vector > valueParamDerivExpressions(force.getNumComputedValues()); @@ -1132,7 +1176,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB particleVariables.insert("x"); particleVariables.insert("y"); particleVariables.insert("z"); - for (int i = 0; i < numPerParticleParameters; i++) { + for (int i = 0; i < force.getNumPerParticleParameters(); i++) { particleVariables.insert(particleParameterNames[i]); pairVariables.insert(particleParameterNames[i]+"1"); pairVariables.insert(particleParameterNames[i]+"2"); @@ -1171,6 +1215,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB // Parse the expressions for energy terms. + energyTypes.clear(); vector > energyDerivExpressions(force.getNumEnergyTerms()); vector > energyGradientExpressions(force.getNumEnergyTerms()); vector > energyParamDerivExpressions(force.getNumEnergyTerms()); @@ -1208,7 +1253,6 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB ixn = new CpuCustomGBForce(numParticles, exclusions, valueExpressions, valueDerivExpressions, valueGradientExpressions, valueParamDerivExpressions, valueNames, valueTypes, energyExpressions, energyDerivExpressions, energyGradientExpressions, energyParamDerivExpressions, energyTypes, particleParameterNames, data.threads); - data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic); } double CpuCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { @@ -1247,6 +1291,22 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c for (int j = 0; j < numParameters; j++) particleParamArray[i][j] = static_cast(parameters[j]); } + + // See if any tabulated functions have changed. + + bool changed = false; + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + changed = true; + } + } + if (changed) { + delete ixn; + ixn = NULL; + createInteraction(force); + } } CpuCalcCustomManyParticleForceKernel::~CpuCalcCustomManyParticleForceKernel() { @@ -1266,6 +1326,14 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons } for (int i = 0; i < force.getNumGlobalParameters(); i++) globalParameterNames.push_back(force.getGlobalParameterName(i)); + + // Record the tabulated functions for future reference. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) + tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); + + // Create the interaction. + ixn = new CpuCustomManyParticleForce(force, data.threads); nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod()); cutoffDistance = force.getCutoffDistance(); @@ -1303,6 +1371,22 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl& for (int j = 0; j < numParameters; j++) particleParamArray[i][j] = static_cast(parameters[j]); } + + // See if any tabulated functions have changed. + + bool changed = false; + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + changed = true; + } + } + if (changed) { + delete ixn; + ixn = NULL; + ixn = new CpuCustomManyParticleForce(force, data.threads); + } } CpuCalcGayBerneForceKernel::~CpuCalcGayBerneForceKernel() { diff --git a/platforms/reference/include/ReferenceKernels.h b/platforms/reference/include/ReferenceKernels.h index 96e51a134e..5caa6d53c2 100644 --- a/platforms/reference/include/ReferenceKernels.h +++ b/platforms/reference/include/ReferenceKernels.h @@ -682,6 +682,7 @@ class ReferenceCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceK */ void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force); private: + void createExpressions(const CustomNonbondedForce& force); int numParticles; std::vector > particleParamArray; double nonbondedCutoff, switchingDistance, periodicBoxSize[3], longRangeCoefficient; @@ -695,6 +696,7 @@ class ReferenceCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceK std::vector parameterNames, globalParameterNames, energyParamDerivNames; std::vector, std::set > > interactionGroups; std::vector longRangeCoefficientDerivs; + std::map tabulatedFunctions; NonbondedMethod nonbondedMethod; NeighborList* neighborList; }; @@ -768,6 +770,7 @@ class ReferenceCalcCustomGBForceKernel : public CalcCustomGBForceKernel { */ void copyParametersToContext(ContextImpl& context, const CustomGBForce& force); private: + void createExpressions(const CustomGBForce& force); int numParticles; bool isPeriodic; std::vector > particleParamArray; @@ -784,6 +787,7 @@ class ReferenceCalcCustomGBForceKernel : public CalcCustomGBForceKernel { std::vector > energyGradientExpressions; std::vector > energyParamDerivExpressions; std::vector energyTypes; + std::map tabulatedFunctions; NonbondedMethod nonbondedMethod; NeighborList* neighborList; }; @@ -861,13 +865,16 @@ class ReferenceCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel { */ void copyParametersToContext(ContextImpl& context, const CustomHbondForce& force); private: + void createInteraction(const CustomHbondForce& force); int numDonors, numAcceptors, numParticles; bool isPeriodic; + std::vector > donorParticles, acceptorParticles; std::vector > donorParamArray, acceptorParamArray; double nonbondedCutoff; ReferenceCustomHbondIxn* ixn; std::vector > exclusions; std::vector globalParameterNames; + std::map tabulatedFunctions; }; /** @@ -902,10 +909,15 @@ class ReferenceCalcCustomCentroidBondForceKernel : public CalcCustomCentroidBond */ void copyParametersToContext(ContextImpl& context, const CustomCentroidBondForce& force); private: + void createInteraction(const CustomCentroidBondForce& force); int numBonds, numParticles; + std::vector > bondGroups; + std::vector > groupAtoms; + std::vector > normalizedWeights; std::vector > bondParamArray; ReferenceCustomCentroidBondIxn* ixn; std::vector globalParameterNames, energyParamDerivNames; + std::map tabulatedFunctions; bool usePeriodic; Vec3* boxVectors; }; @@ -942,10 +954,13 @@ class ReferenceCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBond */ void copyParametersToContext(ContextImpl& context, const CustomCompoundBondForce& force); private: + void createInteraction(const CustomCompoundBondForce& force); int numBonds; + std::vector > bondParticles; std::vector > bondParamArray; ReferenceCustomCompoundBondIxn* ixn; std::vector globalParameterNames, energyParamDerivNames; + std::map tabulatedFunctions; bool usePeriodic; Vec3* boxVectors; }; @@ -987,6 +1002,7 @@ class ReferenceCalcCustomManyParticleForceKernel : public CalcCustomManyParticle std::vector > particleParamArray; ReferenceCustomManyParticleIxn* ixn; std::vector globalParameterNames; + std::map tabulatedFunctions; NonbondedMethod nonbondedMethod; }; diff --git a/platforms/reference/src/ReferenceKernels.cpp b/platforms/reference/src/ReferenceKernels.cpp index 5caf8bc5ab..6c6cd00de7 100644 --- a/platforms/reference/src/ReferenceKernels.cpp +++ b/platforms/reference/src/ReferenceKernels.cpp @@ -80,6 +80,7 @@ #include "openmm/internal/NonbondedForceImpl.h" #include "openmm/Integrator.h" #include "openmm/OpenMMException.h" +#include "openmm/serialization/XmlSerializer.h" #include "SimTKOpenMMUtilities.h" #include "lepton/CustomFunction.h" #include "lepton/Operation.h" @@ -1151,7 +1152,6 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c // Build the arrays. - int numParameters = force.getNumPerParticleParameters(); particleParamArray.resize(numParticles); for (int i = 0; i < numParticles; ++i) force.getParticleParameters(i, particleParamArray[i]); @@ -1167,10 +1167,40 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c switchingDistance = force.getSwitchingDistance(); } + // Record the tabulated functions for future reference. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) + tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); + + // Create the expressions. + + createExpressions(force); + + // Record information for the long range correction. + + if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) { + forceCopy = new CustomNonbondedForce(force); + hasInitializedLongRangeCorrection = false; + } + else { + longRangeCoefficient = 0.0; + hasInitializedLongRangeCorrection = true; + } + + // Record the interaction groups. + + for (int i = 0; i < force.getNumInteractionGroups(); i++) { + set set1, set2; + force.getInteractionGroupParameters(i, set1, set2); + interactionGroups.push_back(make_pair(set1, set2)); + } +} + +void ReferenceCalcCustomNonbondedForceKernel::createExpressions(const CustomNonbondedForce& force) { // Create custom functions for the tabulated functions. map functions; - for (int i = 0; i < force.getNumFunctions(); i++) + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); // Parse the various expressions used to calculate the force. @@ -1178,7 +1208,12 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize(); energyExpression = expression.createCompiledExpression(); forceExpression = expression.differentiate("r").createCompiledExpression(); - for (int i = 0; i < numParameters; i++) + parameterNames.clear(); + globalParameterNames.clear(); + globalParamValues.clear(); + energyParamDerivNames.clear(); + energyParamDerivExpressions.clear(); + for (int i = 0; i < force.getNumPerParticleParameters(); i++) parameterNames.push_back(force.getPerParticleParameterName(i)); for (int i = 0; i < force.getNumGlobalParameters(); i++) { globalParameterNames.push_back(force.getGlobalParameterName(i)); @@ -1191,7 +1226,7 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c } set variables; variables.insert("r"); - for (int i = 0; i < numParameters; i++) { + for (int i = 0; i < force.getNumPerParticleParameters(); i++) { variables.insert(parameterNames[i]+"1"); variables.insert(parameterNames[i]+"2"); } @@ -1202,25 +1237,6 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c for (auto& function : functions) delete function.second; - - // Record information for the long range correction. - - if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) { - forceCopy = new CustomNonbondedForce(force); - hasInitializedLongRangeCorrection = false; - } - else { - longRangeCoefficient = 0.0; - hasInitializedLongRangeCorrection = true; - } - - // Record the interaction groups. - - for (int i = 0; i < force.getNumInteractionGroups(); i++) { - set set1, set2; - force.getInteractionGroupParameters(i, set1, set2); - interactionGroups.push_back(make_pair(set1, set2)); - } } double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { @@ -1300,6 +1316,19 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp hasInitializedLongRangeCorrection = true; *forceCopy = force; } + + // See if any tabulated functions have changed. + + bool changed = false; + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + changed = true; + } + } + if (changed) + createExpressions(force); } ReferenceCalcGBSAOBCForceKernel::~ReferenceCalcGBSAOBCForceKernel() { @@ -1395,11 +1424,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu // Build the arrays. - int numPerParticleParameters = force.getNumPerParticleParameters(); particleParamArray.resize(numParticles); for (int i = 0; i < numParticles; ++i) force.getParticleParameters(i, particleParamArray[i]); - for (int i = 0; i < numPerParticleParameters; i++) + for (int i = 0; i < force.getNumPerParticleParameters(); i++) particleParameterNames.push_back(force.getPerParticleParameterName(i)); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalParameterNames.push_back(force.getGlobalParameterName(i)); @@ -1410,14 +1438,32 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu else neighborList = new NeighborList(); + // Record the tabulated functions for future reference. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) + tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); + + // Create the expressions. + + createExpressions(force); +} + +void ReferenceCalcCustomGBForceKernel::createExpressions(const CustomGBForce& force) { // Create custom functions for the tabulated functions. map functions; - for (int i = 0; i < force.getNumFunctions(); i++) + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); // Parse the expressions for computed values. + valueExpressions.clear(); + valueTypes.clear(); + valueNames.clear(); + energyParamDerivNames.clear(); + valueDerivExpressions.clear(); + valueGradientExpressions.clear(); + valueParamDerivExpressions.clear(); valueDerivExpressions.resize(force.getNumComputedValues()); valueGradientExpressions.resize(force.getNumComputedValues()); valueParamDerivExpressions.resize(force.getNumComputedValues()); @@ -1426,7 +1472,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu particleVariables.insert("x"); particleVariables.insert("y"); particleVariables.insert("z"); - for (int i = 0; i < numPerParticleParameters; i++) { + for (int i = 0; i < force.getNumPerParticleParameters(); i++) { particleVariables.insert(particleParameterNames[i]); pairVariables.insert(particleParameterNames[i]+"1"); pairVariables.insert(particleParameterNames[i]+"2"); @@ -1465,6 +1511,11 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu // Parse the expressions for energy terms. + energyExpressions.clear(); + energyTypes.clear(); + energyDerivExpressions.clear(); + energyGradientExpressions.clear(); + energyParamDerivExpressions.clear(); energyDerivExpressions.resize(force.getNumEnergyTerms()); energyGradientExpressions.resize(force.getNumEnergyTerms()); energyParamDerivExpressions.resize(force.getNumEnergyTerms()); @@ -1540,6 +1591,19 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont for (int j = 0; j < numParameters; j++) particleParamArray[i][j] = parameters[j]; } + + // See if any tabulated functions have changed. + + bool changed = false; + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + changed = true; + } + } + if (changed) + createExpressions(force); } ReferenceCalcCustomExternalForceKernel::~ReferenceCalcCustomExternalForceKernel() { @@ -1637,8 +1701,7 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const // Build the arrays. - vector > donorParticles(numDonors); - int numDonorParameters = force.getNumPerDonorParameters(); + donorParticles.resize(numDonors); donorParamArray.resize(numDonors); for (int i = 0; i < numDonors; ++i) { int d1, d2, d3; @@ -1647,8 +1710,7 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const donorParticles[i].push_back(d2); donorParticles[i].push_back(d3); } - vector > acceptorParticles(numAcceptors); - int numAcceptorParameters = force.getNumPerAcceptorParameters(); + acceptorParticles.resize(numAcceptors); acceptorParamArray.resize(numAcceptors); for (int i = 0; i < numAcceptors; ++i) { int a1, a2, a3; @@ -1657,13 +1719,25 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const acceptorParticles[i].push_back(a2); acceptorParticles[i].push_back(a3); } - NonbondedMethod nonbondedMethod = CalcCustomHbondForceKernel::NonbondedMethod(force.getNonbondedMethod()); + for (int i = 0; i < force.getNumGlobalParameters(); i++) + globalParameterNames.push_back(force.getGlobalParameterName(i)); nonbondedCutoff = force.getCutoffDistance(); + // Record the tabulated functions for future reference. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) + tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); + + // Create the interaction. + + createInteraction(force); +} + +void ReferenceCalcCustomHbondForceKernel::createInteraction(const CustomHbondForce& force) { // Create custom functions for the tabulated functions. map functions; - for (int i = 0; i < force.getNumFunctions(); i++) + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); // Parse the expression and create the object used to calculate the interaction. @@ -1674,13 +1748,12 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const Lepton::ParsedExpression energyExpression = CustomHbondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals); vector donorParameterNames; vector acceptorParameterNames; - for (int i = 0; i < numDonorParameters; i++) + for (int i = 0; i < force.getNumPerDonorParameters(); i++) donorParameterNames.push_back(force.getPerDonorParameterName(i)); - for (int i = 0; i < numAcceptorParameters; i++) + for (int i = 0; i < force.getNumPerAcceptorParameters(); i++) acceptorParameterNames.push_back(force.getPerAcceptorParameterName(i)); - for (int i = 0; i < force.getNumGlobalParameters(); i++) - globalParameterNames.push_back(force.getGlobalParameterName(i)); ixn = new ReferenceCustomHbondIxn(donorParticles, acceptorParticles, energyExpression, donorParameterNames, acceptorParameterNames, distances, angles, dihedrals); + NonbondedMethod nonbondedMethod = CalcCustomHbondForceKernel::NonbondedMethod(force.getNonbondedMethod()); isPeriodic = (nonbondedMethod == CutoffPeriodic); if (nonbondedMethod != NoCutoff) ixn->setUseCutoff(nonbondedCutoff); @@ -1733,6 +1806,22 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c for (int j = 0; j < numAcceptorParameters; j++) acceptorParamArray[i][j] = parameters[j]; } + + // See if any tabulated functions have changed. + + bool changed = false; + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + changed = true; + } + } + if (changed) { + delete ixn; + ixn = NULL; + createInteraction(force); + } } ReferenceCalcCustomCentroidBondForceKernel::~ReferenceCalcCustomCentroidBondForceKernel() { @@ -1746,23 +1835,32 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system // Build the arrays. int numGroups = force.getNumGroups(); - vector > groupAtoms(numGroups); + groupAtoms.resize(numGroups); vector ignored; for (int i = 0; i < numGroups; i++) force.getGroupParameters(i, groupAtoms[i], ignored); - vector > normalizedWeights; CustomCentroidBondForceImpl::computeNormalizedWeights(force, system, normalizedWeights); numBonds = force.getNumBonds(); - vector > bondGroups(numBonds); - int numBondParameters = force.getNumPerBondParameters(); + bondGroups.resize(numBonds); bondParamArray.resize(numBonds); for (int i = 0; i < numBonds; ++i) force.getBondParameters(i, bondGroups[i], bondParamArray[i]); + // Record the tabulated functions for future reference. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) + tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); + + // Create the interaction. + + createInteraction(force); +} + +void ReferenceCalcCustomCentroidBondForceKernel::createInteraction(const CustomCentroidBondForce& force) { // Create custom functions for the tabulated functions. map functions; - for (int i = 0; i < force.getNumFunctions(); i++) + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); // Create implementations of point functions. @@ -1773,9 +1871,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system // Parse the expression and create the object used to calculate the interaction. + int numGroups = force.getNumGroups(); Lepton::ParsedExpression energyExpression = CustomCentroidBondForceImpl::prepareExpression(force, functions); vector bondParameterNames; - for (int i = 0; i < numBondParameters; i++) + for (int i = 0; i < force.getNumPerBondParameters(); i++) bondParameterNames.push_back(force.getPerBondParameterName(i)); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalParameterNames.push_back(force.getGlobalParameterName(i)); @@ -1830,6 +1929,22 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context for (int j = 0; j < numParameters; j++) bondParamArray[i][j] = params[j]; } + + // See if any tabulated functions have changed. + + bool changed = false; + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + changed = true; + } + } + if (changed) { + delete ixn; + ixn = NULL; + createInteraction(force); + } } ReferenceCalcCustomCompoundBondForceKernel::~ReferenceCalcCustomCompoundBondForceKernel() { @@ -1843,16 +1958,26 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system // Build the arrays. numBonds = force.getNumBonds(); - vector > bondParticles(numBonds); - int numBondParameters = force.getNumPerBondParameters(); + bondParticles.resize(numBonds); bondParamArray.resize(numBonds); for (int i = 0; i < numBonds; ++i) force.getBondParameters(i, bondParticles[i], bondParamArray[i]); + // Record the tabulated functions for future reference. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) + tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); + + // Create the interaction. + + createInteraction(force); +} + +void ReferenceCalcCustomCompoundBondForceKernel::createInteraction(const CustomCompoundBondForce& force) { // Create custom functions for the tabulated functions. map functions; - for (int i = 0; i < force.getNumFunctions(); i++) + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); // Create implementations of point functions. @@ -1865,7 +1990,7 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system Lepton::ParsedExpression energyExpression = CustomCompoundBondForceImpl::prepareExpression(force, functions); vector bondParameterNames; - for (int i = 0; i < numBondParameters; i++) + for (int i = 0; i < force.getNumPerBondParameters(); i++) bondParameterNames.push_back(force.getPerBondParameterName(i)); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalParameterNames.push_back(force.getGlobalParameterName(i)); @@ -1920,6 +2045,22 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context for (int j = 0; j < numParameters; j++) bondParamArray[i][j] = params[j]; } + + // See if any tabulated functions have changed. + + bool changed = false; + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + changed = true; + } + } + if (changed) { + delete ixn; + ixn = NULL; + createInteraction(force); + } } ReferenceCalcCustomManyParticleForceKernel::~ReferenceCalcCustomManyParticleForceKernel() { @@ -1928,7 +2069,6 @@ ReferenceCalcCustomManyParticleForceKernel::~ReferenceCalcCustomManyParticleForc } void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system, const CustomManyParticleForce& force) { - // Build the arrays. numParticles = system.getNumParticles(); @@ -1939,6 +2079,14 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system } for (int i = 0; i < force.getNumGlobalParameters(); i++) globalParameterNames.push_back(force.getGlobalParameterName(i)); + + // Record the tabulated functions for future reference. + + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) + tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); + + // Create the interaction. + ixn = new ReferenceCustomManyParticleIxn(force); nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod()); cutoffDistance = force.getCutoffDistance(); @@ -1977,6 +2125,22 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context for (int j = 0; j < numParameters; j++) particleParamArray[i][j] = parameters[j]; } + + // See if any tabulated functions have changed. + + bool changed = false; + for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { + string name = force.getTabulatedFunctionName(i); + if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { + tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); + changed = true; + } + } + if (changed) { + delete ixn; + ixn = NULL; + ixn = new ReferenceCustomManyParticleIxn(force); + } } ReferenceCalcGayBerneForceKernel::~ReferenceCalcGayBerneForceKernel() { diff --git a/tests/TestCustomCentroidBondForce.h b/tests/TestCustomCentroidBondForce.h index 1ffb520b27..a8505ec866 100644 --- a/tests/TestCustomCentroidBondForce.h +++ b/tests/TestCustomCentroidBondForce.h @@ -205,6 +205,18 @@ void testComplexFunction(bool byGroups) { for (int i = 0; i < numParticles; i++) ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], TOL); } + + // Try updating the tabulated function. + + for (int i = 0; i < table.size(); i++) + table[i] *= 0.5; + dynamic_cast(compound->getTabulatedFunction(0)).setFunctionParameters(table, -1, 10); + dynamic_cast(centroid->getTabulatedFunction(0)).setFunctionParameters(table, -1, 10); + compound->updateParametersInContext(context); + centroid->updateParametersInContext(context); + State state1 = context.getState(State::Energy, false, 1<<0); + State state2 = context.getState(State::Energy, false, 1<<1); + ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), TOL); } void testCustomWeights() { diff --git a/tests/TestCustomCompoundBondForce.h b/tests/TestCustomCompoundBondForce.h index ffaa0bc17d..977c7b8ebc 100644 --- a/tests/TestCustomCompoundBondForce.h +++ b/tests/TestCustomCompoundBondForce.h @@ -212,6 +212,31 @@ void testContinuous2DFunction() { ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05); } } + + // Try updating the tabulated function. + + for (int i = 0; i < table.size(); i++) + table[i] *= 0.5; + Continuous2DFunction& fn = dynamic_cast(forceField->getTabulatedFunction(0)); + fn.setFunctionParameters(xsize, ysize, table, xmin, xmax, ymin, ymax); + forceField->updateParametersInContext(context); + for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) { + for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) { + positions[0] = Vec3(x, y, 1.5); + context.setPositions(positions); + State state = context.getState(State::Forces | State::Energy); + const vector& forces = state.getForces(); + double energy = 1; + Vec3 force(0, 0, 0); + if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) { + energy = 0.5*sin(0.25*x)*cos(0.33*y)+1; + force[0] = 0.5*(-0.25*cos(0.25*x)*cos(0.33*y)); + force[1] = 0.5*0.3*sin(0.25*x)*sin(0.33*y); + } + ASSERT_EQUAL_VEC(force, forces[0], 0.1); + ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05); + } + } } void testContinuous3DFunction() { diff --git a/tests/TestCustomGBForce.h b/tests/TestCustomGBForce.h index 8aed47b8bf..6198c151fd 100644 --- a/tests/TestCustomGBForce.h +++ b/tests/TestCustomGBForce.h @@ -7,7 +7,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2008-2016 Stanford University and the Authors. * + * Portions copyright (c) 2008-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -254,7 +254,7 @@ void testMembrane() { double norm = 0.0; for (int i = 0; i < (int) forces.size(); ++i) norm += forces[i].dot(forces[i]); - norm = std::sqrt(norm); + norm = sqrt(norm); const double stepSize = 1e-2; double step = 0.5*stepSize/norm; vector positions2(numParticles), positions3(numParticles); @@ -283,7 +283,7 @@ void testTabulatedFunction() { force->addParticle(vector()); vector table; for (int i = 0; i < 21; i++) - table.push_back(std::sin(0.25*i)); + table.push_back(sin(0.25*i)); force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0)); system.addForce(force); Context context(system, integrator, platform); @@ -296,8 +296,8 @@ void testTabulatedFunction() { context.setPositions(positions); State state = context.getState(State::Forces | State::Energy); const vector& forces = state.getForces(); - double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0)); - double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; + double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0)); + double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0; ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); @@ -308,7 +308,22 @@ void testTabulatedFunction() { positions[1] = Vec3(x, 0, 0); context.setPositions(positions); State state = context.getState(State::Energy); - double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; + double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0; + ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); + } + + // Try updating the tabulated function. + + for (int i = 0; i < table.size(); i++) + table[i] *= 0.5; + dynamic_cast(force->getTabulatedFunction(0)).setFunctionParameters(table, 1.0, 6.0); + force->updateParametersInContext(context); + for (int i = 1; i < 20; i++) { + double x = 0.25*i+1.0; + positions[1] = Vec3(x, 0, 0); + context.setPositions(positions); + State state = context.getState(State::Energy); + double energy = (x < 1.0 || x > 6.0 ? 0.0 : 0.5*sin(x-1.0))+1.0; ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); } } @@ -385,7 +400,7 @@ void testPositionDependence() { double norm = 0.0; for (int i = 0; i < (int) forces.size(); ++i) norm += forces[i].dot(forces[i]); - norm = std::sqrt(norm); + norm = sqrt(norm); const double stepSize = 1e-3; double step = 0.5*stepSize/norm; vector positions2(2), positions3(2); @@ -455,7 +470,7 @@ void testExclusions() { double norm = 0.0; for (int i = 0; i < (int) forces.size(); ++i) norm += forces[i].dot(forces[i]); - norm = std::sqrt(norm); + norm = sqrt(norm); if (norm > 0) { const double stepSize = 1e-3; double step = stepSize/norm; diff --git a/tests/TestCustomHbondForce.h b/tests/TestCustomHbondForce.h index e7491c91d1..7947fb78d9 100644 --- a/tests/TestCustomHbondForce.h +++ b/tests/TestCustomHbondForce.h @@ -223,6 +223,15 @@ void testCustomFunctions() { ASSERT_EQUAL_VEC(Vec3(0, -0.1, 0), forces[1], TOL); ASSERT_EQUAL_VEC(Vec3(-0.1, 0, 0), forces[2], TOL); ASSERT_EQUAL_TOL(0.1*2+0.1*2, state.getPotentialEnergy(), TOL); + + // Try updating the tabulated function. + + for (int i = 0; i < function.size(); i++) + function[i] *= 0.5; + dynamic_cast(custom->getTabulatedFunction(0)).setFunctionParameters(function, 0, 10); + custom->updateParametersInContext(context); + state = context.getState(State::Energy); + ASSERT_EQUAL_TOL(0.5*(0.1*2+0.1*2), state.getPotentialEnergy(), TOL); } void test2DFunction() { diff --git a/tests/TestCustomManyParticleForce.h b/tests/TestCustomManyParticleForce.h index 8186dbbf9d..f635a20a00 100644 --- a/tests/TestCustomManyParticleForce.h +++ b/tests/TestCustomManyParticleForce.h @@ -516,6 +516,15 @@ void testTabulatedFunctions() { expectedEnergy += 0.5*(r12+r13+r23)*(c[i]+c[j]+c[k]); } ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); + + // Try updating the tabulated function. + + for (int i = 0; i < values.size(); i++) + values[i] *= 0.5; + dynamic_cast(force->getTabulatedFunction(1)).setFunctionParameters(numParticles, numParticles, numParticles, values); + force->updateParametersInContext(context); + state = context.getState(State::Energy); + ASSERT_EQUAL_TOL(0.5*expectedEnergy, state.getPotentialEnergy(), 1e-5); } void testTypeFilters() { diff --git a/tests/TestCustomNonbondedForce.h b/tests/TestCustomNonbondedForce.h index 002813d43a..7169db1886 100644 --- a/tests/TestCustomNonbondedForce.h +++ b/tests/TestCustomNonbondedForce.h @@ -355,6 +355,21 @@ void testContinuous1DFunction() { double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0; ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); } + + // Try updating the tabulated function. + + for (int i = 0; i < table.size(); i++) + table[i] *= 0.5; + dynamic_cast(forceField->getTabulatedFunction(0)).setFunctionParameters(table, 1.0, 6.0); + forceField->updateParametersInContext(context); + for (int i = 1; i < 20; i++) { + double x = 0.25*i+1.0; + positions[1] = Vec3(x, 0, 0); + context.setPositions(positions); + State state = context.getState(State::Energy); + double energy = (x < 1.0 || x > 6.0 ? 0.0 : 0.5*sin(x-1.0))+1.0; + ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); + } } void testPeriodicContinuous1DFunction() { diff --git a/wrappers/generateWrappers.py b/wrappers/generateWrappers.py index 7c88b35e1a..864c279992 100644 --- a/wrappers/generateWrappers.py +++ b/wrappers/generateWrappers.py @@ -78,7 +78,21 @@ def __init__(self, inputDirname, output): 'std::vector OpenMM::NoseHooverChain::getYoshidaSuzukiWeights', 'const std::vector& OpenMM::NoseHooverIntegrator::getAllThermostatedIndividualParticles', 'const std::vector >& OpenMM::NoseHooverIntegrator::getAllThermostatedPairs', - 'virtual void OpenMM::NoseHooverIntegrator::stateChanged' + 'virtual void OpenMM::NoseHooverIntegrator::stateChanged', + 'virtual bool OpenMM::TabulatedFunction::operator==', + 'bool OpenMM::Continuous1DFunction::operator==', + 'bool OpenMM::Continuous2DFunction::operator==', + 'bool OpenMM::Continuous3DFunction::operator==', + 'bool OpenMM::Discrete1DFunction::operator==', + 'bool OpenMM::Discrete2DFunction::operator==', + 'bool OpenMM::Discrete3DFunction::operator==', + 'virtual bool OpenMM::TabulatedFunction::operator!=', + 'bool OpenMM::Continuous1DFunction::operator!=', + 'bool OpenMM::Continuous2DFunction::operator!=', + 'bool OpenMM::Continuous3DFunction::operator!=', + 'bool OpenMM::Discrete1DFunction::operator!=', + 'bool OpenMM::Discrete2DFunction::operator!=', + 'bool OpenMM::Discrete3DFunction::operator!=' ] self.hideClasses = ['Kernel', 'KernelImpl', 'KernelFactory', 'ContextImpl', 'SerializationNode', 'SerializationProxy'] self.nodeByID={}