Skip to content

Commit

Permalink
changed SGD to direct update in test-lstm2
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom committed Jan 29, 2016
1 parent 15a8aca commit ae9e80b
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions test-lstm2.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
// Test case for copying parameters and states in/out of the network.

// whether to run the copied test case
#define COPIED
// whether to run the direct test case
#undef DIRECT

#include <assert.h>
#include <math.h>
#include <iostream>
Expand Down Expand Up @@ -81,16 +76,16 @@ double test_net(Network net) {
#define die() (cerr<<"FATAL "<<__FILE__<<" "<<__LINE__<<"\n",abort(),true)

int main(int argc, char **argv) {
auto factory = []{
float learning_rate = 0.01;
auto factory = [&]{
Network net = make_net("lstm1",
{{"ninput", 1}, {"nhidden", 4}, {"noutput", 2}, {"gpu", -1}});
net->setLearningRate(1e-1, 0.0);
net->setLearningRate(learning_rate, 0.0);
return net;
};
Network net = factory();
print("training 1:4:2 network to learn delay");
vector<float> states;
vector<float> weights;
Eigen::Tensor<float, 1> states, weights, derivs;
for (int i = 0; i < ntrain; i++) {
Sequence xs, ys;
gentest(xs, ys);
Expand All @@ -99,14 +94,13 @@ int main(int argc, char **argv) {
clear_derivs(net);
clear_state_derivs(net);

#ifdef COPIED
int nstates = n_states(net);
int nweights = n_params(net);
states.resize(nstates);
weights.resize(nweights);
derivs.resize(nweights);
get_states(net, states.data(), nstates) || die();
get_params(net, weights.data(), nweights) || die();
#endif

#ifdef DIRECT
set_targets(net, ys);
Expand All @@ -115,7 +109,6 @@ int main(int argc, char **argv) {
sgd_update(net);
#endif

#ifdef COPIED
net = factory();
set_states(net, states.data(), nstates) || die();
set_params(net, weights.data(), nweights) || die();
Expand All @@ -124,8 +117,12 @@ int main(int argc, char **argv) {
set_targets(net, ys);
net->backward();
if(i==0) {cerr<<"COPIED:\n";network_detail(net);}
sgd_update(net);
#endif

// perform stochastic gradient descent on
// the externalized weights instead of sgd_update(net)
get_derivs(net, derivs.data(), nweights);
weights = weights + derivs * Float(0.01);
set_params(net, weights.data(), nweights);
}
// network_detail(net);
double merr0 = test_net(net);
Expand Down

0 comments on commit ae9e80b

Please sign in to comment.