Skip to content

Commit

Permalink
[SPARK-17279][SQL] better error message for exceptions during ScalaUD…
Browse files Browse the repository at this point in the history
…F execution

## What changes were proposed in this pull request?

If `ScalaUDF` throws exceptions during executing user code, sometimes it's hard for users to figure out what's wrong, especially when they use Spark shell. An example
```
org.apache.spark.SparkException: Job aborted due to stage failure: Task 12 in stage 325.0 failed 4 times, most recent failure: Lost task 12.3 in stage 325.0 (TID 35622, 10.0.207.202): java.lang.NullPointerException
	at line8414e872fb8b42aba390efc153d1611a12.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$2.apply(<console>:40)
	at line8414e872fb8b42aba390efc153d1611a12.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$2.apply(<console>:40)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
...
```
We should catch these exceptions and rethrow them with better error message, to say that the exception is happened in scala udf.

This PR also does some clean up for `ScalaUDF` and add a unit test suite for it.

## How was this patch tested?

the new test suite

Author: Wenchen Fan <[email protected]>

Closes apache#14850 from cloud-fan/npe.
  • Loading branch information
cloud-fan committed Sep 6, 2016
1 parent 6d86403 commit 8d08f43
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -510,18 +510,18 @@ class ALSSuite
(1, 1L, 1d, 0, 0L, 0d, 5.0)
).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating")
withClue("fit should fail when ids exceed integer range. ") {
assert(intercept[IllegalArgumentException] {
assert(intercept[SparkException] {
als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
}.getMessage.contains("was out of Integer range"))
assert(intercept[IllegalArgumentException] {
}.getCause.getMessage.contains("was out of Integer range"))
assert(intercept[SparkException] {
als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
}.getMessage.contains("was out of Integer range"))
assert(intercept[IllegalArgumentException] {
}.getCause.getMessage.contains("was out of Integer range"))
assert(intercept[SparkException] {
als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
}.getMessage.contains("was out of Integer range"))
assert(intercept[IllegalArgumentException] {
}.getCause.getMessage.contains("was out of Integer range"))
assert(intercept[SparkException] {
als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
}.getMessage.contains("was out of Integer range"))
}.getCause.getMessage.contains("was out of Integer range"))
}
withClue("transform should fail when ids exceed integer range. ") {
val model = als.fit(df)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.DataType
Expand Down Expand Up @@ -994,20 +995,15 @@ case class ScalaUDF(
ctx: CodegenContext,
ev: ExprCode): ExprCode = {

ctx.references += this

val scalaUDFClassName = classOf[ScalaUDF].getName
val scalaUDF = ctx.addReferenceObj("scalaUDF", this)
val converterClassName = classOf[Any => Any].getName
val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
val expressionClassName = classOf[Expression].getName

// Generate codes used to convert the returned value of user-defined functions to Catalyst type
val catalystConverterTerm = ctx.freshName("catalystConverter")
val catalystConverterTermIdx = ctx.references.size - 1
ctx.addMutableState(converterClassName, catalystConverterTerm,
s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
s".createToCatalystConverter((($scalaUDFClassName)references" +
s"[$catalystConverterTermIdx]).dataType());")
s".createToCatalystConverter($scalaUDF.dataType());")

val resultTerm = ctx.freshName("result")

Expand All @@ -1019,10 +1015,8 @@ case class ScalaUDF(
val funcClassName = s"scala.Function${children.size}"

val funcTerm = ctx.freshName("udf")
val funcExpressionIdx = ctx.references.size - 1
ctx.addMutableState(funcClassName, funcTerm,
s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" +
s"[$funcExpressionIdx]).userDefinedFunc());")
s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")

// codegen for children expressions
val evals = children.map(_.genCode(ctx))
Expand All @@ -1039,9 +1033,16 @@ case class ScalaUDF(
(convert, argTerm)
}.unzip

val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " +
s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" +
s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));"
val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})"
val callFunc =
s"""
${ctx.boxedType(dataType)} $resultTerm = null;
try {
$resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
} catch (Exception e) {
throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
}
"""

ev.copy(code = s"""
$evalCode
Expand All @@ -1057,5 +1058,20 @@ case class ScalaUDF(

private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)

override def eval(input: InternalRow): Any = converter(f(input))
lazy val udfErrorMessage = {
val funcCls = function.getClass.getSimpleName
val inputTypes = children.map(_.dataType.simpleString).mkString(", ")
s"Failed to execute user defined function($funcCls: ($inputTypes) => ${dataType.simpleString})"
}

override def eval(input: InternalRow): Any = {
val result = try {
f(input)
} catch {
case e: Exception =>
throw new SparkException(udfErrorMessage, e)
}

converter(result)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("basic") {
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
checkEvaluation(intUdf, 2)

val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
checkEvaluation(stringUdf, "ax")
}

test("better error message for NPE") {
val udf = ScalaUDF(
(s: String) => s.toLowerCase,
StringType,
Literal.create(null, StringType) :: Nil)

val e1 = intercept[SparkException](udf.eval())
assert(e1.getMessage.contains("Failed to execute user defined function"))

val e2 = intercept[SparkException] {
checkEvalutionWithUnsafeProjection(udf, null)
}
assert(e2.getMessage.contains("Failed to execute user defined function"))
}

}

0 comments on commit 8d08f43

Please sign in to comment.