Skip to content

Commit

Permalink
[SPARK-15738][PYSPARK][ML] Adding Pyspark ml RFormula __str__ method …
Browse files Browse the repository at this point in the history
…similar to Scala API

## What changes were proposed in this pull request?
Adding __str__ to RFormula and model that will show the set formula param and resolved formula.  This is currently present in the Scala API, found missing in PySpark during Spark 2.0 coverage review.

## How was this patch tested?
run pyspark-ml tests locally

Author: Bryan Cutler <[email protected]>

Closes apache#13481 from BryanCutler/pyspark-ml-rformula_str-SPARK-15738.
  • Loading branch information
BryanCutler authored and yanboliang committed Jun 10, 2016
1 parent 254bc8c commit 7d7a0a5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class RFormula(override val uid: String)

override def copy(extra: ParamMap): RFormula = defaultCopy(extra)

override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)"
override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)"
}

@Since("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,19 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
* @param hasIntercept whether the formula specifies fitting with an intercept.
*/
private[ml] case class ResolvedRFormula(
label: String, terms: Seq[Seq[String]], hasIntercept: Boolean)
label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) {

override def toString: String = {
val ts = terms.map {
case t if t.length > 1 =>
s"${t.mkString("{", ",", "}")}"
case t =>
t.mkString
}
val termStr = ts.mkString("[", ",", "]")
s"ResolvedRFormula(label=$label, terms=$termStr, hasIntercept=$hasIntercept)"
}
}

/**
* R formula terms. See the R formula docs here for more information:
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2528,6 +2528,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
True
>>> loadedRF.getLabelCol() == rf.getLabelCol()
True
>>> str(loadedRF)
'RFormula(y ~ x + s) (uid=...)'
>>> modelPath = temp_path + "/rFormulaModel"
>>> model.save(modelPath)
>>> loadedModel = RFormulaModel.load(modelPath)
Expand All @@ -2542,6 +2544,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
|0.0|0.0| a|[0.0,1.0]| 0.0|
+---+---+---+---------+-----+
...
>>> str(loadedModel)
'RFormulaModel(ResolvedRFormula(label=y, terms=[x,s], hasIntercept=true)) (uid=...)'
.. versionadded:: 1.5.0
"""
Expand Down Expand Up @@ -2586,6 +2590,10 @@ def getFormula(self):
def _create_model(self, java_model):
return RFormulaModel(java_model)

def __str__(self):
formulaStr = self.getFormula() if self.isDefined(self.formula) else ""
return "RFormula(%s) (uid=%s)" % (formulaStr, self.uid)


class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
Expand All @@ -2597,6 +2605,10 @@ class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
.. versionadded:: 1.5.0
"""

def __str__(self):
resolvedFormula = self._call_java("resolvedFormula")
return "RFormulaModel(%s) (uid=%s)" % (resolvedFormula, self.uid)


@inherit_doc
class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable,
Expand Down

0 comments on commit 7d7a0a5

Please sign in to comment.