Skip to content

Commit

Permalink
Improve performance of Cox PH using Efron's approximation for handlin…
Browse files Browse the repository at this point in the history
…g ties.
  • Loading branch information
aboyoun committed Oct 23, 2014
1 parent f249955 commit b1e50a3
Showing 1 changed file with 76 additions and 58 deletions.
134 changes: 76 additions & 58 deletions src/main/java/hex/CoxPH.java
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,13 @@ protected float[] score0(double[] data, float[] preds) {
for (int t = 0; t < n_time; ++t)
preds[t + 1] = (float) (risk * cumhaz_0[t]);
for (int t = 0; t < n_time; ++t) {
final double cumhaz_0_t = cumhaz_0[t];
double var_cumhaz_2_t = 0;
for (int j = 0; j < n_coef; ++j) {
double sum = 0;
for (int k = 0; k < n_coef; ++k)
sum += var_coef[j][k] * (full_data[k] * cumhaz_0[t] - var_cumhaz_2[k][t]);
var_cumhaz_2_t += (full_data[j] * cumhaz_0[t] - var_cumhaz_2[j][t]) * sum;
sum += var_coef[j][k] * (full_data[k] * cumhaz_0_t - var_cumhaz_2[t][k]);
var_cumhaz_2_t += (full_data[j] * cumhaz_0_t - var_cumhaz_2[t][j]) * sum;
}
preds[t + 1 + n_time] = (float) (risk * Math.sqrt(var_cumhaz_1[t] + var_cumhaz_2_t));
}
Expand Down Expand Up @@ -255,7 +256,7 @@ protected void initStats(Frame source, DataInfo dinfo) {
n_censor = MemoryManager.malloc8d(n_time);
cumhaz_0 = MemoryManager.malloc8d(n_time);
var_cumhaz_1 = MemoryManager.malloc8d(n_time);
var_cumhaz_2 = malloc2DArray(n_coef, n_time);
var_cumhaz_2 = malloc2DArray(n_time, n_coef);
}

protected void calcCounts(CoxPHTask coxMR) {
Expand Down Expand Up @@ -294,22 +295,27 @@ protected double calcLoglik(CoxPHTask coxMR) {
switch (parameters.ties) {
case efron:
for (int t = coxMR.sizeEvents.length - 1; t >= 0; --t) {
if (coxMR.sizeEvents[t] > 0) {
final double avgSize = coxMR.sizeEvents[t] / coxMR.countEvents[t];
newLoglik += coxMR.sumLogRiskEvents[t];
final double sizeEvents_t = coxMR.sizeEvents[t];
if (sizeEvents_t > 0) {
final long countEvents_t = coxMR.countEvents[t];
final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t];
final double sumRiskEvents_t = coxMR.sumRiskEvents[t];
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
final double avgSize = sizeEvents_t / countEvents_t;
newLoglik += sumLogRiskEvents_t;
for (int j = 0; j < n_coef; ++j)
gradient[j] += coxMR.sumXEvents[j][t];
for (long e = 0; e < coxMR.countEvents[t]; ++e) {
final double frac = ((double) e) / ((double) coxMR.countEvents[t]);
final double term = coxMR.rcumsumRisk[t] - frac * coxMR.sumRiskEvents[t];
gradient[j] += coxMR.sumXEvents[t][j];
for (long e = 0; e < countEvents_t; ++e) {
final double frac = ((double) e) / ((double) countEvents_t);
final double term = rcumsumRisk_t - frac * sumRiskEvents_t;
newLoglik -= avgSize * Math.log(term);
for (int j = 0; j < n_coef; ++j) {
final double djTerm = coxMR.rcumsumXRisk[j][t] - frac * coxMR.sumXRiskEvents[j][t];
final double djTerm = coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j];
final double djLogTerm = djTerm / term;
gradient[j] -= avgSize * djLogTerm;
for (int k = 0; k < n_coef; ++k) {
final double dkTerm = coxMR.rcumsumXRisk[k][t] - frac * coxMR.sumXRiskEvents[k][t];
final double djkTerm = coxMR.rcumsumXXRisk[j][k][t] - frac * coxMR.sumXXRiskEvents[j][k][t];
final double dkTerm = coxMR.rcumsumXRisk[t][k] - frac * coxMR.sumXRiskEvents[t][k];
final double djkTerm = coxMR.rcumsumXXRisk[t][j][k] - frac * coxMR.sumXXRiskEvents[t][j][k];
hessian[j][k] -= avgSize * (djkTerm / term - (djLogTerm * (dkTerm / term)));
}
}
Expand All @@ -319,17 +325,20 @@ protected double calcLoglik(CoxPHTask coxMR) {
break;
case breslow:
for (int t = coxMR.sizeEvents.length - 1; t >= 0; --t) {
if (coxMR.sizeEvents[t] > 0) {
newLoglik += coxMR.sumLogRiskEvents[t];
newLoglik -= coxMR.sizeEvents[t] * Math.log(coxMR.rcumsumRisk[t]);
final double sizeEvents_t = coxMR.sizeEvents[t];
if (sizeEvents_t > 0) {
final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t];
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
newLoglik += sumLogRiskEvents_t;
newLoglik -= sizeEvents_t * Math.log(rcumsumRisk_t);
for (int j = 0; j < n_coef; ++j) {
final double dlogTerm = coxMR.rcumsumXRisk[j][t] / coxMR.rcumsumRisk[t];
gradient[j] += coxMR.sumXEvents[j][t];
gradient[j] -= coxMR.sizeEvents[t] * dlogTerm;
final double dlogTerm = coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t;
gradient[j] += coxMR.sumXEvents[t][j];
gradient[j] -= sizeEvents_t * dlogTerm;
for (int k = 0; k < n_coef; ++k)
hessian[j][k] -= coxMR.sizeEvents[t] *
(((coxMR.rcumsumXXRisk[j][k][t] / coxMR.rcumsumRisk[t]) -
(dlogTerm * (coxMR.rcumsumXRisk[k][t] / coxMR.rcumsumRisk[t]))));
hessian[j][k] -= sizeEvents_t *
(((coxMR.rcumsumXXRisk[t][j][k] / rcumsumRisk_t) -
(dlogTerm * (coxMR.rcumsumXRisk[t][k] / rcumsumRisk_t))));
}
}
}
Expand Down Expand Up @@ -386,33 +395,42 @@ protected void calcCumhaz_0(CoxPHTask coxMR) {
switch (parameters.ties) {
case efron:
for (int t = 0; t < coxMR.sizeEvents.length; ++t) {
if (coxMR.sizeEvents[t] > 0 || coxMR.sizeCensored[t] > 0) {
final double avgSize = coxMR.sizeEvents[t] / coxMR.countEvents[t];
final double sizeEvents_t = coxMR.sizeEvents[t];
final double sizeCensored_t = coxMR.sizeCensored[t];
if (sizeEvents_t > 0 || sizeCensored_t > 0) {
final long countEvents_t = coxMR.countEvents[t];
final double sumRiskEvents_t = coxMR.sumRiskEvents[t];
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
final double avgSize = sizeEvents_t / countEvents_t;
cumhaz_0[nz] = 0;
var_cumhaz_1[nz] = 0;
for (int j = 0; j < n_coef; ++j)
var_cumhaz_2[j][nz] = 0;
for (long e = 0; e < coxMR.countEvents[t]; ++e) {
final double frac = ((double) e) / ((double) coxMR.countEvents[t]);
final double haz = 1 / (coxMR.rcumsumRisk[t] - frac * coxMR.sumRiskEvents[t]);
var_cumhaz_2[nz][j] = 0;
for (long e = 0; e < countEvents_t; ++e) {
final double frac = ((double) e) / ((double) countEvents_t);
final double haz = 1 / (rcumsumRisk_t - frac * sumRiskEvents_t);
final double haz_sq = haz * haz;
cumhaz_0[nz] += avgSize * haz;
var_cumhaz_1[nz] += avgSize * haz_sq;
for (int j = 0; j < n_coef; ++j)
var_cumhaz_2[j][nz] +=
avgSize * ((coxMR.rcumsumXRisk[j][t] - frac * coxMR.sumXRiskEvents[j][t]) * haz_sq);
var_cumhaz_2[nz][j] +=
avgSize * ((coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j]) * haz_sq);
}
nz++;
}
}
break;
case breslow:
for (int t = 0; t < coxMR.sizeEvents.length; ++t) {
if (coxMR.sizeEvents[t] > 0 || coxMR.sizeCensored[t] > 0) {
cumhaz_0[nz] = coxMR.sizeEvents[t] / coxMR.rcumsumRisk[t];
var_cumhaz_1[nz] = coxMR.sizeEvents[t] / (coxMR.rcumsumRisk[t] * coxMR.rcumsumRisk[t]);
final double sizeEvents_t = coxMR.sizeEvents[t];
final double sizeCensored_t = coxMR.sizeCensored[t];
if (sizeEvents_t > 0 || sizeCensored_t > 0) {
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
final double cumhaz_0_nz = sizeEvents_t / rcumsumRisk_t;
cumhaz_0[nz] = cumhaz_0_nz;
var_cumhaz_1[nz] = sizeEvents_t / (rcumsumRisk_t * rcumsumRisk_t);
for (int j = 0; j < n_coef; ++j)
var_cumhaz_2[j][nz] = (coxMR.rcumsumXRisk[j][t] / coxMR.rcumsumRisk[t]) * cumhaz_0[nz];
var_cumhaz_2[nz][j] = (coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t) * cumhaz_0_nz;
nz++;
}
}
Expand All @@ -425,7 +443,7 @@ protected void calcCumhaz_0(CoxPHTask coxMR) {
cumhaz_0[t] = cumhaz_0[t - 1] + cumhaz_0[t];
var_cumhaz_1[t] = var_cumhaz_1[t - 1] + var_cumhaz_1[t];
for (int j = 0; j < n_coef; ++j)
var_cumhaz_2[j][t] = var_cumhaz_2[j][t - 1] + var_cumhaz_2[j][t];
var_cumhaz_2[t][j] = var_cumhaz_2[t - 1][j] + var_cumhaz_2[t][j];
}
}

Expand All @@ -449,7 +467,7 @@ public Frame makeSurvfit(Key key, double x_new) { // FIXME
surv.set(t, Math.exp(-cumhaz_1));
}
for (int t = 0; t < n_time; ++t) {
final double gamma = x_centered * cumhaz_0[t] - var_cumhaz_2[j][t];
final double gamma = x_centered * cumhaz_0[t] - var_cumhaz_2[t][j];
se_cumhaz.set(t, risk * Math.sqrt(var_cumhaz_1[t] + (gamma * var_coef[j][j] * gamma)));
}
final Frame fr = new Frame(key, new String[] {"time", "cumhaz", "se_cumhaz", "surv"}, vecs);
Expand Down Expand Up @@ -557,14 +575,14 @@ protected void execImpl() {
final CoxPHTask coxMR = new CoxPHTask(self(), dinfo, newCoef, model.min_time, n_time,
use_start_column, use_weights_column).doAll(dinfo._adaptedFrame);

if (i == 0)
model.calcCounts(coxMR);

final double newLoglik = model.calcLoglik(coxMR);
if (newLoglik > oldLoglik) {
model.calcModelStats(newCoef, newLoglik);
model.calcCumhaz_0(coxMR);

if (i == 0)
model.calcCounts(coxMR);

if (newLoglik == 0)
model.lre = - Math.log10(Math.abs(oldLoglik - newLoglik));
else
Expand Down Expand Up @@ -667,11 +685,11 @@ protected void chunkInit(){
sumRiskEvents = MemoryManager.malloc8d(_n_time);
sumLogRiskEvents = MemoryManager.malloc8d(_n_time);
rcumsumRisk = MemoryManager.malloc8d(_n_time);
sumXEvents = malloc2DArray(n_coef, _n_time);
sumXRiskEvents = malloc2DArray(n_coef, _n_time);
rcumsumXRisk = malloc2DArray(n_coef, _n_time);
sumXXRiskEvents = malloc3DArray(n_coef, n_coef, _n_time);
rcumsumXXRisk = malloc3DArray(n_coef, n_coef, _n_time);
sumXEvents = malloc2DArray(_n_time, n_coef);
sumXRiskEvents = malloc2DArray(_n_time, n_coef);
rcumsumXRisk = malloc2DArray(_n_time, n_coef);
sumXXRiskEvents = malloc3DArray(_n_time, n_coef, n_coef);
rcumsumXXRisk = malloc3DArray(_n_time, n_coef, n_coef);
}

@Override
Expand Down Expand Up @@ -723,27 +741,27 @@ protected void processRow(long gid, double [] nums, int ncats, int [] cats, doub
final double x1 = jIsCat ? 1.0 : nums[jit - ncats];
final double xRisk = x1 * risk;
if (event > 0) {
sumXEvents[j][t2] += weight * x1;
sumXRiskEvents[j][t2] += xRisk;
sumXEvents[t2][j] += weight * x1;
sumXRiskEvents[t2][j] += xRisk;
}
if (_use_start_column) {
for (int t = t1; t <= t2; ++t)
rcumsumXRisk[j][t] += xRisk;
rcumsumXRisk[t][j] += xRisk;
} else {
rcumsumXRisk[j][t2] += xRisk;
rcumsumXRisk[t2][j] += xRisk;
}
for (int kit = 0; kit < ntotal; ++kit) {
final boolean kIsCat = kit < ncats;
final int k = kIsCat ? cats[kit] : numStartIter + kit;
final double x2 = kIsCat ? 1.0 : nums[kit - ncats];
final double xxRisk = x2 * xRisk;
if (event > 0)
sumXXRiskEvents[j][k][t2] += xxRisk;
sumXXRiskEvents[t2][j][k] += xxRisk;
if (_use_start_column) {
for (int t = t1; t <= t2; ++t)
rcumsumXXRisk[j][k][t] += xxRisk;
rcumsumXXRisk[t][j][k] += xxRisk;
} else {
rcumsumXXRisk[j][k][t2] += xxRisk;
rcumsumXXRisk[t2][j][k] += xxRisk;
}
}
}
Expand Down Expand Up @@ -775,14 +793,14 @@ protected void postGlobal() {
for (int t = rcumsumRisk.length - 2; t >= 0; --t)
rcumsumRisk[t] += rcumsumRisk[t + 1];

for (int j = 0; j < rcumsumXRisk.length; ++j)
for (int t = rcumsumXRisk[j].length - 2; t >= 0; --t)
rcumsumXRisk[j][t] += rcumsumXRisk[j][t + 1];
for (int t = rcumsumXRisk.length - 2; t >= 0; --t)
for (int j = 0; j < rcumsumXRisk[t].length; ++j)
rcumsumXRisk[t][j] += rcumsumXRisk[t + 1][j];

for (int j = 0; j < rcumsumXXRisk.length; ++j)
for (int k = 0; k < rcumsumXXRisk[j].length; ++k)
for (int t = rcumsumXXRisk[j][k].length - 2; t >= 0; --t)
rcumsumXXRisk[j][k][t] += rcumsumXXRisk[j][k][t + 1];
for (int t = rcumsumXXRisk.length - 2; t >= 0; --t)
for (int j = 0; j < rcumsumXXRisk[t].length; ++j)
for (int k = 0; k < rcumsumXXRisk[t][j].length; ++k)
rcumsumXXRisk[t][j][k] += rcumsumXXRisk[t + 1][j][k];
}
}
}
Expand Down

0 comments on commit b1e50a3

Please sign in to comment.