Skip to content

Commit

Permalink
[SPARK-42526][ML] Add Classifier.getNumClasses back
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add Classifier.getNumClasses back

### Why are the changes needed?
some famous libraries like `xgboost` happen to depend on this method, even though it is not a public API
so it should be nice to make xgboost integration better.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
update mima

Closes apache#40119 from zhengruifeng/ml_add_classifier_get_num_class.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 22, 2023
1 parent 054522b commit a6098be
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,25 @@ abstract class Classifier[
M <: ClassificationModel[FeaturesType, M]]
extends Predictor[FeaturesType, E, M] with ClassifierParams {

/**
* Get the number of classes. This looks in column metadata first, and if that is missing,
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
* by finding the maximum label value.
*
* Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
* such as in `extractLabeledPoints()`.
*
* @param dataset Dataset which contains a column [[labelCol]]
* @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
* is specified in the metadata, then maxNumClasses is ignored.
* @return number of classes
* @throws IllegalArgumentException if metadata does not specify numClasses, and the
* actual numClasses exceeds maxNumClasses
*/
protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
DatasetUtils.getNumClasses(dataset, $(labelCol), maxNumClasses)
}

/** @group setParam */
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses = getNumClasses(dataset, $(labelCol))
val numClasses = getNumClasses(dataset)

if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class RandomForestClassifier @Since("1.4.0") (
instr.logDataset(dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses = getNumClasses(dataset, $(labelCol))
val numClasses = getNumClasses(dataset)

if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
Expand Down
2 changes: 0 additions & 2 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractLabeledPoints"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateNumClasses"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateLabel"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses$default$2"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRest.extractInstances"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.extractInstances"),

Expand Down

0 comments on commit a6098be

Please sign in to comment.