Skip to content

Commit

Permalink
Support to specify a shorten name for compression codec
Browse files Browse the repository at this point in the history
databricks#66
This PR adds the support for shorten names for compression codecs and added a `CompressionCodecs` class instead of the implicit function as its use is nor recommended.

Author: hyukjinkwon <[email protected]>

Closes databricks#67 from HyukjinKwon/ISSUE-66-shorten-names.
  • Loading branch information
HyukjinKwon committed Feb 1, 2016
1 parent 27e7714 commit c0758fa
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/main/scala/com/databricks/spark/xml/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import com.databricks.spark.xml.util.CompressionCodecs
import com.databricks.spark.xml.util.XmlFile

/**
Expand Down Expand Up @@ -89,7 +90,8 @@ class DefaultSource
}
if (doSave) {
// Only save data when the save mode is not ignore.
val codecClass = compressionCodecClass(XmlOptions(parameters).codec)
val codecClass =
CompressionCodecs.getCodecClass(XmlOptions(parameters).codec)
data.saveAsXmlFile(filesystemPath.toString, parameters, codecClass)
}
createRelation(sqlContext, parameters, data.schema)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/com/databricks/spark/xml/XmlRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.sources.{PrunedScan, InsertableRelation, BaseRelation, TableScan}
import org.apache.spark.sql.types._
import com.databricks.spark.xml.util.InferSchema
import com.databricks.spark.xml.util.{CompressionCodecs, InferSchema}
import com.databricks.spark.xml.parsers.StaxXmlParser

case class XmlRelation protected[spark] (
Expand Down Expand Up @@ -106,7 +106,7 @@ case class XmlRelation protected[spark] (
+ s" to INSERT OVERWRITE a XML table:\n${e.toString}")
}
// Write the data. We assume that schema isn't changed, and we won't update it.
val codecClass = compressionCodecClass(options.codec)
val codecClass = CompressionCodecs.getCodecClass(options.codec)
data.saveAsXmlFile(filesystemPath.toString, parameters, codecClass)
} else {
sys.error("XML tables only support INSERT OVERWRITE for now.")
Expand Down
10 changes: 0 additions & 10 deletions src/main/scala/com/databricks/spark/xml/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,6 @@ import com.databricks.spark.xml.util.XmlFile
import com.databricks.spark.xml.parsers.StaxXmlGenerator

package object xml {
private[xml] def compressionCodecClass(className: String): Class[_ <: CompressionCodec] = {
className match {
case null => null
case codec =>
// scalastyle:off classforname
Class.forName(codec).asInstanceOf[Class[CompressionCodec]]
// scalastyle:on classforname
}
}

/**
* Adds a method, `xmlFile`, to [[SQLContext]] that allows reading XML data.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.xml.util

import scala.util.control.Exception._

import org.apache.hadoop.io.compress._

private[xml] object CompressionCodecs {
private val shortCompressionCodecNames: Map[String, String] = {
val codecMap = collection.mutable.Map.empty[String, String]
allCatch toTry(codecMap += "bzip2" -> classOf[BZip2Codec].getName)
allCatch toTry(codecMap += "gzip" -> classOf[GzipCodec].getName)
allCatch toTry(codecMap += "lz4" -> classOf[Lz4Codec].getName)
allCatch toTry(codecMap += "snappy" -> classOf[SnappyCodec].getName)
codecMap.toMap
}

/**
* Return the codec class of the given name.
*/
def getCodecClass: String => Class[_ <: CompressionCodec] = {
case null => null
case codec =>
val codecName = shortCompressionCodecNames.getOrElse(codec.toLowerCase, codec)
try {
// scalastyle:off classforname
Class.forName(codecName).asInstanceOf[Class[CompressionCodec]]
// scalastyle:on classforname
} catch {
case e: ClassNotFoundException =>
throw new IllegalArgumentException(s"Codec [$codecName] is not " +
s"available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.")
}
}
}
19 changes: 19 additions & 0 deletions src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,25 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet)
}

test("DSL save with gzip compression codec by shorten name") {
// Create temp directory
TestUtils.deleteRecursively(new File(tempEmptyDir))
new File(tempEmptyDir).mkdirs()
val copyFilePath = tempEmptyDir + "cars-copy.xml"

val cars = sqlContext.xmlFile(carsFile)
cars.save("com.databricks.spark.xml", SaveMode.Overwrite,
Map("path" -> copyFilePath, "codec" -> "gZiP"))
val carsCopyPartFile = new File(copyFilePath, "part-00000.gz")
// Check that the part file has a .gz extension
assert(carsCopyPartFile.exists())

val carsCopy = sqlContext.xmlFile(copyFilePath)

assert(carsCopy.count == cars.count)
assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet)
}

test("DSL save") {
// Create temp directory
TestUtils.deleteRecursively(new File(tempEmptyDir))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.xml.util

import org.apache.hadoop.io.compress._
import org.apache.hadoop.util.VersionInfo
import org.scalatest.FunSuite

class CompressionCodecsSuite extends FunSuite {

test("Get classes of compression codecs") {
assert(CompressionCodecs.getCodecClass(classOf[GzipCodec].getName) == classOf[GzipCodec])
assert(CompressionCodecs.getCodecClass(classOf[BZip2Codec].getName) == classOf[BZip2Codec])
assert(CompressionCodecs.getCodecClass(classOf[SnappyCodec].getName) == classOf[SnappyCodec])
assume(VersionInfo.getVersion.take(1) >= "2",
"Lz4 codec was added from Hadoop 2.x")
val codecClassName = "org.apache.hadoop.io.compress.Lz4Codec"
assert(CompressionCodecs.getCodecClass(codecClassName).getName == codecClassName)
}

test("Get classes of compression codecs with short names") {
assert(CompressionCodecs.getCodecClass("GzIp") == classOf[GzipCodec])
assert(CompressionCodecs.getCodecClass("bZip2") == classOf[BZip2Codec])
assert(CompressionCodecs.getCodecClass("Snappy") == classOf[SnappyCodec])
assume(VersionInfo.getVersion.take(1) >= "2",
"Lz4 codec was added from Hadoop 2.x")
val codecClassName = "org.apache.hadoop.io.compress.Lz4Codec"
assert(CompressionCodecs.getCodecClass("lz4").getName == codecClassName)
}
}

0 comments on commit c0758fa

Please sign in to comment.