Skip to content

Commit

Permalink
Add numerical offset specification to Cox PH modeling.
Browse files Browse the repository at this point in the history
  • Loading branch information
aboyoun committed Oct 29, 2014
1 parent d93ac13 commit 049b799
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 38 deletions.
22 changes: 13 additions & 9 deletions R/h2o-package/R/Algorithms.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,28 @@ h2o.coxph.control <- function(lre = 9, iter.max = 20, ...)

list(lre = lre, iter.max = as.integer(iter.max))
}
h2o.coxph <- function(x, y, data, key = "", weights, ties = c("efron", "breslow"),
init = 0, control = h2o.coxph.control(...), ...)
h2o.coxph <- function(x, y, data, key = "", weights = NULL, offset = NULL,
ties = c("efron", "breslow"), init = 0,
control = h2o.coxph.control(...), ...)
{
if (!is(data, "H2OParsedData"))
stop("'data' must be an H2O parsed dataset")

cnames <- colnames(data)
if (!is.character(x) || !all(x %in% cnames))
if (!is.character(x) || length(x) == 0L || !all(x %in% cnames))
stop("'x' must be a character vector specifying column names from 'data'")

ny <- length(y)
if (!is.character(y) || ny < 2L || ny > 3L || !all(y %in% cnames))
stop("'y' must be a character vector of column names from 'data' ",
"specifying a (start, stop, event) triplet or (stop, event) couplet")

useWeights <- !missing(weights)
if (useWeights) {
if (!is.character(weights) || length(weights) != 1L || !(weights %in% cnames))
stop("'weights' must be missing or a character string specifying a column name from 'data'")
} else
weights <- NULL
if (!is.null(weights) &&
(!is.character(weights) || length(weights) != 1L || !(weights %in% cnames)))
stop("'weights' must be NULL or a character string specifying a column name from 'data'")

if (!is.null(offset) && (!is.character(offset) || !all(offset %in% cnames)))
stop("'offset' must be NULL or a character vector specifying a column names from 'data'")

if (!is.character(key) && length(key) == 1L)
stop("'key' must be a character string")
Expand All @@ -65,6 +66,7 @@ h2o.coxph <- function(x, y, data, key = "", weights, ties = c("efron", "breslow"
event_column = y[ny],
x_columns = match(x, cnames) - 1L,
weights_column = weights,
offset_columns = if (is.null(offset)) offset else match(offset, cnames) - 1L,
ties = ties,
init = init,
lre_min = control$lre,
Expand All @@ -86,6 +88,8 @@ h2o.coxph <- function(x, y, data, key = "", weights, ties = c("efron", "breslow"
means = structure(c(unlist(res[[3L]]$x_mean_cat),
unlist(res[[3L]]$x_mean_num)),
names = coef_names),
means.offset = structure(unlist(res[[3L]]$mean_offset),
names = unlist(res[[3L]]$offset_names)),
method = ties,
n = res[[3L]]$n,
nevent = res[[3L]]$total_event,
Expand Down
3 changes: 2 additions & 1 deletion R/h2o-package/R/Classes.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ survfit.H2OCoxPHModel <-
function(formula, newdata, conf.int = 0.95,
conf.type = c("log", "log-log", "plain", "none"), ...) {
if (missing(newdata))
newdata <- as.data.frame(as.list(formula@model$means))
newdata <- as.data.frame(c(as.list(formula@model$means),
as.list(formula@model$means.offset)))
if (is.data.frame(newdata))
capture.output(newdata <- as.h2o(formula@data@h2o, newdata, header = TRUE))
conf.type <- match.arg(conf.type)
Expand Down
7 changes: 5 additions & 2 deletions R/h2o-package/man/h2o.coxph.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ H2O: Cox Proportional Hazards Models
Fit a Cox Proportional Hazards Model.
}
\usage{
h2o.coxph(x, y, data, key = "", weights, ties = c("efron", "breslow"),
init = 0, control = h2o.coxph.control(...), ...)
h2o.coxph(x, y, data, key = "", weights = NULL, offset = NULL,
ties = c("efron", "breslow"), init = 0,
control = h2o.coxph.control(...), ...)

h2o.coxph.control(lre = 9, iter.max = 20, ...)

Expand All @@ -46,6 +47,8 @@ h2o.coxph.control(lre = 9, iter.max = 20, ...)
If none is given, a key will automatically be generated.}
\item{weights}{An optional character string representing the case weights in
the model.}
\item{offset}{An optional character vector representing the offset terms in
the model.}
\item{ties}{A character string denoting which approximation method for
handling ties should be used in the partial likelihood;
one of either \code{"efron"} or \code{"breslow"}.}
Expand Down
30 changes: 30 additions & 0 deletions R/tests/testdir_algos/coxph/runit_CoxPH_bladder.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ test.CoxPH.bladder <- function(conn) {
bladder.coxph <- coxph(Surv(stop, event) ~ enum + rx + number + size, data = bladder, weights = bladder$id)
checkCoxPHModel(bladder.coxph.h2o, bladder.coxph)

Log.info("H2O Cox PH Model of bladder Data Set using Efron's Approximation; 2 predictors and 1 offset\n")
bladder.coxph.h2o <-
h2o.coxph(x = c("enum", "rx"), y = c("stop", "event"), data = bladder.h2o, key = "bladmod.h2o",
offset = "size")
bladder.coxph <- coxph(Surv(stop, event) ~ enum + rx + offset(size), data = bladder)
checkCoxPHModel(bladder.coxph.h2o, bladder.coxph)

Log.info("H2O Cox PH Model of bladder Data Set using Efron's Approximation; 2 predictors and 2 offsets\n")
bladder.coxph.h2o <-
h2o.coxph(x = c("enum", "rx"), y = c("stop", "event"), data = bladder.h2o, key = "bladmod.h2o",
offset = c("number", "size"), weights = "id")
bladder.coxph <- coxph(Surv(stop, event) ~ enum + rx + offset(number) + offset(size), data = bladder,
weights = bladder$id)
checkCoxPHModel(bladder.coxph.h2o, bladder.coxph)

Log.info("H2O Cox PH Model of bladder Data Set using Efron's Approximation; init = 0.2\n")
bladder.coxph.h2o <-
h2o.coxph(x = "size", y = c("stop", "event"), data = bladder.h2o,
Expand Down Expand Up @@ -81,6 +96,21 @@ test.CoxPH.bladder <- function(conn) {
coxph(Surv(stop, event) ~ enum + rx + number + size, data = bladder, weights = bladder$id, ties = "breslow")
checkCoxPHModel(bladder.coxph.h2o, bladder.coxph)

Log.info("H2O Cox PH Model of bladder Data Set using Breslow's Approximation; 2 predictors and 1 offset\n")
bladder.coxph.h2o <-
h2o.coxph(x = c("enum", "rx"), y = c("stop", "event"), data = bladder.h2o, key = "bladmod.h2o",
offset = "size", ties = "breslow")
bladder.coxph <- coxph(Surv(stop, event) ~ enum + rx + offset(size), data = bladder, ties = "breslow")
checkCoxPHModel(bladder.coxph.h2o, bladder.coxph)

Log.info("H2O Cox PH Model of bladder Data Set using Breslow's Approximation; 2 predictors and 2 offsets\n")
bladder.coxph.h2o <-
h2o.coxph(x = c("enum", "rx"), y = c("stop", "event"), data = bladder.h2o, key = "bladmod.h2o",
offset = c("number", "size"), weights = "id", ties = "breslow")
bladder.coxph <- coxph(Surv(stop, event) ~ enum + rx + offset(number) + offset(size), data = bladder,
weights = bladder$id, ties = "breslow")
checkCoxPHModel(bladder.coxph.h2o, bladder.coxph, tolerance = 1e-7)

Log.info("H2O Cox PH Model of bladder Data Set using Breslow's Approximation; init = 0.2\n")
bladder.coxph.h2o <-
h2o.coxph(x = "size", y = c("stop", "event"), data = bladder.h2o,
Expand Down
31 changes: 31 additions & 0 deletions R/tests/testdir_algos/coxph/runit_CoxPH_heart.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ test.CoxPH.heart <- function(conn) {
heart.coxph <- coxph(Surv(start, stop, event) ~ transplant + age + year + surgery, data = heart, weights = heart$id)
checkCoxPHModel(heart.coxph.h2o, heart.coxph)

Log.info("H2O Cox PH Model of heart Data Set using Efron's Approximation; 2 predictors and 1 offset\n")
heart.coxph.h2o <-
h2o.coxph(x = c("transplant", "surgery"), y = c("start", "stop", "event"), data = heart.h2o,
key = "heartmod.h2o", offset = "age")
heart.coxph <- coxph(Surv(start, stop, event) ~ transplant + surgery + offset(age), data = heart)
checkCoxPHModel(heart.coxph.h2o, heart.coxph)

Log.info("H2O Cox PH Model of heart Data Set using Efron's Approximation; 2 predictors and 2 offsets\n")
heart.coxph.h2o <-
h2o.coxph(x = c("transplant", "surgery"), y = c("start", "stop", "event"), data = heart.h2o,
key = "heartmod.h2o", weights = "id", offset = c("age", "year"))
heart.coxph <- coxph(Surv(start, stop, event) ~ transplant + surgery + offset(age) + offset(year), data = heart,
weights = heart$id)
checkCoxPHModel(heart.coxph.h2o, heart.coxph)

Log.info("H2O Cox PH Model of heart Data Set using Efron's Approximation; init = 0.05\n")
heart.coxph.h2o <-
h2o.coxph(x = "age", y = c("start", "stop", "event"), data = heart.h2o,
Expand Down Expand Up @@ -81,6 +96,22 @@ test.CoxPH.heart <- function(conn) {
weights = heart$id, ties = "breslow")
checkCoxPHModel(heart.coxph.h2o, heart.coxph)

Log.info("H2O Cox PH Model of heart Data Set using Breslow's Approximation; 2 predictors and 1 offset\n")
heart.coxph.h2o <-
h2o.coxph(x = c("transplant", "surgery"), y = c("start", "stop", "event"), data = heart.h2o,
key = "heartmod.h2o", offset = "age", ties = "breslow")
heart.coxph <- coxph(Surv(start, stop, event) ~ transplant + surgery + offset(age), data = heart,
ties = "breslow")
checkCoxPHModel(heart.coxph.h2o, heart.coxph)

Log.info("H2O Cox PH Model of heart Data Set using Breslow's Approximation; 2 predictors and 2 offsets\n")
heart.coxph.h2o <-
h2o.coxph(x = c("transplant", "surgery"), y = c("start", "stop", "event"), data = heart.h2o,
key = "heartmod.h2o", weights = "id", offset = c("age", "year"), ties = "breslow")
heart.coxph <- coxph(Surv(start, stop, event) ~ transplant + surgery + offset(age) + offset(year), data = heart,
weights = heart$id, ties = "breslow")
checkCoxPHModel(heart.coxph.h2o, heart.coxph)

Log.info("H2O Cox PH Model of heart Data Set using Breslow's Approximation; init = 0.05\n")
heart.coxph.h2o <-
h2o.coxph(x = "age", y = c("start", "stop", "event"), data = heart.h2o,
Expand Down
78 changes: 52 additions & 26 deletions src/main/java/hex/CoxPH.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ public class CoxPH extends Job {
@API(help="Weights Column", required=false, filter=CoxPHVecSelect.class, json=true)
public Vec weights_column = null;

@API(help="Offset Columns", required=false, filter=CoxPHMultiVecSelect.class, json=true)
public int[] offset_columns;

@API(help="Method for Handling Ties", required=true, filter=Default.class, json=true)
public CoxPHTies ties = CoxPHTies.efron;

Expand Down Expand Up @@ -120,6 +123,10 @@ public static class CoxPHModel extends Model implements Job.Progress {
double[] x_mean_cat;
@API(help = "x weighted mean vector for numeric variables")
double[] x_mean_num;
@API(help = "unweighted mean vector for numeric offsets")
double[] mean_offset;
@API(help = "names of offsets")
String[] offset_names;
@API(help = "n")
long n;
@API(help = "number of rows with missing values")
Expand Down Expand Up @@ -175,15 +182,17 @@ public String[] classNames() {

@Override
protected float[] score0(double[] data, float[] preds) {
final int n_time = time.length;
final int n_coef = coef.length;
final int n_cats = data_info._cats;
final int n_num = data_info._nums;
final int n_data = n_cats + n_num;
final int numStart = data_info.numStart();
boolean catsAllNA = true;
boolean catsHasNA = false;
boolean numsHasNA = false;
final int n_offsets = (parameters.offset_columns == null) ? 0 : parameters.offset_columns.length;
final int n_time = time.length;
final int n_coef = coef.length;
final int n_cats = data_info._cats;
final int n_nums = data_info._nums;
final int n_data = n_cats + n_nums;
final int n_full = n_coef + n_offsets;
final int numStart = data_info.numStart();
boolean catsAllNA = true;
boolean catsHasNA = false;
boolean numsHasNA = false;
for (int j = 0; j < n_cats; ++j) {
catsAllNA &= Double.isNaN(data[j]);
catsHasNA |= Double.isNaN(data[j]);
Expand All @@ -194,19 +203,21 @@ protected float[] score0(double[] data, float[] preds) {
for (int i = 1; i <= 2 * n_time; ++i)
preds[i] = Float.NaN;
} else {
double[] full_data = MemoryManager.malloc8d(n_coef);
double[] full_data = MemoryManager.malloc8d(n_full);
for (int j = 0; j < n_cats; ++j)
if (Double.isNaN(data[j])) {
final int kst = data_info._catOffsets[j];
final int klen = data_info._catOffsets[j+1] - kst;
System.arraycopy(x_mean_cat, kst, full_data, kst, klen);
} else if (data[j] != 0)
full_data[data_info._catOffsets[j] + (int) (data[j] - 1)] = 1;
for (int j = 0; j < n_num; ++j)
for (int j = 0; j < n_nums; ++j)
full_data[numStart + j] = data[n_cats + j] - data_info._normSub[j];
double logRisk = 0;
for (int j = 0; j < n_coef; ++j)
logRisk += full_data[j] * coef[j];
for (int j = n_coef; j < full_data.length; ++j)
logRisk += full_data[j];
final double risk = Math.exp(logRisk);
for (int t = 0; t < n_time; ++t)
preds[t + 1] = (float) (risk * cumhaz_0[t]);
Expand All @@ -229,8 +240,11 @@ protected float[] score0(double[] data, float[] preds) {
protected void initStats(final Frame source, final DataInfo dinfo) {
n = source.numRows();
data_info = dinfo;
final int n_coef = data_info.fullN();
coef_names = data_info.coefNames();
final int n_offsets = (parameters.offset_columns == null) ? 0 : parameters.offset_columns.length;
final int n_coef = data_info.fullN() - n_offsets;
final String[] coefNames = data_info.coefNames();
coef_names = new String[n_coef];
System.arraycopy(coefNames, 0, coef_names, 0, n_coef);
coef = MemoryManager.malloc8d(n_coef);
exp_coef = MemoryManager.malloc8d(n_coef);
exp_neg_coef = MemoryManager.malloc8d(n_coef);
Expand All @@ -239,6 +253,11 @@ protected void initStats(final Frame source, final DataInfo dinfo) {
gradient = MemoryManager.malloc8d(n_coef);
hessian = malloc2DArray(n_coef, n_coef);
var_coef = malloc2DArray(n_coef, n_coef);
x_mean_cat = MemoryManager.malloc8d(n_coef - (data_info._nums - n_offsets));
x_mean_num = MemoryManager.malloc8d(data_info._nums - n_offsets);
mean_offset = MemoryManager.malloc8d(n_offsets);
offset_names = new String[n_offsets];
System.arraycopy(coefNames, n_coef, offset_names, 0, n_offsets);

final Vec start_column = source.vec(source.numCols() - 3);
final Vec stop_column = source.vec(source.numCols() - 2);
Expand All @@ -259,12 +278,11 @@ protected void initStats(final Frame source, final DataInfo dinfo) {
protected void calcCounts(final CoxPHTask coxMR) {
n_missing = n - coxMR.n;
n = coxMR.n;
x_mean_cat = coxMR.sumWeightedCatX.clone();
for (int j = 0; j < x_mean_cat.length; j++)
x_mean_cat[j] /= coxMR.sumWeights;
x_mean_num = coxMR._dinfo._normSub.clone();
x_mean_cat[j] = coxMR.sumWeightedCatX[j] / coxMR.sumWeights;
for (int j = 0; j < x_mean_num.length; j++)
x_mean_num[j] += coxMR.sumWeightedNumX[j] / coxMR.sumWeights;
x_mean_num[j] = coxMR._dinfo._normSub[j] + coxMR.sumWeightedNumX[j] / coxMR.sumWeights;
System.arraycopy(coxMR._dinfo._normSub, x_mean_num.length, mean_offset, 0, mean_offset.length);
int nz = 0;
for (int t = 0; t < coxMR.countEvents.length; ++t) {
total_event += coxMR.countEvents[t];
Expand Down Expand Up @@ -571,15 +589,16 @@ protected void init() {
n_resp++;
if (start_column != null)
n_resp++;
final DataInfo dinfo = new DataInfo(source, n_resp, true, false, DataInfo.TransformType.DEMEAN);
final DataInfo dinfo = new DataInfo(source, n_resp, false, false, DataInfo.TransformType.DEMEAN);
model = new CoxPHModel(this, dest(), source._key, source, null);
model.initStats(source, dinfo);
}

@Override
protected void execImpl() {
final DataInfo dinfo = model.data_info;
final int n_coef = dinfo.fullN();
final int n_offsets = (model.parameters.offset_columns == null) ? 0 : model.parameters.offset_columns.length;
final int n_coef = dinfo.fullN() - n_offsets;
final double[] step = MemoryManager.malloc8d(n_coef);
final double[] oldCoef = MemoryManager.malloc8d(n_coef);
final double[] newCoef = MemoryManager.malloc8d(n_coef);
Expand All @@ -594,7 +613,7 @@ protected void execImpl() {
for (int i = 0; i <= iter_max; ++i) {
model.iter = i;

final CoxPHTask coxMR = new CoxPHTask(self(), dinfo, newCoef, model.min_time, n_time,
final CoxPHTask coxMR = new CoxPHTask(self(), dinfo, newCoef, model.min_time, n_time, n_offsets,
has_start_column, has_weights_column).doAll(dinfo._adaptedFrame);

final double newLoglik = model.calcLoglik(coxMR);
Expand Down Expand Up @@ -645,16 +664,19 @@ private Frame getSubframe() {
final boolean use_start_column = (start_column != null);
final boolean use_weights_column = (weights_column != null);
final int x_ncol = x_columns.length;
int ncol = x_ncol + 2;
final int offset_ncol = offset_columns == null ? 0 : offset_columns.length;
int ncol = x_ncol + offset_ncol + 2;
if (use_weights_column)
ncol++;
if (use_start_column)
ncol++;
final String[] names = new String[ncol];
for (int j = 0; j < x_ncol; ++j)
names[j] = source.names()[x_columns[j]];
for (int j = 0; j < offset_ncol; ++j)
names[x_ncol + j] = source.names()[offset_columns[j]];
if (use_weights_column)
names[x_ncol] = source.names()[source.find(weights_column)];
names[x_ncol + offset_ncol] = source.names()[source.find(weights_column)];
if (use_start_column)
names[ncol - 3] = source.names()[source.find(start_column)];
names[ncol - 2] = source.names()[source.find(stop_column)];
Expand All @@ -666,6 +688,7 @@ protected static class CoxPHTask extends FrameTask<CoxPHTask> {
private final double[] _beta;
private final int _n_time;
private final long _min_time;
private final int _n_offsets;
private final boolean _has_start_column;
private final boolean _has_weights_column;

Expand All @@ -688,19 +711,20 @@ protected static class CoxPHTask extends FrameTask<CoxPHTask> {
protected double[][][] rcumsumXXRisk;

CoxPHTask(Key jobKey, DataInfo dinfo, final double[] beta, final long min_time, final int n_time,
final boolean has_start_column, final boolean has_weights_column) {
final int n_offsets, final boolean has_start_column, final boolean has_weights_column) {
super(jobKey, dinfo);
_beta = beta;
_n_time = n_time;
_min_time = min_time;
_n_offsets = n_offsets;
_has_start_column = has_start_column;
_has_weights_column = has_weights_column;
}

@Override
protected void chunkInit(){
final int n_coef = _beta.length;
sumWeightedCatX = MemoryManager.malloc8d(n_coef - _dinfo._nums);
sumWeightedCatX = MemoryManager.malloc8d(n_coef - (_dinfo._nums - _n_offsets));
sumWeightedNumX = MemoryManager.malloc8d(_dinfo._nums);
sizeRiskSet = MemoryManager.malloc8d(_n_time);
sizeCensored = MemoryManager.malloc8d(_n_time);
Expand Down Expand Up @@ -736,8 +760,10 @@ protected void processRow(long gid, double [] nums, int ncats, int [] cats, doub
double logRisk = 0;
for (int j = 0; j < ncats; ++j)
logRisk += _beta[cats[j]];
for (int j = 0; j < nums.length; ++j)
for (int j = 0; j < nums.length - _n_offsets; ++j)
logRisk += nums[j] * _beta[numStart + j];
for (int j = nums.length - _n_offsets; j < nums.length; ++j)
logRisk += nums[j];
final double risk = weight * Math.exp(logRisk);
logRisk *= weight;
if (event > 0) {
Expand All @@ -757,7 +783,7 @@ protected void processRow(long gid, double [] nums, int ncats, int [] cats, doub
rcumsumRisk[t2] += risk;
}

final int ntotal = ncats + nums.length;
final int ntotal = ncats + (nums.length - _n_offsets);
final int numStartIter = numStart - ncats;
for (int jit = 0; jit < ntotal; ++jit) {
final boolean jIsCat = jit < ncats;
Expand Down

0 comments on commit 049b799

Please sign in to comment.