diff --git a/pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/AttributeOptimizerBattery.java b/pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/AttributeOptimizerBattery.java index f5c6c50e..f306534f 100644 --- a/pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/AttributeOptimizerBattery.java +++ b/pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/AttributeOptimizerBattery.java @@ -24,7 +24,7 @@ public class AttributeOptimizerBattery extends VisitorBattery { public AttributeOptimizerBattery(){ add(ValueParser.class); + add(NodeScoreParser.class); add(TargetCategoryParser.class); - add(NodeScoreOptimizer.class); } } \ No newline at end of file diff --git a/pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/NodeScoreOptimizer.java b/pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/NodeScoreParser.java similarity index 56% rename from pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/NodeScoreOptimizer.java rename to pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/NodeScoreParser.java index 70490dae..6a86fead 100644 --- a/pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/NodeScoreOptimizer.java +++ b/pmml-evaluator/src/main/java/org/jpmml/evaluator/visitors/NodeScoreParser.java @@ -18,12 +18,18 @@ */ package org.jpmml.evaluator.visitors; +import org.dmg.pmml.DataType; import org.dmg.pmml.MathContext; import org.dmg.pmml.MiningFunction; +import org.dmg.pmml.PMMLObject; import org.dmg.pmml.VisitorAction; import org.dmg.pmml.tree.DecisionTree; import org.dmg.pmml.tree.Node; +import org.dmg.pmml.tree.PMMLAttributes; import org.dmg.pmml.tree.TreeModel; +import org.jpmml.evaluator.MissingAttributeException; +import org.jpmml.evaluator.TypeCheckException; +import org.jpmml.evaluator.TypeUtil; import org.jpmml.model.visitors.AbstractVisitor; import org.jpmml.model.visitors.Resettable; @@ -32,7 +38,7 @@ * A Visitor that pre-parses the score attribute of regression-type tree models. *

*/ -public class NodeScoreOptimizer extends AbstractVisitor implements Resettable { +public class NodeScoreParser extends AbstractVisitor implements Resettable { private MathContext mathContext = null; @@ -42,6 +48,28 @@ public void reset(){ this.mathContext = null; } + @Override + public void pushParent(PMMLObject parent){ + super.pushParent(parent); + + if(parent instanceof TreeModel){ + TreeModel treeModel = (TreeModel)parent; + + this.mathContext = treeModel.getMathContext(); + } + } + + @Override + public PMMLObject popParent(){ + PMMLObject parent = super.popParent(); + + if(parent instanceof TreeModel){ + this.mathContext = null; + } + + return parent; + } + @Override public VisitorAction visit(DecisionTree decisionTree){ throw new UnsupportedOperationException(); @@ -50,21 +78,15 @@ public VisitorAction visit(DecisionTree decisionTree){ @Override public VisitorAction visit(TreeModel treeModel){ MiningFunction miningFunction = treeModel.getMiningFunction(); + if(miningFunction == null){ + throw new MissingAttributeException(treeModel, PMMLAttributes.TREEMODEL_MININGFUNCTION); + } - if(miningFunction != null){ - - switch(miningFunction){ - case REGRESSION: - this.mathContext = treeModel.getMathContext(); - break; - default: - this.mathContext = null; - break; - } - } // End if - - if(this.mathContext == null){ - return VisitorAction.SKIP; + switch(miningFunction){ + case REGRESSION: + break; + default: + return VisitorAction.SKIP; } return super.visit(treeModel); @@ -72,31 +94,39 @@ public VisitorAction visit(TreeModel treeModel){ @Override public VisitorAction visit(Node node){ - MathContext mathContext = this.mathContext; - if(mathContext != null && node.hasScore()){ + if(node.hasScore()){ Object score = node.getScore(); if(score instanceof String){ - String stringScore = (String)score; - - try { - switch(mathContext){ - case DOUBLE: - node.setScore(Double.parseDouble(stringScore)); - break; - case FLOAT: - node.setScore(Float.parseFloat(stringScore)); - break; - default: - break; - } - } catch(NumberFormatException nfe){ - // Ignored - } + score = parseScore(score); + + node.setScore(score); } } return super.visit(node); } + + private Object parseScore(Object score){ + + if(score == null){ + return score; + } + + try { + switch(this.mathContext){ + case DOUBLE: + return TypeUtil.parseOrCast(DataType.DOUBLE, score); + case FLOAT: + return TypeUtil.parseOrCast(DataType.FLOAT, score); + default: + break; + } + } catch(IllegalArgumentException | TypeCheckException e){ + // Ignored + } + + return score; + } } \ No newline at end of file diff --git a/pmml-evaluator/src/test/java/org/jpmml/evaluator/visitors/NodeScoreOptimizerTest.java b/pmml-evaluator/src/test/java/org/jpmml/evaluator/visitors/NodeScoreParserTest.java similarity index 96% rename from pmml-evaluator/src/test/java/org/jpmml/evaluator/visitors/NodeScoreOptimizerTest.java rename to pmml-evaluator/src/test/java/org/jpmml/evaluator/visitors/NodeScoreParserTest.java index 29e25a30..94a24e6f 100644 --- a/pmml-evaluator/src/test/java/org/jpmml/evaluator/visitors/NodeScoreOptimizerTest.java +++ b/pmml-evaluator/src/test/java/org/jpmml/evaluator/visitors/NodeScoreParserTest.java @@ -34,7 +34,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; -public class NodeScoreOptimizerTest { +public class NodeScoreParserTest { @Test public void parseAndIntern(){ @@ -54,7 +54,7 @@ public void parseAndIntern(){ .setMathContext(MathContext.FLOAT); VisitorBattery visitorBattery = new VisitorBattery(); - visitorBattery.add(NodeScoreOptimizer.class); + visitorBattery.add(NodeScoreParser.class); visitorBattery.add(FloatInterner.class); visitorBattery.applyTo(treeModel);