Skip to content

Commit

Permalink
r363: added layer normalization
Browse files Browse the repository at this point in the history
Not adsolutely sure if the implementation is correct. Did test on a toy
example. it seems ok.
  • Loading branch information
lh3 committed Feb 10, 2017
1 parent 8831f3c commit bcedf6c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
63 changes: 48 additions & 15 deletions kautodiff.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ KAD_FUNC_OP1(kad_tanh, 7)
KAD_FUNC_OP1(kad_relu, 8)
KAD_FUNC_OP1(kad_1minus, 11)
KAD_FUNC_OP1(kad_softmax, 14)
KAD_FUNC_OP1(kad_layernorm, 32)

/////////// Convolution ///////////

Expand Down Expand Up @@ -1511,6 +1512,40 @@ int kad_op_ce_multi(kad_node_t *p, int action)
return 0;
}

/////////// Normalization ///////////

int kad_op_layernorm(kad_node_t *p, int action)
{
static const float tiny = 0;
int i, n;
kad_node_t *q = p->child[0].p;
n = kad_len(q);
if (action == KAD_SYNC_DIM) {
kad_sync_dim1(p, q);
} else if (action == KAD_ALLOC) {
p->child[0].t = malloc(1 * sizeof(float));
} else if (action == KAD_FORWARD) {
float avg, std_inv;
double s;
for (i = 0, s = 0.0; i < n; ++i) s += q->x[i];
avg = (float)(s / n);
for (i = 0; i < n; ++i) p->x[i] = q->x[i] - avg;
for (i = 0, s = 0.0; i < n; ++i) s += p->x[i] * p->x[i];
std_inv = (float)(1.0 / sqrt(s / n + tiny));
for (i = 0; i < n; ++i) p->x[i] *= std_inv;
*(float*)p->child[0].t = std_inv;
} else if (action == KAD_BACKWARD && kad_is_back(q)) {
float std_inv = *(float*)p->child[0].t;
double s, t;
for (i = 0, s = t = 0.0; i < n; ++i)
s += p->g[i], t += p->x[i] * p->g[i];
s /= n, t /= n;
for (i = 0; i < n; ++i)
q->g[i] += std_inv * (p->g[i] - s - p->x[i] * t);
}
return 0;
}

/////////// Activation functions ///////////

int kad_op_sigm(kad_node_t *p, int action)
Expand All @@ -1523,10 +1558,9 @@ int kad_op_sigm(kad_node_t *p, int action)
} else if (action == KAD_FORWARD) {
for (i = 0; i < n; ++i)
p->x[i] = 1.0f / (1.0f + expf(-q->x[i]));
} else if (action == KAD_BACKWARD) {
if (kad_is_back(q))
for (i = 0; i < n; ++i)
q->g[i] += p->g[i] * (p->x[i] * (1.0f - p->x[i]));
} else if (action == KAD_BACKWARD && kad_is_back(q)) {
for (i = 0; i < n; ++i)
q->g[i] += p->g[i] * (p->x[i] * (1.0f - p->x[i]));
}
return 0;
}
Expand All @@ -1547,10 +1581,9 @@ int kad_op_tanh(kad_node_t *p, int action)
p->x[i] = (1.0f - y) / (1.0f + y);
}
}
} else if (action == KAD_BACKWARD) {
if (kad_is_back(q))
for (i = 0; i < n; ++i)
q->g[i] += p->g[i] * (1.0f - p->x[i] * p->x[i]);
} else if (action == KAD_BACKWARD && kad_is_back(q)) {
for (i = 0; i < n; ++i)
q->g[i] += p->g[i] * (1.0f - p->x[i] * p->x[i]);
}
return 0;
}
Expand All @@ -1565,11 +1598,10 @@ int kad_op_relu(kad_node_t *p, int action)
} else if (action == KAD_FORWARD) {
for (i = 0; i < n; ++i)
p->x[i] = q->x[i] > 0.0f? q->x[i] : 0.0f;
} else if (action == KAD_BACKWARD) {
if (kad_is_back(q))
for (i = 0; i < n; ++i)
if (q->x[i] > 0.0f)
q->g[i] += p->g[i];
} else if (action == KAD_BACKWARD && kad_is_back(q)) {
for (i = 0; i < n; ++i)
if (q->x[i] > 0.0f)
q->g[i] += p->g[i];
}
return 0;
}
Expand Down Expand Up @@ -2101,7 +2133,8 @@ kad_op_f kad_op_list[KAD_MAX_OP] = {
kad_op_avg1d, // 28: 1D average pooling (for 1D ConvNet)
kad_op_mse, // 29: mean square error
kad_op_reshape, // 30
kad_op_concat // 31
kad_op_concat, // 31
kad_op_layernorm // 32: layer normalization
};

/**************************
Expand All @@ -2119,7 +2152,7 @@ void kad_print_graph(FILE *fp, int n, kad_node_t **v)
{
static const char *op[] = { 0, "add", "mul", "cmul", "ce_bin_neg", "square", "sigm", "tanh", "relu", "matmul", "avg", "1minus", "switch", "ce_multi", "softmax",
"dropout", "conv2d", "max2d", "conv1d", "max1d", "slice", "max", "ce_bin", "sub", "sample_normal", "reduce_sum", "reduce_mean", "log",
"avg1d", "mse", "reshape", "concat" };
"avg1d", "mse", "reshape", "concat", "layernorm" };
int i, j;
for (i = 0; i < n; ++i) v[i]->tmp = i;
for (i = 0; i < n; ++i) {
Expand Down
4 changes: 3 additions & 1 deletion kautodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#ifndef KANN_AUTODIFF_H
#define KANN_AUTODIFF_H

#define KAD_VERSION "r392"
#define KAD_VERSION "r393"

#include <stdio.h>
#include <stdint.h>
Expand Down Expand Up @@ -196,6 +196,8 @@ kad_node_t *kad_softmax(kad_node_t *x);// f_i(x_1,...,x_n) = exp(x_i) / \sum_j e
kad_node_t *kad_1minus(kad_node_t *x); // f(x) = 1 - x
kad_node_t *kad_log(kad_node_t *x); // f(x) = log(x)

kad_node_t *kad_layernorm(kad_node_t *x); // layer normalization

// operators taking an indefinite number of operands (e.g. pooling)
kad_node_t *kad_avg(int n, kad_node_t **x); // f(x_1,...,x_n) = \sum_i x_i/n (mean pooling)
kad_node_t *kad_max(int n, kad_node_t **x); // f(x_1,...,x_n) = max{x_1,...,x_n} (max pooling)
Expand Down

0 comments on commit bcedf6c

Please sign in to comment.