Skip to content

Commit

Permalink
r480: fast multi-threading working!
Browse files Browse the repository at this point in the history
  • Loading branch information
lh3 committed Feb 27, 2017

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 7ae16f2 commit 9ab2476
Showing 5 changed files with 119 additions and 96 deletions.
4 changes: 2 additions & 2 deletions examples/rnn-bit.c
Original file line number Diff line number Diff line change
@@ -81,7 +81,7 @@ static void train(kann_t *ann, bit_data_t *d, float lr, int mini_size, int max_e

ua = kann_unroll(ann, d->ulen);
kann_set_batch_size(ua, mini_size);
kann_set_mt(ua, n_threads, mini_size);
kann_mt(ua, n_threads, mini_size);
kann_feed_bind(ua, KANN_F_IN, 0, x);
kann_feed_bind(ua, KANN_F_TRUTH, 0, y);
kann_switch(ua, 1);
@@ -99,7 +99,7 @@ static void train(kann_t *ann, bit_data_t *d, float lr, int mini_size, int max_e
}
}
cost += kann_cost(ua, 0, 1) * d->ulen * mini_size;
// n_cerr += kann_class_error(ua);
n_cerr += kann_class_error(ua);
//kad_check_grad(ua->n, ua->v, ua->n-1);
kann_RMSprop(n_var, lr, 0, 0.9f, ua->g, ua->x, r);
tot += d->ulen * mini_size;
6 changes: 3 additions & 3 deletions examples/textgen.c
Original file line number Diff line number Diff line change
@@ -161,7 +161,7 @@ float tg_perplexity(kann_t *ann, const tg_data_t *tg)
return (float)exp(-loss / (tg->len - 1));
}

void tg_train(kann_t *ann, const tg_data_t *tg, float lr, int ulen, int mbs, int max_epoch, float grad_clip, const char *fn, int batch_len, int use_mini, int use_para)
void tg_train(kann_t *ann, const tg_data_t *tg, float lr, int ulen, int mbs, int max_epoch, float grad_clip, const char *fn, int batch_len, int use_mini, int use_para, int n_threads)
{
int i, epoch, k, n_var, n_char, real_mbs = use_mini? mbs : 1;
float **x, **y, *r, *g;
@@ -179,6 +179,7 @@ void tg_train(kann_t *ann, const tg_data_t *tg, float lr, int ulen, int mbs, int
g = (float*)calloc(n_var, sizeof(float));

ua = kann_unroll(ann, ulen);
kann_mt(ua, n_threads, mbs);
kann_switch(ua, 1);
kann_set_batch_size(ua, real_mbs);
kann_feed_bind(ua, KANN_F_IN, 0, x);
@@ -349,8 +350,7 @@ int main(int argc, char *argv[])
tg = tg_init(argv[optind]);
fprintf(stderr, "Read %d paragraphs and %d characters; alphabet size %d\n", tg->n_para, tg->len, tg->n_char);
if (!ann) ann = model_gen(model, tg->n_char, n_h_layers, n_h_neurons, h_dropout, use_norm);
if (n_threads > 1) kann_set_mt(ann, n_threads, mbs);
tg_train(ann, tg, lr, ulen, mbs, max_epoch, grad_clip, fn_out, batch_len, use_batch, use_para);
tg_train(ann, tg, lr, ulen, mbs, max_epoch, grad_clip, fn_out, batch_len, use_batch, use_para, n_threads);
free(tg->data); free(tg);
} else tg_gen(stdout, ann, temp, len_gen, c2i, prefix);

188 changes: 102 additions & 86 deletions kann.c
Original file line number Diff line number Diff line change
@@ -112,8 +112,7 @@ kann_t *kann_unroll(kann_t *a, int len)

void kann_delete_unrolled(kann_t *a)
{
extern void kann_mt_destroy(kann_t*);
if (a && a->mt) kann_mt_destroy(a);
if (a && a->mt) kann_mt(a, 0, 0);
if (a && a->v) kad_delete(a->n, a->v);
free(a);
}
@@ -163,7 +162,7 @@ int kann_feed_dim(const kann_t *a, uint32_t ext_flag, int32_t ext_label)
return k == 1? n : k == 0? -1 : -2;
}

float kann_cost_core(kann_t *a, int cost_label, int cal_grad)
static float kann_cost_core(kann_t *a, int cost_label, int cal_grad)
{
int i_cost;
float cost;
@@ -174,29 +173,91 @@ float kann_cost_core(kann_t *a, int cost_label, int cal_grad)
return cost;
}

int kann_eval(kann_t *a, uint32_t ext_flag, int ext_label)
{
int i, k;
for (i = k = 0; i < a->n; ++i)
if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
++k, a->v[i]->tmp = 1;
kad_eval_marked(a->n, a->v);
return k;
}

void kann_rnn_start(kann_t *a)
{
int i;
kann_set_batch_size(a, 1);
for (i = 0; i < a->n; ++i) {
kad_node_t *p = a->v[i];
if (p->pre) { // NB: BE CAREFUL of the interaction between kann_rnn_start() and kann_set_batch_size()
kad_node_t *q = p->pre;
if (q->x) memcpy(p->x, q->x, kad_len(p) * sizeof(float));
else memset(p->x, 0, kad_len(p) * sizeof(float));
q->x = p->x;
}
}
}

void kann_rnn_end(kann_t *a)
{
kad_ext_sync(a->n, a->v, a->x, a->g, a->c);
}

static int kann_class_error_core(const kann_t *ann)
{
int i, j, k, n, off, n_err = 0, is_class = 1;
for (i = 0; i < ann->n; ++i) {
kad_node_t *p = ann->v[i];
if ((p->op == 13 || p->op == 22) && p->n_child == 2 && p->n_d == 0) { // ce_bin or ce_multi
kad_node_t *x = p->child[0], *t = p->child[1];
n = kad_len(t) / t->d[0];
for (j = off = 0; j < t->d[0]; ++j, off += n) {
float t_sum = 0.0f, t_min = 1.0f, t_max = 0.0f, x_max = 0.0f, x_min = 1.0f;
int x_max_k = -1, t_max_k = -1;
for (k = 0; k < n; ++k) {
float xk = x->x[off+k], tk = t->x[off+k];
t_sum += tk;
t_min = t_min < tk? t_min : tk;
x_min = x_min < xk? x_min : xk;
if (t_max < tk) t_max = tk, t_max_k = k;
if (x_max < xk) x_max = xk, x_max_k = k;
}
if (t_sum - 1.0f == 0 && t_min >= 0.0f && x_min >= 0.0f && x_max <= 1.0f)
n_err += (x_max_k != t_max_k);
else is_class = 0;
}
}
}
return is_class? n_err : -1;
}

/*************************
* @@MT: multi-threading *
*************************/

#ifdef HAVE_PTHREAD
#include <pthread.h>

struct mtaux_t;

typedef struct {
typedef struct { // per-worker data
kann_t *a;
float cost;
int action;
pthread_t tid;
struct mtaux_t *g;
} mtaux1_t;

typedef struct mtaux_t {
typedef struct mtaux_t { // cross-worker data
int n_threads, max_batch_size;
int cal_grad, cost_label;
volatile int n_idle;
volatile int n_idle; // we will be busy waiting on this, so volatile necessary
pthread_mutex_t mtx;
pthread_cond_t cv;
mtaux1_t *mt;
} mtaux_t;

static void *mt_worker(void *data)
static void *mt_worker(void *data) // pthread worker
{
mtaux1_t *mt1 = (mtaux1_t*)data;
mtaux_t *mt = mt1->g;
@@ -216,10 +277,32 @@ static void *mt_worker(void *data)
pthread_exit(0);
}

void kann_set_mt(kann_t *ann, int n_threads, int max_batch_size)
static void mt_destroy(mtaux_t *mt) // de-allocate an entire mtaux_t struct
{
int i;
pthread_mutex_lock(&mt->mtx);
mt->n_idle = 0;
for (i = 1; i < mt->n_threads; ++i) mt->mt[i].action = -1;
pthread_cond_broadcast(&mt->cv);
pthread_mutex_unlock(&mt->mtx);
for (i = 1; i < mt->n_threads; ++i) pthread_join(mt->mt[i].tid, 0);
for (i = 0; i < mt->n_threads; ++i) kann_delete(mt->mt[i].a);
free(mt->mt);
pthread_cond_destroy(&mt->cv);
pthread_mutex_destroy(&mt->mtx);
free(mt);
}

void kann_mt(kann_t *ann, int n_threads, int max_batch_size)
{
mtaux_t *mt;
int i, k;

if (n_threads <= 1) {
if (ann->mt) mt_destroy((mtaux_t*)ann->mt);
ann->mt = 0;
return;
}
if (n_threads > max_batch_size) n_threads = max_batch_size;
if (n_threads <= 1) return;

@@ -240,25 +323,6 @@ void kann_set_mt(kann_t *ann, int n_threads, int max_batch_size)
ann->mt = mt;
}

void kann_mt_destroy(kann_t *ann)
{
int i;
mtaux_t *mt = (mtaux_t*)ann->mt;
if (ann->mt == 0) return;
pthread_mutex_lock(&mt->mtx);
mt->n_idle = 0;
for (i = 1; i < mt->n_threads; ++i) mt->mt[i].action = -1;
pthread_cond_broadcast(&mt->cv);
pthread_mutex_unlock(&mt->mtx);
for (i = 1; i < mt->n_threads; ++i) pthread_join(mt->mt[i].tid, 0);
for (i = 0; i < mt->n_threads; ++i) kann_delete(mt->mt[i].a);
free(mt->mt);
pthread_cond_destroy(&mt->cv);
pthread_mutex_destroy(&mt->mtx);
free(mt);
ann->mt = 0;
}

float kann_cost(kann_t *a, int cost_label, int cal_grad)
{
mtaux_t *mt = (mtaux_t*)a->mt;
@@ -298,69 +362,21 @@ float kann_cost(kann_t *a, int cost_label, int cal_grad)
}
return cost;
}
#else
void kann_set_mt(kann_t *ann, int n_threads, int max_batch_size) {}
void kann_mt_destroy(kann_t *ann) {}
float kann_cost(kann_t *a, int cost_label, int cal_grad) { return kann_cost_core(a, cost_label, cal_grad); }
#endif

int kann_eval(kann_t *a, uint32_t ext_flag, int ext_label)
{
int i, k;
for (i = k = 0; i < a->n; ++i)
if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
++k, a->v[i]->tmp = 1;
kad_eval_marked(a->n, a->v);
return k;
}

void kann_rnn_start(kann_t *a)
{
int i;
kann_set_batch_size(a, 1);
for (i = 0; i < a->n; ++i) {
kad_node_t *p = a->v[i];
if (p->pre) { // NB: BE CAREFUL of the interaction between kann_rnn_start() and kann_set_batch_size()
kad_node_t *q = p->pre;
if (q->x) memcpy(p->x, q->x, kad_len(p) * sizeof(float));
else memset(p->x, 0, kad_len(p) * sizeof(float));
q->x = p->x;
}
}
}

void kann_rnn_end(kann_t *a)
{
kad_ext_sync(a->n, a->v, a->x, a->g, a->c);
}

int kann_class_error(const kann_t *ann)
{
int i, j, k, n, off, n_err = 0, is_class = 1;
for (i = 0; i < ann->n; ++i) {
kad_node_t *p = ann->v[i];
if ((p->op == 13 || p->op == 22) && p->n_child == 2 && p->n_d == 0) { // ce_bin or ce_multi
kad_node_t *x = p->child[0], *t = p->child[1];
n = kad_len(t) / t->d[0];
for (j = off = 0; j < t->d[0]; ++j, off += n) {
float t_sum = 0.0f, t_min = 1.0f, t_max = 0.0f, x_max = 0.0f, x_min = 1.0f;
int x_max_k = -1, t_max_k = -1;
for (k = 0; k < n; ++k) {
float xk = x->x[off+k], tk = t->x[off+k];
t_sum += tk;
t_min = t_min < tk? t_min : tk;
x_min = x_min < xk? x_min : xk;
if (t_max < tk) t_max = tk, t_max_k = k;
if (x_max < xk) x_max = xk, x_max_k = k;
}
if (t_sum - 1.0f == 0 && t_min >= 0.0f && x_min >= 0.0f && x_max <= 1.0f)
n_err += (x_max_k != t_max_k);
else is_class = 0;
}
}
}
return is_class? n_err : -1;
mtaux_t *mt = (mtaux_t*)ann->mt;
int i, n_err = 0;
if (mt == 0) return kann_class_error_core(ann);
for (i = 0; i < mt->n_threads; ++i)
n_err += kann_class_error_core(mt->mt[i].a);
return n_err;
}
#else
void kann_mt(kann_t *ann, int n_threads, int max_batch_size) {}
float kann_cost(kann_t *a, int cost_label, int cal_grad) { return kann_cost_core(a, cost_label, cal_grad); }
float kann_class_error(const kann_t *a) { return kann_class_error_core(a); }
#endif

/***********************
*** @@IO: model I/O ***
15 changes: 11 additions & 4 deletions kann.h
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
#ifndef KANN_H
#define KANN_H

#define KANN_VERSION "r475"
#define KANN_VERSION "r480"

#define KANN_F_IN 0x1 // input
#define KANN_F_OUT 0x2 // output
@@ -48,7 +48,7 @@ typedef struct {
int n; // number of nodes in the computational graph
kad_node_t **v; // list of nodes
float *x, *g, *c; // collated variable values, gradients and constant values
void *mt;
void *mt; // auxiliary data for multi-threading; NULL if multi-threading disabled
} kann_t;

extern int kann_verbose;
@@ -94,6 +94,15 @@ kann_t *kann_clone(kann_t *a, int batch_size);
void kann_delete(kann_t *a); // delete a network generated by kann_new() or kann_layer_final()
void kann_delete_unrolled(kann_t *a); // delete a network generated by kann_unroll()

/**
* Enable/disable multi-threading (requiring pthread)
*
* @param ann network
* @param n_threads number of threads; <=1 to completely disable multi-threading
* @param max_batch_size max mini-batch size; shall no smaller than n_threads
*/
void kann_mt(kann_t *ann, int n_threads, int max_batch_size);

/**
* Bind float arrays to feed nodes
*
@@ -174,8 +183,6 @@ void kann_RMSprop(int n, float h0, const float *h, float decay, const float *g,

float kann_grad_clip(float thres, int n, float *g);

void kann_set_mt(kann_t *ann, int n_threads, int max_batch_size);

// common layers
kad_node_t *kann_layer_input(int n1);
kad_node_t *kann_layer_linear(kad_node_t *in, int n1);
2 changes: 1 addition & 1 deletion kautodiff.h
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
#ifndef KANN_AUTODIFF_H
#define KANN_AUTODIFF_H

#define KAD_VERSION "r475"
#define KAD_VERSION "r480"

#include <stdio.h>
#include <stdint.h>

0 comments on commit 9ab2476

Please sign in to comment.