Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed Jan 8, 2014
1 parent 9c05ffe commit 775b7b1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ public static void main(String[] args) throws Exception {
ls[i].momentum_stable = .99f;
//ls[i].l1 = .005f;
ls[i].init(ls, i);
if (i>=1) ls[i].randomize(new java.util.Random(), 1.0f);
}
return ls;
}
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/hex/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public final void init(Layer[] ls, int index) {
init(ls, index, true, 0, new MersenneTwisterRNG(MersenneTwisterRNG.SEEDS));
}

public void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
protected void init(Layer[] ls, int index, boolean weights, long step, Random rand) {
_a = new float[units];
_e = new float[units];
_previous = ls[index - 1];
Expand All @@ -121,6 +121,8 @@ public void init(Layer[] ls, int index, boolean weights, long step, Random rand)
*/
// cf. http://machinelearning.wustl.edu/mlpapers/paper_files/AISTATS2010_GlorotB10.pdf
public void randomize(Random rng, float prefactor) {
if (_w == null) return;

if (initial_weight_distribution == InitialWeightDistribution.UniformAdaptive) {
final float range = prefactor * (float)Math.sqrt(6. / (_previous.units + units));
for( int i = 0; i < _w.length; i++ )
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/hex/NeuralNet.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ void startTrain() {
case Tanh:
ls[i + 1] = new Layer.Tanh(hidden[i]);
break;
case TanhWithDropout:
ls[i + 1] = new Layer.TanhDropout(hidden[i]);
break;
case Rectifier:
ls[i + 1] = new Layer.Rectifier(hidden[i]);
break;
Expand All @@ -169,6 +172,7 @@ void startTrain() {
ls[i + 1].momentum_stable = (float) momentum_stable;
ls[i + 1].l1 = (float) l1;
ls[i + 1].l2 = (float) l2;
ls[i + 1].max_w2 = max_w2;
ls[i + 1].loss = loss;
}
if( classification )
Expand Down

0 comments on commit 775b7b1

Please sign in to comment.