Skip to content

Commit

Permalink
Spark: Fix issue when partitioning by UUID (apache#8250)
Browse files Browse the repository at this point in the history
  • Loading branch information
nastra authored May 16, 2024
1 parent 139721f commit bd046f8
Show file tree
Hide file tree
Showing 43 changed files with 216 additions and 59 deletions.
4 changes: 2 additions & 2 deletions data/src/test/java/org/apache/iceberg/RecordWrapperTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ public abstract class RecordWrapperTest {
optional(113, "bytes", Types.BinaryType.get()),
required(114, "dec_9_0", Types.DecimalType.of(9, 0)),
required(115, "dec_11_2", Types.DecimalType.of(11, 2)),
required(116, "dec_38_10", Types.DecimalType.of(38, 10)) // maximum precision
);
required(116, "dec_38_10", Types.DecimalType.of(38, 10)), // maximum precision
optional(117, "uuid", Types.UUIDType.get()));

private static final Types.StructType TIMESTAMP_WITHOUT_ZONE =
Types.StructType.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ public void writePartitionedClusteredDataWriter(Blackhole blackhole) throws IOEx

PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema());
StructType dataSparkType = SparkSchemaUtil.convert(table().schema());
InternalRowWrapper internalRowWrapper = new InternalRowWrapper(dataSparkType);
InternalRowWrapper internalRowWrapper =
new InternalRowWrapper(dataSparkType, table().schema().asStruct());

try (ClusteredDataWriter<InternalRow> closeableWriter = writer) {
for (InternalRow row : rows) {
Expand Down Expand Up @@ -256,7 +257,8 @@ public void writePartitionedFanoutDataWriter(Blackhole blackhole) throws IOExcep

PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema());
StructType dataSparkType = SparkSchemaUtil.convert(table().schema());
InternalRowWrapper internalRowWrapper = new InternalRowWrapper(dataSparkType);
InternalRowWrapper internalRowWrapper =
new InternalRowWrapper(dataSparkType, table().schema().asStruct());

try (FanoutDataWriter<InternalRow> closeableWriter = writer) {
for (InternalRow row : rows) {
Expand Down Expand Up @@ -324,7 +326,8 @@ public void writePartitionedClusteredEqualityDeleteWriter(Blackhole blackhole)

PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema());
StructType deleteSparkType = SparkSchemaUtil.convert(table().schema());
InternalRowWrapper internalRowWrapper = new InternalRowWrapper(deleteSparkType);
InternalRowWrapper internalRowWrapper =
new InternalRowWrapper(deleteSparkType, table().schema().asStruct());

try (ClusteredEqualityDeleteWriter<InternalRow> closeableWriter = writer) {
for (InternalRow row : rows) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ protected class SparkDeleteFilter extends DeleteFilter<InternalRow> {

SparkDeleteFilter(String filePath, List<DeleteFile> deletes, DeleteCounter counter) {
super(filePath, deletes, tableSchema, expectedSchema, counter);
this.asStructLike = new InternalRowWrapper(SparkSchemaUtil.convert(requiredSchema()));
this.asStructLike =
new InternalRowWrapper(
SparkSchemaUtil.convert(requiredSchema()), requiredSchema().asStruct());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
package org.apache.iceberg.spark.source;

import java.nio.ByteBuffer;
import java.util.UUID;
import java.util.function.BiFunction;
import java.util.stream.Stream;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.DataType;
Expand All @@ -40,9 +44,17 @@ class InternalRowWrapper implements StructLike {
private InternalRow row = null;

@SuppressWarnings("unchecked")
InternalRowWrapper(StructType rowType) {
InternalRowWrapper(StructType rowType, Types.StructType icebergSchema) {
this.types = Stream.of(rowType.fields()).map(StructField::dataType).toArray(DataType[]::new);
this.getters = Stream.of(types).map(InternalRowWrapper::getter).toArray(BiFunction[]::new);
Preconditions.checkArgument(
types.length == icebergSchema.fields().size(),
"Invalid length: Spark struct type (%s) != Iceberg struct type (%s)",
types.length,
icebergSchema.fields().size());
this.getters = new BiFunction[types.length];
for (int i = 0; i < types.length; i++) {
getters[i] = getter(icebergSchema.fields().get(i).type(), types[i]);
}
}

InternalRowWrapper wrap(InternalRow internalRow) {
Expand Down Expand Up @@ -71,8 +83,13 @@ public <T> void set(int pos, T value) {
row.update(pos, value);
}

private static BiFunction<InternalRow, Integer, ?> getter(DataType type) {
private static BiFunction<InternalRow, Integer, ?> getter(Type icebergType, DataType type) {
if (type instanceof StringType) {
// Spark represents UUIDs as strings
if (Type.TypeID.UUID == icebergType.typeId()) {
return (row, pos) -> UUID.fromString(row.getUTF8String(pos).toString());
}

return (row, pos) -> row.getUTF8String(pos).toString();
} else if (type instanceof DecimalType) {
DecimalType decimal = (DecimalType) type;
Expand All @@ -82,7 +99,8 @@ public <T> void set(int pos, T value) {
return (row, pos) -> ByteBuffer.wrap(row.getBinary(pos));
} else if (type instanceof StructType) {
StructType structType = (StructType) type;
InternalRowWrapper nestedWrapper = new InternalRowWrapper(structType);
InternalRowWrapper nestedWrapper =
new InternalRowWrapper(structType, icebergType.asStructType());
return (row, pos) -> nestedWrapper.wrap(row.getStruct(pos, structType.size()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public SparkPartitionedFanoutWriter(
StructType sparkSchema) {
super(spec, format, appenderFactory, fileFactory, io, targetFileSize);
this.partitionKey = new PartitionKey(spec, schema);
this.internalRowWrapper = new InternalRowWrapper(sparkSchema);
this.internalRowWrapper = new InternalRowWrapper(sparkSchema, schema.asStruct());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public SparkPartitionedWriter(
StructType sparkSchema) {
super(spec, format, appenderFactory, fileFactory, io, targetFileSize);
this.partitionKey = new PartitionKey(spec, schema);
this.internalRowWrapper = new InternalRowWrapper(sparkSchema);
this.internalRowWrapper = new InternalRowWrapper(sparkSchema, schema.asStruct());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ private abstract static class BaseDeltaWriter implements DeltaWriter<InternalRow

protected InternalRowWrapper initPartitionRowWrapper(Types.StructType partitionType) {
StructType sparkPartitionType = (StructType) SparkSchemaUtil.convert(partitionType);
return new InternalRowWrapper(sparkPartitionType);
return new InternalRowWrapper(sparkPartitionType, partitionType);
}

protected Map<Integer, StructProjection> buildPartitionProjections(
Expand Down Expand Up @@ -645,7 +645,8 @@ private static class PartitionedDeltaWriter extends DeleteAndDataDeltaWriter {

this.dataSpec = table.spec();
this.dataPartitionKey = new PartitionKey(dataSpec, context.dataSchema());
this.internalRowDataWrapper = new InternalRowWrapper(context.dataSparkType());
this.internalRowDataWrapper =
new InternalRowWrapper(context.dataSparkType(), context.dataSchema().asStruct());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ private PartitionedDataWriter(
this.io = io;
this.spec = spec;
this.partitionKey = new PartitionKey(spec, dataSchema);
this.internalRowWrapper = new InternalRowWrapper(dataSparkType);
this.internalRowWrapper = new InternalRowWrapper(dataSparkType, dataSchema.asStruct());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ protected void generateAndValidate(Schema schema, AssertMethod assertMethod) {
Iterable<InternalRow> rowList = RandomData.generateSpark(schema, numRecords, 101L);

InternalRecordWrapper recordWrapper = new InternalRecordWrapper(schema.asStruct());
InternalRowWrapper rowWrapper = new InternalRowWrapper(SparkSchemaUtil.convert(schema));
InternalRowWrapper rowWrapper =
new InternalRowWrapper(SparkSchemaUtil.convert(schema), schema.asStruct());

Iterator<Record> actual = recordList.iterator();
Iterator<InternalRow> expected = rowList.iterator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ protected InternalRow createRow(Integer id, String data) {
protected StructLikeSet expectedRowSet(Iterable<InternalRow> rows) {
StructLikeSet set = StructLikeSet.create(table.schema().asStruct());
for (InternalRow row : rows) {
InternalRowWrapper wrapper = new InternalRowWrapper(sparkType);
InternalRowWrapper wrapper = new InternalRowWrapper(sparkType, table.schema().asStruct());
set.add(wrapper.wrap(row));
}
return set;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ protected StructLikeSet toSet(Iterable<InternalRow> rows) {
StructLikeSet set = StructLikeSet.create(table.schema().asStruct());
StructType sparkType = SparkSchemaUtil.convert(table.schema());
for (InternalRow row : rows) {
InternalRowWrapper wrapper = new InternalRowWrapper(sparkType);
InternalRowWrapper wrapper = new InternalRowWrapper(sparkType, table.schema().asStruct());
set.add(wrapper.wrap(row));
}
return set;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ protected StructLikeSet toSet(Iterable<InternalRow> rows) {
StructLikeSet set = StructLikeSet.create(table.schema().asStruct());
StructType sparkType = SparkSchemaUtil.convert(table.schema());
for (InternalRow row : rows) {
InternalRowWrapper wrapper = new InternalRowWrapper(sparkType);
InternalRowWrapper wrapper = new InternalRowWrapper(sparkType, table.schema().asStruct());
set.add(wrapper.wrap(row));
}
return set;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ protected StructLikeSet toSet(Iterable<InternalRow> rows) {
StructLikeSet set = StructLikeSet.create(table.schema().asStruct());
StructType sparkType = SparkSchemaUtil.convert(table.schema());
for (InternalRow row : rows) {
InternalRowWrapper wrapper = new InternalRowWrapper(sparkType);
InternalRowWrapper wrapper = new InternalRowWrapper(sparkType, table.schema().asStruct());
set.add(wrapper.wrap(row));
}
return set;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ public void testReadEqualityDeleteRows() throws IOException {
new EqualityDeleteRowReader(task, table, null, table.schema(), false)) {
while (reader.next()) {
actualRowSet.add(
new InternalRowWrapper(SparkSchemaUtil.convert(table.schema()))
new InternalRowWrapper(
SparkSchemaUtil.convert(table.schema()), table.schema().asStruct())
.wrap(reader.get().copy()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.io.File;
import java.util.Map;
import java.util.UUID;
import org.apache.iceberg.AssertHelpers;
import org.apache.iceberg.BaseTable;
import org.apache.iceberg.PartitionSpec;
Expand All @@ -33,6 +34,7 @@
import org.apache.iceberg.types.Types.NestedField;
import org.apache.iceberg.types.Types.StructType;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
Expand Down Expand Up @@ -104,6 +106,31 @@ public void testCreateTable() {
table.properties().get(TableProperties.DEFAULT_FILE_FORMAT));
}

@Test
public void testCreateTablePartitionedByUUID() {
Assertions.assertThat(validationCatalog.tableExists(tableIdent)).isFalse();
Schema schema = new Schema(1, Types.NestedField.optional(1, "uuid", Types.UUIDType.get()));
PartitionSpec spec = PartitionSpec.builderFor(schema).bucket("uuid", 16).build();
validationCatalog.createTable(tableIdent, schema, spec);

Table table = validationCatalog.loadTable(tableIdent);
Assertions.assertThat(table).isNotNull();

StructType expectedSchema =
StructType.of(Types.NestedField.optional(1, "uuid", Types.UUIDType.get()));
Assertions.assertThat(table.schema().asStruct()).isEqualTo(expectedSchema);
Assertions.assertThat(table.spec().fields()).hasSize(1);

String uuid = UUID.randomUUID().toString();

sql("INSERT INTO %s VALUES('%s')", tableName, uuid);

Assertions.assertThat(sql("SELECT uuid FROM %s", tableName))
.hasSize(1)
.element(0)
.isEqualTo(row(uuid));
}

@Test
public void testCreateTableInRootNamespace() {
Assume.assumeTrue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ public void writePartitionedClusteredDataWriter(Blackhole blackhole) throws IOEx

PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema());
StructType dataSparkType = SparkSchemaUtil.convert(table().schema());
InternalRowWrapper internalRowWrapper = new InternalRowWrapper(dataSparkType);
InternalRowWrapper internalRowWrapper =
new InternalRowWrapper(dataSparkType, table().schema().asStruct());

try (ClusteredDataWriter<InternalRow> closeableWriter = writer) {
for (InternalRow row : rows) {
Expand Down Expand Up @@ -283,7 +284,8 @@ public void writePartitionedFanoutDataWriter(Blackhole blackhole) throws IOExcep

PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema());
StructType dataSparkType = SparkSchemaUtil.convert(table().schema());
InternalRowWrapper internalRowWrapper = new InternalRowWrapper(dataSparkType);
InternalRowWrapper internalRowWrapper =
new InternalRowWrapper(dataSparkType, table().schema().asStruct());

try (FanoutDataWriter<InternalRow> closeableWriter = writer) {
for (InternalRow row : rows) {
Expand Down Expand Up @@ -351,7 +353,8 @@ public void writePartitionedClusteredEqualityDeleteWriter(Blackhole blackhole)

PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema());
StructType deleteSparkType = SparkSchemaUtil.convert(table().schema());
InternalRowWrapper internalRowWrapper = new InternalRowWrapper(deleteSparkType);
InternalRowWrapper internalRowWrapper =
new InternalRowWrapper(deleteSparkType, table().schema().asStruct());

try (ClusteredEqualityDeleteWriter<InternalRow> closeableWriter = writer) {
for (InternalRow row : rows) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ protected class SparkDeleteFilter extends DeleteFilter<InternalRow> {

SparkDeleteFilter(String filePath, List<DeleteFile> deletes, DeleteCounter counter) {
super(filePath, deletes, tableSchema, expectedSchema, counter);
this.asStructLike = new InternalRowWrapper(SparkSchemaUtil.convert(requiredSchema()));
this.asStructLike =
new InternalRowWrapper(
SparkSchemaUtil.convert(requiredSchema()), requiredSchema().asStruct());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
package org.apache.iceberg.spark.source;

import java.nio.ByteBuffer;
import java.util.UUID;
import java.util.function.BiFunction;
import java.util.stream.Stream;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.DataType;
Expand All @@ -40,9 +44,17 @@ class InternalRowWrapper implements StructLike {
private InternalRow row = null;

@SuppressWarnings("unchecked")
InternalRowWrapper(StructType rowType) {
InternalRowWrapper(StructType rowType, Types.StructType icebergSchema) {
this.types = Stream.of(rowType.fields()).map(StructField::dataType).toArray(DataType[]::new);
this.getters = Stream.of(types).map(InternalRowWrapper::getter).toArray(BiFunction[]::new);
Preconditions.checkArgument(
types.length == icebergSchema.fields().size(),
"Invalid length: Spark struct type (%s) != Iceberg struct type (%s)",
types.length,
icebergSchema.fields().size());
this.getters = new BiFunction[types.length];
for (int i = 0; i < types.length; i++) {
getters[i] = getter(icebergSchema.fields().get(i).type(), types[i]);
}
}

InternalRowWrapper wrap(InternalRow internalRow) {
Expand Down Expand Up @@ -71,8 +83,13 @@ public <T> void set(int pos, T value) {
row.update(pos, value);
}

private static BiFunction<InternalRow, Integer, ?> getter(DataType type) {
private static BiFunction<InternalRow, Integer, ?> getter(Type icebergType, DataType type) {
if (type instanceof StringType) {
// Spark represents UUIDs as strings
if (Type.TypeID.UUID == icebergType.typeId()) {
return (row, pos) -> UUID.fromString(row.getUTF8String(pos).toString());
}

return (row, pos) -> row.getUTF8String(pos).toString();
} else if (type instanceof DecimalType) {
DecimalType decimal = (DecimalType) type;
Expand All @@ -82,7 +99,8 @@ public <T> void set(int pos, T value) {
return (row, pos) -> ByteBuffer.wrap(row.getBinary(pos));
} else if (type instanceof StructType) {
StructType structType = (StructType) type;
InternalRowWrapper nestedWrapper = new InternalRowWrapper(structType);
InternalRowWrapper nestedWrapper =
new InternalRowWrapper(structType, icebergType.asStructType());
return (row, pos) -> nestedWrapper.wrap(row.getStruct(pos, structType.size()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public SparkPartitionedFanoutWriter(
StructType sparkSchema) {
super(spec, format, appenderFactory, fileFactory, io, targetFileSize);
this.partitionKey = new PartitionKey(spec, schema);
this.internalRowWrapper = new InternalRowWrapper(sparkSchema);
this.internalRowWrapper = new InternalRowWrapper(sparkSchema, schema.asStruct());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public SparkPartitionedWriter(
StructType sparkSchema) {
super(spec, format, appenderFactory, fileFactory, io, targetFileSize);
this.partitionKey = new PartitionKey(spec, schema);
this.internalRowWrapper = new InternalRowWrapper(sparkSchema);
this.internalRowWrapper = new InternalRowWrapper(sparkSchema, schema.asStruct());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ private abstract static class BaseDeltaWriter implements DeltaWriter<InternalRow

protected InternalRowWrapper initPartitionRowWrapper(Types.StructType partitionType) {
StructType sparkPartitionType = (StructType) SparkSchemaUtil.convert(partitionType);
return new InternalRowWrapper(sparkPartitionType);
return new InternalRowWrapper(sparkPartitionType, partitionType);
}

protected Map<Integer, StructProjection> buildPartitionProjections(
Expand Down Expand Up @@ -653,7 +653,8 @@ private static class PartitionedDeltaWriter extends DeleteAndDataDeltaWriter {

this.dataSpec = table.spec();
this.dataPartitionKey = new PartitionKey(dataSpec, context.dataSchema());
this.internalRowDataWrapper = new InternalRowWrapper(context.dataSparkType());
this.internalRowDataWrapper =
new InternalRowWrapper(context.dataSparkType(), context.dataSchema().asStruct());
}

@Override
Expand Down
Loading

0 comments on commit bd046f8

Please sign in to comment.