Skip to content

Commit

Permalink
Refactored and renamed NodeScoreOptimizer class to NodeScoreParser
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Mar 23, 2020
1 parent c7cb468 commit a5dde6f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class AttributeOptimizerBattery extends VisitorBattery {

public AttributeOptimizerBattery(){
add(ValueParser.class);
add(NodeScoreParser.class);
add(TargetCategoryParser.class);
add(NodeScoreOptimizer.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
*/
package org.jpmml.evaluator.visitors;

import org.dmg.pmml.DataType;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.tree.DecisionTree;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.PMMLAttributes;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.TypeCheckException;
import org.jpmml.evaluator.TypeUtil;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.model.visitors.Resettable;

Expand All @@ -32,7 +38,7 @@
* A Visitor that pre-parses the score attribute of regression-type tree models.
* </p>
*/
public class NodeScoreOptimizer extends AbstractVisitor implements Resettable {
public class NodeScoreParser extends AbstractVisitor implements Resettable {

private MathContext mathContext = null;

Expand All @@ -42,6 +48,28 @@ public void reset(){
this.mathContext = null;
}

@Override
public void pushParent(PMMLObject parent){
super.pushParent(parent);

if(parent instanceof TreeModel){
TreeModel treeModel = (TreeModel)parent;

this.mathContext = treeModel.getMathContext();
}
}

@Override
public PMMLObject popParent(){
PMMLObject parent = super.popParent();

if(parent instanceof TreeModel){
this.mathContext = null;
}

return parent;
}

@Override
public VisitorAction visit(DecisionTree decisionTree){
throw new UnsupportedOperationException();
Expand All @@ -50,53 +78,55 @@ public VisitorAction visit(DecisionTree decisionTree){
@Override
public VisitorAction visit(TreeModel treeModel){
MiningFunction miningFunction = treeModel.getMiningFunction();
if(miningFunction == null){
throw new MissingAttributeException(treeModel, PMMLAttributes.TREEMODEL_MININGFUNCTION);
}

if(miningFunction != null){

switch(miningFunction){
case REGRESSION:
this.mathContext = treeModel.getMathContext();
break;
default:
this.mathContext = null;
break;
}
} // End if

if(this.mathContext == null){
return VisitorAction.SKIP;
switch(miningFunction){
case REGRESSION:
break;
default:
return VisitorAction.SKIP;
}

return super.visit(treeModel);
}

@Override
public VisitorAction visit(Node node){
MathContext mathContext = this.mathContext;

if(mathContext != null && node.hasScore()){
if(node.hasScore()){
Object score = node.getScore();

if(score instanceof String){
String stringScore = (String)score;

try {
switch(mathContext){
case DOUBLE:
node.setScore(Double.parseDouble(stringScore));
break;
case FLOAT:
node.setScore(Float.parseFloat(stringScore));
break;
default:
break;
}
} catch(NumberFormatException nfe){
// Ignored
}
score = parseScore(score);

node.setScore(score);
}
}

return super.visit(node);
}

private Object parseScore(Object score){

if(score == null){
return score;
}

try {
switch(this.mathContext){
case DOUBLE:
return TypeUtil.parseOrCast(DataType.DOUBLE, score);
case FLOAT:
return TypeUtil.parseOrCast(DataType.FLOAT, score);
default:
break;
}
} catch(IllegalArgumentException | TypeCheckException e){
// Ignored
}

return score;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;

public class NodeScoreOptimizerTest {
public class NodeScoreParserTest {

@Test
public void parseAndIntern(){
Expand All @@ -54,7 +54,7 @@ public void parseAndIntern(){
.setMathContext(MathContext.FLOAT);

VisitorBattery visitorBattery = new VisitorBattery();
visitorBattery.add(NodeScoreOptimizer.class);
visitorBattery.add(NodeScoreParser.class);
visitorBattery.add(FloatInterner.class);

visitorBattery.applyTo(treeModel);
Expand Down

0 comments on commit a5dde6f

Please sign in to comment.