Skip to content

Commit

Permalink
Add uri to get_predicts in TextSet (intel#1821)
Browse files Browse the repository at this point in the history
* add uri in get predicts

* style

* update scala example

* update doc
hkvision authored Dec 6, 2019
1 parent 14eaefd commit 15a0c10
Showing 6 changed files with 37 additions and 25 deletions.
24 changes: 14 additions & 10 deletions pyzoo/test/zoo/feature/text/test_text_set.py
Original file line number Diff line number Diff line change
@@ -125,11 +125,13 @@ def test_local_textset_integration(self):
loaded_predicts = loaded_res_set.get_predicts()
assert len(predicts) == len(loaded_predicts)

for i in range(0, len(predicts)):
assert len(predicts[i]) == 1
assert len(loaded_predicts[i]) == 1
assert predicts[i][0].shape == (5, )
assert np.allclose(predicts[i][0], loaded_predicts[i][0])
for i in range(0, len(predicts)): # (uri, prediction)
assert not predicts[i][0]
assert not loaded_predicts[i][0] # uri is not recorded and thus None
assert len(predicts[i][1]) == 1
assert len(loaded_predicts[i][1]) == 1
assert predicts[i][1][0].shape == (5, )
assert np.allclose(predicts[i][1][0], loaded_predicts[i][1][0])
shutil.rmtree(tmp_log_dir)
shutil.rmtree(tmp_checkpoint_path)
os.remove(tmp_path)
@@ -173,9 +175,10 @@ def test_distributed_textset_integration(self):
model.fit(transformed, batch_size=2, nb_epoch=2)
res_set = model.predict(transformed, batch_per_thread=2)
predicts = res_set.get_predicts().collect()
for predict in predicts:
assert len(predict) == 1
assert predict[0].shape == (5, )
for predict in predicts: # (uri, prediction)
assert not predict[0] # uri is not recorded and thus None
assert len(predict[1]) == 1
assert predict[1][0].shape == (5, )

tmp_path = create_tmp_path() + ".bigdl"
model.save_model(tmp_path, over_write=True)
@@ -192,7 +195,7 @@ def test_read_local(self):
assert len(local_set.get_texts()) == 3
assert local_set.get_labels() == [0, 0, 1]
assert local_set.get_samples() == [None, None, None]
assert local_set.get_predicts() == [None, None, None]
assert local_set.get_predicts() == [(uri, None) for uri in local_set.get_uris()]

def test_read_distributed(self):
distributed_set = TextSet.read(self.path, self.sc, 4)
@@ -201,7 +204,8 @@ def test_read_distributed(self):
assert len(distributed_set.get_texts().collect()) == 3
assert sorted(distributed_set.get_labels().collect()) == [0, 0, 1]
assert distributed_set.get_samples().collect() == [None, None, None]
assert distributed_set.get_predicts().collect() == [None, None, None]
assert distributed_set.get_predicts().collect() ==\
[(uri, None) for uri in distributed_set.get_uris().collect()]

def test_read_csv_parquet(self):
text_set = TextSet.read_csv(self.qa_path + "/question_corpus.csv", self.sc)
Original file line number Diff line number Diff line change
@@ -82,7 +82,9 @@ def predict(record):
predicts = predict_set.get_predicts().collect()
print("Probability distributions of top-5:")
for p in predicts:
for k, v in sorted(enumerate(p[0]), key=lambda x: x[1])[:5]:
(uri, probs) = p
print("Predictions for " + uri + ": ")
for k, v in sorted(enumerate(probs[0]), key=lambda x: x[1])[:5]:
print(labels[k] + " " + str(v))

lines.foreachRDD(predict)
4 changes: 3 additions & 1 deletion pyzoo/zoo/examples/textclassification/text_classification.py
Original file line number Diff line number Diff line change
@@ -81,7 +81,9 @@
predicts = predict_set.get_predicts().take(5)
print("Probability distributions of the first five texts in the validation set:")
for predict in predicts:
print(predict)
(uri, probs) = predict
print("Prediction for " + uri + ": ")
print(probs)

if options.output_path:
model.save_model(options.output_path + "/text_classifier.model")
13 changes: 7 additions & 6 deletions pyzoo/zoo/feature/text/text_set.py
Original file line number Diff line number Diff line change
@@ -167,17 +167,18 @@ def get_labels(self):

def get_predicts(self):
"""
Get the prediction results of a TextSet (if any).
If a text hasn't been predicted by a model, its corresponding position will be None.
Get the prediction results (if any) combined with uris (if any) of a TextSet.
If a text doesn't have a uri, its corresponding uri will be None.
If a text hasn't been predicted by a model, its corresponding prediction will be None.
:return: List of list of numpy array for LocalTextSet.
RDD of list of numpy array for DistributedTextSet.
:return: List of (uri, prediction as a list of numpy array) for LocalTextSet.
RDD of (uri, prediction as a list of numpy array) for DistributedTextSet.
"""
predicts = callZooFunc(self.bigdl_type, "textSetGetPredicts", self.value)
if isinstance(predicts, RDD):
return predicts.map(lambda predict: _process_predict_result(predict))
return predicts.map(lambda predict: (predict[0], _process_predict_result(predict[1])))
else:
return [_process_predict_result(predict) for predict in predicts]
return [(predict[0], _process_predict_result(predict[1])) for predict in predicts]

def get_samples(self):
"""
Original file line number Diff line number Diff line change
@@ -127,7 +127,10 @@ object TextClassification {

val predictSet = model.predict(valTextSet, batchPerThread = param.partitionNum)
println("Probability distributions of the first five texts in the validation set:")
predictSet.toDistributed().rdd.take(5).map(_.getPredict.toTensor).foreach(println)
predictSet.toDistributed().rdd.take(5).foreach(feature => {
println("Prediction for " + feature.getURI + ": ")
println(feature.getPredict.toTensor)
})
if (param.outputPath.isDefined) {
val outputPath = param.outputPath.get
model.saveModel(outputPath + "/text_classifier.model")
Original file line number Diff line number Diff line change
@@ -209,24 +209,24 @@ class PythonTextFeature[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pyth
textSet.rdd.map(_.getLabel).toJavaRDD()
}

def textSetGetPredicts(textSet: LocalTextSet): JList[JList[JTensor]] = {
def textSetGetPredicts(textSet: LocalTextSet): JList[JList[Any]] = {
textSet.array.map{feature =>
if (feature.contains(TextFeature.predict)) {
activityToJTensors(feature[Activity](TextFeature.predict))
List[Any](feature.getURI, activityToJTensors(feature[Activity](TextFeature.predict))).asJava
}
else {
null
List[Any](feature.getURI, null).asJava
}
}.toList.asJava
}

def textSetGetPredicts(textSet: DistributedTextSet): JavaRDD[JList[JTensor]] = {
def textSetGetPredicts(textSet: DistributedTextSet): JavaRDD[JList[Any]] = {
textSet.rdd.map{feature =>
if (feature.contains(TextFeature.predict)) {
activityToJTensors(feature[Activity](TextFeature.predict))
List[Any](feature.getURI, activityToJTensors(feature[Activity](TextFeature.predict))).asJava
}
else {
null
List[Any](feature.getURI, null).asJava
}
}.toJavaRDD()
}

0 comments on commit 15a0c10

Please sign in to comment.