Skip to content

Commit

Permalink
Refactor options to be managed in a single place.
Browse files Browse the repository at this point in the history
Currently, there are a lot of options but they are not managed in a place. Just like `JSONOptions` (SPARK-11745) in Spark Datasource, this PR also makes this library manage the options in `XmlOptions` class and object.

Author: hyukjinkwon <[email protected]>

Closes databricks#50 from HyukjinKwon/separate-options.
  • Loading branch information
HyukjinKwon committed Jan 13, 2016
1 parent 825f0ff commit ff0d067
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 182 deletions.
38 changes: 6 additions & 32 deletions src/main/scala/com/databricks/spark/xml/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,43 +51,17 @@ class DefaultSource
sqlContext: SQLContext,
parameters: Map[String, String],
schema: StructType): XmlRelation = {
def checkedCastToBoolean(value: String, name: String): Boolean = {
if (TypeCast.isBoolean(value)) {
value.toBoolean
}
else {
throw new Exception(s"$name can be only true or false")
}
}

val path = checkPath(parameters)

// TODO Support different encoding types.
val charset = parameters.getOrElse("charset", XmlFile.DEFAULT_CHARSET.name())
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
val rowTag = parameters.getOrElse("rowTag", XmlFile.DEFAULT_ROW_TAG)
val attributePrefix = parameters.getOrElse("attributePrefix", XmlFile.DEFAULT_ATTRIBUTE_PREFIX)
val valueTag = parameters.getOrElse("valueTag", XmlFile.DEFAULT_VALUE_TAG)

val failFast = parameters.getOrElse("failFast", "false")
val failFastFlag = checkedCastToBoolean(failFast, "failFast")

val excludeAttribute = parameters.getOrElse("excludeAttribute", "false")
val excludeAttributeFlag = checkedCastToBoolean(excludeAttribute, "excludeAttribute")

val treatEmptyValuesAsNulls = parameters.getOrElse("treatEmptyValuesAsNulls", "false")
val treatEmptyValuesAsNullsFlag =
checkedCastToBoolean(treatEmptyValuesAsNulls, "treatEmptyValuesAsNulls")
// We need the `charset` and `rowTag` before creating the relation.
val (charset, rowTag) = {
val options = XmlOptions.createFromConfigMap(parameters)
(options.charset, options.rowTag)
}

XmlRelation(
() => XmlFile.withCharset(sqlContext.sparkContext, path, charset, rowTag),
Some(path),
samplingRatio,
excludeAttributeFlag,
treatEmptyValuesAsNullsFlag,
failFastFlag,
attributePrefix,
valueTag,
parameters,
schema)(sqlContext)
}

Expand Down
57 changes: 57 additions & 0 deletions src/main/scala/com/databricks/spark/xml/XmlOptions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.xml

/**
* Options for the XML data source.
*/
private[xml] case class XmlOptions(
charset: String = XmlOptions.DEFAULT_CHARSET,
rowTag: String = XmlOptions.DEFAULT_ROW_TAG,
rootTag: String = XmlOptions.DEFAULT_ROOT_TAG,
samplingRatio: Double = 1.0,
excludeAttributeFlag: Boolean = false,
treatEmptyValuesAsNulls: Boolean = false,
failFastFlag: Boolean = false,
attributePrefix: String = XmlOptions.DEFAULT_ATTRIBUTE_PREFIX,
valueTag: String = XmlOptions.DEFAULT_VALUE_TAG,
nullValue: String = XmlOptions.DEFAULT_NULL_VALUE
)

private[xml] object XmlOptions {
val DEFAULT_ATTRIBUTE_PREFIX = "@"
val DEFAULT_VALUE_TAG = "#VALUE"
val DEFAULT_ROW_TAG = "ROW"
val DEFAULT_ROOT_TAG = "ROWS"
val DEFAULT_CHARSET = "UTF-8"
val DEFAULT_NULL_VALUE = "null"

def createFromConfigMap(parameters: Map[String, String]): XmlOptions = XmlOptions(
// TODO Support different encoding types.
// TODO validate encoidng types. maybe with Charset.forname()
charset = parameters.getOrElse("charset", DEFAULT_CHARSET),
rowTag = parameters.getOrElse("rowTag", DEFAULT_ROW_TAG),
rootTag = parameters.getOrElse("rootTag", DEFAULT_ROOT_TAG),
samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0),
excludeAttributeFlag = parameters.get("excludeAttribute").map(_.toBoolean).getOrElse(false),
treatEmptyValuesAsNulls =
parameters.get("treatEmptyValuesAsNulls").map(_.toBoolean).getOrElse(false),
failFastFlag = parameters.get("failFast").map(_.toBoolean).getOrElse(false),
attributePrefix = parameters.getOrElse("attributePrefix", DEFAULT_ATTRIBUTE_PREFIX),
valueTag = parameters.getOrElse("valueTag", DEFAULT_VALUE_TAG),
nullValue = parameters.getOrElse("nullValue", DEFAULT_NULL_VALUE)
)
}
44 changes: 16 additions & 28 deletions src/main/scala/com/databricks/spark/xml/XmlReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,53 +24,46 @@ import com.databricks.spark.xml.util.XmlFile
* A collection of static functions for working with XML files in Spark SQL
*/
class XmlReader extends Serializable {
private var charset: String = XmlFile.DEFAULT_CHARSET.name()
private var rowTag: String = XmlFile.DEFAULT_ROW_TAG
private var samplingRatio: Double = 1.0
private var excludeAttributeFlag: Boolean = false
private var treatEmptyValuesAsNulls: Boolean = false
private var failFastFlag: Boolean = false
private var attributePrefix: String = XmlFile.DEFAULT_ATTRIBUTE_PREFIX
private var valueTag: String = XmlFile.DEFAULT_VALUE_TAG
private var parameters = collection.mutable.Map.empty[String, String]
private var schema: StructType = null

def withCharset(charset: String): XmlReader = {
this.charset = charset
parameters += ("charset" -> charset)
this
}

def withRowTag(rowTag: String): XmlReader = {
this.rowTag = rowTag
parameters += ("rowTag" -> rowTag)
this
}

def withSamplingRatio(samplingRatio: Double): XmlReader = {
this.samplingRatio = samplingRatio
parameters += ("samplingRatio" -> samplingRatio.toString)
this
}

def withExcludeAttribute(exclude: Boolean): XmlReader = {
this.excludeAttributeFlag = exclude
parameters += ("excludeAttribute" -> exclude.toString)
this
}

def withTreatEmptyValuesAsNulls(treatAsNull: Boolean): XmlReader = {
this.treatEmptyValuesAsNulls = treatAsNull
parameters += ("treatEmptyValuesAsNulls" -> treatAsNull.toString)
this
}

def withFailFast(failFast: Boolean): XmlReader = {
this.failFastFlag = failFast
parameters += ("failFast" -> failFast.toString)
this
}

def withAttributePrefix(attributePrefix: String): XmlReader = {
this.attributePrefix = attributePrefix
parameters += ("attributePrefix" -> attributePrefix)
this
}

def withValueTag(valueTag: String): XmlReader = {
this.valueTag = valueTag
parameters += ("valueTag" -> valueTag)
this
}

Expand All @@ -82,15 +75,15 @@ class XmlReader extends Serializable {
/** Returns a Schema RDD for the given XML path. */
@throws[RuntimeException]
def xmlFile(sqlContext: SQLContext, path: String): DataFrame = {
// We need the `charset` and `rowTag` before creating the relation.
val (charset, rowTag) = {
val options = XmlOptions.createFromConfigMap(parameters.toMap)
(options.charset, options.rowTag)
}
val relation: XmlRelation = XmlRelation(
() => XmlFile.withCharset(sqlContext.sparkContext, path, charset, rowTag),
Some(path),
samplingRatio,
excludeAttributeFlag,
treatEmptyValuesAsNulls,
failFastFlag,
attributePrefix,
valueTag,
parameters.toMap,
schema)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}
Expand All @@ -99,12 +92,7 @@ class XmlReader extends Serializable {
val relation: XmlRelation = XmlRelation(
() => xmlRDD,
None,
samplingRatio,
excludeAttributeFlag,
treatEmptyValuesAsNulls,
failFastFlag,
attributePrefix,
valueTag,
parameters.toMap,
schema)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}
Expand Down
22 changes: 5 additions & 17 deletions src/main/scala/com/databricks/spark/xml/XmlRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,46 +25,34 @@ import org.apache.spark.sql._
import org.apache.spark.sql.sources.{InsertableRelation, BaseRelation, TableScan}
import org.apache.spark.sql.types._
import com.databricks.spark.xml.util.InferSchema
import com.databricks.spark.xml.parsers.{StaxXmlParser, StaxConfiguration}
import com.databricks.spark.xml.parsers.StaxXmlParser

case class XmlRelation protected[spark] (
baseRDD: () => RDD[String],
location: Option[String],
samplingRatio: Double,
excludeAttributeFlag: Boolean,
treatEmptyValuesAsNulls: Boolean,
failFastFlag: Boolean,
attributePrefix: String,
valueTag: String,
parameters: Map[String, String],
userSchema: StructType = null)(@transient val sqlContext: SQLContext)
extends BaseRelation
with InsertableRelation
with TableScan {

private val logger = LoggerFactory.getLogger(XmlRelation.getClass)

private val parseConf = StaxConfiguration(
samplingRatio,
excludeAttributeFlag,
treatEmptyValuesAsNulls,
failFastFlag,
attributePrefix,
valueTag
)
private val options = XmlOptions.createFromConfigMap(parameters)

override val schema: StructType = {
Option(userSchema).getOrElse {
InferSchema.infer(
baseRDD(),
parseConf)
options)
}
}

override def buildScan: RDD[Row] = {
StaxXmlParser.parse(
baseRDD(),
schema,
parseConf)
options)
}

// The function below was borrowed from JSONRelation
Expand Down
47 changes: 21 additions & 26 deletions src/main/scala/com/databricks/spark/xml/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ package com.databricks.spark
import java.io.CharArrayWriter
import javax.xml.stream.XMLOutputFactory

import com.databricks.spark.xml.parsers.StaxXmlGenerator

import scala.collection.Map

import com.sun.xml.internal.txw2.output.IndentingXMLStreamWriter
import org.apache.hadoop.io.compress.CompressionCodec

import org.apache.spark.sql.{DataFrame, SQLContext}
import com.databricks.spark.xml.util.XmlFile
import com.databricks.spark.xml.parsers.StaxXmlGenerator

package object xml {
/**
Expand All @@ -35,24 +34,28 @@ package object xml {
implicit class XmlContext(sqlContext: SQLContext) extends Serializable {
def xmlFile(
filePath: String,
rowTag: String = XmlFile.DEFAULT_ROW_TAG,
rowTag: String = XmlOptions.DEFAULT_ROW_TAG,
samplingRatio: Double = 1.0,
excludeAttributeFlag: Boolean = false,
excludeAttribute: Boolean = false,
treatEmptyValuesAsNulls: Boolean = false,
failFastFlag: Boolean = false,
attributePrefix: String = XmlFile.DEFAULT_ATTRIBUTE_PREFIX,
valueTag: String = XmlFile.DEFAULT_VALUE_TAG,
charset: String = XmlFile.DEFAULT_CHARSET.name()): DataFrame = {
failFast: Boolean = false,
attributePrefix: String = XmlOptions.DEFAULT_ATTRIBUTE_PREFIX,
valueTag: String = XmlOptions.DEFAULT_VALUE_TAG,
charset: String = XmlOptions.DEFAULT_CHARSET): DataFrame = {

val parameters = Map(
"rowTag" -> rowTag,
"samplingRatio" -> samplingRatio.toString,
"excludeAttribute" -> excludeAttribute.toString,
"treatEmptyValuesAsNulls" -> treatEmptyValuesAsNulls.toString,
"failFast" -> failFast.toString,
"attributePrefix" -> attributePrefix,
"valueTag" -> valueTag,
"charset" -> charset)
val xmlRelation = XmlRelation(
() => XmlFile.withCharset(sqlContext.sparkContext, filePath, charset, rowTag),
location = Some(filePath),
samplingRatio = samplingRatio,
excludeAttributeFlag = excludeAttributeFlag,
treatEmptyValuesAsNulls = treatEmptyValuesAsNulls,
failFastFlag = failFastFlag,
attributePrefix = attributePrefix,
valueTag = valueTag)(sqlContext)
parameters = parameters.toMap)(sqlContext)
sqlContext.baseRelationToDataFrame(xmlRelation)
}
}
Expand All @@ -79,14 +82,9 @@ package object xml {
// Namely, roundtrip in writing and reading can end up in different schema structure.
def saveAsXmlFile(path: String, parameters: Map[String, String] = Map(),
compressionCodec: Class[_ <: CompressionCodec] = null): Unit = {
val nullValue = parameters.getOrElse("nullValue", "null")
val rootTag = parameters.getOrElse("rootTag", XmlFile.DEFAULT_ROOT_TAG)
val rowTag = parameters.getOrElse("rowTag", XmlFile.DEFAULT_ROW_TAG)
val attributePrefix =
parameters.getOrElse("attributePrefix", XmlFile.DEFAULT_ATTRIBUTE_PREFIX)
val valueTag = parameters.getOrElse("valueTag", XmlFile.DEFAULT_VALUE_TAG)
val startElement = s"<$rootTag>"
val endElement = s"</$rootTag>"
val options = XmlOptions.createFromConfigMap(parameters.toMap)
val startElement = s"<${options.rootTag}>"
val endElement = s"</${options.rootTag}>"
val rowSchema = dataFrame.schema
val indent = XmlFile.DEFAULT_INDENT
val rowSeparator = XmlFile.DEFAULT_ROW_SEPARATOR
Expand All @@ -109,11 +107,8 @@ package object xml {
val xml = {
StaxXmlGenerator(
rowSchema,
rowTag,
indentingXmlWriter,
nullValue,
attributePrefix,
valueTag)(iter.next())
options)(iter.next())
writer.toString
}
writer.reset()
Expand Down
Loading

0 comments on commit ff0d067

Please sign in to comment.