Skip to content

Commit

Permalink
Spark 2.4: Fix ClassCastException when using bucket UDF (apache#3570)
Browse files Browse the repository at this point in the history
Port of apache#3368 to Spark 2.4.
  • Loading branch information
wypoon authored Nov 18, 2021
1 parent 4eef02d commit 3b2c32d
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public static void registerBucketUDF(SparkSession session, String funcName, Data
SparkTypeToType typeConverter = new SparkTypeToType();
Type sourceIcebergType = typeConverter.atomic(sourceType);
Transform<Object, Integer> bucket = Transforms.bucket(sourceIcebergType, numBuckets);
session.udf().register(funcName, bucket::apply, DataTypes.IntegerType);
session.udf().register(funcName,
value -> bucket.apply(SparkValueConverter.convert(sourceIcebergType, value)), DataTypes.IntegerType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ public static Object convert(Type type, Object object) {
return DateTimeUtils.fromJavaTimestamp((Timestamp) object);
case BINARY:
return ByteBuffer.wrap((byte[]) object);
case BOOLEAN:
case INTEGER:
return ((Number) object).intValue();
case BOOLEAN:
case LONG:
case FLOAT:
case DOUBLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,22 @@

package org.apache.iceberg.spark.source;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.List;
import org.apache.iceberg.spark.IcebergSpark;
import org.apache.iceberg.transforms.Transforms;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.CharType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.VarcharType;
import org.assertj.core.api.Assertions;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
Expand All @@ -48,23 +57,132 @@ public static void stopSpark() {
}

@Test
public void testRegisterBucketUDF() {
public void testRegisterIntegerBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_int_16", DataTypes.IntegerType, 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_int_16(1)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
results.get(0).getInt(0));
}

@Test
public void testRegisterShortBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_short_16", DataTypes.ShortType, 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_short_16(1S)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
results.get(0).getInt(0));
}

@Test
public void testRegisterByteBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_byte_16", DataTypes.ByteType, 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_byte_16(1Y)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
results.get(0).getInt(0));
}

@Test
public void testRegisterLongBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_long_16", DataTypes.LongType, 16);
List<Row> results2 = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList();
Assert.assertEquals(1, results2.size());
List<Row> results = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.LongType.get(), 16).apply(1L),
results2.get(0).getInt(0));
results.get(0).getInt(0));
}

@Test
public void testRegisterStringBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_string_16", DataTypes.StringType, 16);
List<Row> results3 = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList();
Assert.assertEquals(1, results3.size());
List<Row> results = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
results.get(0).getInt(0));
}

@Test
public void testRegisterCharBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_char_16", new CharType(5), 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_char_16('hello')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
results.get(0).getInt(0));
}

@Test
public void testRegisterVarCharBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_varchar_16", new VarcharType(5), 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_varchar_16('hello')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
results3.get(0).getInt(0));
results.get(0).getInt(0));
}

@Test
public void testRegisterDateBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_date_16", DataTypes.DateType, 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_date_16(DATE '2021-06-30')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.DateType.get(), 16)
.apply(DateTimeUtils.fromJavaDate(Date.valueOf("2021-06-30"))),
results.get(0).getInt(0));
}

@Test
public void testRegisterTimestampBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_timestamp_16", DataTypes.TimestampType, 16);
List<Row> results =
spark.sql("SELECT iceberg_bucket_timestamp_16(TIMESTAMP '2021-06-30 00:00:00.000')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.TimestampType.withZone(), 16)
.apply(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2021-06-30 00:00:00.000"))),
results.get(0).getInt(0));
}

@Test
public void testRegisterBinaryBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_binary_16", DataTypes.BinaryType, 16);
List<Row> results =
spark.sql("SELECT iceberg_bucket_binary_16(X'0020001F')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.BinaryType.get(), 16)
.apply(ByteBuffer.wrap(new byte[]{0x00, 0x20, 0x00, 0x1F})),
results.get(0).getInt(0));
}

@Test
public void testRegisterDecimalBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_decimal_16", new DecimalType(4, 2), 16);
List<Row> results =
spark.sql("SELECT iceberg_bucket_decimal_16(11.11)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.DecimalType.of(4, 2), 16)
.apply(new BigDecimal("11.11")),
results.get(0).getInt(0));
}

@Test
public void testRegisterBooleanBucketUDF() {
Assertions.assertThatThrownBy(() ->
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_boolean_16", DataTypes.BooleanType, 16))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot bucket by type: boolean");
}

@Test
public void testRegisterDoubleBucketUDF() {
Assertions.assertThatThrownBy(() ->
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_double_16", DataTypes.DoubleType, 16))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot bucket by type: double");
}

@Test
public void testRegisterFloatBucketUDF() {
Assertions.assertThatThrownBy(() ->
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_float_16", DataTypes.FloatType, 16))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot bucket by type: float");
}
}

0 comments on commit 3b2c32d

Please sign in to comment.