Skip to content

Commit

Permalink
Minor optimizations to custom forces
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed Sep 19, 2011
1 parent 0ee958f commit 84f5e00
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 65 deletions.
6 changes: 6 additions & 0 deletions libraries/lepton/src/ParsedExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
return ExpressionTreeNode(new Operation::Divide(), children[0], children[1].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::RECIPROCAL) // (1/a)*b = b/a
return ExpressionTreeNode(new Operation::Divide(), children[1], children[0].getChildren()[0]);
if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Square(), children[0]); // x*x = square(x)
if (children[0].getOperation().getId() == Operation::SQUARE && children[0].getChildren()[0] == children[1])
return ExpressionTreeNode(new Operation::Cube(), children[1]); // x*x*x = cube(x)
if (children[1].getOperation().getId() == Operation::SQUARE && children[1].getChildren()[0] == children[0])
return ExpressionTreeNode(new Operation::Cube(), children[0]); // x*x*x = cube(x)
break;
}
case Operation::DIVIDE:
Expand Down
38 changes: 24 additions & 14 deletions platforms/opencl/src/OpenCLExpressionUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Portions copyright (c) 2009-2011 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand Down Expand Up @@ -48,26 +48,33 @@ string OpenCLExpressionUtilities::intToString(int value) {

string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, prefix, functionParams);
}

string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
stringstream out;
vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps;
vector<pair<ExpressionTreeNode, string> > temps = variables;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, variables, functions, prefix, functionParams, allExpressions);
processExpression(out, iter->second.getRootNode(), temps, functions, prefix, functionParams, allExpressions);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
}

void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const map<string, string>& variables, const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams,
const vector<ParsedExpression>& allExpressions) {
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const vector<ParsedExpression>& allExpressions) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, variables, functions, prefix, functionParams, allExpressions);
processExpression(out, node.getChildren()[i], temps, functions, prefix, functionParams, allExpressions);
string name = prefix+intToString(temps.size());
bool hasRecordedNode = false;

Expand All @@ -77,13 +84,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
break;
case Operation::VARIABLE:
{
map<string, string>::const_iterator iter = variables.find(node.getOperation().getName());
if (iter == variables.end())
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
out << iter->second;
break;
}
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
case Operation::CUSTOM:
{
int i;
Expand Down Expand Up @@ -145,8 +146,17 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
break;
case Operation::DIVIDE:
out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
{
bool haveReciprocal = false;
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first.getOperation().getId() == Operation::RECIPROCAL && temps[i].first.getChildren()[0] == node.getChildren()[1]) {
haveReciprocal = true;
out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
}
if (!haveReciprocal)
out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
break;
}
case Operation::POWER:
out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
break;
Expand Down
19 changes: 16 additions & 3 deletions platforms/opencl/src/OpenCLExpressionUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Portions copyright (c) 2009-2011 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand Down Expand Up @@ -49,13 +49,26 @@ class OPENMM_EXPORT OpenCLExpressionUtilities {
* Generate the source code for calculating a set of expressions.
*
* @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions
* @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are
* variable names, and the values are the code to generate for them.
* @param functions defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
*/
static std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
/**
* Generate the source code for calculating a set of expressions.
*
* @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions.
* Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears.
* @param functions defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
*/
static std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
/**
* Calculate the spline coefficients for a tabulated function that appears in expressions.
*
Expand All @@ -76,7 +89,7 @@ class OPENMM_EXPORT OpenCLExpressionUtilities {
class FunctionPlaceholder;
private:
static void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, const std::map<std::string, std::string>& variables,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams,
const std::vector<Lepton::ParsedExpression>& allExpressions);
static std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
Expand Down
Loading

0 comments on commit 84f5e00

Please sign in to comment.