forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-8246] [SQL] Implement get_json_object
This is based on apache#7485 , thanks to NathanHowell Tests were copied from Hive, but do not seem to be super comprehensive. I've generally replicated Hive's unusual behavior rather than following a JSONPath reference, except for one case (as noted in the comments). I don't know if there is a way of fully replicating Hive's behavior without a slower TreeNode implementation, so I've erred on the side of performance instead. Author: Davies Liu <[email protected]> Author: Yin Huai <[email protected]> Author: Nathan Howell <[email protected]> Closes apache#7901 from davies/get_json_object and squashes the following commits: 3ace9b9 [Davies Liu] Merge branch 'get_json_object' of github.com:davies/spark into get_json_object 98766fc [Davies Liu] Merge branch 'master' of github.com:apache/spark into get_json_object a7dc6d0 [Davies Liu] Update JsonExpressionsSuite.scala c818519 [Yin Huai] new results. 18ce26b [Davies Liu] fix tests 6ac29fb [Yin Huai] Golden files. 25eebef [Davies Liu] use HiveQuerySuite e0ac6ec [Yin Huai] Golden answer files. 940c060 [Davies Liu] tweat code style 44084c5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into get_json_object 9192d09 [Nathan Howell] Match Hive’s behavior for unwrapping arrays of one element 8dab647 [Nathan Howell] [SPARK-8246] [SQL] Implement get_json_object
- Loading branch information
Showing
17 changed files
with
613 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
309 changes: 309 additions & 0 deletions
309
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonFunctions.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,309 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.expressions | ||
|
||
import java.io.{StringWriter, ByteArrayOutputStream} | ||
|
||
import com.fasterxml.jackson.core._ | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback | ||
import org.apache.spark.sql.types.{StringType, DataType} | ||
import org.apache.spark.unsafe.types.UTF8String | ||
|
||
import scala.util.parsing.combinator.RegexParsers | ||
|
||
private[this] sealed trait PathInstruction | ||
private[this] object PathInstruction { | ||
private[expressions] case object Subscript extends PathInstruction | ||
private[expressions] case object Wildcard extends PathInstruction | ||
private[expressions] case object Key extends PathInstruction | ||
private[expressions] case class Index(index: Long) extends PathInstruction | ||
private[expressions] case class Named(name: String) extends PathInstruction | ||
} | ||
|
||
private[this] sealed trait WriteStyle | ||
private[this] object WriteStyle { | ||
private[expressions] case object RawStyle extends WriteStyle | ||
private[expressions] case object QuotedStyle extends WriteStyle | ||
private[expressions] case object FlattenStyle extends WriteStyle | ||
} | ||
|
||
private[this] object JsonPathParser extends RegexParsers { | ||
import PathInstruction._ | ||
|
||
def root: Parser[Char] = '$' | ||
|
||
def long: Parser[Long] = "\\d+".r ^? { | ||
case x => x.toLong | ||
} | ||
|
||
// parse `[*]` and `[123]` subscripts | ||
def subscript: Parser[List[PathInstruction]] = | ||
for { | ||
operand <- '[' ~> ('*' ^^^ Wildcard | long ^^ Index) <~ ']' | ||
} yield { | ||
Subscript :: operand :: Nil | ||
} | ||
|
||
// parse `.name` or `['name']` child expressions | ||
def named: Parser[List[PathInstruction]] = | ||
for { | ||
name <- '.' ~> "[^\\.\\[]+".r | "[\\'" ~> "[^\\'\\?]+" <~ "\\']" | ||
} yield { | ||
Key :: Named(name) :: Nil | ||
} | ||
|
||
// child wildcards: `..`, `.*` or `['*']` | ||
def wildcard: Parser[List[PathInstruction]] = | ||
(".*" | "['*']") ^^^ List(Wildcard) | ||
|
||
def node: Parser[List[PathInstruction]] = | ||
wildcard | | ||
named | | ||
subscript | ||
|
||
val expression: Parser[List[PathInstruction]] = { | ||
phrase(root ~> rep(node) ^^ (x => x.flatten)) | ||
} | ||
|
||
def parse(str: String): Option[List[PathInstruction]] = { | ||
this.parseAll(expression, str) match { | ||
case Success(result, _) => | ||
Some(result) | ||
|
||
case NoSuccess(msg, next) => | ||
None | ||
} | ||
} | ||
} | ||
|
||
private[this] object GetJsonObject { | ||
private val jsonFactory = new JsonFactory() | ||
|
||
// Enabled for Hive compatibility | ||
jsonFactory.enable(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS) | ||
} | ||
|
||
/** | ||
* Extracts json object from a json string based on json path specified, and returns json string | ||
* of the extracted json object. It will return null if the input json string is invalid. | ||
*/ | ||
case class GetJsonObject(json: Expression, path: Expression) | ||
extends BinaryExpression with ExpectsInputTypes with CodegenFallback { | ||
|
||
import GetJsonObject._ | ||
import PathInstruction._ | ||
import WriteStyle._ | ||
import com.fasterxml.jackson.core.JsonToken._ | ||
|
||
override def left: Expression = json | ||
override def right: Expression = path | ||
override def inputTypes: Seq[DataType] = Seq(StringType, StringType) | ||
override def dataType: DataType = StringType | ||
override def prettyName: String = "get_json_object" | ||
|
||
@transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String]) | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val jsonStr = json.eval(input).asInstanceOf[UTF8String] | ||
if (jsonStr == null) { | ||
return null | ||
} | ||
|
||
val parsed = if (path.foldable) { | ||
parsedPath | ||
} else { | ||
parsePath(path.eval(input).asInstanceOf[UTF8String]) | ||
} | ||
|
||
if (parsed.isDefined) { | ||
try { | ||
val parser = jsonFactory.createParser(jsonStr.getBytes) | ||
val output = new ByteArrayOutputStream() | ||
val generator = jsonFactory.createGenerator(output, JsonEncoding.UTF8) | ||
parser.nextToken() | ||
val matched = evaluatePath(parser, generator, RawStyle, parsed.get) | ||
generator.close() | ||
if (matched) { | ||
UTF8String.fromBytes(output.toByteArray) | ||
} else { | ||
null | ||
} | ||
} catch { | ||
case _: JsonProcessingException => null | ||
} | ||
} else { | ||
null | ||
} | ||
} | ||
|
||
private def parsePath(path: UTF8String): Option[List[PathInstruction]] = { | ||
if (path != null) { | ||
JsonPathParser.parse(path.toString) | ||
} else { | ||
None | ||
} | ||
} | ||
|
||
// advance to the desired array index, assumes to start at the START_ARRAY token | ||
private def arrayIndex(p: JsonParser, f: () => Boolean): Long => Boolean = { | ||
case _ if p.getCurrentToken == END_ARRAY => | ||
// terminate, nothing has been written | ||
false | ||
|
||
case 0 => | ||
// we've reached the desired index | ||
val dirty = f() | ||
|
||
while (p.nextToken() != END_ARRAY) { | ||
// advance the token stream to the end of the array | ||
p.skipChildren() | ||
} | ||
|
||
dirty | ||
|
||
case i if i > 0 => | ||
// skip this token and evaluate the next | ||
p.skipChildren() | ||
p.nextToken() | ||
arrayIndex(p, f)(i - 1) | ||
} | ||
|
||
/** | ||
* Evaluate a list of JsonPath instructions, returning a bool that indicates if any leaf nodes | ||
* have been written to the generator | ||
*/ | ||
private def evaluatePath( | ||
p: JsonParser, | ||
g: JsonGenerator, | ||
style: WriteStyle, | ||
path: List[PathInstruction]): Boolean = { | ||
(p.getCurrentToken, path) match { | ||
case (VALUE_STRING, Nil) if style == RawStyle => | ||
// there is no array wildcard or slice parent, emit this string without quotes | ||
if (p.hasTextCharacters) { | ||
g.writeRaw(p.getTextCharacters, p.getTextOffset, p.getTextLength) | ||
} else { | ||
g.writeRaw(p.getText) | ||
} | ||
true | ||
|
||
case (START_ARRAY, Nil) if style == FlattenStyle => | ||
// flatten this array into the parent | ||
var dirty = false | ||
while (p.nextToken() != END_ARRAY) { | ||
dirty |= evaluatePath(p, g, style, Nil) | ||
} | ||
dirty | ||
|
||
case (_, Nil) => | ||
// general case: just copy the child tree verbatim | ||
g.copyCurrentStructure(p) | ||
true | ||
|
||
case (START_OBJECT, Key :: xs) => | ||
var dirty = false | ||
while (p.nextToken() != END_OBJECT) { | ||
if (dirty) { | ||
// once a match has been found we can skip other fields | ||
p.skipChildren() | ||
} else { | ||
dirty = evaluatePath(p, g, style, xs) | ||
} | ||
} | ||
dirty | ||
|
||
case (START_ARRAY, Subscript :: Wildcard :: Subscript :: Wildcard :: xs) => | ||
// special handling for the non-structure preserving double wildcard behavior in Hive | ||
var dirty = false | ||
g.writeStartArray() | ||
while (p.nextToken() != END_ARRAY) { | ||
dirty |= evaluatePath(p, g, FlattenStyle, xs) | ||
} | ||
g.writeEndArray() | ||
dirty | ||
|
||
case (START_ARRAY, Subscript :: Wildcard :: xs) if style != QuotedStyle => | ||
// retain Flatten, otherwise use Quoted... cannot use Raw within an array | ||
val nextStyle = style match { | ||
case RawStyle => QuotedStyle | ||
case FlattenStyle => FlattenStyle | ||
case QuotedStyle => throw new IllegalStateException() | ||
} | ||
|
||
// temporarily buffer child matches, the emitted json will need to be | ||
// modified slightly if there is only a single element written | ||
val buffer = new StringWriter() | ||
val flattenGenerator = jsonFactory.createGenerator(buffer) | ||
flattenGenerator.writeStartArray() | ||
|
||
var dirty = 0 | ||
while (p.nextToken() != END_ARRAY) { | ||
// track the number of array elements and only emit an outer array if | ||
// we've written more than one element, this matches Hive's behavior | ||
dirty += (if (evaluatePath(p, flattenGenerator, nextStyle, xs)) 1 else 0) | ||
} | ||
flattenGenerator.writeEndArray() | ||
flattenGenerator.close() | ||
|
||
val buf = buffer.getBuffer | ||
if (dirty > 1) { | ||
g.writeRawValue(buf.toString) | ||
} else if (dirty == 1) { | ||
// remove outer array tokens | ||
g.writeRawValue(buf.substring(1, buf.length()-1)) | ||
} // else do not write anything | ||
|
||
dirty > 0 | ||
|
||
case (START_ARRAY, Subscript :: Wildcard :: xs) => | ||
var dirty = false | ||
g.writeStartArray() | ||
while (p.nextToken() != END_ARRAY) { | ||
// wildcards can have multiple matches, continually update the dirty count | ||
dirty |= evaluatePath(p, g, QuotedStyle, xs) | ||
} | ||
g.writeEndArray() | ||
|
||
dirty | ||
|
||
case (START_ARRAY, Subscript :: Index(idx) :: (xs@Subscript :: Wildcard :: _)) => | ||
p.nextToken() | ||
// we're going to have 1 or more results, switch to QuotedStyle | ||
arrayIndex(p, () => evaluatePath(p, g, QuotedStyle, xs))(idx) | ||
|
||
case (START_ARRAY, Subscript :: Index(idx) :: xs) => | ||
p.nextToken() | ||
arrayIndex(p, () => evaluatePath(p, g, style, xs))(idx) | ||
|
||
case (FIELD_NAME, Named(name) :: xs) if p.getCurrentName == name => | ||
// exact field match | ||
p.nextToken() | ||
evaluatePath(p, g, style, xs) | ||
|
||
case (FIELD_NAME, Wildcard :: xs) => | ||
// wildcard field match | ||
p.nextToken() | ||
evaluatePath(p, g, style, xs) | ||
|
||
case _ => | ||
p.skipChildren() | ||
false | ||
} | ||
} | ||
} |
Oops, something went wrong.