Skip to content

Commit

Permalink
Distributed NN training
Browse files Browse the repository at this point in the history
  • Loading branch information
cypof committed Sep 14, 2013
1 parent 5b828fc commit 279de27
Show file tree
Hide file tree
Showing 28 changed files with 1,341 additions and 413 deletions.
84 changes: 84 additions & 0 deletions experiments/src/main/java/hex/IrisDist.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package hex;

import hex.Layer.FrameInput;
import hex.Layer.Input;
import hex.Layer.Softmax;
import hex.NeuralNet.Error;
import hex.NeuralNet.NeuralNetScore;
import hex.NeuralNet.Weights;

import java.text.DecimalFormat;

import water.Sandbox;
import water.util.Log;

public class IrisDist extends NeuralNetIrisTest {
static final DecimalFormat _format = new DecimalFormat("0.000");

public static void main(String[] args) throws Exception {
water.Boot.main(UserCode.class, "-beta");
}

public static class UserCode {
public static void userMain(String[] args) throws Exception {
Sandbox.localCloud(2, true, args);
IrisDist test = new IrisDist();
test.run();
}
}

public void run() {
load();
_train = Trainer.reChunk(_train);

Layer[] ls = new Layer[3];
ls[0] = new FrameInput(_train);
ls[1] = new Layer.Tanh();
ls[2] = new Softmax();
ls[1]._rate = 0.01f;
ls[2]._rate = 0.01f;
ls[1]._l2 = .001f;
ls[2]._l2 = .001f;
ls[0].init(null, 4);
ls[1].init(ls[0], 7);
ls[2].init(ls[1], 3);
for( int i = 1; i < ls.length; i++ )
ls[i].randomize();

// final Trainer.Direct trainer = new Trainer.Direct(ls);
// Trainer.Threaded trainer = new Trainer.Threaded(ls, 1000, 1);
//final Trainer trainer = new Trainer.MR(ls, 0);
//Trainer.MRAsync trainer = new Trainer.MRAsync(ls, 0);
Trainer.MR2 trainer = new Trainer.MR2(ls, 0);
// final Trainer trainer = new Trainer.OpenCL(_ls);
trainer.start();


long start = System.nanoTime();
long lastTime = start;
long lastItems = 0;
for( ;; ) {
try {
Thread.sleep(1000);
} catch( InterruptedException e ) {
throw new RuntimeException(e);
}

Layer[] clones1 = Layer.clone(ls, _train);
Error trainE = NeuralNetScore.eval(clones1, NeuralNet.EVAL_ROW_COUNT);
Layer[] clones2 = Layer.clone(ls, _test);
Error testE = NeuralNetScore.eval(clones2, NeuralNet.EVAL_ROW_COUNT);
long time = System.nanoTime();
double delta = (time - lastTime) / 1e9;
double total = (time - start) / 1e9;
lastTime = time;
long steps = trainer.steps();
int ps = (int) ((steps - lastItems) / delta);

lastItems = steps;
String m = _format.format(total) + "s, " + steps + " steps (" + (ps) + "/s) ";
m += "train: " + trainE + ", test: " + testE;
Log.info(m);
}
}
}
127 changes: 127 additions & 0 deletions experiments/src/main/java/hex/IrisMisc.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package hex;

import hex.Layer.FrameInput;
import hex.Layer.Input;
import hex.Layer.Softmax;
import hex.NeuralNet.Error;
import hex.NeuralNet.NeuralNetScore;
import hex.NeuralNet.Weights;

import java.text.DecimalFormat;

import water.Sandbox;
import water.util.Log;

public class IrisMisc extends NeuralNetIrisTest {
static final DecimalFormat _format = new DecimalFormat("0.000");

public static void main(String[] args) throws Exception {
water.Boot.main(UserCode.class, "-beta");
}

public static class UserCode {
public static void userMain(String[] args) throws Exception {
Sandbox.localCloud(1, true, args);
final IrisMisc test1 = new IrisMisc();
final IrisMisc test2 = new IrisMisc();

Thread t1 = new Thread() {
public void run() {
test1.run();
}
};
Thread t2 = new Thread() {
public void run() {
test2.run();
}
};

t1.start();
// t2.start();
sync(test1, test2);
}
}

public void run() {
load();

Layer[] ls = new Layer[3];
ls[0] = new FrameInput(_train);
ls[1] = new Layer.Tanh();
ls[2] = new Softmax();
ls[1]._rate = 0.99f;
ls[2]._rate = 0.99f;
ls[1]._l2 = .001f;
ls[2]._l2 = .001f;
ls[0].init(null, 4);
ls[1].init(ls[0], 7);
ls[2].init(ls[1], 3);
for( int i = 1; i < ls.length; i++ )
ls[i].randomize();

final Trainer.Direct trainer = new Trainer.Direct(ls);
// Trainer.Threaded trainer = new Trainer.Threaded(ls, 1000, 1);
//final Trainer trainer = new Trainer.MR(ls, 0);
//Trainer.MRAsync trainer = new Trainer.MRAsync(ls, 0);
//Trainer.MR2 trainer = new Trainer.MR2(ls, 0);
// final Trainer trainer = new Trainer.OpenCL(_ls);

// Basic visualization of images and weights
// JFrame frame = new JFrame("H2O");
// frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
// MnistCanvas canvas = new MnistCanvas(trainer);
// frame.setContentPane(canvas.init());
// frame.pack();
// frame.setLocationRelativeTo(null);
// frame.setVisible(true);

//trainer.start();
Input input = (Input) ls[0];
for( int s = 0; s < 1000000; s++ ) {
trainer.step();
input.move();
}
Weights a = Weights.get(ls, true);
eval("a", ls);

for( int s = 0; s < 100000; s++ ) {
trainer.step();
input.move();
}
Weights b = Weights.get(ls, true);
eval("b", ls);

for( int s = 0; s < 100000; s++ ) {
trainer.step();
input.move();
}
Weights c = Weights.get(ls, true);
eval("c", ls);

b.set(ls);
eval("b", ls);
Weights w = Weights.get(ls, true);
for( int y = 1; y < ls.length; y++ ) {
for( int i = 0; i < ls[y]._w.length; i++ )
w._ws[y][i] += b._ws[y][i] - a._ws[y][i];
for( int i = 0; i < ls[y]._b.length; i++ )
w._bs[y][i] += b._bs[y][i] - a._bs[y][i];
}
w.set(ls);
eval("w", ls);

Log.info("Done!");
System.exit(0);
}

void eval(String tag, Layer[] ls) {
Layer[] clones1 = Layer.clone(ls, _train);
Error trainE = NeuralNetScore.eval(clones1, NeuralNet.EVAL_ROW_COUNT);
Layer[] clones2 = Layer.clone(ls, _test);
Error testE = NeuralNetScore.eval(clones2, NeuralNet.EVAL_ROW_COUNT);
Log.info(tag + ": train: " + trainE + ", test: " + testE);
}

private static void sync(IrisMisc test1, IrisMisc test2) {
}
}
4 changes: 2 additions & 2 deletions experiments/src/main/java/hex/Mnist8mSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import hex.Mnist8m.TestInput;
import hex.Mnist8m.Train8mInput;
import hex.Trainer.ThreadedTrainers;
import hex.Trainer.Threaded;
import water.deploy.VM;

public class Mnist8mSearch {
Expand All @@ -29,7 +29,7 @@ public void run() throws Exception {
for( int i = 0; i < ls.length; i++ )
ls[i].init(false);

Trainer trainer = new ThreadedTrainers(ls);
Trainer trainer = new Threaded(ls);
//trainer._batches = Mnist8m.COUNT / trainer._batch;
trainer._batches = 10;
search.run(ls[1], ls[2]);
Expand Down
54 changes: 16 additions & 38 deletions experiments/src/main/java/hex/MnistCanvas.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,15 @@

import javax.swing.*;

import water.H2O;
import water.Sample07_NeuralNet_Mnist;

public class MnistCanvas extends Canvas {
static NeuralNetTest _test;
static final int PIXELS = 784, EDGE = 28;

static Random _rand = new Random();
static int _level = 1;
Trainer _trainer;

public static void main(String[] args) throws Exception {
water.Boot.main(UserCode.class, args);
}

public static class UserCode {
// Entry point can be called 'main' or 'userMain'
public static void userMain(String[] args) throws Exception {
H2O.main(args);
Sample07_NeuralNet_Mnist mnist = new Sample07_NeuralNet_Mnist();
mnist.init();
_test = mnist;

// Basic visualization of images and weights
JFrame frame = new JFrame("H2O");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
MnistCanvas canvas = new MnistCanvas();
frame.setContentPane(canvas.init());
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);

mnist.run();
}
MnistCanvas(Trainer trainer) {
_trainer = trainer;
}

JPanel init() {
Expand All @@ -63,7 +41,7 @@ JPanel init() {
bar.add(new JButton("histo") {
@Override protected void fireActionPerformed(ActionEvent event) {
Histogram.initFromSwingThread();
Histogram.build(_test._trainer.layers());
Histogram.build(_trainer.layers());
}
});
JPanel pane = new JPanel();
Expand All @@ -76,24 +54,25 @@ JPanel init() {
}

@Override public void paint(Graphics g) {
water.fvec.Frame frame = ((FrameInput) _test._ls[0])._frame;
Layer[] ls = _trainer.layers();
water.fvec.Frame frame = ((FrameInput) ls[0])._frame;
int edge = 56, pad = 10;
int rand = _rand.nextInt((int) frame.numRows());

// Side
{
BufferedImage in = new BufferedImage(Sample07_NeuralNet_Mnist.EDGE, Sample07_NeuralNet_Mnist.EDGE, BufferedImage.TYPE_INT_RGB);
BufferedImage in = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB);
WritableRaster r = in.getRaster();

// Input
int[] pix = new int[Sample07_NeuralNet_Mnist.PIXELS];
int[] pix = new int[PIXELS];
for( int i = 0; i < pix.length; i++ )
pix[i] = (int) (frame._vecs[i].at8(rand));
r.setDataElements(0, 0, Sample07_NeuralNet_Mnist.EDGE, Sample07_NeuralNet_Mnist.EDGE, pix);
r.setDataElements(0, 0, EDGE, EDGE, pix);
g.drawImage(in, pad, pad, null);

// Labels
g.drawString("" + frame._vecs[Sample07_NeuralNet_Mnist.PIXELS].at8(rand), 10, 50);
g.drawString("" + frame._vecs[PIXELS].at8(rand), 10, 50);
g.drawString("RBM " + _level, 10, 70);
}

Expand Down Expand Up @@ -132,8 +111,8 @@ JPanel init() {
// }

// Weights
int buf = Sample07_NeuralNet_Mnist.EDGE + pad + pad;
Layer layer = _test._trainer.layers()[_level];
int buf = EDGE + pad + pad;
Layer layer = ls[_level];
double mean = 0;
int n = layer._w.length;
for( int i = 0; i < n; i++ )
Expand Down Expand Up @@ -162,10 +141,9 @@ JPanel init() {
start[i] = ((int) Math.min(-w, 255)) << 16;
}

BufferedImage out = new BufferedImage(Sample07_NeuralNet_Mnist.EDGE, Sample07_NeuralNet_Mnist.EDGE,
BufferedImage.TYPE_INT_RGB);
BufferedImage out = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB);
WritableRaster r = out.getRaster();
r.setDataElements(0, 0, Sample07_NeuralNet_Mnist.EDGE, Sample07_NeuralNet_Mnist.EDGE, start);
r.setDataElements(0, 0, EDGE, EDGE, start);

BufferedImage resized = new BufferedImage(edge, edge, BufferedImage.TYPE_INT_RGB);
Graphics2D g2 = resized.createGraphics();
Expand Down
Loading

0 comments on commit 279de27

Please sign in to comment.