Skip to content

Commit

Permalink
[SPARK-16348][ML][MLLIB][PYTHON] Use full classpaths for pyspark ML J…
Browse files Browse the repository at this point in the history
…VM calls

## What changes were proposed in this pull request?

Issue: Omitting the full classpath can cause problems when calling JVM methods or classes from pyspark.

This PR: Changed all uses of jvm.X in pyspark.ml and pyspark.mllib to use full classpath for X

## How was this patch tested?

Existing unit tests.  Manual testing in an environment where this was an issue.

Author: Joseph K. Bradley <[email protected]>

Closes apache#14023 from jkbradley/SPARK-16348.
  • Loading branch information
jkbradley committed Jul 6, 2016
1 parent 59f9c1b commit fdde7d0
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 26 deletions.
10 changes: 5 additions & 5 deletions python/pyspark/ml/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _to_java_object_rdd(rdd):
RDD is serialized in batch or not.
"""
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True)
return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True)


def _py2java(sc, obj):
Expand All @@ -82,7 +82,7 @@ def _py2java(sc, obj):
pass
else:
data = bytearray(PickleSerializer().dumps(obj))
obj = sc._jvm.MLSerDe.loads(data)
obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data)
return obj


Expand All @@ -95,17 +95,17 @@ def _java2py(sc, r, encoding="bytes"):
clsName = 'JavaRDD'

if clsName == 'JavaRDD':
jrdd = sc._jvm.MLSerDe.javaToPython(r)
jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r)
return RDD(jrdd, sc)

if clsName == 'Dataset':
return DataFrame(r, SQLContext.getOrCreate(sc))

if clsName in _picklable_classes:
r = sc._jvm.MLSerDe.dumps(r)
r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
elif isinstance(r, (JavaArray, JavaList)):
try:
r = sc._jvm.MLSerDe.dumps(r)
r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
except Py4JJavaError:
pass # not pickable

Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,12 +1195,12 @@ class VectorTests(MLlibTestCase):

def _test_serialize(self, v):
self.assertEqual(v, ser.loads(ser.dumps(v)))
jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec)))
jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec)))
self.assertEqual(v, nv)
vs = [v] * 100
jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs)))
nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs)))
jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs)))
nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs)))
self.assertEqual(vs, nvs)

def test_serialize(self):
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def load(cls, sc, path):
Path to where the model is stored.
"""
model = cls._load_java(sc, path)
wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
wrapper = sc._jvm.org.apache.spark.mllib.api.python.GaussianMixtureModelWrapper(model)
return cls(wrapper)


Expand Down Expand Up @@ -638,7 +638,8 @@ def load(cls, sc, path):
Load a model from the given path.
"""
model = cls._load_java(sc, path)
wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model)
wrapper =\
sc._jvm.org.apache.spark.mllib.api.python.PowerIterationClusteringModelWrapper(model)
return PowerIterationClusteringModel(wrapper)


Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/mllib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _to_java_object_rdd(rdd):
RDD is serialized in batch or not.
"""
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True)


def _py2java(sc, obj):
Expand All @@ -85,7 +85,7 @@ def _py2java(sc, obj):
pass
else:
data = bytearray(PickleSerializer().dumps(obj))
obj = sc._jvm.SerDe.loads(data)
obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data)
return obj


Expand All @@ -98,17 +98,17 @@ def _java2py(sc, r, encoding="bytes"):
clsName = 'JavaRDD'

if clsName == 'JavaRDD':
jrdd = sc._jvm.SerDe.javaToPython(r)
jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r)
return RDD(jrdd, sc)

if clsName == 'Dataset':
return DataFrame(r, SQLContext.getOrCreate(sc))

if clsName in _picklable_classes:
r = sc._jvm.SerDe.dumps(r)
r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
elif isinstance(r, (JavaArray, JavaList)):
try:
r = sc._jvm.SerDe.dumps(r)
r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
except Py4JJavaError:
pass # not pickable

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def load(cls, sc, path):
"""
jmodel = sc._jvm.org.apache.spark.mllib.feature \
.Word2VecModel.load(sc._jsc.sc(), path)
model = sc._jvm.Word2VecModelWrapper(jmodel)
model = sc._jvm.org.apache.spark.mllib.api.python.Word2VecModelWrapper(jmodel)
return Word2VecModel(model)


Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/mllib/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def load(cls, sc, path):
Load a model from the given path.
"""
model = cls._load_java(sc, path)
wrapper = sc._jvm.FPGrowthModelWrapper(model)
wrapper = sc._jvm.org.apache.spark.mllib.api.python.FPGrowthModelWrapper(model)
return FPGrowthModel(wrapper)


Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/mllib/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def rank(self):
def load(cls, sc, path):
"""Load a model from the given path"""
model = cls._load_java(sc, path)
wrapper = sc._jvm.MatrixFactorizationModelWrapper(model)
wrapper = sc._jvm.org.apache.spark.mllib.api.python.MatrixFactorizationModelWrapper(model)
return MatrixFactorizationModel(wrapper)


Expand Down
15 changes: 8 additions & 7 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ class VectorTests(MLlibTestCase):

def _test_serialize(self, v):
self.assertEqual(v, ser.loads(ser.dumps(v)))
jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
self.assertEqual(v, nv)
vs = [v] * 100
jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs)))
nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs)))
self.assertEqual(vs, nvs)

def test_serialize(self):
Expand Down Expand Up @@ -1650,16 +1650,17 @@ class ALSTests(MLlibTestCase):

def test_als_ratings_serialize(self):
r = Rating(7, 1123, 3.14)
jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r)))
nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr)))
jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
self.assertEqual(r.user, nr.user)
self.assertEqual(r.product, nr.product)
self.assertAlmostEqual(r.rating, nr.rating, 2)

def test_als_ratings_id_long_error(self):
r = Rating(1205640308657491975, 50233468418, 1.0)
# rating user id exceeds max int value, should fail when pickled
self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,
bytearray(ser.dumps(r)))


class HashingTFTest(MLlibTestCase):
Expand Down

0 comments on commit fdde7d0

Please sign in to comment.