Skip to content

Commit

Permalink
[SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise …
Browse files Browse the repository at this point in the history
…error on bad input

In the Python API for Gaussian Mixture Model, predict() and predictSoft() methods should raise an error when the input argument is not an RDD.

Author: FlytxtRnD <[email protected]>

Closes apache#6180 from FlytxtRnD/GmmPredictException and squashes the following commits:

4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not an RDD
  • Loading branch information
FlytxtRnD authored and jkbradley committed May 15, 2015
1 parent f96b85a commit 8f4aaba
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def predict(self, x):
if isinstance(x, RDD):
cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
return cluster_labels
else:
raise TypeError("x should be represented by an RDD, "
"but got %s." % type(x))

def predictSoft(self, x):
"""
Expand All @@ -225,6 +228,9 @@ def predictSoft(self, x):
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
_convert_to_vector(self._weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
else:
raise TypeError("x should be represented by an RDD, "
"but got %s." % type(x))


class GaussianMixture(object):
Expand Down

0 comments on commit 8f4aaba

Please sign in to comment.