Skip to content

Commit b1e50a3

Browse files
committed
Improve performance of Cox PH using Efron's approximation for handling ties.
1 parent f249955 commit b1e50a3

File tree

1 file changed

+76
-58
lines changed

1 file changed

+76
-58
lines changed

src/main/java/hex/CoxPH.java

+76-58
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,13 @@ protected float[] score0(double[] data, float[] preds) {
214214
for (int t = 0; t < n_time; ++t)
215215
preds[t + 1] = (float) (risk * cumhaz_0[t]);
216216
for (int t = 0; t < n_time; ++t) {
217+
final double cumhaz_0_t = cumhaz_0[t];
217218
double var_cumhaz_2_t = 0;
218219
for (int j = 0; j < n_coef; ++j) {
219220
double sum = 0;
220221
for (int k = 0; k < n_coef; ++k)
221-
sum += var_coef[j][k] * (full_data[k] * cumhaz_0[t] - var_cumhaz_2[k][t]);
222-
var_cumhaz_2_t += (full_data[j] * cumhaz_0[t] - var_cumhaz_2[j][t]) * sum;
222+
sum += var_coef[j][k] * (full_data[k] * cumhaz_0_t - var_cumhaz_2[t][k]);
223+
var_cumhaz_2_t += (full_data[j] * cumhaz_0_t - var_cumhaz_2[t][j]) * sum;
223224
}
224225
preds[t + 1 + n_time] = (float) (risk * Math.sqrt(var_cumhaz_1[t] + var_cumhaz_2_t));
225226
}
@@ -255,7 +256,7 @@ protected void initStats(Frame source, DataInfo dinfo) {
255256
n_censor = MemoryManager.malloc8d(n_time);
256257
cumhaz_0 = MemoryManager.malloc8d(n_time);
257258
var_cumhaz_1 = MemoryManager.malloc8d(n_time);
258-
var_cumhaz_2 = malloc2DArray(n_coef, n_time);
259+
var_cumhaz_2 = malloc2DArray(n_time, n_coef);
259260
}
260261

261262
protected void calcCounts(CoxPHTask coxMR) {
@@ -294,22 +295,27 @@ protected double calcLoglik(CoxPHTask coxMR) {
294295
switch (parameters.ties) {
295296
case efron:
296297
for (int t = coxMR.sizeEvents.length - 1; t >= 0; --t) {
297-
if (coxMR.sizeEvents[t] > 0) {
298-
final double avgSize = coxMR.sizeEvents[t] / coxMR.countEvents[t];
299-
newLoglik += coxMR.sumLogRiskEvents[t];
298+
final double sizeEvents_t = coxMR.sizeEvents[t];
299+
if (sizeEvents_t > 0) {
300+
final long countEvents_t = coxMR.countEvents[t];
301+
final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t];
302+
final double sumRiskEvents_t = coxMR.sumRiskEvents[t];
303+
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
304+
final double avgSize = sizeEvents_t / countEvents_t;
305+
newLoglik += sumLogRiskEvents_t;
300306
for (int j = 0; j < n_coef; ++j)
301-
gradient[j] += coxMR.sumXEvents[j][t];
302-
for (long e = 0; e < coxMR.countEvents[t]; ++e) {
303-
final double frac = ((double) e) / ((double) coxMR.countEvents[t]);
304-
final double term = coxMR.rcumsumRisk[t] - frac * coxMR.sumRiskEvents[t];
307+
gradient[j] += coxMR.sumXEvents[t][j];
308+
for (long e = 0; e < countEvents_t; ++e) {
309+
final double frac = ((double) e) / ((double) countEvents_t);
310+
final double term = rcumsumRisk_t - frac * sumRiskEvents_t;
305311
newLoglik -= avgSize * Math.log(term);
306312
for (int j = 0; j < n_coef; ++j) {
307-
final double djTerm = coxMR.rcumsumXRisk[j][t] - frac * coxMR.sumXRiskEvents[j][t];
313+
final double djTerm = coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j];
308314
final double djLogTerm = djTerm / term;
309315
gradient[j] -= avgSize * djLogTerm;
310316
for (int k = 0; k < n_coef; ++k) {
311-
final double dkTerm = coxMR.rcumsumXRisk[k][t] - frac * coxMR.sumXRiskEvents[k][t];
312-
final double djkTerm = coxMR.rcumsumXXRisk[j][k][t] - frac * coxMR.sumXXRiskEvents[j][k][t];
317+
final double dkTerm = coxMR.rcumsumXRisk[t][k] - frac * coxMR.sumXRiskEvents[t][k];
318+
final double djkTerm = coxMR.rcumsumXXRisk[t][j][k] - frac * coxMR.sumXXRiskEvents[t][j][k];
313319
hessian[j][k] -= avgSize * (djkTerm / term - (djLogTerm * (dkTerm / term)));
314320
}
315321
}
@@ -319,17 +325,20 @@ protected double calcLoglik(CoxPHTask coxMR) {
319325
break;
320326
case breslow:
321327
for (int t = coxMR.sizeEvents.length - 1; t >= 0; --t) {
322-
if (coxMR.sizeEvents[t] > 0) {
323-
newLoglik += coxMR.sumLogRiskEvents[t];
324-
newLoglik -= coxMR.sizeEvents[t] * Math.log(coxMR.rcumsumRisk[t]);
328+
final double sizeEvents_t = coxMR.sizeEvents[t];
329+
if (sizeEvents_t > 0) {
330+
final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t];
331+
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
332+
newLoglik += sumLogRiskEvents_t;
333+
newLoglik -= sizeEvents_t * Math.log(rcumsumRisk_t);
325334
for (int j = 0; j < n_coef; ++j) {
326-
final double dlogTerm = coxMR.rcumsumXRisk[j][t] / coxMR.rcumsumRisk[t];
327-
gradient[j] += coxMR.sumXEvents[j][t];
328-
gradient[j] -= coxMR.sizeEvents[t] * dlogTerm;
335+
final double dlogTerm = coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t;
336+
gradient[j] += coxMR.sumXEvents[t][j];
337+
gradient[j] -= sizeEvents_t * dlogTerm;
329338
for (int k = 0; k < n_coef; ++k)
330-
hessian[j][k] -= coxMR.sizeEvents[t] *
331-
(((coxMR.rcumsumXXRisk[j][k][t] / coxMR.rcumsumRisk[t]) -
332-
(dlogTerm * (coxMR.rcumsumXRisk[k][t] / coxMR.rcumsumRisk[t]))));
339+
hessian[j][k] -= sizeEvents_t *
340+
(((coxMR.rcumsumXXRisk[t][j][k] / rcumsumRisk_t) -
341+
(dlogTerm * (coxMR.rcumsumXRisk[t][k] / rcumsumRisk_t))));
333342
}
334343
}
335344
}
@@ -386,33 +395,42 @@ protected void calcCumhaz_0(CoxPHTask coxMR) {
386395
switch (parameters.ties) {
387396
case efron:
388397
for (int t = 0; t < coxMR.sizeEvents.length; ++t) {
389-
if (coxMR.sizeEvents[t] > 0 || coxMR.sizeCensored[t] > 0) {
390-
final double avgSize = coxMR.sizeEvents[t] / coxMR.countEvents[t];
398+
final double sizeEvents_t = coxMR.sizeEvents[t];
399+
final double sizeCensored_t = coxMR.sizeCensored[t];
400+
if (sizeEvents_t > 0 || sizeCensored_t > 0) {
401+
final long countEvents_t = coxMR.countEvents[t];
402+
final double sumRiskEvents_t = coxMR.sumRiskEvents[t];
403+
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
404+
final double avgSize = sizeEvents_t / countEvents_t;
391405
cumhaz_0[nz] = 0;
392406
var_cumhaz_1[nz] = 0;
393407
for (int j = 0; j < n_coef; ++j)
394-
var_cumhaz_2[j][nz] = 0;
395-
for (long e = 0; e < coxMR.countEvents[t]; ++e) {
396-
final double frac = ((double) e) / ((double) coxMR.countEvents[t]);
397-
final double haz = 1 / (coxMR.rcumsumRisk[t] - frac * coxMR.sumRiskEvents[t]);
408+
var_cumhaz_2[nz][j] = 0;
409+
for (long e = 0; e < countEvents_t; ++e) {
410+
final double frac = ((double) e) / ((double) countEvents_t);
411+
final double haz = 1 / (rcumsumRisk_t - frac * sumRiskEvents_t);
398412
final double haz_sq = haz * haz;
399413
cumhaz_0[nz] += avgSize * haz;
400414
var_cumhaz_1[nz] += avgSize * haz_sq;
401415
for (int j = 0; j < n_coef; ++j)
402-
var_cumhaz_2[j][nz] +=
403-
avgSize * ((coxMR.rcumsumXRisk[j][t] - frac * coxMR.sumXRiskEvents[j][t]) * haz_sq);
416+
var_cumhaz_2[nz][j] +=
417+
avgSize * ((coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j]) * haz_sq);
404418
}
405419
nz++;
406420
}
407421
}
408422
break;
409423
case breslow:
410424
for (int t = 0; t < coxMR.sizeEvents.length; ++t) {
411-
if (coxMR.sizeEvents[t] > 0 || coxMR.sizeCensored[t] > 0) {
412-
cumhaz_0[nz] = coxMR.sizeEvents[t] / coxMR.rcumsumRisk[t];
413-
var_cumhaz_1[nz] = coxMR.sizeEvents[t] / (coxMR.rcumsumRisk[t] * coxMR.rcumsumRisk[t]);
425+
final double sizeEvents_t = coxMR.sizeEvents[t];
426+
final double sizeCensored_t = coxMR.sizeCensored[t];
427+
if (sizeEvents_t > 0 || sizeCensored_t > 0) {
428+
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
429+
final double cumhaz_0_nz = sizeEvents_t / rcumsumRisk_t;
430+
cumhaz_0[nz] = cumhaz_0_nz;
431+
var_cumhaz_1[nz] = sizeEvents_t / (rcumsumRisk_t * rcumsumRisk_t);
414432
for (int j = 0; j < n_coef; ++j)
415-
var_cumhaz_2[j][nz] = (coxMR.rcumsumXRisk[j][t] / coxMR.rcumsumRisk[t]) * cumhaz_0[nz];
433+
var_cumhaz_2[nz][j] = (coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t) * cumhaz_0_nz;
416434
nz++;
417435
}
418436
}
@@ -425,7 +443,7 @@ protected void calcCumhaz_0(CoxPHTask coxMR) {
425443
cumhaz_0[t] = cumhaz_0[t - 1] + cumhaz_0[t];
426444
var_cumhaz_1[t] = var_cumhaz_1[t - 1] + var_cumhaz_1[t];
427445
for (int j = 0; j < n_coef; ++j)
428-
var_cumhaz_2[j][t] = var_cumhaz_2[j][t - 1] + var_cumhaz_2[j][t];
446+
var_cumhaz_2[t][j] = var_cumhaz_2[t - 1][j] + var_cumhaz_2[t][j];
429447
}
430448
}
431449

@@ -449,7 +467,7 @@ public Frame makeSurvfit(Key key, double x_new) { // FIXME
449467
surv.set(t, Math.exp(-cumhaz_1));
450468
}
451469
for (int t = 0; t < n_time; ++t) {
452-
final double gamma = x_centered * cumhaz_0[t] - var_cumhaz_2[j][t];
470+
final double gamma = x_centered * cumhaz_0[t] - var_cumhaz_2[t][j];
453471
se_cumhaz.set(t, risk * Math.sqrt(var_cumhaz_1[t] + (gamma * var_coef[j][j] * gamma)));
454472
}
455473
final Frame fr = new Frame(key, new String[] {"time", "cumhaz", "se_cumhaz", "surv"}, vecs);
@@ -557,14 +575,14 @@ protected void execImpl() {
557575
final CoxPHTask coxMR = new CoxPHTask(self(), dinfo, newCoef, model.min_time, n_time,
558576
use_start_column, use_weights_column).doAll(dinfo._adaptedFrame);
559577

560-
if (i == 0)
561-
model.calcCounts(coxMR);
562-
563578
final double newLoglik = model.calcLoglik(coxMR);
564579
if (newLoglik > oldLoglik) {
565580
model.calcModelStats(newCoef, newLoglik);
566581
model.calcCumhaz_0(coxMR);
567582

583+
if (i == 0)
584+
model.calcCounts(coxMR);
585+
568586
if (newLoglik == 0)
569587
model.lre = - Math.log10(Math.abs(oldLoglik - newLoglik));
570588
else
@@ -667,11 +685,11 @@ protected void chunkInit(){
667685
sumRiskEvents = MemoryManager.malloc8d(_n_time);
668686
sumLogRiskEvents = MemoryManager.malloc8d(_n_time);
669687
rcumsumRisk = MemoryManager.malloc8d(_n_time);
670-
sumXEvents = malloc2DArray(n_coef, _n_time);
671-
sumXRiskEvents = malloc2DArray(n_coef, _n_time);
672-
rcumsumXRisk = malloc2DArray(n_coef, _n_time);
673-
sumXXRiskEvents = malloc3DArray(n_coef, n_coef, _n_time);
674-
rcumsumXXRisk = malloc3DArray(n_coef, n_coef, _n_time);
688+
sumXEvents = malloc2DArray(_n_time, n_coef);
689+
sumXRiskEvents = malloc2DArray(_n_time, n_coef);
690+
rcumsumXRisk = malloc2DArray(_n_time, n_coef);
691+
sumXXRiskEvents = malloc3DArray(_n_time, n_coef, n_coef);
692+
rcumsumXXRisk = malloc3DArray(_n_time, n_coef, n_coef);
675693
}
676694

677695
@Override
@@ -723,27 +741,27 @@ protected void processRow(long gid, double [] nums, int ncats, int [] cats, doub
723741
final double x1 = jIsCat ? 1.0 : nums[jit - ncats];
724742
final double xRisk = x1 * risk;
725743
if (event > 0) {
726-
sumXEvents[j][t2] += weight * x1;
727-
sumXRiskEvents[j][t2] += xRisk;
744+
sumXEvents[t2][j] += weight * x1;
745+
sumXRiskEvents[t2][j] += xRisk;
728746
}
729747
if (_use_start_column) {
730748
for (int t = t1; t <= t2; ++t)
731-
rcumsumXRisk[j][t] += xRisk;
749+
rcumsumXRisk[t][j] += xRisk;
732750
} else {
733-
rcumsumXRisk[j][t2] += xRisk;
751+
rcumsumXRisk[t2][j] += xRisk;
734752
}
735753
for (int kit = 0; kit < ntotal; ++kit) {
736754
final boolean kIsCat = kit < ncats;
737755
final int k = kIsCat ? cats[kit] : numStartIter + kit;
738756
final double x2 = kIsCat ? 1.0 : nums[kit - ncats];
739757
final double xxRisk = x2 * xRisk;
740758
if (event > 0)
741-
sumXXRiskEvents[j][k][t2] += xxRisk;
759+
sumXXRiskEvents[t2][j][k] += xxRisk;
742760
if (_use_start_column) {
743761
for (int t = t1; t <= t2; ++t)
744-
rcumsumXXRisk[j][k][t] += xxRisk;
762+
rcumsumXXRisk[t][j][k] += xxRisk;
745763
} else {
746-
rcumsumXXRisk[j][k][t2] += xxRisk;
764+
rcumsumXXRisk[t2][j][k] += xxRisk;
747765
}
748766
}
749767
}
@@ -775,14 +793,14 @@ protected void postGlobal() {
775793
for (int t = rcumsumRisk.length - 2; t >= 0; --t)
776794
rcumsumRisk[t] += rcumsumRisk[t + 1];
777795

778-
for (int j = 0; j < rcumsumXRisk.length; ++j)
779-
for (int t = rcumsumXRisk[j].length - 2; t >= 0; --t)
780-
rcumsumXRisk[j][t] += rcumsumXRisk[j][t + 1];
796+
for (int t = rcumsumXRisk.length - 2; t >= 0; --t)
797+
for (int j = 0; j < rcumsumXRisk[t].length; ++j)
798+
rcumsumXRisk[t][j] += rcumsumXRisk[t + 1][j];
781799

782-
for (int j = 0; j < rcumsumXXRisk.length; ++j)
783-
for (int k = 0; k < rcumsumXXRisk[j].length; ++k)
784-
for (int t = rcumsumXXRisk[j][k].length - 2; t >= 0; --t)
785-
rcumsumXXRisk[j][k][t] += rcumsumXXRisk[j][k][t + 1];
800+
for (int t = rcumsumXXRisk.length - 2; t >= 0; --t)
801+
for (int j = 0; j < rcumsumXXRisk[t].length; ++j)
802+
for (int k = 0; k < rcumsumXXRisk[t][j].length; ++k)
803+
rcumsumXXRisk[t][j][k] += rcumsumXXRisk[t + 1][j][k];
786804
}
787805
}
788806
}

0 commit comments

Comments
 (0)