forked from h2oai/h2o-2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for visualization for DeepLearning neuron layers. Might n…
…ot work on Java7.
- Loading branch information
1 parent
61ac0c5
commit fb36313
Showing
5 changed files
with
154 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
src/main/resources/samples/expert/DeepLearningVisualization.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
package samples.expert; | ||
|
||
import hex.deeplearning.DeepLearningModel; | ||
import hex.deeplearning.DeepLearningTask; | ||
import hex.deeplearning.Neurons; | ||
|
||
import javax.swing.*; | ||
import java.awt.*; | ||
import java.awt.event.ActionEvent; | ||
import java.awt.image.BufferedImage; | ||
import java.awt.image.WritableRaster; | ||
|
||
public class DeepLearningVisualization extends Canvas { | ||
|
||
static int _level = 1; | ||
Neurons[] _neurons; | ||
|
||
public DeepLearningVisualization(Neurons[] neurons) { | ||
_neurons = neurons; | ||
} | ||
|
||
public JPanel init() { | ||
JToolBar bar = new JToolBar(); | ||
bar.add(new JButton("refresh") { | ||
@Override protected void fireActionPerformed(ActionEvent event) { | ||
DeepLearningVisualization.this.repaint(); | ||
} | ||
}); | ||
bar.add(new JButton("++") { | ||
@Override protected void fireActionPerformed(ActionEvent event) { | ||
if (_level < _neurons.length-2) _level++; | ||
} | ||
}); | ||
bar.add(new JButton("--") { | ||
@Override protected void fireActionPerformed(ActionEvent event) { | ||
if (_level > 1) _level--; | ||
} | ||
}); | ||
JPanel pane = new JPanel(); | ||
BorderLayout bord = new BorderLayout(); | ||
pane.setLayout(bord); | ||
pane.add("North", bar); | ||
setSize(1024, 1024); | ||
pane.add(this); | ||
return pane; | ||
} | ||
|
||
@Override public void paint(Graphics g) { | ||
Neurons layer = _neurons[_level]; | ||
int edge = 56, pad = 10; | ||
final int EDGE = (int) Math.ceil(Math.sqrt(layer._previous._a.size())); | ||
assert (layer._previous._a.size() <= EDGE * EDGE); | ||
|
||
int offset = pad; | ||
int buf = EDGE + pad + pad; | ||
double mean = 0; | ||
long n = layer._w.size(); | ||
for (int i = 0; i < n; i++) | ||
mean += layer._w.raw()[i]; | ||
mean /= layer._w.size(); | ||
double sigma = 0; | ||
for (int i = 0; i < layer._w.size(); i++) { | ||
double d = layer._w.raw()[i] - mean; | ||
sigma += d * d; | ||
} | ||
sigma = Math.sqrt(sigma / (layer._w.size() - 1)); | ||
|
||
for (int o = 0; o < layer._a.size(); o++) { | ||
if (o % 10 == 0) { | ||
offset = pad; | ||
buf += pad + edge; | ||
} | ||
|
||
int[] pic = new int[EDGE * EDGE]; | ||
for (int i = 0; i < layer._previous._a.size(); i++) { | ||
double w = layer._w.get(o, i); | ||
w = ((w - mean) / sigma) * 200; | ||
if (w >= 0) | ||
pic[i] = ((int) Math.min(+w, 255)) << 8; //GREEN | ||
else | ||
pic[i] = ((int) Math.min(-w, 255)) << 16; //RED | ||
} | ||
|
||
BufferedImage out = new BufferedImage(EDGE, EDGE, BufferedImage.TYPE_INT_RGB); | ||
WritableRaster r = out.getRaster(); | ||
r.setDataElements(0, 0, EDGE, EDGE, pic); | ||
|
||
BufferedImage resized = new BufferedImage(edge, edge, BufferedImage.TYPE_INT_RGB); | ||
Graphics2D g2 = resized.createGraphics(); | ||
try { | ||
g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC); | ||
g2.clearRect(0, 0, edge, edge); | ||
g2.drawImage(out, 0, 0, edge, edge, null); | ||
} finally { | ||
g2.dispose(); | ||
} | ||
g.drawImage(resized, buf, offset, null); | ||
|
||
offset += pad + edge; | ||
} | ||
} | ||
|
||
static JFrame frame = new JFrame("H2O Deep Learning"); | ||
static public void visualize(final DeepLearningModel dlm) { | ||
Neurons[] neurons = DeepLearningTask.makeNeuronsForTesting(dlm.model_info()); | ||
frame.dispose(); | ||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); | ||
DeepLearningVisualization canvas = new DeepLearningVisualization(neurons); | ||
frame.setContentPane(canvas.init()); | ||
frame.pack(); | ||
frame.setLocationRelativeTo(null); | ||
frame.setVisible(true); | ||
} | ||
} |