Skip to content

Commit

Permalink
[SPARK-12446][SQL] Add unit tests for JDBCRDD internal functions
Browse files Browse the repository at this point in the history
No tests done for JDBCRDD#compileFilter.

Author: Takeshi YAMAMURO <[email protected]>

Closes #10409 from maropu/AddTestsInJdbcRdd.
  • Loading branch information
maropu authored and rxin committed Dec 22, 2015
1 parent 969d566 commit 8c1b867
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,37 @@ private[sql] object JDBCRDD extends Logging {
* @return A Catalyst schema corresponding to columns in the given order.
*/
private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*)
new StructType(columns map { name => fieldMap(name) })
val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*)
new StructType(columns.map(name => fieldMap(name)))
}

/**
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
case stringValue: String => s"'${escapeSql(stringValue)}'"
case timestampValue: Timestamp => "'" + timestampValue + "'"
case dateValue: Date => "'" + dateValue + "'"
case _ => value
}

private def escapeSql(value: String): String =
if (value == null) null else StringUtils.replace(value, "'", "''")

/**
* Turns a single Filter into a String representing a SQL expression.
* Returns null for an unhandled filter.
*/
private def compileFilter(f: Filter): String = f match {
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}"
case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}"
case IsNull(attr) => s"$attr IS NULL"
case IsNotNull(attr) => s"$attr IS NOT NULL"
case _ => null
}

/**
Expand Down Expand Up @@ -262,40 +291,12 @@ private[sql] class JDBCRDD(
if (sb.length == 0) "1" else sb.substring(1)
}

/**
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
case stringValue: String => s"'${escapeSql(stringValue)}'"
case timestampValue: Timestamp => "'" + timestampValue + "'"
case dateValue: Date => "'" + dateValue + "'"
case _ => value
}

private def escapeSql(value: String): String =
if (value == null) null else StringUtils.replace(value, "'", "''")

/**
* Turns a single Filter into a String representing a SQL expression.
* Returns null for an unhandled filter.
*/
private def compileFilter(f: Filter): String = f match {
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}"
case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}"
case IsNull(attr) => s"$attr IS NULL"
case IsNotNull(attr) => s"$attr IS NOT NULL"
case _ => null
}

/**
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
private val filterWhereClause: String = {
val filterStrings = filters map compileFilter filter (_ != null)
val filterStrings = filters.map(JDBCRDD.compileFilter).filter(_ != null)
if (filterStrings.size > 0) {
val sb = new StringBuilder("WHERE ")
filterStrings.foreach(x => sb.append(x).append(" AND "))
Expand Down
24 changes: 22 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@
package org.apache.spark.sql.jdbc

import java.math.BigDecimal
import java.sql.DriverManager
import java.sql.{Date, DriverManager, Timestamp}
import java.util.{Calendar, GregorianCalendar, Properties}

import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources._
import org.apache.spark.util.Utils

class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
class JDBCSuite extends SparkFunSuite
with BeforeAndAfter with PrivateMethodTester with SharedSQLContext {
import testImplicits._

val url = "jdbc:h2:mem:testdb0"
Expand Down Expand Up @@ -429,6 +433,22 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
assert(DerbyColumns === Seq(""""abc"""", """"key""""))
}

test("compile filters") {
val compileFilter = PrivateMethod[String]('compileFilter)
def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f)
assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3")
assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "col1 != 'abc'")
assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5")
assert(doCompileFilter(LessThan("col3",
Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 00:00:00.0'")
assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) === "col4 < '1983-08-04'")
assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5")
assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3")
assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3")
assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL")
assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL")
}

test("Dialect unregister") {
JdbcDialects.registerDialect(testH2Dialect)
JdbcDialects.unregisterDialect(testH2Dialect)
Expand Down

0 comments on commit 8c1b867

Please sign in to comment.