Skip to content

Commit

Permalink
[SPARK-16675][SQL] Avoid per-record type dispatch in JDBC when writing
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Currently, `JdbcUtils.savePartition` is doing type-based dispatch for each row to write appropriate values.

So, appropriate setters for `PreparedStatement` can be created first according to the schema, and then apply them to each row. This approach is similar with `CatalystWriteSupport`.

This PR simply make the setters to avoid this.

## How was this patch tested?

Existing tests should cover this.

Author: hyukjinkwon <[email protected]>

Closes apache#14323 from HyukjinKwon/SPARK-16675.
  • Loading branch information
HyukjinKwon authored and cloud-fan committed Jul 26, 2016
1 parent 03c2743 commit 3b2b785
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,19 +322,19 @@ private[sql] class JDBCRDD(
}
}

// A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet`
// into a field for `MutableRow`. The last argument `Int` means the index for the
// value to be set in the row and also used for the value to retrieve from `ResultSet`.
private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit
// A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
// for `MutableRow`. The last argument `Int` means the index for the value to be set in
// the row and also used for the value in `ResultSet`.
private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit

/**
* Creates `JDBCValueSetter`s according to [[StructType]], which can set
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[MutableRow]] correctly.
*/
def makeSetters(schema: StructType): Array[JDBCValueSetter] =
schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata))
def makeGetters(schema: StructType): Array[JDBCValueGetter] =
schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))

private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match {
private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
Expand Down Expand Up @@ -489,15 +489,15 @@ private[sql] class JDBCRDD(
stmt.setFetchSize(fetchSize)
val rs = stmt.executeQuery()

val setters: Array[JDBCValueSetter] = makeSetters(schema)
val getters: Array[JDBCValueGetter] = makeGetters(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))

def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
while (i < setters.length) {
setters(i).apply(rs, mutableRow, i)
while (i < getters.length) {
getters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,79 @@ object JdbcUtils extends Logging {
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}

// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. The last argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.
private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

private def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getInt(pos))

case LongType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setLong(pos + 1, row.getLong(pos))

case DoubleType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDouble(pos + 1, row.getDouble(pos))

case FloatType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setFloat(pos + 1, row.getFloat(pos))

case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getShort(pos))

case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getByte(pos))

case BooleanType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBoolean(pos + 1, row.getBoolean(pos))

case StringType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setString(pos + 1, row.getString(pos))

case BinaryType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

case TimestampType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

case DateType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

case t: DecimalType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

case ArrayType(et, _) =>
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition
.toLowerCase.split("\\(")(0)
(stmt: PreparedStatement, row: Row, pos: Int) =>
val array = conn.createArrayOf(
typeName,
row.getSeq[AnyRef](pos).toArray)
stmt.setArray(pos + 1, array)

case _ =>
(_: PreparedStatement, _: Row, pos: Int) =>
throw new IllegalArgumentException(
s"Can't translate non-null value for field $pos")
}

/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction (unless isolation level is "NONE")
Expand Down Expand Up @@ -215,6 +288,9 @@ object JdbcUtils extends Logging {
conn.setTransactionIsolation(finalIsolationLevel)
}
val stmt = insertStatement(conn, table, rddSchema, dialect)
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
.map(makeSetter(conn, dialect, _)).toArray

try {
var rowCount = 0
while (iterator.hasNext) {
Expand All @@ -225,30 +301,7 @@ object JdbcUtils extends Logging {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
rddSchema.fields(i).dataType match {
case IntegerType => stmt.setInt(i + 1, row.getInt(i))
case LongType => stmt.setLong(i + 1, row.getLong(i))
case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
case ShortType => stmt.setInt(i + 1, row.getShort(i))
case ByteType => stmt.setInt(i + 1, row.getByte(i))
case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
case StringType => stmt.setString(i + 1, row.getString(i))
case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case ArrayType(et, _) =>
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition
.toLowerCase.split("\\(")(0)
val array = conn.createArrayOf(
typeName,
row.getSeq[AnyRef](i).toArray)
stmt.setArray(i + 1, array)
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
setters(i).apply(stmt, row, i)
}
i = i + 1
}
Expand Down Expand Up @@ -333,5 +386,4 @@ object JdbcUtils extends Logging {
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
)
}

}

0 comments on commit 3b2b785

Please sign in to comment.