Skip to content

Commit 7c96f20

Browse files
committed
adding c stubs
1 parent 8d25f03 commit 7c96f20

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

c/Makefile

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
aa.x : aa.cc
2+
g++ -O3 -ffast-math aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl ${THEANO_BLAS_LDFLAGS}
3+
4+
clean :
5+
rm aa.x

c/aa.cc

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
*
3+
* g++ -O2 -ffast-math -I$PUB_PREFIX/include aa.cc -o aa.x -lgsl -lgslcblas
4+
*
5+
* g++ -O2 -ffast-math -I$PUB_PREFIX/include aa.cc -o aa.x -L$PUB_PREFIX/lib -lgsl -lcblas -lgoto -lgfortran
6+
*
7+
* ./aa.x 10 5 7 1000
8+
*
9+
* */
10+
#include <cassert>
11+
#include <cstdlib>
12+
#include <cstdio>
13+
#include <cmath>
14+
#include <gsl/gsl_rng.h>
15+
#include <gsl/gsl_blas.h>
16+
17+
#include <time.h>
18+
#include <sys/time.h>
19+
20+
double pytime(const struct timeval * tv)
21+
{
22+
return (double) tv->tv_sec + (double) tv->tv_usec / 1000000.0;
23+
}
24+
25+
int main(int argc, char **argv)
26+
{
27+
assert(argc == 5);
28+
29+
int neg = strtol(argv[1], 0, 0);
30+
int nout = strtol(argv[2], 0, 0);
31+
int nin = nout;
32+
int nhid = strtol(argv[3], 0, 0);
33+
int niter = strtol(argv[4], 0, 0);
34+
double lr = 0.01;
35+
gsl_rng * rng = gsl_rng_alloc (gsl_rng_taus);
36+
gsl_rng_set(rng, 234);
37+
38+
39+
gsl_matrix * x = gsl_matrix_alloc(neg, nin);
40+
gsl_matrix * w = gsl_matrix_alloc(nin, nhid);
41+
gsl_vector * a = gsl_vector_alloc(nhid);
42+
gsl_vector * b = gsl_vector_alloc(nout);
43+
gsl_matrix * xw = gsl_matrix_alloc(neg, nhid);
44+
gsl_matrix * hid = gsl_matrix_alloc(neg, nhid);
45+
gsl_matrix * hidwt = gsl_matrix_alloc(neg, nout);
46+
gsl_matrix * g_hidwt = gsl_matrix_alloc(neg, nout);
47+
gsl_matrix * g_hid = gsl_matrix_alloc(neg, nhid);
48+
gsl_matrix * g_w = gsl_matrix_alloc(nout, nhid);
49+
gsl_vector * g_b = gsl_vector_alloc(nout);
50+
51+
for (int i = 0; i < neg*nout; ++i) x->data[i] = (gsl_rng_uniform(rng) -0.5)*1.5;
52+
for (int i = 0; i < nout*nhid; ++i) w->data[i] = gsl_rng_uniform(rng);
53+
for (int i = 0; i < nhid; ++i) a->data[i] = 0.0;
54+
for (int i = 0; i < nout; ++i) b->data[i] = 0.0;
55+
56+
//
57+
//
58+
//
59+
//
60+
61+
struct timeval tv0, tv1;
62+
63+
struct timeval tdot0, tdot1;
64+
double time_of_dot = 0.0;
65+
66+
gettimeofday(&tv0, 0);
67+
double err = 0.0;
68+
for (int iter = 0; iter < niter; ++iter)
69+
{
70+
gettimeofday(&tdot0, 0);
71+
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, w, 0.0, xw);
72+
gettimeofday(&tdot1, 0);
73+
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
74+
75+
for (int i = 0; i < neg; ++i)
76+
for (int j = 0; j < nhid; ++j)
77+
{
78+
double act = xw->data[i*nhid+j] + a->data[j];
79+
hid->data[i*nhid+j] = tanh(act);
80+
}
81+
82+
gettimeofday(&tdot0, 0);
83+
gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, hid, w, 0.0, hidwt);
84+
gettimeofday(&tdot1, 0);
85+
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
86+
87+
for (int i = 0; i < nout; ++i) g_b->data[i] = 0.0;
88+
err = 0.0;
89+
for (int i = 0; i < neg; ++i)
90+
for (int j = 0; j < nout; ++j)
91+
{
92+
double act = hidwt->data[i*nout+j] + b->data[j];
93+
double out = tanh(act);
94+
double g_out = out - x->data[i*nout+j];
95+
err += g_out * g_out;
96+
g_hidwt->data[i*nout+j] = g_out * (1.0 - out*out);
97+
g_b->data[j] += g_hidwt->data[i*nout+j];
98+
}
99+
for (int i = 0; i < nout; ++i) b->data[i] -= lr * g_b->data[i];
100+
101+
if (1)
102+
{
103+
gettimeofday(&tdot0, 0);
104+
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, g_hidwt, w, 0.0, g_hid);
105+
gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, g_hidwt, hid, 0.0, g_w);
106+
gettimeofday(&tdot1, 0);
107+
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
108+
109+
110+
for (int i = 0; i < neg; ++i)
111+
for (int j = 0; j < nhid; ++j)
112+
{
113+
g_hid->data[i*nhid+j] *= (1.0 - hid->data[i*nhid+j] * hid->data[i*nhid+j]);
114+
a->data[j] -= lr * g_hid->data[i*nhid+j];
115+
}
116+
117+
gettimeofday(&tdot0, 0);
118+
gsl_blas_dgemm(CblasTrans, CblasNoTrans, -lr, x, g_hid, 1.0, w);
119+
gettimeofday(&tdot1, 0);
120+
time_of_dot += pytime(&tdot1) - pytime(&tdot0);
121+
for (int i = 0; i < nout*nhid; ++i) w->data[i] -= lr * g_w->data[i];
122+
}
123+
124+
}
125+
gettimeofday(&tv1, 0);
126+
127+
double total_time = pytime(&tv1) - pytime(&tv0);
128+
fprintf(stdout, "took = %lfs to get err %lf\n", total_time, 0.5 * err);
129+
fprintf(stdout, "... of which %.2lfs was spent in dgemm (fraction: %.2lf)\n", time_of_dot, time_of_dot / total_time);
130+
//skip freeing
131+
return 0;
132+
}
133+

0 commit comments

Comments
 (0)