Skip to content

Commit

Permalink
Fixed RectifierDropout bug, finished per-weight acceleration
Browse files Browse the repository at this point in the history
  • Loading branch information
cypof committed Dec 6, 2013
1 parent 897b315 commit 629776c
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 276 deletions.
10 changes: 5 additions & 5 deletions experiments/src/main/java/hex/Histograms.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package hex;

import hex.Layer;

import java.util.ArrayList;
import java.util.concurrent.*;

Expand Down Expand Up @@ -58,9 +56,11 @@ public static void build(final Layer[] ls) {
for( int i = ls.length - 1; i > 0; i-- ) {
HBox h = new HBox();
h.getChildren().add(new Histograms("Layer " + i + " W", ls[i]._w));
h.getChildren().add(new Histograms("B", ls[i]._b));
h.getChildren().add(new Histograms("A", ls[i]._a));
h.getChildren().add(new Histograms("E", ls[i]._e));
h.getChildren().add(new Histograms("Bias", ls[i]._b));
h.getChildren().add(new Histograms("Activity", ls[i]._a));
h.getChildren().add(new Histograms("Error", ls[i]._e));
h.getChildren().add(new Histograms("Momentum", ls[i]._wm));
h.getChildren().add(new Histograms("Per weight", ls[i]._wp));
v.getChildren().add(h);
}
Stage stage = new Stage();
Expand Down
144 changes: 53 additions & 91 deletions src/main/java/hex/Layer.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package hex;

import hex.rng.MersenneTwisterRNG;
import hex.rng.XorShiftRNG;

import java.util.Random;

Expand Down Expand Up @@ -45,12 +44,12 @@ public abstract class Layer extends Iced {
@API(help = "Momentum value once ramp is over")
public float momentum_stable;

// TODO
public float _perWeight;
public float _perWeightAnnealing;
@API(help = "Per-weight rates")
public boolean per_weight;

// Weights, biases, activity, error
// TODO hold transients only for current two layers
// TODO extract transients & code in separate one-shot trees to avoid cloning
protected transient float[] _w, _b, _a, _e;

// Momentum for weights and biases
Expand All @@ -60,8 +59,8 @@ public abstract class Layer extends Iced {
protected transient float[] _wp, _bp;

// Previous and input layers
transient Layer _previous;
transient Input _input;
protected transient Layer _previous;
protected transient Input _input;

/**
* Start of refactoring in specification & running data, for layers and trainers.
Expand Down Expand Up @@ -89,21 +88,14 @@ public void init(Layer[] ls, int index, boolean weights, long step, Random rand)
_wm = new float[_w.length];
_bm = new float[_b.length];
}
}

if( _perWeight != 0 ) {
// _wInit = new float[_w.length];
// _wMult = new float[_w.length];
// for( int i = 0; i < _w.length; i++ ) {
// _wInit[i] = _w[i];
// _wMult[i] = 1;
// }
// _bInit = new float[_b.length];
// _bMult = new float[_b.length];
// for( int i = 0; i < _b.length; i++ ) {
// _bInit[i] = _b[i];
// _bMult[i] = 1;
// }
if( per_weight ) {
_wp = new float[_w.length];
_bp = new float[_b.length];
for( int i = 0; i < _wp.length; i++ )
_wp[i] = 1;
for( int i = 0; i < _bp.length; i++ )
_bp[i] = 1;
}
}
}

Expand All @@ -124,6 +116,21 @@ protected final void bprop(int u, float g, float r, float m) {
if( _previous._e != null )
_previous._e[i] += g * _w[w];
float d = g * _previous._a[i] - _w[w] * l2 - Math.signum(_w[w]) * l1;
if( _wp != null && d != 0 ) {
boolean sign = _wp[w] >= 0;
float mult = Math.abs(_wp[w]);
// If the gradient kept its sign, increase
if( (d >= 0) == sign )
mult += .05f;
else {
if( mult > 1 )
mult *= .95f;
else
sign = !sign;
}
d *= mult;
_wp[w] = sign ? mult : -mult;
}
if( _wm != null ) {
_wm[w] *= m;
_wm[w] = d = _wm[w] + d;
Expand Down Expand Up @@ -161,39 +168,6 @@ public float momentum(long n) {
return m;
}

private final void adjust(int i, float[] w, float[] prev, float[] init, float[] mult) {
float coef = 1;

if( init != null ) {
float g = w[i] - init[i];
boolean sign = g > 0;
boolean last = mult[i] > 0;
coef = Math.abs(mult[i]);
// If the gradient kept its sign, increase
if( sign == last ) {
if( coef < 4 )
coef += _perWeight;
} else
coef *= 1 - _perWeight;
mult[i] = sign ? coef : -coef;
w[i] = init[i] + coef * g;
}

if( prev != null ) {
// Nesterov's Accelerated Gradient
// float v = (w[i] - prev[i]) * _m;
// prev[i] = w[i];
// w[i] += coef * v;
// if( w == _w )
// _wSpeed[i] = v;
// else
// _bSpeed[i] = v;
}

if( init != null )
init[i] = w[i];
}

public static abstract class Input extends Layer {
protected long _pos, _len;

Expand Down Expand Up @@ -258,8 +232,8 @@ public VecsInput(Vec[] vecs, VecsInput train) {
}

static int categories(Vec vec) {
String [] dom = vec.domain();
return dom == null?1:dom.length-1;
String[] dom = vec.domain();
return dom == null ? 1 : dom.length - 1;
}

static int expand(Vec[] vecs) {
Expand Down Expand Up @@ -497,9 +471,9 @@ public ChunkSoftmax(Chunk chunk, VecSoftmax stats) {
momentum_start = stats.momentum_start;
momentum_stable = stats.momentum_stable;
momentum_ramp = stats.momentum_ramp;
_perWeight = stats._perWeight;
_perWeightAnnealing = stats._perWeightAnnealing;
l1 = stats.l1;
l2 = stats.l2;
per_weight = stats.per_weight;
}

@Override protected int target() {
Expand Down Expand Up @@ -569,9 +543,9 @@ public ChunkLinear(Chunk chunk, VecLinear stats) {
momentum_start = stats.momentum_start;
momentum_stable = stats.momentum_stable;
momentum_ramp = stats.momentum_ramp;
_perWeight = stats._perWeight;
_perWeightAnnealing = stats._perWeightAnnealing;
l1 = stats.l1;
l2 = stats.l2;
per_weight = stats.per_weight;
}

@Override float[] target() {
Expand Down Expand Up @@ -758,20 +732,20 @@ public Rectifier(int units) {
@Override public void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
super.init(ls, index, weights, step, rand);
if( weights ) {
int count = Math.min(15, _previous.units);
float min = -.1f, max = +.1f;
//float min = -1f, max = +1f;
for( int o = 0; o < units; o++ ) {
for( int n = 0; n < count; n++ ) {
int i = rand.nextInt(_previous.units);
int w = o * _previous.units + i;
_w[w] = rand(rand, min, max);
}
}
// float min = (float) -Math.sqrt(6. / (_previous.units + units));
// float max = (float) +Math.sqrt(6. / (_previous.units + units));
// for( int i = 0; i < _w.length; i++ )
// _w[i] = rand(rand, min, max);
// int count = Math.min(15, _previous.units);
// float min = -.1f, max = +.1f;
// //float min = -1f, max = +1f;
// for( int o = 0; o < units; o++ ) {
// for( int n = 0; n < count; n++ ) {
// int i = rand.nextInt(_previous.units);
// int w = o * _previous.units + i;
// _w[w] = rand(rand, min, max);
// }
// }
float min = (float) -Math.sqrt(6. / (_previous.units + units));
float max = (float) +Math.sqrt(6. / (_previous.units + units));
for( int i = 0; i < _w.length; i++ )
_w[i] = rand(rand, min, max);

// for( int i = 0; i < _w.length; i++ )
// _w[i] = rand(rand, -.01f, .01f);
Expand All @@ -781,20 +755,14 @@ public Rectifier(int units) {
}

@Override protected void fprop(boolean training) {
float max = 0;
for( int o = 0; o < _a.length; o++ ) {
_a[o] = 0;
for( int i = 0; i < _previous._a.length; i++ )
_a[o] += _w[o * _previous._a.length + i] * _previous._a[i];
_a[o] += _b[o];
if( _a[o] < 0 )
_a[o] = 0;
if( max < _a[o] )
max = _a[o];
}
if( max > 1 )
for( int o = 0; o < _a.length; o++ )
_a[o] /= max;
}

@Override protected void bprop() {
Expand All @@ -803,7 +771,6 @@ public Rectifier(int units) {
float r = rate(processed) * (1 - m);
for( int u = 0; u < _a.length; u++ ) {
float g = _e[u];
int todo_test_ge;
if( _a[u] > 0 )
bprop(u, g, r, m);
}
Expand All @@ -823,15 +790,13 @@ public RectifierDropout(int units) {

@Override protected void fprop(boolean training) {
if( _rand == null ) {
//_rand = new MersenneTwisterRNG(MersenneTwisterRNG.SEEDS);
_rand = new XorShiftRNG(0x9afa938d554f4e76L);
_bits = new byte[units / 8 + 1];
_rand = new MersenneTwisterRNG(MersenneTwisterRNG.SEEDS);
_bits = new byte[(units + 7) / 8];
}
_rand.nextBytes(_bits);
float max = 0;
for( int o = 0; o < _a.length; o++ ) {
_a[o] = 0;
boolean b = (_bits[o >> 3] & (1 << o)) != 0;
boolean b = (_bits[o / 8] & (1 << (o % 8))) != 0;
if( !training || b ) {
for( int i = 0; i < _previous._a.length; i++ )
_a[o] += _w[o * _previous._a.length + i] * _previous._a[i];
Expand All @@ -840,13 +805,8 @@ public RectifierDropout(int units) {
_a[o] = 0;
else if( !training )
_a[o] *= .5f;
if( max < _a[o] )
max = _a[o];
}
}
if( max > 1 )
for( int o = 0; o < _a.length; o++ )
_a[o] /= max;
}
}

Expand Down Expand Up @@ -907,6 +867,8 @@ public static void shareWeights(Layer src, Layer dst) {
dst._b = src._b;
dst._wm = src._wm;
dst._bm = src._bm;
dst._wp = src._wp;
dst._bp = src._bp;
}

public static void shareWeights(Layer[] src, Layer[] dst) {
Expand Down
80 changes: 0 additions & 80 deletions src/main/java/hex/Plot.java

This file was deleted.

Loading

0 comments on commit 629776c

Please sign in to comment.