Skip to content

Commit

Permalink
Added python binding for bulk recommendation
Browse files Browse the repository at this point in the history
  • Loading branch information
falaki committed Jan 5, 2014
1 parent dfe57fa commit 8d0c2f7
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,24 @@ class PythonMLLibAPI extends Serializable {
return new Rating(user, product, rating)
}

private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
val bb = ByteBuffer.wrap(tupleBytes)
bb.order(ByteOrder.nativeOrder())
val v1 = bb.getInt()
val v2 = bb.getInt()
(v1, v2)
}

private[spark] def serializeRating(rate: Rating): Array[Byte] = {
val bytes = new Array[Byte](24)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.putDouble(rate.user.toDouble)
bb.putDouble(rate.product.toDouble)
bb.putDouble(rate.rating)
bytes
}

/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.mllib.recommendation

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.api.python.PythonMLLibAPI

import org.jblas._
import java.nio.{ByteOrder, ByteBuffer}
import org.apache.spark.api.java.JavaRDD


/**
* Model representing the result of matrix factorization.
Expand Down Expand Up @@ -65,6 +67,12 @@ class MatrixFactorizationModel(
}
}

def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val pythonAPI = new PythonMLLibAPI()
val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))
}

// TODO: Figure out what other good bulk prediction methods would look like.
// Probably want a way to get the top users for a product or vice-versa.
}
10 changes: 10 additions & 0 deletions python/pyspark/mllib/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,16 @@ def _serialize_rating(r):
intpart[0], intpart[1], doublepart[0] = r
return ba

def _deserialize_rating(ba):
ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C')
return ar.copy()

def _serialize_tuple(t):
ba = bytearray(8)
intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
intpart[0], intpart[1] = t
return ba

def _test():
import doctest
globs = globals().copy()
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/mllib/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
_get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
_serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \
_get_initial_weights, _serialize_rating, _regression_train_wrapper
_get_initial_weights, _serialize_rating, _regression_train_wrapper, \
_serialize_tuple, _deserialize_rating
from pyspark.serializers import BatchedSerializer
from pyspark.rdd import RDD

class MatrixFactorizationModel(object):
"""A matrix factorisation model trained by regularized alternating
Expand All @@ -45,6 +48,11 @@ def __del__(self):
def predict(self, user, product):
return self._java_model.predict(user, product)

def predictAll(self, usersProducts):
usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd),
self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize))

class ALS(object):
@classmethod
def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
Expand Down

0 comments on commit 8d0c2f7

Please sign in to comment.