Skip to content

Commit

Permalink
errorEquals for approximate neutral changes. v1 NormalForm
Browse files Browse the repository at this point in the history
  • Loading branch information
dvitel committed Jun 2, 2021
1 parent 0749849 commit 2565c5e
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 86 deletions.
5 changes: 5 additions & 0 deletions ecj/src/main/java/ec/domain/regression/Benchmarks.java
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,11 @@ public static boolean equals(double v1, double v2) {
return Math.abs(v1 - v2) < epsilon;
}

public static boolean errorEquals(double v1, double v2) {
double epsilon = 1.0 / (Benchmarks.precision / 100) / 2.0;
return Math.abs(v1 - v2) < epsilon;
}

public static double maxValue;
public static double minValue;
public static int precision;
Expand Down
10 changes: 6 additions & 4 deletions ecj/src/main/java/ec/domain/regression/meta/FactorMul.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,20 @@ public GPNode apply(EvolutionState state, int thread, Map<GPNode, GPNode> bindin
double factor2 = 1;
long integralPart = Math.round(num);
if (Benchmarks.equals(integralPart, num)) { //factoring int;
double factor = state.random[thread].nextDouble() * Math.abs(num) / 2 + 1;
double factor = state.random[thread].nextDouble() * Math.abs(num) / 2 + 1.1;
long factorLong = Math.round(factor);

while (true) { //iterate through all numbers from factor up to num and search for first num % cur == 0;
while (factorLong >= 1) { //iterate through all numbers from factor up to num and search for first num % cur == 0;
if (integralPart % factorLong == 0) break;
factorLong--; //eventually we hit 1;
}
factor1 = factorLong;
factor2 = integralPart / factorLong;
} else { //factoring float number
double factor =
Benchmarks.round(state.random[thread].nextDouble() * (Benchmarks.maxValue - Benchmarks.minValue) + Benchmarks.minValue);
double factor = 0;
while (Benchmarks.equals(factor, 0)) {
factor = Benchmarks.round(state.random[thread].nextDouble() * (Benchmarks.maxValue - Benchmarks.minValue) + Benchmarks.minValue);;
}
factor1 = factor;
factor2 = num / factor1;
}
Expand Down
53 changes: 20 additions & 33 deletions ecj/src/main/java/ec/domain/regression/strategy/NormalForm.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package ec.domain.regression.strategy;

import ec.domain.regression.ArithRuleProvider;
import ec.domain.regression.Num;
import ec.gp.transform.RuleProvider;
import ec.gp.transform.meta.Meta;
import ec.gp.transform.strategy.On;
import ec.gp.transform.strategy.ProxyStrategy;
import ec.gp.transform.strategy.Strategy;

Expand Down Expand Up @@ -29,41 +33,24 @@ public class NormalForm extends ProxyStrategy implements CommutUnifierMixin {
public NormalForm(RuleProvider rules) {
super(Strategy.All(
Strategy.Fixpoint( // sum(a_i * x^i)
// Strategy.FirstMatch(rules.r("k + n -> [k + n]")),
// Strategy.FirstMatch(rules.r("k * n -> [k * n]")),
Strategy.FirstMatch(rules.r("x * (y + z) -> x * y + x * z"))
// term: x * x * 7 * (1 + 2) * 9 -> Mul [ x, x, 7, Add [1, 2], 9]

// Strategy.FirstMatch(rules.r("x + x -> 2 * x")),
// Strategy.FirstMatch(rules.r("n * x + x -> eval(n + 1) * x")),
// Strategy.FirstMatch(rules.r("n * x + k * x -> eval(n + k) * x"))

// 3 * x * x * x + 2 * x * x * x -> 5 * x * x * x
),
Strategy.FirstMatch(rules.r("x -> flat(x)")),
Strategy.FirstMatch(rules.r("x -> sort(x)")),
Strategy.FirstMatch(rules.r("x -> groupVars(x)")),
Strategy.Fixpoint(
Strategy.FirstSeqMatch(rules.r("0 * x -> 0")),
Strategy.FirstSeqMatch(rules.r("k + n -> [k + n]")), //literal, seq (after sort)
Strategy.FirstSeqMatch(rules.r("k * n -> [k * n]")), //literal, seq (after sort)
Strategy.FirstSeqMatch(rules.r("x + x -> 2 * x")), //literal (after sort, before sort - it should have commut), seq
Strategy.FirstSeqMatch(rules.r("n * x + x -> eval(n + 1) * x")),
Strategy.FirstSeqMatch(rules.r("n * x + k * x -> eval(n + k) * x"))
),
// Strategy.On(Strategy.Depth(2,
// Strategy.All(
// Strategy.FirstMatch(rules.r("x -> flat(x)")),
Strategy.FirstMatch(rules.r("x -> sort(x)")),
Strategy.Fixpoint(
// Strategy.Any(
// Strategy.FirstMatch(rules.r("x * 0 -> 0")),
Strategy.FirstMatch(rules.r("k + n -> [k + n]")), //literal, seq (after sort)
Strategy.FirstMatch(rules.r("k * n -> [k * n]")), //literal, seq (after sort)
Strategy.FirstMatch(rules.r("x + x -> x * 2")), //literal (after sort, before sort - it should have commut), seq
Strategy.FirstMatch(rules.r("x + x * n -> x * eval(n + 1)")),
Strategy.FirstMatch(rules.r("x * n + x * k -> x * eval(n + k)"))
// )
)
// Strategy.FirstMatch(rules.r("x -> unflat(x)"))
// )
// ))




//


Strategy.FirstMatch(rules.r("x -> unflat(x)")),
Strategy.AnyMatch(2, rules.r("k -> fadd(k)")),
On(On.MissingERC(Num::hasNoZero, "z", Strategy.FirstSeqMatch(rules.r("x -> x + 0"))))
// Strategy.FirstMatch(rules.r("k -> fmul(k)"))
// ),
));
}

Expand Down
28 changes: 19 additions & 9 deletions ecj/src/main/java/ec/gp/transform/CommutUnifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,23 @@ protected List<List<GPNode>> shuffle(List<List<GPNode>> permutations, MersenneTw
}

protected boolean unify(GPNode left, GPNode right, Map<GPNode, GPNode> bindings, MersenneTwisterFast rand, boolean exact) {
if (nodesEquivalent(left, right, true)) {
if (!exact && commutativeNodes != null && commutativeNodes.contains(left.getClass())) {
if (nodesEquivalent(left, right, true)) {
{
boolean unified = true;
Map<GPNode, GPNode> localBinds = new HashMap<GPNode, GPNode>(bindings);
for (int i = 0; i < left.children.length; i++) {
if (!unify(left.children[i], right.children[i], localBinds, rand, exact)) {
unified = false;
break;
}
}
if (unified) {
bindings.putAll(localBinds);
return true;
}
}
if (!exact && commutativeNodes != null && commutativeNodes.contains(left.getClass())
&& (right.children.length <= 5)) {
for (List<GPNode> perm: shuffle(permute(new ArrayList<>(List.of(right.children))), rand)) {
boolean continueOuter = false;
Map<GPNode, GPNode> localBinds = new HashMap<GPNode, GPNode>(bindings);
Expand All @@ -70,13 +85,8 @@ protected boolean unify(GPNode left, GPNode right, Map<GPNode, GPNode> bindings,
bindings.putAll(localBinds);
return true;
}
} else {
for (int i = 0; i < left.children.length; i++) {
if (!unify(left.children[i], right.children[i], bindings, rand, exact))
return false;
}
return true;
}
}
return false;
}
else if (left.getClass().equals(Var.class) && ((Var)left).matches(right)) {
//NOTE: we DO NOT implement full unification here - important
Expand Down
2 changes: 1 addition & 1 deletion ecj/src/main/java/ec/gp/transform/NeutralMutator.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public int produce(int min, int max, int subpopulation, ArrayList<Individual> in
// problem.evaluate(state, i, subpopulation, thread);

KozaFitness fitessAfter = (KozaFitness)i.fitness;
if (!Benchmarks.equals(fitnessBefore.standardizedFitness(), fitessAfter.standardizedFitness())) {
if (!Benchmarks.errorEquals(fitnessBefore.standardizedFitness(), fitessAfter.standardizedFitness())) {

state.output.message("\n---------------------");
for (TransformApplication a: res.appliedTransform) {
Expand Down
2 changes: 1 addition & 1 deletion ecj/src/main/java/ec/gp/transform/meta/Unflatten.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ else if (node.children.length < chNum) {
newChildren[i] = node.children[i];
unflatten(newChildren[i], emptyNodeBuilder);
}
for (int i = 0; i < chNum; i++) {
for (int i = node.children.length; i < chNum; i++) {
newChildren[i] = emptyNodeBuilder.apply(node.getClass());
}
node.children = newChildren;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public boolean getLastTestResult() {
}

public ConditionalStrategy(Condition condition, Strategy child) {
super(child);
this.child = child;
this.condition = condition;
}
Expand Down
2 changes: 1 addition & 1 deletion ecj/src/main/java/ec/gp/transform/strategy/On.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public static CaseBuilder WithBreak(CaseBuilder caseBuilder) {
public void apply(EvolutionState state, int thread, StrategyResult res) throws ReplaceFailed {
TreeStats stats = TreeStats.collectFor(res.gpNode);
for (CaseBuilder oneCase: cases) {
Strategy strategy = oneCase.build(stats);
Strategy strategy = oneCase.build(stats).withUnifier(unifier).withNodeBuilder(nodeBuilder);
strategy.apply(state, thread, res);
if ((oneCase instanceof WithBreakCaseBuilder) && (strategy instanceof ConditionalStrategy) && ((ConditionalStrategy)strategy).getLastTestResult()) {
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ public StrategyCollection(Strategy... strategies) {
public Strategy withStats(StrategyStats newStats) {
super.withStats(newStats);
Arrays.stream(strategies).forEach(s -> {
if (s.stats == null) s.withStats(newStats);
else s.withStats(new StrategyStats() {
@Override
void add(String transform) {
s.stats.add(transform);
newStats.add(transform);
}
});
// if (s.stats == null)
s.withStats(newStats);
// else s.withStats(new StrategyStats() {
// @Override
// void add(String transform) {
// s.stats.add(transform);
// newStats.add(transform);
// }
// });
});
return this;
}
Expand Down
85 changes: 56 additions & 29 deletions ecj/src/test/java/ec/domain/regression/MatchTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ec.gp.transform.NaiveUnifier;
import ec.gp.transform.Transform;
import ec.gp.transform.Unifier;
import ec.gp.transform.meta.DefaultNodeBuilder;
import ec.gp.transform.meta.Meta;
import ec.gp.transform.meta.NodeBuilder;
import ec.gp.transform.strategy.ReplaceFailed;
Expand Down Expand Up @@ -55,7 +56,7 @@ public void setUp()
state.output.setThrowsErrors(true);
state.random = new MersenneTwisterFast[] { new MersenneTwisterFast() };
Benchmarks.precision = 10000;
nodeBuilder = new NodeBuilder();
nodeBuilder = new DefaultNodeBuilder();
rules = new ArithRuleProvider();
rules.withNodeBuilder(nodeBuilder);
}
Expand Down Expand Up @@ -711,41 +712,67 @@ public void CheckListUnifier() throws ReplaceFailed {

@Test
public void CheckGroupVars() throws ReplaceFailed {
GPNode term =
nodeBuilder.create(Add.class,
new Num(2),
nodeBuilder.create(Mul.class,
new Var("x"),
new Var("x"),
new Num(3.3)
),
nodeBuilder.create(Mul.class,
new Var("x"),
new Num(3)
),
nodeBuilder.create(Mul.class,
new Num(2),
new Var("x"),
new Var("x"),
new Num(2)
),
new Var("x"),
new Num(3),
nodeBuilder.create(Mul.class,
new Num(2),
new Var("x"),
GPNode term = //( 0.5097 + x) * ( x + 4.339)
nodeBuilder.create(Mul.class,
nodeBuilder.create(Add.class,
new Num(0.5),
new Var("x")
),
new Num(1.2)
);
),
nodeBuilder.create(Add.class,
new Var("x"),
new Num(4.3)
)
);
// nodeBuilder.create(Add.class,
// new Num(2),
// nodeBuilder.create(Mul.class,
// new Var("x"),
// new Var("x"),
// new Num(3.3)
// ),
// nodeBuilder.create(Mul.class,
// new Var("x"),
// new Num(3)
// ),
// nodeBuilder.create(Mul.class,
// new Num(2),
// new Var("x"),
// new Var("x"),
// new Num(2)
// ),
// new Var("x"),
// new Num(3),
// nodeBuilder.create(Mul.class,
// new Num(2),
// new Var("x"),
// new Var("x")
// ),
// new Num(1.2)
// );

term.printRootedTreeForHumans(state, Output.ALL_MESSAGE_LOGS, 0, 0);
state.output.message(" ");

Strategy group =
Strategy.FirstMatch(rules.r("x -> groupVars(x)"))
Strategy.All(
Strategy.Fixpoint( // sum(a_i * x^i)
Strategy.FirstMatch(rules.r("x * (y + z) -> x * y + x * z"))
),
Strategy.FirstMatch(rules.r("x -> flat(x)")),
Strategy.FirstMatch(rules.r("x -> sort(x)")),
Strategy.FirstMatch(rules.r("x -> groupVars(x)")),
Strategy.Fixpoint(
Strategy.FirstSeqMatch(rules.r("k + n -> [k + n]")), //literal, seq (after sort)
Strategy.FirstSeqMatch(rules.r("k * n -> [k * n]")), //literal, seq (after sort)
Strategy.FirstSeqMatch(rules.r("x + x -> 2 * x")), //literal (after sort, before sort - it should have commut), seq
Strategy.FirstSeqMatch(rules.r("n * x + x -> eval(n + 1) * x")),
Strategy.FirstSeqMatch(rules.r("n * x + k * x -> eval(n + k) * x"))
)
)
.withNodeBuilder(nodeBuilder)
.withUnifier(new CommutUnifier());
.withUnifier(new CommutUnifier() {
{ commutativeNodes = Set.of(Add.class, Mul.class); }
});

StrategyResult res = new StrategyResult(term);

Expand Down

0 comments on commit 2565c5e

Please sign in to comment.