From b1e50a35ffd533e247e1274c4e03e1cf2c599389 Mon Sep 17 00:00:00 2001 From: Patrick Aboyoun Date: Wed, 22 Oct 2014 22:12:49 -0700 Subject: [PATCH] Improve performance of Cox PH using Efron's approximation for handling ties. --- src/main/java/hex/CoxPH.java | 134 ++++++++++++++++++++--------------- 1 file changed, 76 insertions(+), 58 deletions(-) diff --git a/src/main/java/hex/CoxPH.java b/src/main/java/hex/CoxPH.java index 8f5ae19215..c8baa7887c 100644 --- a/src/main/java/hex/CoxPH.java +++ b/src/main/java/hex/CoxPH.java @@ -214,12 +214,13 @@ protected float[] score0(double[] data, float[] preds) { for (int t = 0; t < n_time; ++t) preds[t + 1] = (float) (risk * cumhaz_0[t]); for (int t = 0; t < n_time; ++t) { + final double cumhaz_0_t = cumhaz_0[t]; double var_cumhaz_2_t = 0; for (int j = 0; j < n_coef; ++j) { double sum = 0; for (int k = 0; k < n_coef; ++k) - sum += var_coef[j][k] * (full_data[k] * cumhaz_0[t] - var_cumhaz_2[k][t]); - var_cumhaz_2_t += (full_data[j] * cumhaz_0[t] - var_cumhaz_2[j][t]) * sum; + sum += var_coef[j][k] * (full_data[k] * cumhaz_0_t - var_cumhaz_2[t][k]); + var_cumhaz_2_t += (full_data[j] * cumhaz_0_t - var_cumhaz_2[t][j]) * sum; } preds[t + 1 + n_time] = (float) (risk * Math.sqrt(var_cumhaz_1[t] + var_cumhaz_2_t)); } @@ -255,7 +256,7 @@ protected void initStats(Frame source, DataInfo dinfo) { n_censor = MemoryManager.malloc8d(n_time); cumhaz_0 = MemoryManager.malloc8d(n_time); var_cumhaz_1 = MemoryManager.malloc8d(n_time); - var_cumhaz_2 = malloc2DArray(n_coef, n_time); + var_cumhaz_2 = malloc2DArray(n_time, n_coef); } protected void calcCounts(CoxPHTask coxMR) { @@ -294,22 +295,27 @@ protected double calcLoglik(CoxPHTask coxMR) { switch (parameters.ties) { case efron: for (int t = coxMR.sizeEvents.length - 1; t >= 0; --t) { - if (coxMR.sizeEvents[t] > 0) { - final double avgSize = coxMR.sizeEvents[t] / coxMR.countEvents[t]; - newLoglik += coxMR.sumLogRiskEvents[t]; + final double sizeEvents_t = coxMR.sizeEvents[t]; + if (sizeEvents_t > 0) { + final long countEvents_t = coxMR.countEvents[t]; + final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t]; + final double sumRiskEvents_t = coxMR.sumRiskEvents[t]; + final double rcumsumRisk_t = coxMR.rcumsumRisk[t]; + final double avgSize = sizeEvents_t / countEvents_t; + newLoglik += sumLogRiskEvents_t; for (int j = 0; j < n_coef; ++j) - gradient[j] += coxMR.sumXEvents[j][t]; - for (long e = 0; e < coxMR.countEvents[t]; ++e) { - final double frac = ((double) e) / ((double) coxMR.countEvents[t]); - final double term = coxMR.rcumsumRisk[t] - frac * coxMR.sumRiskEvents[t]; + gradient[j] += coxMR.sumXEvents[t][j]; + for (long e = 0; e < countEvents_t; ++e) { + final double frac = ((double) e) / ((double) countEvents_t); + final double term = rcumsumRisk_t - frac * sumRiskEvents_t; newLoglik -= avgSize * Math.log(term); for (int j = 0; j < n_coef; ++j) { - final double djTerm = coxMR.rcumsumXRisk[j][t] - frac * coxMR.sumXRiskEvents[j][t]; + final double djTerm = coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j]; final double djLogTerm = djTerm / term; gradient[j] -= avgSize * djLogTerm; for (int k = 0; k < n_coef; ++k) { - final double dkTerm = coxMR.rcumsumXRisk[k][t] - frac * coxMR.sumXRiskEvents[k][t]; - final double djkTerm = coxMR.rcumsumXXRisk[j][k][t] - frac * coxMR.sumXXRiskEvents[j][k][t]; + final double dkTerm = coxMR.rcumsumXRisk[t][k] - frac * coxMR.sumXRiskEvents[t][k]; + final double djkTerm = coxMR.rcumsumXXRisk[t][j][k] - frac * coxMR.sumXXRiskEvents[t][j][k]; hessian[j][k] -= avgSize * (djkTerm / term - (djLogTerm * (dkTerm / term))); } } @@ -319,17 +325,20 @@ protected double calcLoglik(CoxPHTask coxMR) { break; case breslow: for (int t = coxMR.sizeEvents.length - 1; t >= 0; --t) { - if (coxMR.sizeEvents[t] > 0) { - newLoglik += coxMR.sumLogRiskEvents[t]; - newLoglik -= coxMR.sizeEvents[t] * Math.log(coxMR.rcumsumRisk[t]); + final double sizeEvents_t = coxMR.sizeEvents[t]; + if (sizeEvents_t > 0) { + final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t]; + final double rcumsumRisk_t = coxMR.rcumsumRisk[t]; + newLoglik += sumLogRiskEvents_t; + newLoglik -= sizeEvents_t * Math.log(rcumsumRisk_t); for (int j = 0; j < n_coef; ++j) { - final double dlogTerm = coxMR.rcumsumXRisk[j][t] / coxMR.rcumsumRisk[t]; - gradient[j] += coxMR.sumXEvents[j][t]; - gradient[j] -= coxMR.sizeEvents[t] * dlogTerm; + final double dlogTerm = coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t; + gradient[j] += coxMR.sumXEvents[t][j]; + gradient[j] -= sizeEvents_t * dlogTerm; for (int k = 0; k < n_coef; ++k) - hessian[j][k] -= coxMR.sizeEvents[t] * - (((coxMR.rcumsumXXRisk[j][k][t] / coxMR.rcumsumRisk[t]) - - (dlogTerm * (coxMR.rcumsumXRisk[k][t] / coxMR.rcumsumRisk[t])))); + hessian[j][k] -= sizeEvents_t * + (((coxMR.rcumsumXXRisk[t][j][k] / rcumsumRisk_t) - + (dlogTerm * (coxMR.rcumsumXRisk[t][k] / rcumsumRisk_t)))); } } } @@ -386,21 +395,26 @@ protected void calcCumhaz_0(CoxPHTask coxMR) { switch (parameters.ties) { case efron: for (int t = 0; t < coxMR.sizeEvents.length; ++t) { - if (coxMR.sizeEvents[t] > 0 || coxMR.sizeCensored[t] > 0) { - final double avgSize = coxMR.sizeEvents[t] / coxMR.countEvents[t]; + final double sizeEvents_t = coxMR.sizeEvents[t]; + final double sizeCensored_t = coxMR.sizeCensored[t]; + if (sizeEvents_t > 0 || sizeCensored_t > 0) { + final long countEvents_t = coxMR.countEvents[t]; + final double sumRiskEvents_t = coxMR.sumRiskEvents[t]; + final double rcumsumRisk_t = coxMR.rcumsumRisk[t]; + final double avgSize = sizeEvents_t / countEvents_t; cumhaz_0[nz] = 0; var_cumhaz_1[nz] = 0; for (int j = 0; j < n_coef; ++j) - var_cumhaz_2[j][nz] = 0; - for (long e = 0; e < coxMR.countEvents[t]; ++e) { - final double frac = ((double) e) / ((double) coxMR.countEvents[t]); - final double haz = 1 / (coxMR.rcumsumRisk[t] - frac * coxMR.sumRiskEvents[t]); + var_cumhaz_2[nz][j] = 0; + for (long e = 0; e < countEvents_t; ++e) { + final double frac = ((double) e) / ((double) countEvents_t); + final double haz = 1 / (rcumsumRisk_t - frac * sumRiskEvents_t); final double haz_sq = haz * haz; cumhaz_0[nz] += avgSize * haz; var_cumhaz_1[nz] += avgSize * haz_sq; for (int j = 0; j < n_coef; ++j) - var_cumhaz_2[j][nz] += - avgSize * ((coxMR.rcumsumXRisk[j][t] - frac * coxMR.sumXRiskEvents[j][t]) * haz_sq); + var_cumhaz_2[nz][j] += + avgSize * ((coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j]) * haz_sq); } nz++; } @@ -408,11 +422,15 @@ protected void calcCumhaz_0(CoxPHTask coxMR) { break; case breslow: for (int t = 0; t < coxMR.sizeEvents.length; ++t) { - if (coxMR.sizeEvents[t] > 0 || coxMR.sizeCensored[t] > 0) { - cumhaz_0[nz] = coxMR.sizeEvents[t] / coxMR.rcumsumRisk[t]; - var_cumhaz_1[nz] = coxMR.sizeEvents[t] / (coxMR.rcumsumRisk[t] * coxMR.rcumsumRisk[t]); + final double sizeEvents_t = coxMR.sizeEvents[t]; + final double sizeCensored_t = coxMR.sizeCensored[t]; + if (sizeEvents_t > 0 || sizeCensored_t > 0) { + final double rcumsumRisk_t = coxMR.rcumsumRisk[t]; + final double cumhaz_0_nz = sizeEvents_t / rcumsumRisk_t; + cumhaz_0[nz] = cumhaz_0_nz; + var_cumhaz_1[nz] = sizeEvents_t / (rcumsumRisk_t * rcumsumRisk_t); for (int j = 0; j < n_coef; ++j) - var_cumhaz_2[j][nz] = (coxMR.rcumsumXRisk[j][t] / coxMR.rcumsumRisk[t]) * cumhaz_0[nz]; + var_cumhaz_2[nz][j] = (coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t) * cumhaz_0_nz; nz++; } } @@ -425,7 +443,7 @@ protected void calcCumhaz_0(CoxPHTask coxMR) { cumhaz_0[t] = cumhaz_0[t - 1] + cumhaz_0[t]; var_cumhaz_1[t] = var_cumhaz_1[t - 1] + var_cumhaz_1[t]; for (int j = 0; j < n_coef; ++j) - var_cumhaz_2[j][t] = var_cumhaz_2[j][t - 1] + var_cumhaz_2[j][t]; + var_cumhaz_2[t][j] = var_cumhaz_2[t - 1][j] + var_cumhaz_2[t][j]; } } @@ -449,7 +467,7 @@ public Frame makeSurvfit(Key key, double x_new) { // FIXME surv.set(t, Math.exp(-cumhaz_1)); } for (int t = 0; t < n_time; ++t) { - final double gamma = x_centered * cumhaz_0[t] - var_cumhaz_2[j][t]; + final double gamma = x_centered * cumhaz_0[t] - var_cumhaz_2[t][j]; se_cumhaz.set(t, risk * Math.sqrt(var_cumhaz_1[t] + (gamma * var_coef[j][j] * gamma))); } final Frame fr = new Frame(key, new String[] {"time", "cumhaz", "se_cumhaz", "surv"}, vecs); @@ -557,14 +575,14 @@ protected void execImpl() { final CoxPHTask coxMR = new CoxPHTask(self(), dinfo, newCoef, model.min_time, n_time, use_start_column, use_weights_column).doAll(dinfo._adaptedFrame); - if (i == 0) - model.calcCounts(coxMR); - final double newLoglik = model.calcLoglik(coxMR); if (newLoglik > oldLoglik) { model.calcModelStats(newCoef, newLoglik); model.calcCumhaz_0(coxMR); + if (i == 0) + model.calcCounts(coxMR); + if (newLoglik == 0) model.lre = - Math.log10(Math.abs(oldLoglik - newLoglik)); else @@ -667,11 +685,11 @@ protected void chunkInit(){ sumRiskEvents = MemoryManager.malloc8d(_n_time); sumLogRiskEvents = MemoryManager.malloc8d(_n_time); rcumsumRisk = MemoryManager.malloc8d(_n_time); - sumXEvents = malloc2DArray(n_coef, _n_time); - sumXRiskEvents = malloc2DArray(n_coef, _n_time); - rcumsumXRisk = malloc2DArray(n_coef, _n_time); - sumXXRiskEvents = malloc3DArray(n_coef, n_coef, _n_time); - rcumsumXXRisk = malloc3DArray(n_coef, n_coef, _n_time); + sumXEvents = malloc2DArray(_n_time, n_coef); + sumXRiskEvents = malloc2DArray(_n_time, n_coef); + rcumsumXRisk = malloc2DArray(_n_time, n_coef); + sumXXRiskEvents = malloc3DArray(_n_time, n_coef, n_coef); + rcumsumXXRisk = malloc3DArray(_n_time, n_coef, n_coef); } @Override @@ -723,14 +741,14 @@ protected void processRow(long gid, double [] nums, int ncats, int [] cats, doub final double x1 = jIsCat ? 1.0 : nums[jit - ncats]; final double xRisk = x1 * risk; if (event > 0) { - sumXEvents[j][t2] += weight * x1; - sumXRiskEvents[j][t2] += xRisk; + sumXEvents[t2][j] += weight * x1; + sumXRiskEvents[t2][j] += xRisk; } if (_use_start_column) { for (int t = t1; t <= t2; ++t) - rcumsumXRisk[j][t] += xRisk; + rcumsumXRisk[t][j] += xRisk; } else { - rcumsumXRisk[j][t2] += xRisk; + rcumsumXRisk[t2][j] += xRisk; } for (int kit = 0; kit < ntotal; ++kit) { final boolean kIsCat = kit < ncats; @@ -738,12 +756,12 @@ protected void processRow(long gid, double [] nums, int ncats, int [] cats, doub final double x2 = kIsCat ? 1.0 : nums[kit - ncats]; final double xxRisk = x2 * xRisk; if (event > 0) - sumXXRiskEvents[j][k][t2] += xxRisk; + sumXXRiskEvents[t2][j][k] += xxRisk; if (_use_start_column) { for (int t = t1; t <= t2; ++t) - rcumsumXXRisk[j][k][t] += xxRisk; + rcumsumXXRisk[t][j][k] += xxRisk; } else { - rcumsumXXRisk[j][k][t2] += xxRisk; + rcumsumXXRisk[t2][j][k] += xxRisk; } } } @@ -775,14 +793,14 @@ protected void postGlobal() { for (int t = rcumsumRisk.length - 2; t >= 0; --t) rcumsumRisk[t] += rcumsumRisk[t + 1]; - for (int j = 0; j < rcumsumXRisk.length; ++j) - for (int t = rcumsumXRisk[j].length - 2; t >= 0; --t) - rcumsumXRisk[j][t] += rcumsumXRisk[j][t + 1]; + for (int t = rcumsumXRisk.length - 2; t >= 0; --t) + for (int j = 0; j < rcumsumXRisk[t].length; ++j) + rcumsumXRisk[t][j] += rcumsumXRisk[t + 1][j]; - for (int j = 0; j < rcumsumXXRisk.length; ++j) - for (int k = 0; k < rcumsumXXRisk[j].length; ++k) - for (int t = rcumsumXXRisk[j][k].length - 2; t >= 0; --t) - rcumsumXXRisk[j][k][t] += rcumsumXXRisk[j][k][t + 1]; + for (int t = rcumsumXXRisk.length - 2; t >= 0; --t) + for (int j = 0; j < rcumsumXXRisk[t].length; ++j) + for (int k = 0; k < rcumsumXXRisk[t][j].length; ++k) + rcumsumXXRisk[t][j][k] += rcumsumXXRisk[t + 1][j][k]; } } }