Skip to content

Commit

Permalink
[SPARK-19595][SQL] Support json array in from_json
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR proposes to both,

**Do not allow json arrays with multiple elements and return null in `from_json` with `StructType` as the schema.**

Currently, it only reads the single row when the input is a json array. So, the codes below:

```scala
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val schema = StructType(StructField("a", IntegerType) :: Nil)
Seq(("""[{"a": 1}, {"a": 2}]""")).toDF("struct").select(from_json(col("struct"), schema)).show()
```
prints

```
+--------------------+
|jsontostruct(struct)|
+--------------------+
|                 [1]|
+--------------------+
```

This PR simply suggests to print this as `null` if the schema is `StructType` and input is json array.with multiple elements

```
+--------------------+
|jsontostruct(struct)|
+--------------------+
|                null|
+--------------------+
```

**Support json arrays in `from_json` with `ArrayType` as the schema.**

```scala
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
Seq(("""[{"a": 1}, {"a": 2}]""")).toDF("array").select(from_json(col("array"), schema)).show()
```

prints

```
+-------------------+
|jsontostruct(array)|
+-------------------+
|         [[1], [2]]|
+-------------------+
```

## How was this patch tested?

Unit test in `JsonExpressionsSuite`, `JsonFunctionsSuite`, Python doctests and manual test.

Author: hyukjinkwon <[email protected]>

Closes apache#16929 from HyukjinKwon/disallow-array.
  • Loading branch information
HyukjinKwon authored and brkyvz committed Mar 5, 2017
1 parent 80d5338 commit 369a148
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 17 deletions.
11 changes: 8 additions & 3 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,11 +1773,11 @@ def json_tuple(col, *fields):
@since(2.1)
def from_json(col, schema, options={}):
"""
Parses a column containing a JSON string into a [[StructType]] with the
specified schema. Returns `null`, in the case of an unparseable string.
Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]]
with the specified schema. Returns `null`, in the case of an unparseable string.
:param col: string column in json format
:param schema: a StructType to use when parsing the json column
:param schema: a StructType or ArrayType to use when parsing the json column
:param options: options to control parsing. accepts the same options as the json datasource
>>> from pyspark.sql.types import *
Expand All @@ -1786,6 +1786,11 @@ def from_json(col, schema, options={}):
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_json(df.value, schema).alias("json")).collect()
[Row(json=Row(a=1))]
>>> data = [(1, '''[{"a": 1}]''')]
>>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_json(df.value, schema).alias("json")).collect()
[Row(json=[Row(a=1)])]
"""

sc = SparkContext._active_spark_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.util.ParseModes
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ParseModes}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -480,23 +480,45 @@ case class JsonTuple(children: Seq[Expression])
}

/**
* Converts an json input string to a [[StructType]] with the specified schema.
* Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema.
*/
case class JsonToStruct(
schema: StructType,
schema: DataType,
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

def this(schema: StructType, options: Map[String, String], child: Expression) =
def this(schema: DataType, options: Map[String, String], child: Expression) =
this(schema, options, child, None)

override def checkInputDataTypes(): TypeCheckResult = schema match {
case _: StructType | ArrayType(_: StructType, _) =>
super.checkInputDataTypes()
case _ => TypeCheckResult.TypeCheckFailure(
s"Input schema ${schema.simpleString} must be a struct or an array of structs.")
}

@transient
lazy val rowSchema = schema match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
}

// This converts parsed rows to the desired output by the given schema.
@transient
lazy val converter = schema match {
case _: StructType =>
(rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
case ArrayType(_: StructType, _) =>
(rows: Seq[InternalRow]) => new GenericArrayData(rows)
}

@transient
lazy val parser =
new JacksonParser(
schema,
rowSchema,
new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))

override def dataType: DataType = schema
Expand All @@ -505,11 +527,32 @@ case class JsonToStruct(
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(json: Any): Any = {
// When input is,
// - `null`: `null`.
// - invalid json: `null`.
// - empty string: `null`.
//
// When the schema is array,
// - json array: `Array(Row(...), ...)`
// - json object: `Array(Row(...))`
// - empty json array: `Array()`.
// - empty json object: `Array(Row(null))`.
//
// When the schema is a struct,
// - json object/array with single element: `Row(...)`
// - json array with multiple elements: `null`
// - empty json array: `null`.
// - empty json object: `Row(null)`.

// We need `null` if the input string is an empty string. `JacksonParser` can
// deal with this but produces `Nil`.
if (json.toString.trim.isEmpty) return null

try {
parser.parse(
converter(parser.parse(
json.asInstanceOf[UTF8String],
CreateJacksonParser.utf8String,
identity[UTF8String]).headOption.orNull
identity[UTF8String]))
} catch {
case _: SparkSQLJsonProcessingException => null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Calendar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -372,6 +372,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
)
}

test("from_json - input=array, schema=array, output=array") {
val input = """[{"a": 1}, {"a": 2}]"""
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(1) :: InternalRow(2) :: Nil
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=object, schema=array, output=array of single row") {
val input = """{"a": 1}"""
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(1) :: Nil
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty array, schema=array, output=empty array") {
val input = "[ ]"
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = Nil
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty object, schema=array, output=array of single row with null") {
val input = "{ }"
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(null) :: Nil
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=array of single object, schema=struct, output=single row") {
val input = """[{"a": 1}]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = InternalRow(1)
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=array, schema=struct, output=null") {
val input = """[{"a": 1}, {"a": 2}]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = null
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty array, schema=struct, output=null") {
val input = """[]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = null
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty object, schema=struct, output=single row with null") {
val input = """{ }"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = InternalRow(null)
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json null input column") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
Expand Down
52 changes: 47 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2973,7 +2973,22 @@ object functions {
* @group collection_funcs
* @since 2.1.0
*/
def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr {
def from_json(e: Column, schema: StructType, options: Map[String, String]): Column =
from_json(e, schema.asInstanceOf[DataType], options)

/**
* (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
* with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
* @param options options to control how the json is parsed. accepts the same options and the
* json data source.
*
* @group collection_funcs
* @since 2.2.0
*/
def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr {
JsonToStruct(schema, options, e.expr)
}

Expand All @@ -2992,6 +3007,21 @@ object functions {
def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column =
from_json(e, schema, options.asScala.toMap)

/**
* (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
* with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
* @param options options to control how the json is parsed. accepts the same options and the
* json data source.
*
* @group collection_funcs
* @since 2.2.0
*/
def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column =
from_json(e, schema, options.asScala.toMap)

/**
* Parses a column containing a JSON string into a `StructType` with the specified schema.
* Returns `null`, in the case of an unparseable string.
Expand All @@ -3006,8 +3036,21 @@ object functions {
from_json(e, schema, Map.empty[String, String])

/**
* Parses a column containing a JSON string into a `StructType` with the specified schema.
* Returns `null`, in the case of an unparseable string.
* Parses a column containing a JSON string into a `StructType` or `ArrayType`
* with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
*
* @group collection_funcs
* @since 2.2.0
*/
def from_json(e: Column, schema: DataType): Column =
from_json(e, schema, Map.empty[String, String])

/**
* Parses a column containing a JSON string into a `StructType` or `ArrayType`
* with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string as a json string
Expand All @@ -3016,8 +3059,7 @@ object functions {
* @since 2.1.0
*/
def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column =
from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options)

from_json(e, DataType.fromJson(schema), options)

/**
* (Scala-specific) Converts a column containing a `StructType` into a JSON string with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql

import org.apache.spark.sql.functions.{from_json, struct, to_json}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType, TimestampType}
import org.apache.spark.sql.types._

class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -133,6 +133,29 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null) :: Nil)
}

test("from_json invalid schema") {
val df = Seq("""{"a" 1}""").toDS()
val schema = ArrayType(StringType)
val message = intercept[AnalysisException] {
df.select(from_json($"value", schema))
}.getMessage

assert(message.contains(
"Input schema array<string> must be a struct or an array of structs."))
}

test("from_json array support") {
val df = Seq("""[{"a": 1, "b": "a"}, {"a": 2}, { }]""").toDS()
val schema = ArrayType(
StructType(
StructField("a", IntegerType) ::
StructField("b", StringType) :: Nil))

checkAnswer(
df.select(from_json($"value", schema)),
Row(Seq(Row(1, "a"), Row(2, null), Row(null, null))))
}

test("to_json") {
val df = Seq(Tuple1(Tuple1(1))).toDF("a")

Expand Down

0 comments on commit 369a148

Please sign in to comment.