Skip to content

Commit

Permalink
Add support for visualization for DeepLearning neuron layers. Might n…
Browse files Browse the repository at this point in the history
…ot work on Java7.
  • Loading branch information
arnocandel committed May 7, 2014
1 parent 61ac0c5 commit fb36313
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 21 deletions.
10 changes: 4 additions & 6 deletions experiments/src/main/java/hex/MnistCanvas.java
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package hex;

import hex.Layer.VecsInput;
import water.fvec.Vec;

import javax.swing.*;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.image.BufferedImage;
import java.awt.image.WritableRaster;
import java.util.Random;

import javax.swing.*;

import water.fvec.Vec;

public class MnistCanvas extends Canvas {
static final int PIXELS = 784, EDGE = 28;

Expand Down Expand Up @@ -139,9 +137,9 @@ public JPanel init() {
double w = layer._w[o * layer._previous._a.length + i];
w = ((w - mean) / sigma) * 200;
if( w >= 0 )
start[i] = ((int) Math.min(+w, 255)) << 8;
start[i] = ((int) Math.min(+w, 255)) << 8; //GREEN
else
start[i] = ((int) Math.min(-w, 255)) << 16;
start[i] = ((int) Math.min(-w, 255)) << 16; //RED
}

BufferedImage out = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB);
Expand Down
33 changes: 22 additions & 11 deletions experiments/src/main/java/samples/NeuralNetMnistPretrain.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import hex.Layer;
import hex.Layer.VecSoftmax;
import hex.Layer.VecsInput;
import hex.MnistCanvas;
import hex.NeuralNet;
import hex.Trainer;
import samples.expert.NeuralNetMnist;
import water.fvec.Vec;

import javax.swing.*;

public class NeuralNetMnistPretrain extends NeuralNetMnist {
public static void main(String[] args) throws Exception {
Class job = Class.forName(Thread.currentThread().getStackTrace()[1].getClassName());
Expand Down Expand Up @@ -46,8 +49,7 @@ public static void main(String[] args) throws Exception {
}

@Override protected void startTraining(Layer[] ls) {
// pretrain for
int pretrain_epochs = 4;
int pretrain_epochs = 2;
preTrain(ls, pretrain_epochs);

// actual run
Expand All @@ -56,7 +58,16 @@ public static void main(String[] args) throws Exception {
if (epochs > 0) {
// _trainer = new Trainer.Direct(ls, epochs, self());
_trainer = new Trainer.Threaded(ls, epochs, self(), -1);
//_trainer = new Trainer.MapReduce(ls, epochs, self());
// Basic visualization of images and weights

JFrame frame = new JFrame("H2O Training");
frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
MnistCanvas canvas = new MnistCanvas(_trainer);
frame.setContentPane(canvas.init());
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);//_trainer = new Trainer.MapReduce(ls, epochs, self());

_trainer.start();
_trainer.join();
}
Expand Down Expand Up @@ -101,14 +112,14 @@ final private void preTrain(Layer[] ls, int index, int epochs) {

_trainer = new Trainer.Direct(pre, epochs, self());

// // Basic visualization of images and weights
// JFrame frame = new JFrame("H2O");
// frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
// MnistCanvas canvas = new MnistCanvas(_trainer);
// frame.setContentPane(canvas.init());
// frame.pack();
// frame.setLocationRelativeTo(null);
// frame.setVisible(true);
// Basic visualization of images and weights
JFrame frame = new JFrame("H2O Pre-Training");
frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
MnistCanvas canvas = new MnistCanvas(_trainer);
frame.setContentPane(canvas.init());
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);

_trainer.start();
_trainer.join();
Expand Down
15 changes: 11 additions & 4 deletions h2o-samples/src/main/java/samples/expert/DeepLearningMnist.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package samples.expert;

import static samples.expert.DeepLearningVisualization.visualize;
import static water.util.MRUtils.sampleFrame;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import water.Job;
import water.TestUtil;
import water.UKV;
import water.fvec.Frame;
import water.util.Log;

Expand Down Expand Up @@ -39,12 +42,12 @@ public static void main(String[] args) throws Exception {
DeepLearning p = new DeepLearning();
// Hinton parameters -> should lead to ~1 % test error after a few dozen million samples
p.seed = seed;
p.hidden = new int[]{1024,1024,2048};
// p.hidden = new int[]{128,128,256};
// p.hidden = new int[]{1024,1024,2048};
p.hidden = new int[]{128,128,256};
p.activation = DeepLearning.Activation.RectifierWithDropout;
p.loss = DeepLearning.Loss.CrossEntropy;
p.input_dropout_ratio = 0.2;
p.epochs = 10000;
p.epochs = 10;
p.l1 = 1e-5;
p.l2 = 0;

Expand Down Expand Up @@ -84,11 +87,15 @@ public static void main(String[] args) throws Exception {
p.score_interval = 30;
p.variable_importances = false;
p.fast_mode = true; //to match old NeuralNet behavior
p.ignore_const_cols = true; //to match old NeuralNet behavior
// p.ignore_const_cols = true;
p.ignore_const_cols = false; //to match old NeuralNet behavior and to have images look straight
p.shuffle_training_data = false;
p.force_load_balance = true;
p.replicate_training_data = true;
p.quiet_mode = false;
p.invoke();

visualize((DeepLearningModel) UKV.get(p.dest()));
}

}
3 changes: 3 additions & 0 deletions src/main/java/hex/deeplearning/DeepLearningModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,9 @@ public boolean generateHTML(String title, StringBuilder sb) {
return true;
}

// optional JFrame creation for visualization of weights
// DeepLearningVisualization.visualize(this);

final String mse_format = "%g";
// final String cross_entropy_format = "%2.6f";

Expand Down
114 changes: 114 additions & 0 deletions src/main/resources/samples/expert/DeepLearningVisualization.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package samples.expert;

import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.DeepLearningTask;
import hex.deeplearning.Neurons;

import javax.swing.*;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.image.BufferedImage;
import java.awt.image.WritableRaster;

public class DeepLearningVisualization extends Canvas {

static int _level = 1;
Neurons[] _neurons;

public DeepLearningVisualization(Neurons[] neurons) {
_neurons = neurons;
}

public JPanel init() {
JToolBar bar = new JToolBar();
bar.add(new JButton("refresh") {
@Override protected void fireActionPerformed(ActionEvent event) {
DeepLearningVisualization.this.repaint();
}
});
bar.add(new JButton("++") {
@Override protected void fireActionPerformed(ActionEvent event) {
if (_level < _neurons.length-2) _level++;
}
});
bar.add(new JButton("--") {
@Override protected void fireActionPerformed(ActionEvent event) {
if (_level > 1) _level--;
}
});
JPanel pane = new JPanel();
BorderLayout bord = new BorderLayout();
pane.setLayout(bord);
pane.add("North", bar);
setSize(1024, 1024);
pane.add(this);
return pane;
}

@Override public void paint(Graphics g) {
Neurons layer = _neurons[_level];
int edge = 56, pad = 10;
final int EDGE = (int) Math.ceil(Math.sqrt(layer._previous._a.size()));
assert (layer._previous._a.size() <= EDGE * EDGE);

int offset = pad;
int buf = EDGE + pad + pad;
double mean = 0;
long n = layer._w.size();
for (int i = 0; i < n; i++)
mean += layer._w.raw()[i];
mean /= layer._w.size();
double sigma = 0;
for (int i = 0; i < layer._w.size(); i++) {
double d = layer._w.raw()[i] - mean;
sigma += d * d;
}
sigma = Math.sqrt(sigma / (layer._w.size() - 1));

for (int o = 0; o < layer._a.size(); o++) {
if (o % 10 == 0) {
offset = pad;
buf += pad + edge;
}

int[] pic = new int[EDGE * EDGE];
for (int i = 0; i < layer._previous._a.size(); i++) {
double w = layer._w.get(o, i);
w = ((w - mean) / sigma) * 200;
if (w >= 0)
pic[i] = ((int) Math.min(+w, 255)) << 8; //GREEN
else
pic[i] = ((int) Math.min(-w, 255)) << 16; //RED
}

BufferedImage out = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB);
WritableRaster r = out.getRaster();
r.setDataElements(0, 0, EDGE, EDGE, pic);

BufferedImage resized = new BufferedImage(edge, edge, BufferedImage.TYPE_INT_RGB);
Graphics2D g2 = resized.createGraphics();
try {
g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC);
g2.clearRect(0, 0, edge, edge);
g2.drawImage(out, 0, 0, edge, edge, null);
} finally {
g2.dispose();
}
g.drawImage(resized, buf, offset, null);

offset += pad + edge;
}
}

static JFrame frame = new JFrame("H2O Deep Learning");
static public void visualize(final DeepLearningModel dlm) {
Neurons[] neurons = DeepLearningTask.makeNeuronsForTesting(dlm.model_info());
frame.dispose();
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
DeepLearningVisualization canvas = new DeepLearningVisualization(neurons);
frame.setContentPane(canvas.init());
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);
}
}

0 comments on commit fb36313

Please sign in to comment.