Skip to content

Commit

Permalink
Optimized Context creation with complex CustomIntegrators (openmm#4191)
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman authored Aug 16, 2023
1 parent ac6133b commit 065e34a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 11 deletions.
4 changes: 2 additions & 2 deletions openmmapi/include/openmm/internal/CustomIntegratorUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2015-2016 Stanford University and the Authors. *
* Portions copyright (c) 2015-2023 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand Down Expand Up @@ -81,7 +81,7 @@ class OPENMM_EXPORT CustomIntegratorUtilities {
static bool usesVariable(const Lepton::ExpressionTreeNode& node, const std::string& variable);
static void enumeratePaths(int firstStep, std::vector<int> steps, std::vector<int> jumps, const std::vector<int>& blockEnd,
const std::vector<CustomIntegrator::ComputationType>& stepType, const std::vector<bool>& needsForces, const std::vector<bool>& needsEnergy,
const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth);
const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth, const std::vector<bool>& isSignificant);
static void analyzeForceComputationsForPath(std::vector<int>& steps, const std::vector<bool>& needsForces, const std::vector<bool>& needsEnergy,
const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth);
static void validateDerivatives(const Lepton::ExpressionTreeNode& node, const std::vector<std::string>& derivNames);
Expand Down
30 changes: 22 additions & 8 deletions openmmapi/src/CustomIntegratorUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2015-2019 Stanford University and the Authors. *
* Portions copyright (c) 2015-2023 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand Down Expand Up @@ -188,6 +188,20 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
if (blockStart.size() > 0)
throw OpenMMException("CustomIntegrator: Missing EndBlock");

// Identify whether each block contains any operation that either invalidates forces,
// or requires forces or energy. These are the ones that are significant for the
// analysis that follows.

vector<bool> isSignificant(numSteps, false);
for (int step = 0; step < numSteps; step++) {
if (stepType[step] == CustomIntegrator::IfBlockStart || stepType[step] == CustomIntegrator::WhileBlockStart)
for (int i = step; i < blockEnd[step]; i++)
if (needsForces[i] || needsEnergy[i] || invalidatesForces[i]) {
isSignificant[step] = true;
break;
}
}

// If a step requires either forces or energy, and a later step will require the other one, it's most efficient
// to compute both at the same time. Figure out whether we should do that. In principle it's easy: step through
// the sequence of computations and see if the other one is used before the next time they get invalidated.
Expand All @@ -211,7 +225,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
int numBlocks = blockEnd.size();
for (int i = 0; i < numBlocks; i++)
blockEnd.push_back(blockEnd[i]+numSteps);
enumeratePaths(0, stepsInPath, jumps, blockEnd, stepType, needsForces, needsEnergy, alwaysInvalidatesForces, forceGroup, computeBoth);
enumeratePaths(0, stepsInPath, jumps, blockEnd, stepType, needsForces, needsEnergy, alwaysInvalidatesForces, forceGroup, computeBoth, isSignificant);

// Make sure calls to deriv() all valid.

Expand All @@ -224,7 +238,7 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,

void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, vector<int> jumps, const vector<int>& blockEnd,
const vector<CustomIntegrator::ComputationType>& stepType, const vector<bool>& needsForces, const vector<bool>& needsEnergy,
const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth) {
const vector<bool>& invalidatesForces, const vector<int>& forceGroup, vector<bool>& computeBoth, const vector<bool>& isSignificant) {
int step = firstStep;
int numSteps = stepType.size();
while (step < 2*numSteps) {
Expand All @@ -237,23 +251,23 @@ void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps,
jumps[step] = -1;
step = nextStep;
}
else if (stepType[index] == CustomIntegrator::IfBlockStart) {
else if (stepType[index] == CustomIntegrator::IfBlockStart && isSignificant[index]) {
// Consider skipping the block.

enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth);
enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth, isSignificant);

// Continue on to execute the block.

step++;
}
else if (stepType[index] == CustomIntegrator::WhileBlockStart && jumps[step] != -2) {
else if (stepType[index] == CustomIntegrator::WhileBlockStart && jumps[step] != -2 && isSignificant[index]) {
// Consider skipping the block.

enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth);
enumeratePaths(blockEnd[step]+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth, isSignificant);

// Consider executing the block once.

enumeratePaths(step+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth);
enumeratePaths(step+1, steps, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth, isSignificant);

// Continue on to execute the block twice.

Expand Down
71 changes: 70 additions & 1 deletion tests/TestCustomIntegrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2020 Stanford University and the Authors. *
* Portions copyright (c) 2008-2023 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand All @@ -33,6 +33,7 @@
#define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/internal/CustomIntegratorUtilities.h"
#include "openmm/Context.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/CustomAngleForce.h"
Expand Down Expand Up @@ -1211,6 +1212,73 @@ void testSaveParameters() {
ASSERT_EQUAL_VEC(b1[0], b3[0], 1e-6);
}

void testAnalyzeComputations() {
System system;
system.addParticle(1.0);
CustomBondForce* bond = new CustomBondForce("scale*r");
bond->addGlobalParameter("scale", 2.0);
bond->setForceGroup(1);
system.addForce(bond);

// Create a complex integrator with lots of nested blocks and steps that use or invalidate
// forces or energies.

CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("color", 1.5);
integrator.addPerDofVariable("z", 0);
integrator.addComputeGlobal("color", "energy"); // 0
integrator.beginIfBlock("color > 1.0"); // 1
integrator.addComputeGlobal("scale", "energy0"); // 2
integrator.endBlock(); // 3
integrator.beginIfBlock("scale < color"); // 4
integrator.addComputePerDof("v", "x"); // 5
integrator.endBlock(); // 6
integrator.addComputePerDof("z", "f1"); // 7
integrator.beginWhileBlock("energy2 > 0"); // 8
integrator.beginIfBlock("color = 1"); // 9
integrator.addComputePerDof("v", "2*z"); // 10
integrator.endBlock(); // 11
integrator.beginIfBlock("color = 2"); // 12
integrator.addComputeGlobal("color", "color+1"); // 13
integrator.addUpdateContextState(); // 14
integrator.endBlock(); // 15
integrator.endBlock(); // 16
integrator.addComputePerDof("x", "x+f"); // 17

// Call analyzeComputations() and see if the results are what we expect.

Context context(system, integrator, platform);
ContextImpl* contextImpl = *reinterpret_cast<ContextImpl**>(&context);
vector<vector<Lepton::ParsedExpression> > expressions;
vector<CustomIntegratorUtilities::Comparison> comparisons;
vector<int> blockEnd, forceGroup;
vector<bool> invalidatesForces, needsForces, needsEnergy, computeBoth;
map<string, Lepton::CustomFunction*> functions;
CustomIntegratorUtilities::analyzeComputations(*contextImpl, integrator, expressions, comparisons, blockEnd, invalidatesForces,
needsForces, needsEnergy, computeBoth, forceGroup, functions);
ASSERT_EQUAL(3, blockEnd[1]);
ASSERT_EQUAL(6, blockEnd[4]);
ASSERT_EQUAL(16, blockEnd[8]);
ASSERT_EQUAL(11, blockEnd[9]);
ASSERT_EQUAL(15, blockEnd[12]);
for (int i = 0; i < integrator.getNumComputations(); i++) {
ASSERT_EQUAL(i == 2 || i == 14 || i == 17, invalidatesForces[i]);
ASSERT_EQUAL(i == 7 || i == 17, needsForces[i]);
ASSERT_EQUAL(i == 0 || i == 2 || i == 8, needsEnergy[i]);
ASSERT_EQUAL(i == 17, computeBoth[i]);
if (needsForces[i] || needsEnergy[i]) {
int group = -1;
if (i == 2)
group = 0;
else if (i == 7)
group = 1;
else if (i == 8)
group = 2;
ASSERT_EQUAL(group, forceGroup[i]);
}
}
}

void runPlatformTests();

int main(int argc, char* argv[]) {
Expand Down Expand Up @@ -1241,6 +1309,7 @@ int main(int argc, char* argv[]) {
testInitialTemperature();
testCheckpoint();
testSaveParameters();
testAnalyzeComputations();
runPlatformTests();
}
catch(const exception& e) {
Expand Down

0 comments on commit 065e34a

Please sign in to comment.