Skip to content

Commit

Permalink
r395: layernorm working and converging faster
Browse files Browse the repository at this point in the history
  • Loading branch information
lh3 committed Feb 11, 2017
1 parent d3b3c39 commit 5541e7b
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 51 deletions.
6 changes: 4 additions & 2 deletions examples/rnn-bit.c
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ static void train(kann_t *ann, bit_data_t *d, float lr, int mini_size, int max_e
for (epoch = 0; epoch < max_epoch; ++epoch) {
double cost = 0.0;
int tot = 0, n_cerr = 0;
for (j = 0; j < d->n; j += mini_size) {
for (j = 0; j < d->n - mini_size; j += mini_size) {
int i, b, k;
for (k = 0; k < d->ulen; ++k) {
for (b = 0; b < mini_size; ++b) {
Expand All @@ -98,6 +98,7 @@ static void train(kann_t *ann, bit_data_t *d, float lr, int mini_size, int max_e
}
cost += kann_cost(ua, 0, 1) * d->ulen * mini_size;
n_cerr += kann_class_error(ua);
//kad_check_grad(ua->n, ua->v, ua->n-1);
kann_RMSprop(n_var, lr, 0, 0.9f, ua->g, ua->x, r);
tot += d->ulen * mini_size;
}
Expand Down Expand Up @@ -132,6 +133,7 @@ int main(int argc, char *argv[])
fprintf(stderr, "Usage: rnn-bit [options] <in.txt>\n");
return 1;
}
kad_trap_fe();
kann_srand(seed);
if (fn_in) ann = kann_load(fn_in);

Expand All @@ -142,7 +144,7 @@ int main(int argc, char *argv[])
kad_node_t *t;
t = kann_layer_input(d->n_in);
for (i = 0; i < n_h_layers; ++i)
t = kann_layer_gru(t, n_h_neurons, 1);
t = kann_layer_gru(t, n_h_neurons, KANN_RNN_VAR_H0|KANN_RNN_NORM);
ann = kann_new(kann_layer_cost(t, 2, KANN_C_CEM), 0);
}
train(ann, d, lr, mini_size, max_epoch, fn_out);
Expand Down
61 changes: 42 additions & 19 deletions kann.c
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,18 @@ kad_node_t *kann_new_weight(int n_row, int n_col)
return w;
}

kad_node_t *kann_new_bias(int n)
kad_node_t *kann_new_vec(int n, float x)
{
kad_node_t *b;
int i;
b = kad_var(0, 0, 1, n);
b->x = (float*)calloc(n, sizeof(float));
for (i = 0; i < n; ++i) b->x[i] = x;
return b;
}

kad_node_t *kann_new_bias(int n) { return kann_new_vec(n, 0.0f); }

kad_node_t *kann_new_weight_conv2d(int n_out, int n_in, int k_row, int k_col)
{
kad_node_t *w;
Expand Down Expand Up @@ -381,86 +385,105 @@ kad_node_t *kann_layer_dropout(kad_node_t *t, float r)
return kad_switch(2, x);
}

kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int var_h0)
kad_node_t *kann_layer_layernorm(kad_node_t *in)
{
int n0;
kad_node_t *alpha, *beta;
n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
alpha = kann_new_vec(n0, 1.0f);
beta = kann_new_vec(n0, 0.0f);
return kad_add(kad_mul(kad_stdnorm(in), alpha), beta);
}

static kad_node_t *kann_cmul_norm(kad_node_t *x, kad_node_t *w)
{
return kann_layer_layernorm(kad_cmul(x, w));
}

kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag)
{
int n0;
kad_node_t *h0, *w, *u, *b, *out;
kad_node_t *(*cmul)(kad_node_t*, kad_node_t*) = (rnn_flag & KANN_RNN_NORM)? kann_cmul_norm : kad_cmul;

n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
h0 = var_h0? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
h0->x = (float*)calloc(n1, sizeof(float));
w = kann_new_weight(n1, n0);
u = kann_new_weight(n1, n1);
b = kann_new_bias(n1);
out = kad_tanh(kad_add(kad_add(kad_cmul(in, w), kad_cmul(h0, u)), b));
out = kad_tanh(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
out->pre = h0;
return out;
}

kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int var_h0)
kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int rnn_flag)
{
int j, n0;
int n0;
kad_node_t *i, *f, *o, *g, *w, *u, *b, *h0, *c0, *c, *out;
kad_node_t *(*cmul)(kad_node_t*, kad_node_t*) = (rnn_flag & KANN_RNN_NORM)? kann_cmul_norm : kad_cmul;

n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
h0 = var_h0? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
h0->x = (float*)calloc(n1, sizeof(float));
c0 = var_h0? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
c0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
c0->x = (float*)calloc(n1, sizeof(float));

// i = sigm(x_t * W_i + h_{t-1} * U_i + b_i)
w = kann_new_weight(n1, n0);
u = kann_new_weight(n1, n1);
b = kann_new_bias(n1);
i = kad_sigm(kad_add(kad_add(kad_cmul(in, w), kad_cmul(h0, u)), b));
i = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
// f = sigm(x_t * W_f + h_{t-1} * U_f + b_f)
w = kann_new_weight(n1, n0);
u = kann_new_weight(n1, n1);
b = kann_new_bias(n1);
for (j = 0; j < n1; ++j) b->x[j] = 1.0f; // see Jozefowicz et al on using a large bias
f = kad_sigm(kad_add(kad_add(kad_cmul(in, w), kad_cmul(h0, u)), b));
b = kann_new_vec(n1, 1.0f); // see Jozefowicz et al on using a large bias
f = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
// o = sigm(x_t * W_o + h_{t-1} * U_o + b_o)
w = kann_new_weight(n1, n0);
u = kann_new_weight(n1, n1);
b = kann_new_bias(n1);
o = kad_sigm(kad_add(kad_add(kad_cmul(in, w), kad_cmul(h0, u)), b));
o = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
// g = tanh(x_t * W_g + h_{t-1} * U_g + b_g)
w = kann_new_weight(n1, n0);
u = kann_new_weight(n1, n1);
b = kann_new_bias(n1);
g = kad_tanh(kad_add(kad_add(kad_cmul(in, w), kad_cmul(h0, u)), b));
g = kad_tanh(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
// c_t = c_{t-1} # f + g # i
c = kad_add(kad_mul(f, c0), kad_mul(g, i)); // can't be kad_mul(c0, f)!!!
c->pre = c0;
// h_t = tanh(c_t) # o
if (rnn_flag & KANN_RNN_NORM) c = kann_layer_layernorm(c); // see Ba et al (2016) about how to apply layer normalization to LSTM
out = kad_mul(kad_tanh(c), o);
out->pre = h0;
return out;
}

kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int var_h0)
kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int rnn_flag)
{
int n0;
kad_node_t *r, *z, *w, *u, *b, *s, *h0, *out;
kad_node_t *(*cmul)(kad_node_t*, kad_node_t*) = (rnn_flag & KANN_RNN_NORM)? kann_cmul_norm : kad_cmul;

n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
h0 = var_h0? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
h0->x = (float*)calloc(n1, sizeof(float));

// z = sigm(x_t * W_z + h_{t-1} * U_z + b_z)
w = kann_new_weight(n1, n0);
u = kann_new_weight(n1, n1);
b = kann_new_bias(n1);
z = kad_sigm(kad_add(kad_add(kad_cmul(in, w), kad_cmul(h0, u)), b));
z = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
// r = sigm(x_t * W_r + h_{t-1} * U_r + b_r)
w = kann_new_weight(n1, n0);
u = kann_new_weight(n1, n1);
b = kann_new_bias(n1);
r = kad_sigm(kad_add(kad_add(kad_cmul(in, w), kad_cmul(h0, u)), b));
r = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
// s = tanh(x_t * W_s + (h_{t-1} # r) * U_s + b_s)
w = kann_new_weight(n1, n0);
u = kann_new_weight(n1, n1);
b = kann_new_bias(n1);
s = kad_tanh(kad_add(kad_add(kad_cmul(in, w), kad_cmul(kad_mul(r, h0), u)), b)); // can't be kad_mul(h0, r)!!!
s = kad_tanh(kad_add(kad_add(cmul(in, w), cmul(kad_mul(r, h0), u)), b)); // can't be kad_mul(h0, r)!!!
// h_t = z # h_{t-1} + (1 - z) # s
out = kad_add(kad_mul(kad_1minus(z), s), kad_mul(z, h0));
out->pre = h0;
Expand Down
13 changes: 9 additions & 4 deletions kann.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#ifndef KANN_H
#define KANN_H

#define KANN_VERSION "r390"
#define KANN_VERSION "r395"

// #define NO_ATOMIC_BUILTIN // use this for VC++

Expand All @@ -43,6 +43,10 @@

#define KANN_L_TEMP_INV (-1)

#define KANN_RNN_VAR_H0 0x1 // take the initial hidden values as variables
#define KANN_RNN_NORM 0x2 // apply layer normalization


#include "kautodiff.h"

typedef struct {
Expand Down Expand Up @@ -184,9 +188,10 @@ float kann_grad_clip(float thres, int n, float *g);
kad_node_t *kann_layer_input(int n1);
kad_node_t *kann_layer_linear(kad_node_t *in, int n1);
kad_node_t *kann_layer_dropout(kad_node_t *t, float r);
kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int var_h0);
kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int var_h0);
kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int var_h0);
kad_node_t *kann_layer_layernorm(kad_node_t *in);
kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag);
kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int rnn_flag);
kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int rnn_flag);
kad_node_t *kann_layer_conv2d(kad_node_t *in, int n_flt, int k_rows, int k_cols, int stride, int pad);
kad_node_t *kann_layer_max2d(kad_node_t *in, int k_rows, int k_cols, int stride, int pad);
kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type);
Expand Down
56 changes: 32 additions & 24 deletions kautodiff.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ KAD_FUNC_OP1(kad_tanh, 7)
KAD_FUNC_OP1(kad_relu, 8)
KAD_FUNC_OP1(kad_1minus, 11)
KAD_FUNC_OP1(kad_softmax, 14)
KAD_FUNC_OP1(kad_layernorm, 32)
KAD_FUNC_OP1(kad_stdnorm, 32)

/////////// Convolution ///////////

Expand Down Expand Up @@ -1514,34 +1514,42 @@ int kad_op_ce_multi(kad_node_t *p, int action)

/////////// Normalization ///////////

int kad_op_layernorm(kad_node_t *p, int action)
int kad_op_stdnorm(kad_node_t *p, int action)
{
static const float tiny = 0;
int i, n;
int i, j, n, m;
kad_node_t *q = p->child[0].p;
n = kad_len(q);
assert(q->n_d > 0);
if (q->n_d == 1) m = 1, n = kad_len(q);
else m = q->d[0], n = kad_len(q) / m;
if (action == KAD_SYNC_DIM) {
kad_sync_dim1(p, q);
} else if (action == KAD_ALLOC) {
p->child[0].t = malloc(1 * sizeof(float));
p->child[0].t = realloc(p->child[0].t, m * sizeof(float));
} else if (action == KAD_FORWARD) {
float avg, std_inv;
double s;
for (i = 0, s = 0.0; i < n; ++i) s += q->x[i];
avg = (float)(s / n);
for (i = 0; i < n; ++i) p->x[i] = q->x[i] - avg;
for (i = 0, s = 0.0; i < n; ++i) s += p->x[i] * p->x[i];
std_inv = (float)(1.0 / sqrt(s / n + tiny));
for (i = 0; i < n; ++i) p->x[i] *= std_inv;
*(float*)p->child[0].t = std_inv;
float *si = (float*)p->child[0].t;
for (j = 0; j < m; ++j) {
float *px = &p->x[j * n], *qx = &q->x[j * n];
float avg, std_inv;
double s;
for (i = 0, s = 0.0; i < n; ++i) s += qx[i];
avg = (float)(s / n);
for (i = 0; i < n; ++i) px[i] = qx[i] - avg;
for (i = 0, s = 0.0; i < n; ++i) s += px[i] * px[i];
std_inv = s == 0.0? 1.0f : (float)(1.0 / sqrt(s / n));
for (i = 0; i < n; ++i) px[i] *= std_inv;
si[j] = std_inv;
}
} else if (action == KAD_BACKWARD && kad_is_back(q)) {
float std_inv = *(float*)p->child[0].t;
double s, t;
for (i = 0, s = t = 0.0; i < n; ++i)
s += p->g[i], t += p->x[i] * p->g[i];
s /= n, t /= n;
for (i = 0; i < n; ++i)
q->g[i] += std_inv * (p->g[i] - s - p->x[i] * t);
float *si = (float*)p->child[0].t;
for (j = 0; j < m; ++j) {
float *pg = &p->g[j * n], *qg = &q->g[j * n], *px = &p->x[j * n], std_inv = si[j];
double s, t;
for (i = 0, s = t = 0.0; i < n; ++i)
s += pg[i], t += px[i] * pg[i];
s /= n, t /= n;
for (i = 0; i < n; ++i)
qg[i] += std_inv * (pg[i] - s - px[i] * t);
}
}
return 0;
}
Expand Down Expand Up @@ -2134,7 +2142,7 @@ kad_op_f kad_op_list[KAD_MAX_OP] = {
kad_op_mse, // 29: mean square error
kad_op_reshape, // 30
kad_op_concat, // 31
kad_op_layernorm // 32: layer normalization
kad_op_stdnorm // 32: layer normalization
};

/**************************
Expand All @@ -2152,7 +2160,7 @@ void kad_print_graph(FILE *fp, int n, kad_node_t **v)
{
static const char *op[] = { 0, "add", "mul", "cmul", "ce_bin_neg", "square", "sigm", "tanh", "relu", "matmul", "avg", "1minus", "switch", "ce_multi", "softmax",
"dropout", "conv2d", "max2d", "conv1d", "max1d", "slice", "max", "ce_bin", "sub", "sample_normal", "reduce_sum", "reduce_mean", "log",
"avg1d", "mse", "reshape", "concat", "layernorm" };
"avg1d", "mse", "reshape", "concat", "stdnorm" };
int i, j;
for (i = 0; i < n; ++i) v[i]->tmp = i;
for (i = 0; i < n; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions kautodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#ifndef KANN_AUTODIFF_H
#define KANN_AUTODIFF_H

#define KAD_VERSION "r393"
#define KAD_VERSION "r395"

#include <stdio.h>
#include <stdint.h>
Expand Down Expand Up @@ -196,7 +196,7 @@ kad_node_t *kad_softmax(kad_node_t *x);// f_i(x_1,...,x_n) = exp(x_i) / \sum_j e
kad_node_t *kad_1minus(kad_node_t *x); // f(x) = 1 - x
kad_node_t *kad_log(kad_node_t *x); // f(x) = log(x)

kad_node_t *kad_layernorm(kad_node_t *x); // layer normalization
kad_node_t *kad_stdnorm(kad_node_t *x); // layer normalization

// operators taking an indefinite number of operands (e.g. pooling)
kad_node_t *kad_avg(int n, kad_node_t **x); // f(x_1,...,x_n) = \sum_i x_i/n (mean pooling)
Expand Down

0 comments on commit 5541e7b

Please sign in to comment.