Skip to content

Commit

Permalink
ORC: Fix decimal and timestamp bugs (apache#1271)
Browse files Browse the repository at this point in the history
  • Loading branch information
openinx authored Aug 7, 2020
1 parent d861581 commit 6f96b36
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 19 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ project(':iceberg-spark') {
}
testCompile project(path: ':iceberg-hive', configuration: 'testArtifacts')
testCompile project(path: ':iceberg-api', configuration: 'testArtifacts')
testCompile project(path: ':iceberg-data', configuration: 'testArtifacts')
}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public OrcValueWriter<?> primitive(Type.PrimitiveType iPrimitive, TypeDescriptio
return GenericOrcWriters.byteBuffers();
case DECIMAL:
Types.DecimalType decimalType = (Types.DecimalType) iPrimitive;
return GenericOrcWriters.decimal(decimalType.scale(), decimalType.precision());
return GenericOrcWriters.decimal(decimalType.precision(), decimalType.scale());
default:
throw new IllegalArgumentException(String.format("Invalid iceberg type %s corresponding to ORC type %s",
iPrimitive, primitive));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Map;
import java.util.UUID;
import org.apache.iceberg.orc.OrcValueWriter;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.orc.storage.common.type.HiveDecimal;
import org.apache.orc.storage.ql.exec.vector.BytesColumnVector;
Expand Down Expand Up @@ -103,11 +104,13 @@ public static OrcValueWriter<LocalDateTime> timestamp() {
return TimestampWriter.INSTANCE;
}

public static OrcValueWriter<BigDecimal> decimal(int scale, int precision) {
public static OrcValueWriter<BigDecimal> decimal(int precision, int scale) {
if (precision <= 18) {
return new Decimal18Writer(scale);
return new Decimal18Writer(precision, scale);
} else if (precision <= 38) {
return new Decimal38Writer(precision, scale);
} else {
return Decimal38Writer.INSTANCE;
throw new IllegalArgumentException("Invalid precision: " + precision);
}
}

Expand Down Expand Up @@ -288,8 +291,10 @@ public Class<OffsetDateTime> getJavaClass() {
@Override
public void nonNullWrite(int rowId, OffsetDateTime data, ColumnVector output) {
TimestampColumnVector cv = (TimestampColumnVector) output;
cv.time[rowId] = data.toInstant().toEpochMilli(); // millis
cv.nanos[rowId] = (data.getNano() / 1_000) * 1_000; // truncate nanos to only keep microsecond precision
// millis
cv.time[rowId] = data.toInstant().toEpochMilli();
// truncate nanos to only keep microsecond precision
cv.nanos[rowId] = data.getNano() / 1_000 * 1_000;
}
}

Expand All @@ -311,9 +316,11 @@ public void nonNullWrite(int rowId, LocalDateTime data, ColumnVector output) {
}

private static class Decimal18Writer implements OrcValueWriter<BigDecimal> {
private final int precision;
private final int scale;

Decimal18Writer(int scale) {
Decimal18Writer(int precision, int scale) {
this.precision = precision;
this.scale = scale;
}

Expand All @@ -324,14 +331,24 @@ public Class<BigDecimal> getJavaClass() {

@Override
public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) {
// TODO: validate precision and scale from schema
Preconditions.checkArgument(data.scale() == scale,
"Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, data);
Preconditions.checkArgument(data.precision() <= precision,
"Cannot write value as decimal(%s,%s), invalid precision: %s", precision, scale, data);

((DecimalColumnVector) output).vector[rowId]
.setFromLongAndScale(data.unscaledValue().longValueExact(), scale);
}
}

private static class Decimal38Writer implements OrcValueWriter<BigDecimal> {
private static final OrcValueWriter<BigDecimal> INSTANCE = new Decimal38Writer();
private final int precision;
private final int scale;

Decimal38Writer(int precision, int scale) {
this.precision = precision;
this.scale = scale;
}

@Override
public Class<BigDecimal> getJavaClass() {
Expand All @@ -340,7 +357,11 @@ public Class<BigDecimal> getJavaClass() {

@Override
public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) {
// TODO: validate precision and scale from schema
Preconditions.checkArgument(data.scale() == scale,
"Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, data);
Preconditions.checkArgument(data.precision() <= precision,
"Cannot write value as decimal(%s,%s), invalid precision: %s", precision, scale, data);

((DecimalColumnVector) output).vector[rowId].set(HiveDecimal.create(data, false));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;
import org.apache.iceberg.orc.OrcValueReader;
import org.apache.iceberg.orc.OrcValueReaders;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Types;
import org.apache.orc.storage.ql.exec.vector.BytesColumnVector;
Expand Down Expand Up @@ -57,8 +58,10 @@ public static OrcValueReader<Long> timestampTzs() {
public static OrcValueReader<Decimal> decimals(int precision, int scale) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
return new SparkOrcValueReaders.Decimal18Reader(precision, scale);
} else {
} else if (precision <= 38) {
return new SparkOrcValueReaders.Decimal38Reader(precision, scale);
} else {
throw new IllegalArgumentException("Invalid precision: " + precision);
}
}

Expand Down Expand Up @@ -177,13 +180,12 @@ private TimestampTzReader() {

@Override
public Long nonNullRead(ColumnVector vector, int row) {
TimestampColumnVector timestampVector = (TimestampColumnVector) vector;
return (timestampVector.time[row] / 1000) * 1_000_000 + timestampVector.nanos[row] / 1000;
TimestampColumnVector tcv = (TimestampColumnVector) vector;
return (Math.floorDiv(tcv.time[row], 1_000)) * 1_000_000 + Math.floorDiv(tcv.nanos[row], 1000);
}
}

private static class Decimal18Reader implements OrcValueReader<Decimal> {
//TODO: these are being unused. check for bug
private final int precision;
private final int scale;

Expand All @@ -195,7 +197,15 @@ private static class Decimal18Reader implements OrcValueReader<Decimal> {
@Override
public Decimal nonNullRead(ColumnVector vector, int row) {
HiveDecimalWritable value = ((DecimalColumnVector) vector).vector[row];
return new Decimal().set(value.serialize64(value.scale()), value.precision(), value.scale());

// The scale of decimal read from hive ORC file may be not equals to the expected scale. For data type
// decimal(10,3) and the value 10.100, the hive ORC writer will remove its trailing zero and store it
// as 101*10^(-1), its scale will adjust from 3 to 1. So here we could not assert that value.scale() == scale.
// we also need to convert the hive orc decimal to a decimal with expected precision and scale.
Preconditions.checkArgument(value.precision() <= precision,
"Cannot read value as decimal(%s,%s), too large: %s", precision, scale, value);

return new Decimal().set(value.serialize64(scale), precision, scale);
}
}

Expand All @@ -212,6 +222,10 @@ private static class Decimal38Reader implements OrcValueReader<Decimal> {
public Decimal nonNullRead(ColumnVector vector, int row) {
BigDecimal value = ((DecimalColumnVector) vector).vector[row]
.getHiveDecimal().bigDecimalValue();

Preconditions.checkArgument(value.precision() <= precision,
"Cannot read value as decimal(%s,%s), too large: %s", precision, scale, value);

return new Decimal().set(new scala.math.BigDecimal(value), precision, scale);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ private static class TimestampTzWriter implements SparkOrcValueWriter {
@Override
public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) {
TimestampColumnVector cv = (TimestampColumnVector) output;
long micros = data.getLong(column);
cv.time[rowId] = micros / 1_000; // millis
cv.nanos[rowId] = (int) (micros % 1_000_000) * 1_000; // nanos
long micros = data.getLong(column); // it could be negative.
cv.time[rowId] = Math.floorDiv(micros, 1_000); // millis
cv.nanos[rowId] = (int) (Math.floorMod(micros, 1_000_000)) * 1_000; // nanos
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ private static void assertEqualsUnsafe(Type type, Object expected, Object actual
break;
case DATE:
Assert.assertTrue("Should expect a LocalDate", expected instanceof LocalDate);
long expectedDays = ChronoUnit.DAYS.between(EPOCH_DAY, (LocalDate) expected);
int expectedDays = (int) ChronoUnit.DAYS.between(EPOCH_DAY, (LocalDate) expected);
Assert.assertEquals("Primitive value should be equal to expected", expectedDays, actual);
break;
case TIMESTAMP:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iceberg.spark.data;

import java.io.File;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.Iterator;
import java.util.List;
import org.apache.iceberg.Files;
import org.apache.iceberg.Schema;
import org.apache.iceberg.data.GenericRecord;
import org.apache.iceberg.data.RandomGenericData;
import org.apache.iceberg.data.Record;
import org.apache.iceberg.data.orc.GenericOrcReader;
import org.apache.iceberg.data.orc.GenericOrcWriter;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.FileAppender;
import org.apache.iceberg.orc.ORC;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.catalyst.InternalRow;
import org.junit.Assert;
import org.junit.Test;

import static org.apache.iceberg.types.Types.NestedField.required;

public class TestSparkRecordOrcReaderWriter extends AvroDataTest {
private static final int NUM_RECORDS = 200;

private void writeAndValidate(Schema schema, List<Record> expectedRecords) throws IOException {
final File originalFile = temp.newFile();
Assert.assertTrue("Delete should succeed", originalFile.delete());

// Write few generic records into the original test file.
try (FileAppender<Record> writer = ORC.write(Files.localOutput(originalFile))
.createWriterFunc(GenericOrcWriter::buildWriter)
.schema(schema)
.build()) {
writer.addAll(expectedRecords);
}

// Read into spark InternalRow from the original test file.
List<InternalRow> internalRows = Lists.newArrayList();
try (CloseableIterable<InternalRow> reader = ORC.read(Files.localInput(originalFile))
.project(schema)
.createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema))
.build()) {
reader.forEach(internalRows::add);
assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size());
}

final File anotherFile = temp.newFile();
Assert.assertTrue("Delete should succeed", anotherFile.delete());

// Write those spark InternalRows into a new file again.
try (FileAppender<InternalRow> writer = ORC.write(Files.localOutput(anotherFile))
.createWriterFunc(SparkOrcWriter::new)
.schema(schema)
.build()) {
writer.addAll(internalRows);
}

// Check whether the InternalRows are expected records.
try (CloseableIterable<InternalRow> reader = ORC.read(Files.localInput(anotherFile))
.project(schema)
.createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema))
.build()) {
assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size());
}

// Read into iceberg GenericRecord and check again.
try (CloseableIterable<Record> reader = ORC.read(Files.localInput(anotherFile))
.createReaderFunc(typeDesc -> GenericOrcReader.buildReader(schema, typeDesc))
.project(schema)
.build()) {
assertRecordEquals(expectedRecords, reader, expectedRecords.size());
}
}

@Override
protected void writeAndValidate(Schema schema) throws IOException {
List<Record> expectedRecords = RandomGenericData.generate(schema, NUM_RECORDS, 1992L);
writeAndValidate(schema, expectedRecords);
}

@Test
public void testDecimalWithTrailingZero() throws IOException {
Schema schema = new Schema(
required(1, "d1", Types.DecimalType.of(10, 2)),
required(2, "d2", Types.DecimalType.of(20, 5)),
required(3, "d3", Types.DecimalType.of(38, 20))
);

List<Record> expected = Lists.newArrayList();

GenericRecord record = GenericRecord.create(schema);
record.set(0, new BigDecimal("101.00"));
record.set(1, new BigDecimal("10.00E-3"));
record.set(2, new BigDecimal("1001.0000E-16"));

expected.add(record.copy());

writeAndValidate(schema, expected);
}

private static void assertRecordEquals(Iterable<Record> expected, Iterable<Record> actual, int size) {
Iterator<Record> expectedIter = expected.iterator();
Iterator<Record> actualIter = actual.iterator();
for (int i = 0; i < size; i += 1) {
Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext());
Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext());
Assert.assertEquals("Should have same rows.", expectedIter.next(), actualIter.next());
}
Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext());
Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext());
}

private static void assertEqualsUnsafe(Types.StructType struct, Iterable<Record> expected,
Iterable<InternalRow> actual, int size) {
Iterator<Record> expectedIter = expected.iterator();
Iterator<InternalRow> actualIter = actual.iterator();
for (int i = 0; i < size; i += 1) {
Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext());
Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext());
GenericsHelpers.assertEqualsUnsafe(struct, expectedIter.next(), actualIter.next());
}
Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext());
Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext());
}
}

0 comments on commit 6f96b36

Please sign in to comment.