Skip to content

Commit

Permalink
Add Mnist test case for new NN, and fix compilation of old NN samples.
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed Feb 5, 2014
1 parent bcaf699 commit 22bfea3
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 24 deletions.
35 changes: 29 additions & 6 deletions experiments/src/main/java/hex/Histograms.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package hex;

import java.util.ArrayList;
import java.util.concurrent.*;

import javafx.application.Platform;
import javafx.beans.value.ChangeListener;
import javafx.beans.value.ObservableValue;
Expand All @@ -14,11 +11,21 @@
import javafx.scene.Scene;
import javafx.scene.chart.LineChart;
import javafx.scene.chart.NumberAxis;
import javafx.scene.control.*;
import javafx.scene.layout.*;
import javafx.scene.control.Button;
import javafx.scene.control.CheckBox;
import javafx.scene.control.ScrollPane;
import javafx.scene.control.ToolBar;
import javafx.scene.layout.BorderPane;
import javafx.scene.layout.HBox;
import javafx.scene.layout.VBox;
import javafx.stage.Stage;

import javax.swing.SwingUtilities;
import javax.swing.*;
import java.util.ArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

public class Histograms extends LineChart {
private static final int SLICES = 64;
Expand Down Expand Up @@ -118,6 +125,22 @@ public Histograms(String title, float[] data) {
_instances.add(this);
}

public Histograms(String title, double[] data) {
super(new NumberAxis(), new NumberAxis());
_data = new float[data.length];
for (int i=0; i<data.length; ++i) _data[i] = (float)data[i];

ObservableList<Series<Float, Float>> series = FXCollections.observableArrayList();
for( int i = 0; i < SLICES; i++ )
_list.add(new Data<Float, Float>(0f, 0f));
series.add(new LineChart.Series<Float, Float>(title, _list));
setData(series);
setPrefWidth(600);
setPrefHeight(250);

_instances.add(this);
}

static void refresh() {
for( Histograms h : _instances ) {
if( h._data != null ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public static void main(String[] args) throws Exception {
preTrain(ls);

//_trainer = new Trainer.Direct(ls, 0, self());
_trainer = new Trainer.Threaded(ls, 0, self());
_trainer = new Trainer.Threaded(ls, 0, self(), -1);
//_trainer = new Trainer.MapReduce(ls, 0, self());

_trainer.start();
Expand Down
2 changes: 1 addition & 1 deletion experiments/src/main/java/samples/NeuralNetViz.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public static void main(String[] args) throws Exception {

protected void startTraining(Layer[] ls) {
//_trainer = new Trainer.MapReduce(ls, 0, self());
_trainer = new Trainer.Threaded(ls, 0, self());
_trainer = new Trainer.Threaded(ls, 0, self(), -1);

// Basic visualization of images and weights
JFrame frame = new JFrame("H2O");
Expand Down
69 changes: 69 additions & 0 deletions h2o-samples/src/main/java/samples/NeuralNetMnist2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package samples;

import hex.FrameTask;
import hex.nn.NN;
import water.Job;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.util.Log;

/**
* Runs a neural network on the MNIST dataset.
*/
public class NeuralNetMnist2 extends Job {
public static void main(String[] args) throws Exception {
Class job = NeuralNetMnist2.class;
samples.launchers.CloudLocal.launch(job, 2);
// samples.launchers.CloudProcess.launch(job, 4);
//samples.launchers.CloudConnect.launch(job, "localhost:54321");
// samples.launchers.CloudRemote.launchIPs(job, "192.168.1.171", "192.168.1.172", "192.168.1.173", "192.168.1.174", "192.168.1.175");
// samples.launchers.CloudRemote.launchIPs(job, "192.168.1.161", "192.168.1.163", "192.168.1.164");
// samples.launchers.CloudRemote.launchIPs(job, "192.168.1.161");
//samples.launchers.CloudRemote.launchEC2(job, 4);
}

@Override protected Status exec() {
final long seed = 0xC0FFEE;

Log.info("Parsing data.");
// Frame trainf = TestUtil.parseFromH2OFolder("smalldata/mnist/train10x.csv.gz");
Frame trainf = TestUtil.parseFromH2OFolder("smalldata/mnist/train.csv.gz");
Frame testf = TestUtil.parseFromH2OFolder("smalldata/mnist/test.csv.gz");
Log.info("Done.");

NN p = new NN();
// Hinton parameters -> should lead to ~1 % test error after ~ 10M training points
p.seed = seed;
//p.hidden = new int[]{1024,1024,2048};
p.hidden = new int[]{128,128,256};
p.rate = 0.003;
p.activation = NN.Activation.RectifierWithDropout;
p.loss = NN.Loss.CrossEntropy;
p.input_dropout_ratio = 0.2;
p.max_w2 = 15;
p.epochs = 200;
p.rate_annealing = 1e-6;
p.l1 = 1e-5;
p.l2 = 0;
p.momentum_stable = 0.99;
p.momentum_start = 0.5;
p.momentum_ramp = 1800000;
p.initial_weight_distribution = NN.InitialWeightDistribution.UniformAdaptive;
// p.initial_weight_scale = 0.01
p.classification = true;
p.diagnostics = false;
p.validation = testf;
p.source = trainf;
p.response = trainf.lastVec();
p.ignored_cols = null;
p.destination_key = Key.make("mnist.model");

Frame fr = FrameTask.DataInfo.prepareFrame(p.source, p.response, p.ignored_cols, true);
p._dinfo = new FrameTask.DataInfo(fr, 1, true);

p.initModel();
p.trainModel(true);
return Status.Running;
}
}
10 changes: 6 additions & 4 deletions h2o-samples/src/main/java/samples/launchers/CloudRemote.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package samples.launchers;

import java.util.Arrays;

import water.Job;
import water.deploy.*;
import water.deploy.Cloud;
import water.deploy.EC2;
import water.deploy.VM;
import water.util.Log;

import java.util.Arrays;

/**
* Builds a remote cluster. H2O jar, or classes from current workspace, are deployed through rsync.
* <nl>
Expand Down Expand Up @@ -61,7 +63,7 @@ public static void launch(Cloud cloud, Class<? extends Job> job) throws Exceptio
cloud.clientRSyncExcludes.add("lib/javassist");
cloud.clientRSyncExcludes.add("**/*-sources.jar");

String java = "-ea -Xmx60G -Dh2o.debug";
String java = "-ea -Xmx20G -Dh2o.debug";
String node = "-mainClass " + UserCode.class.getName() + " " + (job != null ? job.getName() : null) + " -beta";
cloud.start(java.split(" "), node.split(" "));
}
Expand Down
10 changes: 4 additions & 6 deletions src/main/java/hex/nn/Neurons.java
Original file line number Diff line number Diff line change
Expand Up @@ -430,19 +430,17 @@ else if( !training && dropout != null )
for( int u = 0; u < _a.length; u++ ) {
//(d/dx)(max(0,x)) = 1 if x > 0, otherwise 0

// short-cut: set gradient to 0
// AND
// no need to update the weights since there's no momenta, no l1 and no l2
// no need to update the weights if there are no momenta and l1=0 and l2=0
if (_wm == null && l1 == 0.0 && l2 == 0.0) {
if( _a[u] > 0 ) { // don't use >= (faster this way: lots of zeros)
final double g = _e[u]; // * 1.0 (from derivative of rectifier)
bprop(u, g, r, m);
}
// otherwise g = _e[u] * 0.0 = 0 and we don't allow other contributions by (and to) weights and momenta
}
// TODO: might always want to use this version (faster)
// if we have momenta or l1 or l2, then EVEN for g=0, there will be contributions to the weight updates
// Note: this is slower than always doing the shortcut above, and might not affect the accuracy much
else {
final double g = _a[u] > 0 ? _e[u] : 0; // * 1.0 (from derivative of rectifier)
final double g = _a[u] > 0 ? _e[u] : 0;
bprop(u, g, r, m);
}
}
Expand Down
14 changes: 8 additions & 6 deletions src/main/java/water/deploy/Cloud.java
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package water.deploy;

import java.io.File;
import java.io.Serializable;
import java.util.*;

import water.*;
import water.Boot;
import water.H2O;
import water.H2O.FlatFileEntry;
import water.TestUtil;
import water.deploy.VM.Params;
import water.deploy.VM.Watchdog;
import water.util.Log;
import water.util.Utils;

import java.io.File;
import java.io.Serializable;
import java.util.*;

/**
* Deploys and starts a remote cluster.
* <nl>
Expand All @@ -30,7 +32,7 @@ public class Cloud {
public final Set<String> fannedRSyncExcludes = new HashSet<String>();

/** Port for all remote machines. */
public static final int PORT = 54321;
public static final int PORT = 54423;
public static final int FORWARDED_LOCAL_PORT = 54321;
/**
* To avoid configuring remote machines, a JVM can be sent through rsync with H2O. By default,
Expand Down

0 comments on commit 22bfea3

Please sign in to comment.