Skip to content

Commit

Permalink
ORC: Collect NaN counts in ORC writers (apache#1790)
Browse files Browse the repository at this point in the history
  • Loading branch information
yyanyy authored Feb 3, 2021
1 parent 97703fb commit 8e026f1
Show file tree
Hide file tree
Showing 17 changed files with 291 additions and 40 deletions.
4 changes: 1 addition & 3 deletions core/src/test/java/org/apache/iceberg/TestMetrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,7 @@ protected void assertCounts(int fieldId, Long valueCount, Long nullValueCount, L
Map<Integer, Long> nanValueCounts = metrics.nanValueCounts();
Assert.assertEquals(valueCount, valueCounts.get(fieldId));
Assert.assertEquals(nullValueCount, nullValueCounts.get(fieldId));
if (fileFormat() != FileFormat.ORC) {
Assert.assertEquals(nanValueCount, nanValueCounts.get(fieldId));
}
Assert.assertEquals(nanValueCount, nanValueCounts.get(fieldId));
}

protected <T> void assertBounds(int fieldId, Type type, T lowerBound, T upperBound, Metrics metrics) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
package org.apache.iceberg.data.orc;

import java.util.List;
import java.util.stream.Stream;
import org.apache.iceberg.FieldMetrics;
import org.apache.iceberg.Schema;
import org.apache.iceberg.data.Record;
import org.apache.iceberg.orc.ORCSchemaUtil;
import org.apache.iceberg.orc.OrcRowWriter;
import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor;
import org.apache.iceberg.orc.OrcValueWriter;
Expand Down Expand Up @@ -79,9 +82,9 @@ public OrcValueWriter<?> primitive(Type.PrimitiveType iPrimitive, TypeDescriptio
case LONG:
return GenericOrcWriters.longs();
case FLOAT:
return GenericOrcWriters.floats();
return GenericOrcWriters.floats(ORCSchemaUtil.fieldId(primitive));
case DOUBLE:
return GenericOrcWriters.doubles();
return GenericOrcWriters.doubles(ORCSchemaUtil.fieldId(primitive));
case DATE:
return GenericOrcWriters.dates();
case TIME:
Expand Down Expand Up @@ -125,6 +128,11 @@ public void write(Record value, VectorizedRowBatch output) {
}
}

@Override
public Stream<FieldMetrics> metrics() {
return writer.metrics();
}

private static class RecordWriter implements OrcValueWriter<Record> {
private final List<OrcValueWriter<?>> writers;

Expand All @@ -150,5 +158,10 @@ public void nonNullWrite(int rowId, Record data, ColumnVector output) {
child.write(rowId, data.get(c, child.getJavaClass()), cv.fields[c]);
}
}

@Override
public Stream<FieldMetrics> metrics() {
return writers.stream().flatMap(OrcValueWriter::metrics);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Stream;
import org.apache.iceberg.FieldMetrics;
import org.apache.iceberg.FloatFieldMetrics;
import org.apache.iceberg.orc.OrcValueWriter;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
Expand Down Expand Up @@ -77,12 +80,12 @@ public static OrcValueWriter<Long> longs() {
return LongWriter.INSTANCE;
}

public static OrcValueWriter<Float> floats() {
return FloatWriter.INSTANCE;
public static OrcValueWriter<Float> floats(int id) {
return new FloatWriter(id);
}

public static OrcValueWriter<Double> doubles() {
return DoubleWriter.INSTANCE;
public static OrcValueWriter<Double> doubles(int id) {
return new DoubleWriter(id);
}

public static OrcValueWriter<String> strings() {
Expand Down Expand Up @@ -216,7 +219,13 @@ public void nonNullWrite(int rowId, Long data, ColumnVector output) {
}

private static class FloatWriter implements OrcValueWriter<Float> {
private static final OrcValueWriter<Float> INSTANCE = new FloatWriter();
private final int id;
private long nanCount;

private FloatWriter(int id) {
this.id = id;
this.nanCount = 0;
}

@Override
public Class<Float> getJavaClass() {
Expand All @@ -226,11 +235,25 @@ public Class<Float> getJavaClass() {
@Override
public void nonNullWrite(int rowId, Float data, ColumnVector output) {
((DoubleColumnVector) output).vector[rowId] = data;
if (Float.isNaN(data)) {
nanCount++;
}
}

@Override
public Stream<FieldMetrics> metrics() {
return Stream.of(new FloatFieldMetrics(id, nanCount));
}
}

private static class DoubleWriter implements OrcValueWriter<Double> {
private static final OrcValueWriter<Double> INSTANCE = new DoubleWriter();
private final int id;
private long nanCount;

private DoubleWriter(Integer id) {
this.id = id;
this.nanCount = 0;
}

@Override
public Class<Double> getJavaClass() {
Expand All @@ -240,6 +263,14 @@ public Class<Double> getJavaClass() {
@Override
public void nonNullWrite(int rowId, Double data, ColumnVector output) {
((DoubleColumnVector) output).vector[rowId] = data;
if (Double.isNaN(data)) {
nanCount++;
}
}

@Override
public Stream<FieldMetrics> metrics() {
return Stream.of(new FloatFieldMetrics(id, nanCount));
}
}

Expand Down Expand Up @@ -436,6 +467,11 @@ public void nonNullWrite(int rowId, List<T> value, ColumnVector output) {
element.write((int) (e + cv.offsets[rowId]), value.get(e), cv.child);
}
}

@Override
public Stream<FieldMetrics> metrics() {
return element.metrics();
}
}

private static class MapWriter<K, V> implements OrcValueWriter<Map<K, V>> {
Expand Down Expand Up @@ -475,5 +511,10 @@ public void nonNullWrite(int rowId, Map<K, V> map, ColumnVector output) {
valueWriter.write(pos, values.get(e), cv.values);
}
}

@Override
public Stream<FieldMetrics> metrics() {
return Stream.concat(keyWriter.metrics(), valueWriter.metrics());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public abstract class TestMergingMetrics<T> {

@Parameterized.Parameters(name = "fileFormat = {0}")
public static Object[] parameters() {
return new Object[] {FileFormat.PARQUET };
return new Object[] { FileFormat.PARQUET, FileFormat.ORC };
}

public TestMergingMetrics(FileFormat fileFormat) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
* under the License.
*/

package org.apache.iceberg;
package org.apache.iceberg.parquet;

import java.io.IOException;
import java.util.List;
import org.apache.iceberg.FileFormat;
import org.apache.iceberg.TestMergingMetrics;
import org.apache.iceberg.data.GenericAppenderFactory;
import org.apache.iceberg.data.Record;
import org.apache.iceberg.io.FileAppender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@

package org.apache.iceberg.flink.data;

import java.util.Deque;
import java.util.List;
import java.util.stream.Stream;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.iceberg.FieldMetrics;
import org.apache.iceberg.Schema;
import org.apache.iceberg.data.orc.GenericOrcWriters;
import org.apache.iceberg.orc.OrcRowWriter;
import org.apache.iceberg.orc.OrcValueWriter;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
Expand Down Expand Up @@ -63,10 +67,27 @@ public void write(RowData row, VectorizedRowBatch output) {
}
}

@Override
public Stream<FieldMetrics> metrics() {
return writer.metrics();
}

private static class WriteBuilder extends FlinkSchemaVisitor<OrcValueWriter<?>> {
private final Deque<Integer> fieldIds = Lists.newLinkedList();

private WriteBuilder() {
}

@Override
public void beforeField(Types.NestedField field) {
fieldIds.push(field.fieldId());
}

@Override
public void afterField(Types.NestedField field) {
fieldIds.pop();
}

@Override
public OrcValueWriter<RowData> record(Types.StructType iStruct,
List<OrcValueWriter<?>> results,
Expand Down Expand Up @@ -101,9 +122,15 @@ public OrcValueWriter<?> primitive(Type.PrimitiveType iPrimitive, LogicalType fl
case LONG:
return GenericOrcWriters.longs();
case FLOAT:
return GenericOrcWriters.floats();
Preconditions.checkArgument(fieldIds.peek() != null,
String.format("[BUG] Cannot find field id for primitive field with type %s. This is likely because id " +
"information is not properly pushed during schema visiting.", iPrimitive));
return GenericOrcWriters.floats(fieldIds.peek());
case DOUBLE:
return GenericOrcWriters.doubles();
Preconditions.checkArgument(fieldIds.peek() != null,
String.format("[BUG] Cannot find field id for primitive field with type %s. This is likely because id " +
"information is not properly pushed during schema visiting.", iPrimitive));
return GenericOrcWriters.doubles(fieldIds.peek());
case DATE:
return FlinkOrcWriters.dates();
case TIME:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.util.List;
import java.util.stream.Stream;
import org.apache.flink.table.data.ArrayData;
import org.apache.flink.table.data.DecimalData;
import org.apache.flink.table.data.MapData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.TimestampData;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.iceberg.FieldMetrics;
import org.apache.iceberg.orc.OrcValueWriter;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
Expand Down Expand Up @@ -254,6 +256,12 @@ public void nonNullWrite(int rowId, ArrayData data, ColumnVector output) {
elementWriter.write((int) (e + cv.offsets[rowId]), (T) value, cv.child);
}
}

@Override
public Stream<FieldMetrics> metrics() {
return elementWriter.metrics();
}

}

static class MapWriter<K, V> implements OrcValueWriter<MapData> {
Expand Down Expand Up @@ -296,6 +304,11 @@ public void nonNullWrite(int rowId, MapData data, ColumnVector output) {
valueWriter.write(pos, (V) valueGetter.getElementOrNull(valArray, e), cv.values);
}
}

@Override
public Stream<FieldMetrics> metrics() {
return Stream.concat(keyWriter.metrics(), valueWriter.metrics());
}
}

static class StructWriter implements OrcValueWriter<RowData> {
Expand Down Expand Up @@ -329,5 +342,10 @@ public void nonNullWrite(int rowId, RowData data, ColumnVector output) {
writer.write(rowId, fieldGetters.get(c).getFieldOrNull(data), cv.fields[c]);
}
}

@Override
public Stream<FieldMetrics> metrics() {
return writers.stream().flatMap(OrcValueWriter::metrics);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,39 @@ private static <T> T visit(LogicalType flinkType, Type iType, FlinkSchemaVisitor
case MAP:
MapType mapType = (MapType) flinkType;
Types.MapType iMapType = iType.asMapType();

T key = visit(mapType.getKeyType(), iMapType.keyType(), visitor);
T value = visit(mapType.getValueType(), iMapType.valueType(), visitor);
T key;
T value;

Types.NestedField keyField = iMapType.field(iMapType.keyId());
visitor.beforeMapKey(keyField);
try {
key = visit(mapType.getKeyType(), iMapType.keyType(), visitor);
} finally {
visitor.afterMapKey(keyField);
}

Types.NestedField valueField = iMapType.field(iMapType.valueId());
visitor.beforeMapValue(valueField);
try {
value = visit(mapType.getValueType(), iMapType.valueType(), visitor);
} finally {
visitor.afterMapValue(valueField);
}

return visitor.map(iMapType, key, value, mapType.getKeyType(), mapType.getValueType());

case LIST:
ArrayType listType = (ArrayType) flinkType;
Types.ListType iListType = iType.asListType();
T element;

T element = visit(listType.getElementType(), iListType.elementType(), visitor);
Types.NestedField elementField = iListType.field(iListType.elementId());
visitor.beforeListElement(elementField);
try {
element = visit(listType.getElementType(), iListType.elementType(), visitor);
} finally {
visitor.afterListElement(elementField);
}

return visitor.list(iListType, element, listType.getElementType());

Expand Down Expand Up @@ -82,7 +104,13 @@ private static <T> T visitRecord(LogicalType flinkType, Types.StructType struct,
LogicalType fieldFlinkType = rowType.getTypeAt(fieldIndex);

fieldTypes.add(fieldFlinkType);
results.add(visit(fieldFlinkType, iField.type(), visitor));

visitor.beforeField(iField);
try {
results.add(visit(fieldFlinkType, iField.type(), visitor));
} finally {
visitor.afterField(iField);
}
}

return visitor.record(struct, results, fieldTypes);
Expand All @@ -103,4 +131,34 @@ public T map(Types.MapType iMap, T key, T value, LogicalType keyType, LogicalTyp
public T primitive(Type.PrimitiveType iPrimitive, LogicalType flinkPrimitive) {
return null;
}

public void beforeField(Types.NestedField field) {
}

public void afterField(Types.NestedField field) {
}

public void beforeListElement(Types.NestedField elementField) {
beforeField(elementField);
}

public void afterListElement(Types.NestedField elementField) {
afterField(elementField);
}

public void beforeMapKey(Types.NestedField keyField) {
beforeField(keyField);
}

public void afterMapKey(Types.NestedField keyField) {
afterField(keyField);
}

public void beforeMapValue(Types.NestedField valueField) {
beforeField(valueField);
}

public void afterMapValue(Types.NestedField valueField) {
afterField(valueField);
}
}
Loading

0 comments on commit 8e026f1

Please sign in to comment.