Skip to content

Commit

Permalink
Fix forward pass issue with MLN.score(DataSet)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Mar 16, 2016
1 parent 56f9a4d commit e9d7a0a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1629,10 +1629,11 @@ public double score(DataSet data, boolean training){
boolean hasMaskArray = data.hasMaskArrays();
if(hasMaskArray) setLayerMaskArrays(data.getFeaturesMaskArray(),data.getLabelsMaskArray());
// activation for output layer is calculated in computeScore
feedForwardToLayer(layers.length - 2, data.getFeatureMatrix(),training);
List<INDArray> activations = feedForwardToLayer(layers.length - 2, data.getFeatureMatrix(),training);
setLabels(data.getLabels());
if( getOutputLayer() instanceof BaseOutputLayer ){
BaseOutputLayer<?> ol = (BaseOutputLayer<?>)getOutputLayer();
ol.setInput(activations.get(activations.size()-1)); //Feedforward doesn't include output layer for efficiency
ol.setLabels(data.getLabels());
ol.computeScore(calcL1(),calcL2(), training);
this.score = ol.score();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,30 @@ public void testScoreExamples(){
}
}

@Test
public void testDataSetScore(){

Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.regularization(false)
.learningRate(1.0)
.weightInit(WeightInit.XAVIER)
.seed(12345L)
.list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation("sigmoid").build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax").nIn(3).nOut(3).build())
.pretrain(false).backprop(true)
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

INDArray in = Nd4j.create(new double[]{1.0,2.0,3.0,4.0});
INDArray out = Nd4j.create(new double[]{1,0,0});

double score = net.score(new DataSet(in,out));
}

@Test
@Ignore
public void testCid() throws Exception {
Expand Down

0 comments on commit e9d7a0a

Please sign in to comment.