Skip to content

Commit

Permalink
Clamping NN activations in addition to the weights
Browse files Browse the repository at this point in the history
  • Loading branch information
cypof committed Dec 2, 2013
1 parent 702ea20 commit 8d0c68a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@

import javax.swing.SwingUtilities;

public class Histogram extends LineChart {
public class Histograms extends LineChart {
private static final int SLICES = 64;

private static final ArrayList<Histogram> _instances = new ArrayList<Histogram>();
private static final ArrayList<Histograms> _instances = new ArrayList<Histograms>();
private static final ScheduledExecutorService _executor = Executors.newSingleThreadScheduledExecutor();
private static CheckBox _auto;

private final float[] _data;
private final ObservableList<Data<Float, Float>> _list = FXCollections.observableArrayList();

static void init() {
public static void init() {
final CountDownLatch latch = new CountDownLatch(1);
SwingUtilities.invokeLater(new Runnable() {
public void run() {
Expand All @@ -57,10 +57,10 @@ public static void build(final Layer[] ls) {
VBox v = new VBox();
for( int i = ls.length - 1; i > 0; i-- ) {
HBox h = new HBox();
h.getChildren().add(new Histogram("Layer " + i + " W", ls[i]._w));
h.getChildren().add(new Histogram("B", ls[i]._b));
h.getChildren().add(new Histogram("A", ls[i]._a));
h.getChildren().add(new Histogram("E", ls[i]._e));
h.getChildren().add(new Histograms("Layer " + i + " W", ls[i]._w));
h.getChildren().add(new Histograms("B", ls[i]._b));
h.getChildren().add(new Histograms("A", ls[i]._a));
h.getChildren().add(new Histograms("E", ls[i]._e));
v.getChildren().add(h);
}
Stage stage = new Stage();
Expand Down Expand Up @@ -103,7 +103,7 @@ public void changed(ObservableValue<? extends Boolean> ov, Boolean old_val, Bool
});
}

public Histogram(String title, float[] data) {
public Histograms(String title, float[] data) {
super(new NumberAxis(), new NumberAxis());
_data = data;

Expand All @@ -119,7 +119,7 @@ public Histogram(String title, float[] data) {
}

static void refresh() {
for( Histogram h : _instances ) {
for( Histograms h : _instances ) {
if( h._data != null ) {
float[] data = h._data.clone();
float min = Float.MAX_VALUE, max = Float.MIN_VALUE;
Expand Down
10 changes: 5 additions & 5 deletions experiments/src/main/java/hex/MnistCanvas.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public JPanel init() {
});
bar.add(new JButton("histo") {
@Override protected void fireActionPerformed(ActionEvent event) {
Histogram.initFromSwingThread();
Histogram.build(_trainer.layers());
Histograms.initFromSwingThread();
Histograms.build(_trainer.layers());
}
});
JPanel pane = new JPanel();
Expand Down Expand Up @@ -135,9 +135,9 @@ public JPanel init() {
buf += pad + edge;
}

int[] start = new int[layer._in._a.length];
for( int i = 0; i < layer._in._a.length; i++ ) {
double w = layer._w[o * layer._in._a.length + i];
int[] start = new int[layer._previous._a.length];
for( int i = 0; i < layer._previous._a.length; i++ ) {
double w = layer._w[o * layer._previous._a.length + i];
w = ((w - mean) / sigma) * 200;
if( w >= 0 )
start[i] = ((int) Math.min(+w, 255)) << 8;
Expand Down
49 changes: 31 additions & 18 deletions src/main/java/hex/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ static abstract class Training {
transient Training _training;

public final void init(Layer[] ls, int index) {
init(ls, index, true, 0);
init(ls, index, true, 0, new MersenneTwisterRNG(MersenneTwisterRNG.SEEDS));
}

public void init(Layer[] ls, int index, boolean weights, long step) {
public void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
_a = new float[units];
_e = new float[units];
_previous = ls[index - 1];
Expand Down Expand Up @@ -187,7 +187,7 @@ private final void adjust(int i, float[] w, float[] prev, float[] init, float[]
public static abstract class Input extends Layer {
protected long _pos, _len;

@Override public void init(Layer[] ls, int index, boolean weights, long step) {
@Override public void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
_a = new float[units];
}

Expand Down Expand Up @@ -572,11 +572,10 @@ public Tanh(int units) {
this.units = units;
}

@Override public void init(Layer[] ls, int index, boolean weights, long step) {
super.init(ls, index, weights, step);
@Override public void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
super.init(ls, index, weights, step, rand);
if( weights ) {
// C.f. deeplearning.net tutorial
Random rand = new MersenneTwisterRNG(MersenneTwisterRNG.SEEDS);
float min = (float) -Math.sqrt(6. / (_previous.units + units));
float max = (float) +Math.sqrt(6. / (_previous.units + units));
for( int i = 0; i < _w.length; i++ )
Expand Down Expand Up @@ -622,8 +621,8 @@ public TanhPrime(int units) {
this.units = units;
}

@Override public void init(Layer[] ls, int index, boolean weights, long step) {
super.init(ls, index, weights, step);
@Override public void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
super.init(ls, index, weights, step, rand);
// Auto encoder has it's own bias vector
_b = new float[units];
}
Expand Down Expand Up @@ -667,10 +666,9 @@ public Maxout(int units) {
this.units = units;
}

@Override public void init(Layer[] ls, int index, boolean weights, long step) {
super.init(ls, index, weights, step);
@Override public void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
super.init(ls, index, weights, step, rand);
if( weights ) {
// Random rand = new MersenneTwisterRNG(MersenneTwisterRNG.SEEDS);
// int count = Math.min(15, _previous.units);
// //float min = -.1f, max = +.1f;
// float min = -1f, max = +1f;
Expand All @@ -681,7 +679,6 @@ public Maxout(int units) {
// _w[w] = rand(rand, min, max);
// }
// }
Random rand = new MersenneTwisterRNG(MersenneTwisterRNG.SEEDS);
float min = (float) -Math.sqrt(6. / (_previous.units + units));
float max = (float) +Math.sqrt(6. / (_previous.units + units));
for( int i = 0; i < _w.length; i++ )
Expand All @@ -697,6 +694,7 @@ public Maxout(int units) {
_bits = new byte[units / 8 + 1];
}
_rand.nextBytes(_bits);
float max = 0;
for( int o = 0; o < _a.length; o++ ) {
_a[o] = 0;
boolean b = (_bits[o >> 3] & (1 << o)) != 0;
Expand All @@ -707,8 +705,13 @@ public Maxout(int units) {
_a[o] += _b[o];
if( !training )
_a[o] *= .5f;
if( max < _a[o] )
max = _a[o];
}
}
if( max > 1 )
for( int o = 0; o < _a.length; o++ )
_a[o] /= max;
}

@Override protected void bprop() {
Expand All @@ -732,10 +735,9 @@ public Rectifier(int units) {
this.units = units;
}

@Override public void init(Layer[] ls, int index, boolean weights, long step) {
super.init(ls, index, weights, step);
@Override public void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
super.init(ls, index, weights, step, rand);
if( weights ) {
// Random rand = new MersenneTwisterRNG(MersenneTwisterRNG.SEEDS);
// int count = Math.min(15, _previous.units);
// float min = -.1f, max = +.1f;
// //float min = -1f, max = +1f;
Expand All @@ -746,7 +748,6 @@ public Rectifier(int units) {
// _w[w] = rand(rand, min, max);
// }
// }
Random rand = new MersenneTwisterRNG(MersenneTwisterRNG.SEEDS);
float min = (float) -Math.sqrt(6. / (_previous.units + units));
float max = (float) +Math.sqrt(6. / (_previous.units + units));
for( int i = 0; i < _w.length; i++ )
Expand All @@ -758,14 +759,20 @@ public Rectifier(int units) {
}

@Override protected void fprop(boolean training) {
float max = 0;
for( int o = 0; o < _a.length; o++ ) {
_a[o] = 0;
for( int i = 0; i < _previous._a.length; i++ )
_a[o] += _w[o * _previous._a.length + i] * _previous._a[i];
_a[o] += _b[o];
if( _a[o] < 0 )
_a[o] = 0;
if( max < _a[o] )
max = _a[o];
}
if( max > 1 )
for( int o = 0; o < _a.length; o++ )
_a[o] /= max;
}

@Override protected void bprop() {
Expand Down Expand Up @@ -799,6 +806,7 @@ public RectifierDropout(int units) {
_bits = new byte[units / 8 + 1];
}
_rand.nextBytes(_bits);
float max = 0;
for( int o = 0; o < _a.length; o++ ) {
_a[o] = 0;
boolean b = (_bits[o >> 3] & (1 << o)) != 0;
Expand All @@ -810,8 +818,13 @@ public RectifierDropout(int units) {
_a[o] = 0;
else if( !training )
_a[o] *= .5f;
if( max < _a[o] )
max = _a[o];
}
}
if( max > 1 )
for( int o = 0; o < _a.length; o++ )
_a[o] /= max;
}
}

Expand All @@ -823,8 +836,8 @@ public RectifierPrime(int units) {
this.units = units;
}

@Override public void init(Layer[] ls, int index, boolean weights, long step) {
super.init(ls, index, weights, step);
@Override public void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
super.init(ls, index, weights, step, rand);
// Auto encoder has it's own bias vector
_b = new float[units];
for( int i = 0; i < _b.length; i++ )
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/hex/NeuralNet.java
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ public static Errors eval(Layer[] ls, Input input, Output output, long n, long[]
clones[y] = ls[y].clone();
clones[clones.length - 1] = output;
for( int y = 0; y < clones.length; y++ )
clones[y].init(clones, y, false, 0);
clones[y].init(clones, y, false, 0, null);
Layer.shareWeights(ls, clones);
return eval(clones, n, cm);
}
Expand Down Expand Up @@ -278,7 +278,6 @@ private static boolean correct(Layer[] ls, Errors e, long[][] confusion) {
float max = out[0];
int idx = 0;
for( int o = 1; o < out.length; o++ ) {
assert !Double.isNaN(out[o]);
if( out[o] > max ) {
max = out[o];
idx = o;
Expand Down Expand Up @@ -406,7 +405,7 @@ public static class NeuralNetModel extends Model {
for( int y = 0; y < clones.length; y++ ) {
clones[y]._w = weights[y];
clones[y]._b = biases[y];
clones[y].init(clones, y, false, 0);
clones[y].init(clones, y, false, 0, null);
}
((Input) clones[0])._pos = rowInChunk;
for( int i = 0; i < clones.length; i++ )
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/hex/Trainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public Threaded(Layer[] ls, double epochs, final Key job) {
for( int y = 0; y < clones.length; y++ )
clones[y] = ls[y].clone();
for( int y = 0; y < clones.length; y++ ) {
clones[y].init(clones, y, false, 0);
clones[y].init(clones, y, false, 0, null);
clones[y]._training = new Training() {
@Override long processed() {
return _processed.get();
Expand Down Expand Up @@ -431,7 +431,7 @@ static class DescentChunk extends NodeTask {
else
clones[clones.length - 1] = new ChunkLinear(_cs[_cs.length - 1], (VecLinear) output);
for( int y = 0; y < clones.length; y++ ) {
clones[y].init(clones, y, false, _node._total);
clones[y].init(clones, y, false, _node._total, null);
clones[y]._w = _node._ws[y];
clones[y]._b = _node._bs[y];
clones[y]._wm = _node._wm[y];
Expand Down

0 comments on commit 8d0c68a

Please sign in to comment.