Skip to content

Commit

Permalink
[SPARK-23138][ML][DOC] Multiclass logistic regression summary example…
Browse files Browse the repository at this point in the history
… and user guide

## What changes were proposed in this pull request?

User guide and examples are updated to reflect multiclass logistic regression summary which was added in [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139).

I did not make a separate summary example, but added the summary code to the multiclass example that already existed. I don't see the need for a separate example for the summary.

## How was this patch tested?

Docs and examples only. Ran all examples locally using spark-submit.

Author: sethah <[email protected]>

Closes apache#20332 from sethah/multiclass_summary_example.
  • Loading branch information
sethah authored and Nick Pentreath committed Jan 30, 2018
1 parent 8b98324 commit 5056877
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 33 deletions.
22 changes: 11 additions & 11 deletions docs/ml-classification-regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark
The `spark.ml` implementation of logistic regression also supports
extracting a summary of the model over the training set. Note that the
predictions and metrics which are stored as `DataFrame` in
`BinaryLogisticRegressionSummary` are annotated `@transient` and hence
`LogisticRegressionSummary` are annotated `@transient` and hence
only available on the driver.

<div class="codetabs">
Expand All @@ -97,10 +97,9 @@ only available on the driver.
[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
provides a summary for a
[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
Currently, only binary classification is supported and the
summary must be explicitly cast to
[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
This will likely change when multiclass classification is supported.
In the case of binary classification, certain additional metrics are
available, e.g. ROC curve. The binary summary can be accessed via the
`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).

Continuing the earlier example:

Expand All @@ -111,10 +110,9 @@ Continuing the earlier example:
[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html)
provides a summary for a
[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html).
Currently, only binary classification is supported and the
summary must be explicitly cast to
[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
Support for multiclass model summaries will be added in the future.
In the case of binary classification, certain additional metrics are
available, e.g. ROC curve. The binary summary can be accessed via the
`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).

Continuing the earlier example:

Expand All @@ -125,7 +123,8 @@ Continuing the earlier example:
[`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary)
provides a summary for a
[`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel).
Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future.
In the case of binary classification, certain additional metrics are
available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary).

Continuing the earlier example:

Expand Down Expand Up @@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin
**Examples**

The following example shows how to train a multiclass logistic regression
model with elastic net regularization.
model with elastic net regularization, as well as extract the multiclass
training summary for evaluating the model.

<div class="codetabs">

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
package org.apache.spark.examples.ml;

// $example on$
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
Expand Down Expand Up @@ -50,29 +49,23 @@ public static void main(String[] args) {
// $example on$
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
// example
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary();

// Obtain the loss per iteration.
double[] objectiveHistory = trainingSummary.objectiveHistory();
for (double lossPerIteration : objectiveHistory) {
System.out.println(lossPerIteration);
}

// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
// classification problem.
BinaryLogisticRegressionSummary binarySummary =
(BinaryLogisticRegressionSummary) trainingSummary;

// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
Dataset<Row> roc = binarySummary.roc();
Dataset<Row> roc = trainingSummary.roc();
roc.show();
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC());
System.out.println(trainingSummary.areaUnderROC());

// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
Dataset<Row> fMeasure = trainingSummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
.select("threshold").head().getDouble(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
// $example on$
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
Expand Down Expand Up @@ -48,6 +49,67 @@ public static void main(String[] args) {
// Print the coefficients and intercept for multinomial logistic regression
System.out.println("Coefficients: \n"
+ lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector());
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();

// Obtain the loss per iteration.
double[] objectiveHistory = trainingSummary.objectiveHistory();
for (double lossPerIteration : objectiveHistory) {
System.out.println(lossPerIteration);
}

// for multiclass, we can inspect metrics on a per-label basis
System.out.println("False positive rate by label:");
int i = 0;
double[] fprLabel = trainingSummary.falsePositiveRateByLabel();
for (double fpr : fprLabel) {
System.out.println("label " + i + ": " + fpr);
i++;
}

System.out.println("True positive rate by label:");
i = 0;
double[] tprLabel = trainingSummary.truePositiveRateByLabel();
for (double tpr : tprLabel) {
System.out.println("label " + i + ": " + tpr);
i++;
}

System.out.println("Precision by label:");
i = 0;
double[] precLabel = trainingSummary.precisionByLabel();
for (double prec : precLabel) {
System.out.println("label " + i + ": " + prec);
i++;
}

System.out.println("Recall by label:");
i = 0;
double[] recLabel = trainingSummary.recallByLabel();
for (double rec : recLabel) {
System.out.println("label " + i + ": " + rec);
i++;
}

System.out.println("F-measure by label:");
i = 0;
double[] fLabel = trainingSummary.fMeasureByLabel();
for (double f : fLabel) {
System.out.println("label " + i + ": " + f);
i++;
}

double accuracy = trainingSummary.accuracy();
double falsePositiveRate = trainingSummary.weightedFalsePositiveRate();
double truePositiveRate = trainingSummary.weightedTruePositiveRate();
double fMeasure = trainingSummary.weightedFMeasure();
double precision = trainingSummary.weightedPrecision();
double recall = trainingSummary.weightedRecall();
System.out.println("Accuracy: " + accuracy);
System.out.println("FPR: " + falsePositiveRate);
System.out.println("TPR: " + truePositiveRate);
System.out.println("F-measure: " + fMeasure);
System.out.println("Precision: " + precision);
System.out.println("Recall: " + recall);
// $example off$

spark.stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,44 @@
# Print the coefficients and intercept for multinomial logistic regression
print("Coefficients: \n" + str(lrModel.coefficientMatrix))
print("Intercept: " + str(lrModel.interceptVector))

trainingSummary = lrModel.summary

# Obtain the objective per iteration
objectiveHistory = trainingSummary.objectiveHistory
print("objectiveHistory:")
for objective in objectiveHistory:
print(objective)

# for multiclass, we can inspect metrics on a per-label basis
print("False positive rate by label:")
for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel):
print("label %d: %s" % (i, rate))

print("True positive rate by label:")
for i, rate in enumerate(trainingSummary.truePositiveRateByLabel):
print("label %d: %s" % (i, rate))

print("Precision by label:")
for i, prec in enumerate(trainingSummary.precisionByLabel):
print("label %d: %s" % (i, prec))

print("Recall by label:")
for i, rec in enumerate(trainingSummary.recallByLabel):
print("label %d: %s" % (i, rec))

print("F-measure by label:")
for i, f in enumerate(trainingSummary.fMeasureByLabel()):
print("label %d: %s" % (i, f))

accuracy = trainingSummary.accuracy
falsePositiveRate = trainingSummary.weightedFalsePositiveRate
truePositiveRate = trainingSummary.weightedTruePositiveRate
fMeasure = trainingSummary.weightedFMeasure()
precision = trainingSummary.weightedPrecision
recall = trainingSummary.weightedRecall
print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s"
% (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall))
# $example off$

spark.stop()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
import org.apache.spark.ml.classification.LogisticRegression
// $example off$
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.max
Expand Down Expand Up @@ -47,25 +47,20 @@ object LogisticRegressionSummaryExample {
// $example on$
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
// example
val trainingSummary = lrModel.summary
val trainingSummary = lrModel.binarySummary

// Obtain the objective per iteration.
val objectiveHistory = trainingSummary.objectiveHistory
println("objectiveHistory:")
objectiveHistory.foreach(loss => println(loss))

// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
// binary classification problem.
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]

// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
val roc = binarySummary.roc
val roc = trainingSummary.roc
roc.show()
println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
println(s"areaUnderROC: ${trainingSummary.areaUnderROC}")

// Set the model threshold to maximize F-Measure
val fMeasure = binarySummary.fMeasureByThreshold
val fMeasure = trainingSummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure)
.select("threshold").head().getDouble(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,49 @@ object MulticlassLogisticRegressionWithElasticNetExample {
// Print the coefficients and intercept for multinomial logistic regression
println(s"Coefficients: \n${lrModel.coefficientMatrix}")
println(s"Intercepts: \n${lrModel.interceptVector}")

val trainingSummary = lrModel.summary

// Obtain the objective per iteration
val objectiveHistory = trainingSummary.objectiveHistory
println("objectiveHistory:")
objectiveHistory.foreach(println)

// for multiclass, we can inspect metrics on a per-label basis
println("False positive rate by label:")
trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
println(s"label $label: $rate")
}

println("True positive rate by label:")
trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
println(s"label $label: $rate")
}

println("Precision by label:")
trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) =>
println(s"label $label: $prec")
}

println("Recall by label:")
trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) =>
println(s"label $label: $rec")
}


println("F-measure by label:")
trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) =>
println(s"label $label: $f")
}

val accuracy = trainingSummary.accuracy
val falsePositiveRate = trainingSummary.weightedFalsePositiveRate
val truePositiveRate = trainingSummary.weightedTruePositiveRate
val fMeasure = trainingSummary.weightedFMeasure
val precision = trainingSummary.weightedPrecision
val recall = trainingSummary.weightedRecall
println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" +
s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall")
// $example off$

spark.stop()
Expand Down

0 comments on commit 5056877

Please sign in to comment.