Skip to content

Commit

Permalink
NN progress on Dropout & pre-training
Browse files Browse the repository at this point in the history
  • Loading branch information
cypof committed Nov 9, 2013
1 parent 5613569 commit 535ae92
Show file tree
Hide file tree
Showing 8 changed files with 352 additions and 272 deletions.
24 changes: 4 additions & 20 deletions experiments/src/main/java/hex/Histogram.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package hex;

import hex.Histogram;
import hex.Layer;

import java.util.ArrayList;
Expand Down Expand Up @@ -58,26 +57,11 @@ public static void build(final Layer[] ls) {
VBox v = new VBox();
for( int i = ls.length - 1; i > 0; i-- ) {
HBox h = new HBox();
h.getChildren().add(new Histogram("Layer " + i + " A", ls[i]._a));
h.getChildren().add(new Histogram("E", ls[i]._e));
v.getChildren().add(h);

h = new HBox();
h.getChildren().add(new Histogram("Layer " + i + " W", ls[i]._w));
h.getChildren().add(new Histogram("B", ls[i]._b));
h.getChildren().add(new Histogram("A", ls[i]._a));
h.getChildren().add(new Histogram("E", ls[i]._e));
v.getChildren().add(h);

h = new HBox();
h.getChildren().add(new Histogram("Layer " + i + " W S", ls[i]._wSpeed));
h.getChildren().add(new Histogram("W B", ls[i]._bSpeed));
v.getChildren().add(h);

if( ls[i]._v != null ) {
h = new HBox();
h.getChildren().add(new Histogram("Layer " + i + " V", ls[i]._v));
h.getChildren().add(new Histogram("Gradient " + i + " V", ls[i]._gv));
v.getChildren().add(h);
}
}
Stage stage = new Stage();
BorderPane root = new BorderPane();
Expand Down Expand Up @@ -105,8 +89,8 @@ public void changed(ObservableValue<? extends Boolean> ov, Boolean old_val, Bool
root.setCenter(scroll);
Scene scene = new Scene(root);
stage.setScene(scene);
stage.setWidth(1500);
stage.setHeight(1100);
stage.setWidth(2450);
stage.setHeight(1500);
stage.show();

scene.getWindow().onCloseRequestProperty().addListener(new ChangeListener() {
Expand Down
2 changes: 1 addition & 1 deletion h2o-samples/src/main/java/samples/MapReduce.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static void main(String[] args) throws Exception {
samples.launchers.CloudLocal.launch(1, MapReduce.class);
// samples.launchers.CloudProcess.launch(2, MapReduce.class);
// samples.launchers.CloudConnect.launch("localhost:54321", MapReduce.class);
// samples.launchers.CloudRemote.launchIPs(MapReduce.class);
// samples.launchers.CloudRemote.launchDefaultIPs(MapReduce.class);
}

@Override protected void exec() {
Expand Down
9 changes: 4 additions & 5 deletions h2o-samples/src/main/java/samples/NeuralNetMnist.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class NeuralNetMnist extends Job {
public static void main(String[] args) throws Exception {
samples.launchers.CloudLocal.launch(1, NeuralNetMnist.class);
// samples.launchers.CloudProcess.launch(4, NeuralNetMnist.class);
// samples.launchers.CloudRemote.launchIPs(NeuralNetMnist.class);
// samples.launchers.CloudRemote.launchDefaultIPs(NeuralNetMnist.class);
// samples.launchers.CloudConnect.launch("localhost:54321", NeuralNetMnist.class);
}

Expand All @@ -40,11 +40,10 @@ protected Layer[] build(Vec[] data, Vec labels, VecsInput inputStats, VecSoftmax
ls[0] = new VecsInput(data, inputStats);
ls[1] = new Tanh(500);
ls[2] = new VecSoftmax(labels, outputStats);
ls[1].rate = .05f;
ls[2].rate = .02f;
for( int i = 0; i < ls.length; i++ ) {
ls[i].l2 = .0001f;
ls[i].rate_annealing = 1 / 2e6f;
ls[i].rate = .005f;
ls[i].rate_annealing = 1 / 1e6f;
ls[i].l2 = .001f;
ls[i].init(ls, i);
}
return ls;
Expand Down
33 changes: 18 additions & 15 deletions h2o-samples/src/main/java/samples/NeuralNetMnistDeep.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,27 @@
*/
public class NeuralNetMnistDeep extends NeuralNetMnist {
public static void main(String[] args) throws Exception {
samples.launchers.CloudLocal.launch(1, NeuralNetMnistDeep.class);
//samples.launchers.CloudLocal.launch(1, NeuralNetMnistDeep.class);
// samples.launchers.CloudProcess.launch(4, NeuralNetMnistDeep.class);
// samples.launchers.CloudConnect.launch("localhost:54321", NeuralNetMnistDeep.class);
// samples.launchers.CloudRemote.launchIPs(NeuralNetMnistDeep.class);
//samples.launchers.CloudRemote.launchIPs(NeuralNetMnistDeep.class, "192.168.1.163");
//samples.launchers.CloudRemote.launchIPs(NeuralNetMnistDeep.class, "192.168.1.162");
samples.launchers.CloudRemote.launchDefaultIPs(NeuralNetMnistDeep.class);
}

@Override protected Layer[] build(Vec[] data, Vec labels, VecsInput inputStats, VecSoftmax outputStats) {
Layer[] ls = new Layer[5];
ls[0] = new VecsInput(data, inputStats);
for( int i = 1; i < ls.length - 1; i++ ) {
ls[i] = new Tanh(500);
ls[i].rate = .05f;
}
for( int i = 1; i < ls.length - 1; i++ )
// TODO Work on Rectifier
//ls[i] = new TanhDropout(1000);
ls[i] = new Tanh(1000);
ls[ls.length - 1] = new VecSoftmax(labels, outputStats);
ls[ls.length - 1].rate = .02f;
for( int i = 0; i < ls.length; i++ ) {
ls[i].l2 = .0001f;
ls[i].rate_annealing = 1 / 1e5f;
ls[i].rate = .005f;
ls[i].rate_annealing = 1 / 1e6f;
ls[i].l2 = .001f;
//ls[i].dropout = .5f;
ls[i].init(ls, i);
}
return ls;
Expand Down Expand Up @@ -60,24 +63,24 @@ protected void preTrain(Layer[] ls, int index) {
for( int i = 1; i < index; i++ ) {
pre[i] = new Tanh(ls[i].units);
pre[i].rate = 0;
pre[i].l2 = .0001f;
pre[i].l2 = .01f;
Layer.shareWeights(ls[i], pre[i]);
}
// Auto-encoder is a tanh and a reverse tanh on top
// Auto-encoder is a layer and a reverse layer on top
pre[index] = new Tanh(ls[index].units);
pre[index].rate = .01f;
pre[index].l2 = .0001f;
pre[index].rate = .001f;
pre[index].l2 = 1f;
pre[index + 1] = new TanhPrime(ls[index - 1].units);
pre[index + 1].rate = .001f;
pre[index + 1].l2 = .0001f;
pre[index + 1].l2 = 1f;
Layer.shareWeights(ls[index], pre[index]);
Layer.shareWeights(ls[index], pre[index + 1]);

for( int i = 0; i < pre.length; i++ )
pre[i].init(pre, i, false, 0);

Trainer.Direct trainer = new Trainer.Direct(pre, this);
trainer.samples = 10000;
trainer.samples = 1000;
trainer.run();
}
}
21 changes: 14 additions & 7 deletions h2o-samples/src/main/java/samples/launchers/CloudRemote.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package samples.launchers;

import java.util.Arrays;

import water.Job;
import water.deploy.*;
import water.util.Log;
Expand All @@ -12,8 +14,8 @@
*/
public class CloudRemote {
public static void main(String[] args) throws Exception {
// launchEC2(null);
launchIPs(null);
launchEC2(null);
// launchDefaultIPs(null);
}

/**
Expand All @@ -26,16 +28,21 @@ public static void launchEC2(Class<? extends Job> job) throws Exception {
launch(c, job);
}

public static void launchDefaultIPs(Class<? extends Job> job) throws Exception {
launchIPs(job, //
"192.168.1.161", //
"192.168.1.162", //
"192.168.1.163", //
"192.168.1.164");
}

/**
* The current user is assumed to have ssh access (key-pair, no password) to the remote machines.
* H2O will be deployed to '~/h2o_rsync/'.
*/
public static void launchIPs(Class<? extends Job> job) throws Exception {
public static void launchIPs(Class<? extends Job> job, String... ips) throws Exception {
Cloud cloud = new Cloud();
cloud.publicIPs.add("192.168.1.161");
cloud.publicIPs.add("192.168.1.162");
cloud.publicIPs.add("192.168.1.163");
cloud.publicIPs.add("192.168.1.164");
cloud.publicIPs.addAll(Arrays.asList(ips));
launch(cloud, job);
}

Expand Down
Loading

0 comments on commit 535ae92

Please sign in to comment.