From fb36313addcbc8449920b9d0fb5d9e2b4ff8593e Mon Sep 17 00:00:00 2001 From: Arno Candel Date: Tue, 6 May 2014 17:44:20 -0700 Subject: [PATCH] Add support for visualization for DeepLearning neuron layers. Might not work on Java7. --- .../src/main/java/hex/MnistCanvas.java | 10 +- .../java/samples/NeuralNetMnistPretrain.java | 33 +++-- .../samples/expert/DeepLearningMnist.java | 15 ++- .../hex/deeplearning/DeepLearningModel.java | 3 + .../expert/DeepLearningVisualization.java | 114 ++++++++++++++++++ 5 files changed, 154 insertions(+), 21 deletions(-) create mode 100644 src/main/resources/samples/expert/DeepLearningVisualization.java diff --git a/experiments/src/main/java/hex/MnistCanvas.java b/experiments/src/main/java/hex/MnistCanvas.java index afb6d763bf..6c4f6e8871 100644 --- a/experiments/src/main/java/hex/MnistCanvas.java +++ b/experiments/src/main/java/hex/MnistCanvas.java @@ -1,17 +1,15 @@ package hex; import hex.Layer.VecsInput; +import water.fvec.Vec; +import javax.swing.*; import java.awt.*; import java.awt.event.ActionEvent; import java.awt.image.BufferedImage; import java.awt.image.WritableRaster; import java.util.Random; -import javax.swing.*; - -import water.fvec.Vec; - public class MnistCanvas extends Canvas { static final int PIXELS = 784, EDGE = 28; @@ -139,9 +137,9 @@ public JPanel init() { double w = layer._w[o * layer._previous._a.length + i]; w = ((w - mean) / sigma) * 200; if( w >= 0 ) - start[i] = ((int) Math.min(+w, 255)) << 8; + start[i] = ((int) Math.min(+w, 255)) << 8; //GREEN else - start[i] = ((int) Math.min(-w, 255)) << 16; + start[i] = ((int) Math.min(-w, 255)) << 16; //RED } BufferedImage out = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB); diff --git a/experiments/src/main/java/samples/NeuralNetMnistPretrain.java b/experiments/src/main/java/samples/NeuralNetMnistPretrain.java index 1d3953ac4b..8af1cfb2fa 100644 --- a/experiments/src/main/java/samples/NeuralNetMnistPretrain.java +++ b/experiments/src/main/java/samples/NeuralNetMnistPretrain.java @@ -4,11 +4,14 @@ import hex.Layer; import hex.Layer.VecSoftmax; import hex.Layer.VecsInput; +import hex.MnistCanvas; import hex.NeuralNet; import hex.Trainer; import samples.expert.NeuralNetMnist; import water.fvec.Vec; +import javax.swing.*; + public class NeuralNetMnistPretrain extends NeuralNetMnist { public static void main(String[] args) throws Exception { Class job = Class.forName(Thread.currentThread().getStackTrace()[1].getClassName()); @@ -46,8 +49,7 @@ public static void main(String[] args) throws Exception { } @Override protected void startTraining(Layer[] ls) { - // pretrain for - int pretrain_epochs = 4; + int pretrain_epochs = 2; preTrain(ls, pretrain_epochs); // actual run @@ -56,7 +58,16 @@ public static void main(String[] args) throws Exception { if (epochs > 0) { // _trainer = new Trainer.Direct(ls, epochs, self()); _trainer = new Trainer.Threaded(ls, epochs, self(), -1); - //_trainer = new Trainer.MapReduce(ls, epochs, self()); + // Basic visualization of images and weights + + JFrame frame = new JFrame("H2O Training"); + frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); + MnistCanvas canvas = new MnistCanvas(_trainer); + frame.setContentPane(canvas.init()); + frame.pack(); + frame.setLocationRelativeTo(null); + frame.setVisible(true);//_trainer = new Trainer.MapReduce(ls, epochs, self()); + _trainer.start(); _trainer.join(); } @@ -101,14 +112,14 @@ final private void preTrain(Layer[] ls, int index, int epochs) { _trainer = new Trainer.Direct(pre, epochs, self()); -// // Basic visualization of images and weights -// JFrame frame = new JFrame("H2O"); -// frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); -// MnistCanvas canvas = new MnistCanvas(_trainer); -// frame.setContentPane(canvas.init()); -// frame.pack(); -// frame.setLocationRelativeTo(null); -// frame.setVisible(true); + // Basic visualization of images and weights + JFrame frame = new JFrame("H2O Pre-Training"); + frame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); + MnistCanvas canvas = new MnistCanvas(_trainer); + frame.setContentPane(canvas.init()); + frame.pack(); + frame.setLocationRelativeTo(null); + frame.setVisible(true); _trainer.start(); _trainer.join(); diff --git a/h2o-samples/src/main/java/samples/expert/DeepLearningMnist.java b/h2o-samples/src/main/java/samples/expert/DeepLearningMnist.java index fcd0d58970..104b54ee1c 100644 --- a/h2o-samples/src/main/java/samples/expert/DeepLearningMnist.java +++ b/h2o-samples/src/main/java/samples/expert/DeepLearningMnist.java @@ -1,9 +1,12 @@ package samples.expert; +import static samples.expert.DeepLearningVisualization.visualize; import static water.util.MRUtils.sampleFrame; import hex.deeplearning.DeepLearning; +import hex.deeplearning.DeepLearningModel; import water.Job; import water.TestUtil; +import water.UKV; import water.fvec.Frame; import water.util.Log; @@ -39,12 +42,12 @@ public static void main(String[] args) throws Exception { DeepLearning p = new DeepLearning(); // Hinton parameters -> should lead to ~1 % test error after a few dozen million samples p.seed = seed; - p.hidden = new int[]{1024,1024,2048}; -// p.hidden = new int[]{128,128,256}; +// p.hidden = new int[]{1024,1024,2048}; + p.hidden = new int[]{128,128,256}; p.activation = DeepLearning.Activation.RectifierWithDropout; p.loss = DeepLearning.Loss.CrossEntropy; p.input_dropout_ratio = 0.2; - p.epochs = 10000; + p.epochs = 10; p.l1 = 1e-5; p.l2 = 0; @@ -84,11 +87,15 @@ public static void main(String[] args) throws Exception { p.score_interval = 30; p.variable_importances = false; p.fast_mode = true; //to match old NeuralNet behavior - p.ignore_const_cols = true; //to match old NeuralNet behavior +// p.ignore_const_cols = true; + p.ignore_const_cols = false; //to match old NeuralNet behavior and to have images look straight p.shuffle_training_data = false; p.force_load_balance = true; p.replicate_training_data = true; p.quiet_mode = false; p.invoke(); + + visualize((DeepLearningModel) UKV.get(p.dest())); } + } diff --git a/src/main/java/hex/deeplearning/DeepLearningModel.java b/src/main/java/hex/deeplearning/DeepLearningModel.java index b80a7f568f..0617207c1e 100644 --- a/src/main/java/hex/deeplearning/DeepLearningModel.java +++ b/src/main/java/hex/deeplearning/DeepLearningModel.java @@ -915,6 +915,9 @@ public boolean generateHTML(String title, StringBuilder sb) { return true; } + // optional JFrame creation for visualization of weights +// DeepLearningVisualization.visualize(this); + final String mse_format = "%g"; // final String cross_entropy_format = "%2.6f"; diff --git a/src/main/resources/samples/expert/DeepLearningVisualization.java b/src/main/resources/samples/expert/DeepLearningVisualization.java new file mode 100644 index 0000000000..44606eb350 --- /dev/null +++ b/src/main/resources/samples/expert/DeepLearningVisualization.java @@ -0,0 +1,114 @@ +package samples.expert; + +import hex.deeplearning.DeepLearningModel; +import hex.deeplearning.DeepLearningTask; +import hex.deeplearning.Neurons; + +import javax.swing.*; +import java.awt.*; +import java.awt.event.ActionEvent; +import java.awt.image.BufferedImage; +import java.awt.image.WritableRaster; + +public class DeepLearningVisualization extends Canvas { + + static int _level = 1; + Neurons[] _neurons; + + public DeepLearningVisualization(Neurons[] neurons) { + _neurons = neurons; + } + + public JPanel init() { + JToolBar bar = new JToolBar(); + bar.add(new JButton("refresh") { + @Override protected void fireActionPerformed(ActionEvent event) { + DeepLearningVisualization.this.repaint(); + } + }); + bar.add(new JButton("++") { + @Override protected void fireActionPerformed(ActionEvent event) { + if (_level < _neurons.length-2) _level++; + } + }); + bar.add(new JButton("--") { + @Override protected void fireActionPerformed(ActionEvent event) { + if (_level > 1) _level--; + } + }); + JPanel pane = new JPanel(); + BorderLayout bord = new BorderLayout(); + pane.setLayout(bord); + pane.add("North", bar); + setSize(1024, 1024); + pane.add(this); + return pane; + } + + @Override public void paint(Graphics g) { + Neurons layer = _neurons[_level]; + int edge = 56, pad = 10; + final int EDGE = (int) Math.ceil(Math.sqrt(layer._previous._a.size())); + assert (layer._previous._a.size() <= EDGE * EDGE); + + int offset = pad; + int buf = EDGE + pad + pad; + double mean = 0; + long n = layer._w.size(); + for (int i = 0; i < n; i++) + mean += layer._w.raw()[i]; + mean /= layer._w.size(); + double sigma = 0; + for (int i = 0; i < layer._w.size(); i++) { + double d = layer._w.raw()[i] - mean; + sigma += d * d; + } + sigma = Math.sqrt(sigma / (layer._w.size() - 1)); + + for (int o = 0; o < layer._a.size(); o++) { + if (o % 10 == 0) { + offset = pad; + buf += pad + edge; + } + + int[] pic = new int[EDGE * EDGE]; + for (int i = 0; i < layer._previous._a.size(); i++) { + double w = layer._w.get(o, i); + w = ((w - mean) / sigma) * 200; + if (w >= 0) + pic[i] = ((int) Math.min(+w, 255)) << 8; //GREEN + else + pic[i] = ((int) Math.min(-w, 255)) << 16; //RED + } + + BufferedImage out = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB); + WritableRaster r = out.getRaster(); + r.setDataElements(0, 0, EDGE, EDGE, pic); + + BufferedImage resized = new BufferedImage(edge, edge, BufferedImage.TYPE_INT_RGB); + Graphics2D g2 = resized.createGraphics(); + try { + g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC); + g2.clearRect(0, 0, edge, edge); + g2.drawImage(out, 0, 0, edge, edge, null); + } finally { + g2.dispose(); + } + g.drawImage(resized, buf, offset, null); + + offset += pad + edge; + } + } + + static JFrame frame = new JFrame("H2O Deep Learning"); + static public void visualize(final DeepLearningModel dlm) { + Neurons[] neurons = DeepLearningTask.makeNeuronsForTesting(dlm.model_info()); + frame.dispose(); + frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); + DeepLearningVisualization canvas = new DeepLearningVisualization(neurons); + frame.setContentPane(canvas.init()); + frame.pack(); + frame.setLocationRelativeTo(null); + frame.setVisible(true); + } +}