Skip to content

Commit

Permalink
[SPARK-8432] [SQL] fix hashCode() and equals() of BinaryType in Row
Browse files Browse the repository at this point in the history
Also added more tests in LiteralExpressionSuite

Author: Davies Liu <[email protected]>

Closes apache#6876 from davies/fix_hashcode and squashes the following commits:

429c2c0 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_hashcode
32d9811 [Davies Liu] fix test
a0626ed [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_hashcode
89c2432 [Davies Liu] fix style
bd20780 [Davies Liu] check with catalyst types
41caec6 [Davies Liu] change for to while
d96929b [Davies Liu] address comment
6ad2a90 [Davies Liu] fix style
5819d33 [Davies Liu] unify equals() and hashCode()
0fff25d [Davies Liu] fix style
53c38b1 [Davies Liu] fix hashCode() and equals() of BinaryType in Row
  • Loading branch information
Davies Liu authored and marmbrus committed Jun 23, 2015
1 parent 7b1450b commit 6f4cadf
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 135 deletions.
21 changes: 0 additions & 21 deletions sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
Original file line number Diff line number Diff line change
Expand Up @@ -155,27 +155,6 @@ public int fieldIndex(String name) {
throw new UnsupportedOperationException();
}

/**
* A generic version of Row.equals(Row), which is used for tests.
*/
@Override
public boolean equals(Object other) {
if (other instanceof Row) {
Row row = (Row) other;
int n = size();
if (n != row.size()) {
return false;
}
for (int i = 0; i < n; i ++) {
if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) {
return false;
}
}
return true;
}
return false;
}

@Override
public InternalRow copy() {
final int n = size();
Expand Down
32 changes: 0 additions & 32 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql

import scala.util.hashing.MurmurHash3

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -365,36 +363,6 @@ trait Row extends Serializable {
false
}

override def equals(that: Any): Boolean = that match {
case null => false
case that: Row =>
if (this.length != that.length) {
return false
}
var i = 0
val len = this.length
while (i < len) {
if (apply(i) != that.apply(i)) {
return false
}
i += 1
}
true
case _ => false
}

override def hashCode: Int = {
// Using Scala's Seq hash code implementation.
var n = 0
var h = MurmurHash3.seqSeed
val len = length
while (n < len) {
h = MurmurHash3.mix(h, apply(n).##)
n += 1
}
MurmurHash3.finalizeHash(h, n)
}

/* ---------------------- utility methods for Scala ---------------------- */

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,78 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.catalyst.expressions._

/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
* internal types.
*/
abstract class InternalRow extends Row {
// A default implementation to change the return type
override def copy(): InternalRow = {this}
override def copy(): InternalRow = this

override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[Row]) {
return false
}

val other = o.asInstanceOf[Row]
if (length != other.length) {
return false
}

var i = 0
while (i < length) {
if (isNullAt(i) != other.isNullAt(i)) {
return false
}
if (!isNullAt(i)) {
val o1 = apply(i)
val o2 = other.apply(i)
if (o1.isInstanceOf[Array[Byte]]) {
// handle equality of Array[Byte]
val b1 = o1.asInstanceOf[Array[Byte]]
if (!o2.isInstanceOf[Array[Byte]] ||
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
return false
}
} else if (o1 != o2) {
return false
}
}
i += 1
}
true
}

// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37
var i = 0
while (i < length) {
val update: Int =
if (isNullAt(i)) {
0
} else {
apply(i) match {
case b: Boolean => if (b) 0 else 1
case b: Byte => b.toInt
case s: Short => s.toInt
case i: Int => i
case l: Long => (l ^ (l >>> 32)).toInt
case f: Float => java.lang.Float.floatToIntBits(f)
case d: Double =>
val b = java.lang.Double.doubleToLongBits(d)
(b ^ (b >>> 32)).toInt
case a: Array[Byte] => java.util.Arrays.hashCode(a)
case other => other.hashCode()
}
}
result = 37 * result + update
i += 1
}
result
}
}

object InternalRow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
case FloatType => s"Float.floatToIntBits($col)"
case DoubleType =>
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
case BinaryType => s"java.util.Arrays.hashCode($col)"
case _ => s"$col.hashCode()"
}
s"isNullAt($i) ? 0 : ($nonNull)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,58 +121,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow {
}
}

// TODO(davies): add getDate and getDecimal

// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37

var i = 0
while (i < values.length) {
val update: Int =
if (isNullAt(i)) {
0
} else {
apply(i) match {
case b: Boolean => if (b) 0 else 1
case b: Byte => b.toInt
case s: Short => s.toInt
case i: Int => i
case l: Long => (l ^ (l >>> 32)).toInt
case f: Float => java.lang.Float.floatToIntBits(f)
case d: Double =>
val b = java.lang.Double.doubleToLongBits(d)
(b ^ (b >>> 32)).toInt
case other => other.hashCode()
}
}
result = 37 * result + update
i += 1
}
result
}

override def equals(o: Any): Boolean = o match {
case other: InternalRow =>
if (values.length != other.length) {
return false
}

var i = 0
while (i < values.length) {
if (isNullAt(i) != other.isNullAt(i)) {
return false
}
if (apply(i) != other.apply(i)) {
return false
}
i += 1
}
true

case _ => false
}

override def copy(): InternalRow = this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,23 @@ trait ExpressionEvalHelper {

protected def checkEvaluation(
expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
checkEvaluationWithoutCodegen(expression, expected, inputRow)
checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow)
checkEvaluationWithGeneratedProjection(expression, expected, inputRow)
checkEvaluationWithOptimization(expression, expected, inputRow)
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow)
checkEvaluationWithOptimization(expression, catalystValue, inputRow)
}

/**
* Check the equality between result of expression and expected value, it will handle
* Array[Byte].
*/
protected def checkResult(result: Any, expected: Any): Boolean = {
(result, expected) match {
case (result: Array[Byte], expected: Array[Byte]) =>
java.util.Arrays.equals(result, expected)
case _ => result == expected
}
}

protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
Expand All @@ -55,7 +68,7 @@ trait ExpressionEvalHelper {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
if (actual != expected) {
if (!checkResult(actual, expected)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (codegen off): $expression, " +
s"actual: $actual, " +
Expand Down Expand Up @@ -83,7 +96,7 @@ trait ExpressionEvalHelper {
}

val actual = plan(inputRow).apply(0)
if (actual != expected) {
if (!checkResult(actual, expected)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
Expand All @@ -109,7 +122,7 @@ trait ExpressionEvalHelper {
}

val actual = plan(inputRow)
val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
val expectedRow = new GenericRow(Array[Any](expected))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,79 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types._


class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

// TODO: Add tests for all data types.
test("null") {
checkEvaluation(Literal.create(null, BooleanType), null)
checkEvaluation(Literal.create(null, ByteType), null)
checkEvaluation(Literal.create(null, ShortType), null)
checkEvaluation(Literal.create(null, IntegerType), null)
checkEvaluation(Literal.create(null, LongType), null)
checkEvaluation(Literal.create(null, FloatType), null)
checkEvaluation(Literal.create(null, LongType), null)
checkEvaluation(Literal.create(null, StringType), null)
checkEvaluation(Literal.create(null, BinaryType), null)
checkEvaluation(Literal.create(null, DecimalType()), null)
checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
}

test("boolean literals") {
checkEvaluation(Literal(true), true)
checkEvaluation(Literal(false), false)
}

test("int literals") {
checkEvaluation(Literal(1), 1)
checkEvaluation(Literal(0L), 0L)
List(0, 1, Int.MinValue, Int.MaxValue).foreach { d =>
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toLong), d.toLong)
checkEvaluation(Literal(d.toShort), d.toShort)
checkEvaluation(Literal(d.toByte), d.toByte)
}
checkEvaluation(Literal(Long.MinValue), Long.MinValue)
checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)
}

test("double literals") {
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach {
d => {
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toFloat), d.toFloat)
}
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toFloat), d.toFloat)
}
checkEvaluation(Literal(Double.MinValue), Double.MinValue)
checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
checkEvaluation(Literal(Float.MinValue), Float.MinValue)
checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)

}

test("string literals") {
checkEvaluation(Literal(""), "")
checkEvaluation(Literal("test"), "test")
checkEvaluation(Literal.create(null, StringType), null)
checkEvaluation(Literal("\0"), "\0")
}

test("sum two literals") {
checkEvaluation(Add(Literal(1), Literal(1)), 2)
}

test("binary literals") {
checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))
}

test("decimal") {
List(0.0, 1.2, 1.1111, 5).foreach { d =>
checkEvaluation(Literal(Decimal(d)), Decimal(d))
checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt))
checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong))
checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)),
Decimal((d * 1000L).toLong, 10, 1))
}
}

// TODO(davies): add tests for ArrayType, MapType and StructType
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringLength(regEx), 5, create_row("abdef"))
checkEvaluation(StringLength(regEx), 0, create_row(""))
checkEvaluation(StringLength(regEx), null, create_row(null))
// TODO currently bug in codegen, let's temporally disable this
// checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}


}
Loading

0 comments on commit 6f4cadf

Please sign in to comment.