Skip to content

Commit

Permalink
using x_traj in erk adj
Browse files Browse the repository at this point in the history
  • Loading branch information
giaf committed Oct 8, 2017
1 parent 5e2c977 commit 06a371c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 62 deletions.
5 changes: 3 additions & 2 deletions include/hpipm_d_erk_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ struct d_erk_workspace
void *ode_args; // pointer to ode args
struct d_erk_arg *erk_arg; // erk arg
double *K; // internal variables
double *x; // states and forward sensitivities
double *x_for; // states and forward sensitivities
double *x_traj; // states at all steps
double *l; // adjoint sensitivities
double *p; // parameter
double *xt; // temporary states and forward sensitivities
double *x_tmp; // temporary states and forward sensitivities
double *adj_in;
double *adj_tmp;
int nx; // number of states
Expand Down
7 changes: 4 additions & 3 deletions ocp_nlp/d_ocp_nlp_aux.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void d_cvt_erk_int_to_ocp_qp(int n, struct d_erk_workspace *erk_ws, struct d_ocp
int nf = erk_ws->nf;
int nX = nx*(1+nf);

double *x = erk_ws->x;
double *x = erk_ws->x_for;
// if(adj_sens!=0 & erk_ws->erk_arg->adj_sens!=0)
// x = erk_ws->x + nX*erk_ws->erk_arg->steps;

Expand Down Expand Up @@ -91,9 +91,10 @@ void d_cvt_erk_int_to_ocp_qp_rhs(int n, struct d_erk_workspace *erk_ws, struct d

int nX = nx*(1+nf);

// double *x = erk_ws->x;
// double *x = erk_ws->x_for;
// if(erk_ws->erk_arg->adj_sens!=0)
double *x = erk_ws->x + nX*erk_ws->erk_arg->steps;
// double *x = erk_ws->x_for + nX*erk_ws->erk_arg->steps;
double *x = erk_ws->x_for;

struct d_strvec sl;
d_create_strvec(nu+nx, &sl, erk_ws->l);
Expand Down
130 changes: 73 additions & 57 deletions sim_core/d_erk_int.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,19 @@ int d_memsize_erk_int(struct d_erk_arg *erk_arg, int nx, int np, int nf_max, int
int size = 0;

size += 1*np*sizeof(double); // p
size += 1*nX*sizeof(double); // x_for
size += 1*ns*nX*sizeof(double); // K
size += 1*nX*sizeof(double); // x_tmp
if(na_max>0)
{
size += 1*nX*(steps+1)*sizeof(double); // x
size += 1*ns*nX*steps*sizeof(double); // K
size += 1*nX*ns*sizeof(double); // xt
// size += 1*nX*(steps+1)*sizeof(double); // x_traj XXX
size += 1*nx*(ns*steps+1)*sizeof(double); // x_traj
// size += 1*ns*nX*steps*sizeof(double); // K
// size += 1*nX*ns*sizeof(double); // x_tmp
size += 1*nf_max*(steps+1)*sizeof(double); // l // XXX *na_max ???
size += 1*(nx+nf_max)*sizeof(double); // adj_in // XXX *na_max ???
size += 1*nf_max*ns*sizeof(double); // adj_tmp // XXX *na_max ???
}
else
{
size += 1*nX*sizeof(double); // x
size += 1*ns*nX*sizeof(double); // K
size += 1*nX*sizeof(double); // xt
}

return size;

Expand Down Expand Up @@ -97,17 +95,29 @@ void d_create_erk_int(struct d_erk_arg *erk_arg, int nx, int np, int nf_max, int
ws->p = d_ptr;
d_ptr += np;
//
ws->x_for = d_ptr;
d_ptr += nX;
//
ws->K = d_ptr;
d_ptr += ns*nX;
//
ws->x_tmp = d_ptr;
d_ptr += nX;
//
if(na_max>0)
{
//
ws->x = d_ptr;
d_ptr += nX*(steps+1);
// ws->x_for = d_ptr;
// d_ptr += nX*(steps+1);
//
ws->x_traj = d_ptr;
d_ptr += nx*(ns*steps+1);
//
ws->K = d_ptr;
d_ptr += ns*nX*steps;
// ws->K = d_ptr;
// d_ptr += ns*nX*steps;
//
ws->xt = d_ptr;
d_ptr += nX*ns;
// ws->x_tmp = d_ptr;
// d_ptr += nX*ns;
//
ws->l = d_ptr;
d_ptr += nf_max*(steps+1);
Expand All @@ -118,18 +128,6 @@ void d_create_erk_int(struct d_erk_arg *erk_arg, int nx, int np, int nf_max, int
ws->adj_tmp = d_ptr;
d_ptr += nf_max*ns;
}
else
{
//
ws->K = d_ptr;
d_ptr += ns*nX;
//
ws->x = d_ptr;
d_ptr += nX;
//
ws->xt = d_ptr;
d_ptr += nX;
}


ws->memsize = d_memsize_erk_int(erk_arg, nx, np, nf_max, na_max);
Expand Down Expand Up @@ -169,15 +167,15 @@ void d_init_erk_int(int nf, int na, double *x0, double *p0, double *fs0, double

int steps = ws->erk_arg->steps;

double *x = ws->x;
double *x_for = ws->x_for;
double *p = ws->p;
double *l = ws->l;

for(ii=0; ii<nx; ii++)
x[ii] = x0[ii];
x_for[ii] = x0[ii];

for(ii=0; ii<nx*nf; ii++)
x[nx+ii] = fs0[ii];
x_for[nx+ii] = fs0[ii];

for(ii=0; ii<np; ii++)
p[ii] = p0[ii];
Expand Down Expand Up @@ -237,10 +235,11 @@ void d_erk_int(struct d_erk_workspace *ws)
int nf = ws->nf;
int na = ws->na;
double *K0 = ws->K;
double *x0 = ws->x;
double *x1 = ws->x;
double *x0 = ws->x_for;
double *x1 = ws->x_for;
double *x_traj = ws->x_traj;
double *p = ws->p;
double *xt = ws->xt;
double *x_tmp = ws->x_tmp;
double *adj_in = ws->adj_in;
double *adj_tmp = ws->adj_tmp;

Expand All @@ -253,7 +252,7 @@ void d_erk_int(struct d_erk_workspace *ws)

struct d_strvec sxt; // XXX
struct d_strvec sK; // XXX
sxt.pa = xt; // XXX
sxt.pa = x_tmp; // XXX

int ii, jj, step, ss;
double t, a, b;
Expand All @@ -267,20 +266,27 @@ void d_erk_int(struct d_erk_workspace *ws)
// TODO no need to save the entire [x Su Sx] & sens, but only [x] & sens !!!

t = 0.0; // TODO plus time of multiple-shooting stage !!!
if(na>0)
{
x_traj = ws->x_traj;
for(ii=0; ii<nx; ii++)
x_traj[ii] = x0[ii];
x_traj += nx;
}
for(step=0; step<steps; step++)
{
if(na>0)
{
x0 = ws->x + step*nX;
x1 = ws->x + (step+1)*nX;
for(ii=0; ii<nX; ii++)
x1[ii] = x0[ii];
K0 = ws->K + ns*step*nX;
}
// if(na>0)
// {
// x0 = ws->x_for + step*nX;
// x1 = ws->x_for + (step+1)*nX;
// for(ii=0; ii<nX; ii++)
// x1[ii] = x0[ii];
// K0 = ws->K + ns*step*nX;
// }
for(ss=0; ss<ns; ss++)
{
for(ii=0; ii<nX; ii++)
xt[ii] = x0[ii];
x_tmp[ii] = x0[ii];
for(ii=0; ii<ss; ii++)
{
a = A_rk[ss+ns*ii];
Expand All @@ -292,11 +298,17 @@ void d_erk_int(struct d_erk_workspace *ws)
daxpy_libstr(nX, a, &sK, 0, &sxt, 0, &sxt, 0); // XXX
#else
for(jj=0; jj<nX; jj++)
xt[jj] += a*K0[jj+ii*(nX)];
x_tmp[jj] += a*K0[jj+ii*(nX)];
#endif
}
}
ws->vde_for(t+h*C_rk[ss], xt, p, ws->ode_args, K0+ss*(nX));
if(na>0)
{
for(ii=0; ii<nx; ii++)
x_traj[ii] = x_tmp[ii];
x_traj += nx;
}
ws->vde_for(t+h*C_rk[ss], x_tmp, p, ws->ode_args, K0+ss*(nX));
}
for(ss=0; ss<ns; ss++)
{
Expand All @@ -311,29 +323,33 @@ void d_erk_int(struct d_erk_workspace *ws)

if(na>0)
{
x_traj = ws->x_traj + nx*ns*steps;
t = steps*h; // TODO plus time of multiple-shooting stage !!!
for(step=steps-1; step>=0; step--)
{
l0 = ws->l + step*nA;
l1 = ws->l + (step+1)*nA;
x0 = ws->x + step*nX;
x0 = ws->x_for + step*nX;
K0 = ws->K + ns*step*nX; // XXX save all x insead !!!
// TODO save all x instead of K !!!
for(ss=ns-1; ss>=0; ss--)
{
// x
for(ii=0; ii<nx; ii++)
adj_in[0+ii] = x0[ii];
for(ii=0; ii<ss; ii++)
{
a = A_rk[ss+ns*ii];
if(a!=0)
{
a *= h;
for(jj=0; jj<nx; jj++)
adj_in[0+jj] += a*K0[jj+ii*(nX)];
}
}
adj_in[0+ii] = x_traj[ii];
x_traj -= nx;
// for(ii=0; ii<nx; ii++)
// adj_in[0+ii] = x0[ii];
// for(ii=0; ii<ss; ii++)
// {
// a = A_rk[ss+ns*ii];
// if(a!=0)
// {
// a *= h;
// for(jj=0; jj<nx; jj++)
// adj_in[0+jj] += a*K0[jj+ii*(nX)];
// }
// }
// l
b = h*B_rk[ss];
for(ii=0; ii<nx; ii++)
Expand Down

0 comments on commit 06a371c

Please sign in to comment.