Skip to content

Commit

Permalink
Added Rating deserializer
Browse files Browse the repository at this point in the history
  • Loading branch information
falaki committed Jan 6, 2014
1 parent 11a93fb commit 04132ea
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@ class MatrixFactorizationModel(
}
}

def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
/**
* Predict the rating of many users for many products.
* This is a Java stub for python predictAll()
*
* @param usersProductsJRDD A JavaRDD with serialized tuples (user, product)
* @return JavaRDD of serialized Rating objects.
*/
def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val pythonAPI = new PythonMLLibAPI()
val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))
Expand Down
21 changes: 18 additions & 3 deletions python/pyspark/mllib/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
from pyspark import SparkContext

from pyspark.serializers import Serializer
import struct

# Double vector format:
#
# [8-byte 1] [8-byte length] [length*8 bytes of data]
Expand Down Expand Up @@ -213,9 +216,21 @@ def _serialize_rating(r):
intpart[0], intpart[1], doublepart[0] = r
return ba

def _deserialize_rating(ba):
ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C')
return ar.copy()
class RatingDeserializer(Serializer):
def loads(self, stream):
length = struct.unpack("!i", stream.read(4))[0]
ba = stream.read(length)
res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4)
return int(res[0]), int(res[1]), res[2]

def load_stream(self, stream):
while True:
try:
yield self.loads(stream)
except struct.error:
return
except EOFError:
return

def _serialize_tuple(t):
ba = bytearray(8)
Expand Down

0 comments on commit 04132ea

Please sign in to comment.