Skip to content

Commit

Permalink
[jvm-packages] Expose prediction feature contribution on the Java side (
Browse files Browse the repository at this point in the history
dmlc#2441)

* Exposed prediction feature contribution on the Java side

* was not supplying the newly added argument

* Exposed from Scala-side as well

* formatting (keep declaration in one line unless exceeding 100 chars)
  • Loading branch information
edi-bice authored and CodingCat committed Jun 28, 2017
1 parent d01a310 commit 2911597
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public void update(DMatrix dtrain, int iter) throws XGBoostError {
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = this.predict(dtrain, true, 0, false);
float[][] predicts = this.predict(dtrain, true, 0, false, false);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
}
Expand Down Expand Up @@ -219,19 +219,24 @@ public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eva
* @param outputMargin output margin
* @param treeLimit limit number of trees, 0 means all trees.
* @param predLeaf prediction minimum to keep leafs
* @param predContribs prediction feature contributions
* @return predict results
*/
private synchronized float[][] predict(DMatrix data,
boolean outputMargin,
int treeLimit,
boolean predLeaf) throws XGBoostError {
boolean predLeaf,
boolean predContribs) throws XGBoostError {
int optionMask = 0;
if (outputMargin) {
optionMask = 1;
}
if (predLeaf) {
optionMask = 2;
}
if (predContribs) {
optionMask = 4;
}
float[][] rawPredicts = new float[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
treeLimit, rawPredicts));
Expand All @@ -256,7 +261,19 @@ private synchronized float[][] predict(DMatrix data,
* @throws XGBoostError
*/
public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError {
return this.predict(data, false, treeLimit, true);
return this.predict(data, false, treeLimit, true, false);
}

/**
* Output feature contributions toward predictions of given data
*
* @param data The input data.
* @param treeLimit Number of trees to include, 0 means all trees.
* @return The feature contributions and bias.
* @throws XGBoostError
*/
public float[][] predictContrib(DMatrix data, int treeLimit) throws XGBoostError {
return this.predict(data, false, treeLimit, true, true);
}

/**
Expand All @@ -267,7 +284,7 @@ public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError {
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data) throws XGBoostError {
return this.predict(data, false, 0, false);
return this.predict(data, false, 0, false, false);
}

/**
Expand All @@ -278,7 +295,7 @@ public float[][] predict(DMatrix data) throws XGBoostError {
* @return predict results
*/
public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError {
return this.predict(data, outputMargin, 0, false);
return this.predict(data, outputMargin, 0, false, false);
}

/**
Expand All @@ -290,7 +307,7 @@ public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError
* @return predict results
*/
public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError {
return this.predict(data, outputMargin, treeLimit, false);
return this.predict(data, outputMargin, treeLimit, false, false);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,23 @@ class Booster private[xgboost4j](private var booster: JBooster)
* @throws XGBoostError native error
*/
@throws(classOf[XGBoostError])
def predictLeaf(data: DMatrix, treeLimit: Int = 0)
: Array[Array[Float]] = {
def predictLeaf(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
booster.predictLeaf(data.jDMatrix, treeLimit)
}

/**
* Output feature contributions toward predictions of given data
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return The feature contributions and bias.
* @throws XGBoostError native error
*/
@throws(classOf[XGBoostError])
def predictContrib(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
booster.predictContrib(data.jDMatrix, treeLimit)
}

/**
* save model to modelPath
*
Expand Down

0 comments on commit 2911597

Please sign in to comment.