Skip to content

Commit

Permalink
Update dl4j to match new op executioner api change (remove scalar/tra…
Browse files Browse the repository at this point in the history
…nsform along dim)
  • Loading branch information
AlexDBlack committed Apr 4, 2016
1 parent c174ddd commit b8930ef
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ public Pair<INDArray,INDArray> sampleHiddenGivenVisible(INDArray v) {
break;
}
case SOFTMAX: {
h1Sample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", h1Mean), 0);
h1Sample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", h1Mean));
break;
}
case BINARY: {
Expand Down Expand Up @@ -294,7 +294,7 @@ public Pair<INDArray,INDArray> sampleVisibleGivenHidden(INDArray h) {
break;
}
case SOFTMAX: {
v1Sample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", v1Mean), 0);
v1Sample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", v1Mean));
break;
}
case BINARY: {
Expand Down Expand Up @@ -342,7 +342,7 @@ public INDArray propUp(INDArray v,boolean training) {
case BINARY:
return sigmoid(preSig);
case SOFTMAX:
return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", preSig), 0);
return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", preSig));
default:
throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear");
}
Expand Down Expand Up @@ -382,7 +382,7 @@ public INDArray propDown(INDArray h) {
case BINARY:
return sigmoid(vMean);
case SOFTMAX:
return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", vMean), 0);
return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", vMean));
default:
throw new IllegalStateException("Visible unit type should either be binary or gaussian");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,45 +657,6 @@ public List<INDArray> feedForward() {
return feedForward(false);
}


/**
* Compute input linear transformation (z)
* Compute activations (applies activation transformation to z)
*
* @return a pair of activations and corresponding derivatives
*/
public Pair<List<INDArray>,List<INDArray>> feedForwardActivationsAndDerivatives(boolean training) {
INDArray currInput = input;

List<INDArray> activations = new ArrayList<>();
List<INDArray> derivatives = new ArrayList<>();
activations.add(currInput);

for (int i = 0; i < layers.length; i++) {
currInput = zFromPrevLayer(i, currInput,training); // w*x+b for each layer
//special case: row wise softmax
if (layers[i].conf().getLayer().getActivationFunction().equals("softmax"))
activations.add(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", currInput.dup()), 1));
else
activations.add(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(layerWiseConfigurations.getConf(i).getLayer().getActivationFunction(), currInput)));
}

currInput = this.input;
for (int i = 0; i < layers.length; i++) {
currInput = zFromPrevLayer(i, currInput,training); // w*x+b for each layer
INDArray dup = currInput.dup();
//special case: row wise softmax
if (layers[i].conf().getLayer().getActivationFunction().equals("softmax"))
derivatives.add(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(layerWiseConfigurations.getConf(i).getLayer().getActivationFunction(), dup).derivative(), 1));
else
derivatives.add(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(layerWiseConfigurations.getConf(i).getLayer().getActivationFunction(), dup).derivative()));
}
// Duplicating last layer derivative to keep pair list equal
derivatives.add(derivatives.get(layers.length - 1));
return new Pair<>(activations, derivatives);
}


/**
* Compute activations from input to output of the output layer
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,6 @@ public void testGradientWithAsList(){
net2.gradient();
}


@Test
public void testFeedForwardActivationsAndDerivatives(){
MultiLayerNetwork network = new MultiLayerNetwork(getConf());
network.init();
DataSet data = new IrisDataSetIterator(1,150).next();
network.fit(data);
Pair result = network.feedForwardActivationsAndDerivatives(false);
List<INDArray> first = (List) result.getFirst();
List<INDArray> second = (List) result.getSecond();
assertEquals(first.size(), second.size());
}

/**
* This test intended only to test activateSelectedLayers method, it does not involves fully-working AutoEncoder.
*/
Expand Down

0 comments on commit b8930ef

Please sign in to comment.