Skip to content

Commit

Permalink
implemented nesterov trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
mkroutikov committed Oct 28, 2014
1 parent 3eba1b9 commit d763882
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
1 change: 1 addition & 0 deletions demo/trainers.html
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
trainer_defs.push({learning_rate:LR, method: 'adagrad', eps: 1e-6, batch_size:BS, l2_decay:L2});\n\
trainer_defs.push({learning_rate:LR, method: 'windowgrad', eps: 1e-6, ro: 0.95, batch_size:BS, l2_decay:L2});\n\
trainer_defs.push({learning_rate:1.0, method: 'adadelta', eps: 1e-6, ro:0.95, batch_size:BS, l2_decay:L2});\n\
trainer_defs.push({learning_rate:LR, method: 'nesterov', momentum: 0.9, batch_size:BS, l2_decay:L2});\n\
\n\
// names for all trainers above\n\
legend = ['sgd', 'sgd+momentum', 'adagrad', 'windowgrad', 'adadelta'];\n\
Expand Down
7 changes: 6 additions & 1 deletion src/convnet_trainers.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
this.l1_decay = typeof options.l1_decay !== 'undefined' ? options.l1_decay : 0.0;
this.l2_decay = typeof options.l2_decay !== 'undefined' ? options.l2_decay : 0.0;
this.batch_size = typeof options.batch_size !== 'undefined' ? options.batch_size : 1;
this.method = typeof options.method !== 'undefined' ? options.method : 'sgd'; // sgd/adagrad/adadelta/windowgrad
this.method = typeof options.method !== 'undefined' ? options.method : 'sgd'; // sgd/adagrad/adadelta/windowgrad/netsterov

this.momentum = typeof options.momentum !== 'undefined' ? options.momentum : 0.9;
this.ro = typeof options.ro !== 'undefined' ? options.ro : 0.95; // used in adadelta
Expand Down Expand Up @@ -99,6 +99,11 @@
var dx = - Math.sqrt((xsumi[j] + this.eps)/(gsumi[j] + this.eps)) * gij;
xsumi[j] = this.ro * xsumi[j] + (1-this.ro) * dx * dx; // yes, xsum lags behind gsum by 1.
p[j] += dx;
} else if(this.method === 'nesterov') {
var dx = gsumi[j];
gsumi[j] = gsumi[j] * this.momentum + this.learning_rate * gij;
dx = self.momentum * dx - (1.0 + this.momentum) * gsumi[j];
p[j] += dx;
} else {
// assume SGD
if(this.momentum > 0.0) {
Expand Down

0 comments on commit d763882

Please sign in to comment.