Skip to content

Commit

Permalink
min(), max(), and abs() can appear in expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed Apr 16, 2010
1 parent 5bca8d1 commit a5e272e
Show file tree
Hide file tree
Showing 17 changed files with 163 additions and 21 deletions.
71 changes: 70 additions & 1 deletion libraries/lepton/include/lepton/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class LEPTON_EXPORT Operation {
*/
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, ERF, ERFC, STEP, SQUARE, CUBE, RECIPROCAL,
ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT};
ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT, MIN, MAX, ABS};
/**
* Get the name of this Operation.
*/
Expand Down Expand Up @@ -148,6 +148,9 @@ class LEPTON_EXPORT Operation {
class AddConstant;
class MultiplyConstant;
class PowerConstant;
class Min;
class Max;
class Abs;
};

class Operation::Constant : public Operation {
Expand Down Expand Up @@ -972,6 +975,72 @@ class Operation::PowerConstant : public Operation {
double value;
};

class Operation::Min : public Operation {
public:
Min() {
}
std::string getName() const {
return "min";
}
Id getId() const {
return MIN;
}
int getNumArguments() const {
return 2;
}
Operation* clone() const {
return new Min();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::min(args[0], args[1]);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};

class Operation::Max : public Operation {
public:
Max() {
}
std::string getName() const {
return "max";
}
Id getId() const {
return MAX;
}
int getNumArguments() const {
return 2;
}
Operation* clone() const {
return new Max();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::max(args[0], args[1]);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};

class Operation::Abs : public Operation {
public:
Abs() {
}
std::string getName() const {
return "abs";
}
Id getId() const {
return ABS;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Abs();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::abs(args[0]);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};

} // namespace Lepton

#endif /*LEPTON_OPERATION_H_*/
36 changes: 31 additions & 5 deletions libraries/lepton/src/Operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,35 @@ ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector<
}

ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(value),
ExpressionTreeNode(new Operation::PowerConstant(value-1),
children[0])),
childDerivs[0]));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(value),
ExpressionTreeNode(new Operation::PowerConstant(value-1),
children[0])),
childDerivs[0]);
}

ExpressionTreeNode Operation::Min::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
ExpressionTreeNode step(new Operation::Step(),
ExpressionTreeNode(new Operation::Subtract(), children[0], children[1]));
return ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], step),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[0],
ExpressionTreeNode(new Operation::AddConstant(-1), step)));
}

ExpressionTreeNode Operation::Max::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
ExpressionTreeNode step(new Operation::Step(),
ExpressionTreeNode(new Operation::Subtract(), children[0], children[1]));
return ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], step),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[1],
ExpressionTreeNode(new Operation::AddConstant(-1), step)));
}

ExpressionTreeNode Operation::Abs::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
ExpressionTreeNode step(new Operation::Step(), children[0]);
return ExpressionTreeNode(new Operation::Multiply(),
childDerivs[0],
ExpressionTreeNode(new Operation::AddConstant(-1),
ExpressionTreeNode(new Operation::MultiplyConstant(2), step)));
}
9 changes: 9 additions & 0 deletions libraries/lepton/src/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
opMap["square"] = Operation::SQUARE;
opMap["cube"] = Operation::CUBE;
opMap["recip"] = Operation::RECIPROCAL;
opMap["min"] = Operation::MIN;
opMap["max"] = Operation::MAX;
opMap["abs"] = Operation::ABS;
}
string trimmed = name.substr(0, name.size()-1);

Expand Down Expand Up @@ -373,6 +376,12 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
return new Operation::Cube();
case Operation::RECIPROCAL:
return new Operation::Reciprocal();
case Operation::MIN:
return new Operation::Min();
case Operation::MAX:
return new Operation::Max();
case Operation::ABS:
return new Operation::Abs();
default:
throw Exception("Parse error: unknown function");
}
Expand Down
2 changes: 1 addition & 1 deletion openmmapi/include/openmm/CustomAngleForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, step. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise.
*/

Expand Down
2 changes: 1 addition & 1 deletion openmmapi/include/openmm/CustomBondForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, step. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise.
*/

Expand Down
2 changes: 1 addition & 1 deletion openmmapi/include/openmm/CustomExternalForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, step. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise.
*/

Expand Down
6 changes: 3 additions & 3 deletions openmmapi/include/openmm/CustomGBForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ namespace OpenMM {
* custom->addComputedValue("I", "step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C);"
* "U=r+sr2;"
* "C=2*(1/or1-1/L)*step(sr2-r-or1);"
* "L=step(or1-D)*or1+(1-step(or1-D))*D;"
* "D=step(r-sr2)*(r-sr2)+(1-step(r-sr2))*(sr2-r);"
* "L=max(or1, D);"
* "D=abs(r-sr2);"
* "sr2 = scale2*or2;"
* "or1 = radius1-0.009; or2 = radius2-0.009", CustomGBForce::ParticlePairNoExclusions);
* custom->addComputedValue("B", "1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius);"
Expand Down Expand Up @@ -127,7 +127,7 @@ namespace OpenMM {
* particular piece of the computation.
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, step. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. In expressions for
* particle pair calculations, the names of per-particle parameters and computed values
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
Expand Down
2 changes: 1 addition & 1 deletion openmmapi/include/openmm/CustomHbondForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, step. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise.
*
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of
Expand Down
2 changes: 1 addition & 1 deletion openmmapi/include/openmm/CustomNonbondedForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, step. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. The names of per-particle parameters
* have the suffix "1" or "2" appended to them to indicate the values for the two interacting particles. As seen in the above example,
* the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
Expand Down
2 changes: 1 addition & 1 deletion openmmapi/include/openmm/CustomTorsionForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace OpenMM {
* </pre></tt>
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, step. All trigonometric functions
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, step. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise.
*/

Expand Down
3 changes: 2 additions & 1 deletion platforms/cuda/src/kernels/cudatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ enum CudaNonbondedMethod

enum ExpressionOp {
VARIABLE0 = 0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, VARIABLE8, MULTIPLY, DIVIDE, ADD, SUBTRACT, POWER, MULTIPLY_CONSTANT, POWER_CONSTANT, ADD_CONSTANT,
GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, STEP, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, ERF, ERFC
GLOBAL, CONSTANT, CUSTOM, CUSTOM_DERIV, NEGATE, RECIPROCAL, SQRT, EXP, LOG, SQUARE, CUBE, STEP, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SINH, COSH, TANH, ERF, ERFC,
MIN, MAX, ABS
};

template<int SIZE>
Expand Down
9 changes: 9 additions & 0 deletions platforms/cuda/src/kernels/gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,15 @@ static Expression<SIZE> createExpression(gpuContext gpu, const string& expressio
exp.op[i] = POWER_CONSTANT;
exp.arg[i] = (float) dynamic_cast<const Operation::PowerConstant*>(&op)->getValue();
break;
case Operation::MIN:
exp.op[i] = MIN;
break;
case Operation::MAX:
exp.op[i] = MAX;
break;
case Operation::ABS:
exp.op[i] = ABS;
break;
}
}
return exp;
Expand Down
13 changes: 12 additions & 1 deletion platforms/cuda/src/kernels/kEvaluateExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,20 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
else if (op == ERF) {
STACK(stackPointer) = erf(STACK(stackPointer));
}
else /*if (op == ERFC)*/ {
else if (op == ERFC) {
STACK(stackPointer) = erfc(STACK(stackPointer));
}
else if (op == MIN) {
float temp = STACK(stackPointer);
STACK(stackPointer) = min(temp, STACK(--stackPointer));
}
else if (op == MAX) {
float temp = STACK(stackPointer);
STACK(stackPointer) = max(temp, STACK(--stackPointer));
}
else /*if (op == ABS)*/ {
STACK(stackPointer) = fabs(STACK(stackPointer));
}
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions platforms/opencl/src/OpenCLExpressionUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,15 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << doubleToString(exponent) << ")";
break;
}
case Operation::MIN:
out << "min(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::MAX:
out << "max(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::ABS:
out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")";
break;
default:
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
}
Expand Down
4 changes: 2 additions & 2 deletions platforms/opencl/tests/TestOpenCLCustomGBForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ void testOBC(GBSAOBCForce::NonbondedMethod obcMethod, CustomGBForce::NonbondedMe
custom->addComputedValue("I", "step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C);"
"U=r+sr2;"
"C=2*(1/or1-1/L)*step(sr2-r-or1);"
"L=step(or1-D)*or1+(1-step(or1-D))*D;"
"D=step(r-sr2)*(r-sr2)+(1-step(r-sr2))*(sr2-r);"
"L=max(or1, D);"
"D=abs(r-sr2);"
"sr2 = scale2*or2;"
"or1 = radius1-0.009; or2 = radius2-0.009", CustomGBForce::ParticlePairNoExclusions);
custom->addComputedValue("B", "1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius);"
Expand Down
4 changes: 2 additions & 2 deletions platforms/reference/tests/TestReferenceCustomGBForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ void testOBC(GBSAOBCForce::NonbondedMethod obcMethod, CustomGBForce::NonbondedMe
custom->addComputedValue("I", "step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C);"
"U=r+sr2;"
"C=2*(1/or1-1/L)*step(sr2-r-or1);"
"L=step(or1-D)*or1+(1-step(or1-D))*D;"
"D=step(r-sr2)*(r-sr2)+(1-step(r-sr2))*(sr2-r);"
"L=max(or1, D);"
"D=abs(r-sr2);"
"sr2 = scale2*or2;"
"or1 = radius1-0.009; or2 = radius2-0.009", CustomGBForce::ParticlePairNoExclusions);
custom->addComputedValue("B", "1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius);"
Expand Down
8 changes: 8 additions & 0 deletions tests/TestParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ int main() {
verifyEvaluation("x*w; w = 5", 3.0, 1.0, 15.0);
verifyEvaluation("a+b^2;a=x-b;b=3*y", 2.0, 3.0, 74.0);
verifyEvaluation("erf(x)+erfc(x)", 2.0, 3.0, 1.0);
verifyEvaluation("min(3, x)", 2.0, 3.0, 2.0);
verifyEvaluation("min(y, 5)", 2.0, 3.0, 3.0);
verifyEvaluation("max(x, y)", 2.0, 3.0, 3.0);
verifyEvaluation("max(x, -1)", 2.0, 3.0, 2.0);
verifyEvaluation("abs(x-y)", 2.0, 3.0, 1.0);
verifyInvalidExpression("1..2");
verifyInvalidExpression("1*(2+3");
verifyInvalidExpression("5++4");
Expand Down Expand Up @@ -229,6 +234,9 @@ int main() {
verifyDerivative("recip(x)", "-1/x^2");
verifyDerivative("square(x)", "2*x");
verifyDerivative("cube(x)", "3*x^2");
verifyDerivative("min(x, 2*x)", "step(x-2*x)*2+(1-step(x-2*x))*1");
verifyDerivative("max(5, x^2)", "(1-step(5-x^2))*2*x");
verifyDerivative("abs(3*x)", "step(3*x)*3+(1-step(3*x))*-3");
testCustomFunction("custom(x, y)/2", "x*y");
testCustomFunction("custom(x^2, 1)+custom(2, y-1)", "2*x^2+4*(y-1)");
cout << Parser::parse("2*3*x").optimize() << endl;
Expand Down

0 comments on commit a5e272e

Please sign in to comment.