Skip to content
This repository has been archived by the owner on Apr 10, 2019. It is now read-only.

Commit

Permalink
Activated 'float' math context
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jul 29, 2017
1 parent 33f3d9b commit d6530ec
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 7 deletions.
4 changes: 3 additions & 1 deletion src/main/java/org/jpmml/tensorflow/DNNEstimator.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.common.primitives.Floats;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Entity;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
Expand All @@ -51,7 +52,8 @@ protected NeuralNetwork encodeNeuralNetwork(TensorFlowEncoder encoder){
SavedModel savedModel = getSavedModel();

NeuralNetwork neuralNetwork = new NeuralNetwork()
.setActivationFunction(NeuralNetwork.ActivationFunction.RECTIFIER);
.setActivationFunction(NeuralNetwork.ActivationFunction.RECTIFIER)
.setMathContext(MathContext.FLOAT);

List<NodeDef> biasAdds = Lists.newArrayList(savedModel.getInputs(getHead(), "BiasAdd"));

Expand Down
6 changes: 3 additions & 3 deletions src/main/java/org/jpmml/tensorflow/LinearClassifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ public RegressionModel encodeModel(TensorFlowEncoder encoder){
if(regressionTables.size() == 1){
categories = Arrays.asList("0", "1");

RegressionTable activeRegressionTable = regressionTables.get(0)
.setTargetCategory(categories.get(1));

RegressionTable passiveRegressionTable = new RegressionTable(0)
.setTargetCategory(categories.get(0));

regressionModel.addRegressionTables(passiveRegressionTable);

RegressionTable activeRegressionTable = regressionTables.get(0)
.setTargetCategory(categories.get(1));
} else

if(regressionTables.size() > 2){
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/org/jpmml/tensorflow/LinearEstimator.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;

import com.google.common.primitives.Floats;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CMatrixUtil;
Expand Down Expand Up @@ -143,7 +144,8 @@ public RegressionModel encodeRegressionModel(TensorFlowEncoder encoder){
}
}

RegressionModel regressionModel = new RegressionModel();
RegressionModel regressionModel = new RegressionModel()
.setMathContext(MathContext.FLOAT);

for(Equation equation : equations){
RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(equation.getFeatures(), equation.getCoefficients(), equation.getIntercept());
Expand Down
3 changes: 1 addition & 2 deletions src/test/java/org/jpmml/tensorflow/LinearRegressorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
*/
package org.jpmml.tensorflow;

import org.jpmml.evaluator.PMMLEquivalence;
import org.junit.Test;

public class LinearRegressorTest extends EstimatorTest {

public LinearRegressorTest(){
super(new PMMLEquivalence(1e-6, 1e-6));
super(new TensorFlowEquivalence(2));
}

@Test
Expand Down
47 changes: 47 additions & 0 deletions src/test/java/org/jpmml/tensorflow/TensorFlowEquivalence.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2017 Villu Ruusmann
*
* This file is part of JPMML-TensorFlow
*
* JPMML-TensorFlow is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-TensorFlow is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-TensorFlow. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.tensorflow;

import org.dmg.pmml.DataType;
import org.jpmml.evaluator.Computable;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.RealNumberEquivalence;
import org.jpmml.evaluator.TypeUtil;

public class TensorFlowEquivalence extends RealNumberEquivalence {

public TensorFlowEquivalence(int tolerance){
super(tolerance);
}

@Override
public boolean doEquivalent(Object expected, Object actual){

if(actual instanceof Computable){
actual = EvaluatorUtil.decode(actual);
} // End if

if(actual instanceof Number){
actual = TypeUtil.parseOrCast(DataType.FLOAT, actual);
expected = TypeUtil.parseOrCast(DataType.FLOAT, expected);
}

return super.doEquivalent(expected, actual);
}
}

0 comments on commit d6530ec

Please sign in to comment.