Skip to content

Commit

Permalink
[FLINK-20321][formats] Fix NPE when using Avro/Json/Csv formats to de…
Browse files Browse the repository at this point in the history
…serialize null input (apache#14539)

This closes apache#14539

Co-authored-by: Alex Wang <[email protected]>
  • Loading branch information
xuewang and Alex Wang authored Jan 8, 2021
1 parent af36844 commit 241185a
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;

/**
Expand Down Expand Up @@ -149,6 +150,8 @@ private void testRowDataWriteReadWithSchema(Schema schema) throws Exception {
serializer.open(null);
deserializer.open(null);

assertNull(deserializer.deserialize(null));

RowData oriData = address2RowData(address);
byte[] serialized = serializer.serialize(oriData);
RowData rowData = deserializer.deserialize(serialized);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ Decoder getDecoder() {
}

@Override
public T deserialize(byte[] message) throws IOException {
public T deserialize(@Nullable byte[] message) throws IOException {
if (message == null) {
return null;
}
// read record
checkAvroInitialized();
inputStream.setBuffer(message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import org.apache.avro.generic.GenericRecord;

import javax.annotation.Nullable;

import java.io.IOException;
import java.util.Objects;

Expand Down Expand Up @@ -93,7 +95,10 @@ public void open(InitializationContext context) throws Exception {
}

@Override
public RowData deserialize(byte[] message) throws IOException {
public RowData deserialize(@Nullable byte[] message) throws IOException {
if (message == null) {
return null;
}
try {
GenericRecord deserialize = nestedSchema.deserialize(message);
return (RowData) runtimeConverter.convert(deserialize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ public RegistryAvroDeserializationSchema(
}

@Override
public T deserialize(byte[] message) throws IOException {
public T deserialize(@Nullable byte[] message) throws IOException {
if (message == null) {
return null;
}
checkAvroInitialized();
getInputStream().setBuffer(message);
Schema writerSchema = schemaCoder.readSchema(getInputStream());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,22 @@

import static org.apache.flink.formats.avro.utils.AvroTestUtils.writeRecord;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

/** Tests for {@link AvroDeserializationSchema}. */
public class AvroDeserializationSchemaTest {

private static final Address address = TestDataGenerator.generateRandomAddress(new Random());

@Test
public void testNullRecord() throws Exception {
DeserializationSchema<Address> deserializer =
AvroDeserializationSchema.forSpecific(Address.class);

Address deserializedAddress = deserializer.deserialize(null);
assertNull(deserializedAddress);
}

@Test
public void testGenericRecord() throws Exception {
DeserializationSchema<GenericRecord> deserializationSchema =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,20 @@
import static org.apache.flink.table.api.DataTypes.TIMESTAMP;
import static org.apache.flink.table.api.DataTypes.TINYINT;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertNull;

/** Test for the Avro serialization and deserialization schema. */
public class AvroRowDataDeSerializationSchemaTest {

@Test
public void testDeserializeNullRow() throws Exception {
final DataType dataType = ROW(FIELD("bool", BOOLEAN())).nullable();
AvroRowDataDeserializationSchema deserializationSchema =
createDeserializationSchema(dataType);

assertNull(deserializationSchema.deserialize(null));
}

@Test
public void testSerializeDeserialize() throws Exception {
final DataType dataType =
Expand All @@ -97,7 +107,6 @@ public void testSerializeDeserialize() throws Exception {
FIELD("nullEntryMap", MAP(STRING(), STRING())))
.notNull();
final RowType rowType = (RowType) dataType.getLogicalType();
final TypeInformation<RowData> typeInfo = InternalTypeInfo.of(rowType);

final Schema schema = AvroSchemaConverter.convertToSchema(rowType);
final GenericRecord record = new GenericData.Record(schema);
Expand Down Expand Up @@ -148,12 +157,9 @@ public void testSerializeDeserialize() throws Exception {
map2.put("key1", null);
record.put(18, map2);

AvroRowDataSerializationSchema serializationSchema =
new AvroRowDataSerializationSchema(rowType);
serializationSchema.open(null);
AvroRowDataSerializationSchema serializationSchema = createSerializationSchema(dataType);
AvroRowDataDeserializationSchema deserializationSchema =
new AvroRowDataDeserializationSchema(rowType, typeInfo);
deserializationSchema.open(null);
createDeserializationSchema(dataType);

ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
GenericDatumWriter<IndexedRecord> datumWriter = new GenericDatumWriter<>(schema);
Expand Down Expand Up @@ -189,14 +195,9 @@ public void testSpecificType() throws Exception {
FIELD("type_date", DATE().notNull()),
FIELD("type_time_millis", TIME(3).notNull()))
.notNull();
final RowType rowType = (RowType) dataType.getLogicalType();
final TypeInformation<RowData> typeInfo = InternalTypeInfo.of(rowType);
AvroRowDataSerializationSchema serializationSchema =
new AvroRowDataSerializationSchema(rowType);
serializationSchema.open(null);
AvroRowDataSerializationSchema serializationSchema = createSerializationSchema(dataType);
AvroRowDataDeserializationSchema deserializationSchema =
new AvroRowDataDeserializationSchema(rowType, typeInfo);
deserializationSchema.open(null);
createDeserializationSchema(dataType);

RowData rowData = deserializationSchema.deserialize(input);
byte[] output = serializationSchema.serialize(rowData);
Expand All @@ -214,4 +215,25 @@ public void testSpecificType() throws Exception {
.toExternal(rowData.getInt(2))
.toString());
}

private AvroRowDataSerializationSchema createSerializationSchema(DataType dataType)
throws Exception {
final RowType rowType = (RowType) dataType.getLogicalType();

AvroRowDataSerializationSchema serializationSchema =
new AvroRowDataSerializationSchema(rowType);
serializationSchema.open(null);
return serializationSchema;
}

private AvroRowDataDeserializationSchema createDeserializationSchema(DataType dataType)
throws Exception {
final RowType rowType = (RowType) dataType.getLogicalType();
final TypeInformation<RowData> typeInfo = InternalTypeInfo.of(rowType);

AvroRowDataDeserializationSchema deserializationSchema =
new AvroRowDataDeserializationSchema(rowType, typeInfo);
deserializationSchema.open(null);
return deserializationSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.dataformat.csv.CsvMapper;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.dataformat.csv.CsvSchema;

import javax.annotation.Nullable;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
Expand Down Expand Up @@ -140,7 +142,10 @@ public CsvRowDataDeserializationSchema build() {
}

@Override
public RowData deserialize(byte[] message) throws IOException {
public RowData deserialize(@Nullable byte[] message) throws IOException {
if (message == null) {
return null;
}
try {
final JsonNode root = objectReader.readValue(message);
return (RowData) runtimeConverter.convert(root);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ public void testDeserializeUnsupportedNull() throws Exception {
Row.of("Test", null, "Test"), testDeserialization(true, false, "Test,null,Test"));
}

@Test
public void testDeserializeNullRow() throws Exception {
// return null for null input
assertNull(testDeserialization(false, false, null));
}

@Test
public void testDeserializeIncompleteRow() throws Exception {
// last two columns are missing
Expand Down Expand Up @@ -404,7 +410,7 @@ private static RowData deserialize(
InstantiationUtil.deserializeObject(
InstantiationUtil.serializeObject(deserSchemaBuilder.build()),
CsvRowDeSerializationSchemaTest.class.getClassLoader());
return schema.deserialize(csv.getBytes());
return schema.deserialize(csv != null ? csv.getBytes() : null);
}

private static RowData rowData(String str1, int integer, String str2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;

import javax.annotation.Nullable;

import java.io.IOException;
import java.util.Objects;

Expand Down Expand Up @@ -95,7 +97,10 @@ public JsonRowDataDeserializationSchema(
}

@Override
public RowData deserialize(byte[] message) throws IOException {
public RowData deserialize(@Nullable byte[] message) throws IOException {
if (message == null) {
return null;
}
try {
final JsonNode root = objectMapper.readTree(message);
return (RowData) runtimeConverter.convert(root);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ public RowData deserialize(byte[] message) throws IOException {
}

@Override
public void deserialize(byte[] message, Collector<RowData> out) throws IOException {
public void deserialize(@Nullable byte[] message, Collector<RowData> out) throws IOException {
if (message == null || message.length == 0) {
return;
}
try {
GenericRowData row = (GenericRowData) jsonDeserializer.deserialize(message);
if (database != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import static org.apache.flink.table.api.DataTypes.TINYINT;
import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;

/**
Expand Down Expand Up @@ -377,6 +378,18 @@ public void testSerDeMultiRowsWithNullValues() throws Exception {
}
}

@Test
public void testDeserializationNullRow() throws Exception {
DataType dataType = ROW(FIELD("name", STRING()));
RowType schema = (RowType) dataType.getLogicalType();

JsonRowDataDeserializationSchema deserializationSchema =
new JsonRowDataDeserializationSchema(
schema, InternalTypeInfo.of(schema), true, false, TimestampFormat.ISO_8601);

assertNull(deserializationSchema.deserialize(null));
}

@Test
public void testDeserializationMissingNode() throws Exception {
DataType dataType = ROW(FIELD("name", STRING()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ public void testFilteringTables() throws Exception {
runTest(lines, deserializationSchema);
}

@Test
public void testDeserializeNullRow() throws Exception {
final List<ReadableMetadata> requestedMetadata = Arrays.asList(ReadableMetadata.values());
final CanalJsonDeserializationSchema deserializationSchema =
createCanalJsonDeserializationSchema(null, null, requestedMetadata);
final SimpleCollector collector = new SimpleCollector();

deserializationSchema.deserialize(null, collector);
deserializationSchema.deserialize(new byte[0], collector);
assertEquals(0, collector.list.size());
}

@Test
public void testDeserializationWithMetadata() throws Exception {
testDeserializationWithMetadata(
Expand Down Expand Up @@ -251,26 +263,32 @@ private void testDeserializationWithMetadata(
// we only read the first line for keeping the test simple
final String firstLine = readLines(resourceFile).get(0);
final List<ReadableMetadata> requestedMetadata = Arrays.asList(ReadableMetadata.values());
final CanalJsonDeserializationSchema deserializationSchema =
createCanalJsonDeserializationSchema(database, table, requestedMetadata);
final SimpleCollector collector = new SimpleCollector();

deserializationSchema.deserialize(firstLine.getBytes(StandardCharsets.UTF_8), collector);
assertEquals(9, collector.list.size());
testConsumer.accept(collector.list.get(0));
}

private CanalJsonDeserializationSchema createCanalJsonDeserializationSchema(
String database, String table, List<ReadableMetadata> requestedMetadata) {
final DataType producedDataType =
DataTypeUtils.appendRowFields(
PHYSICAL_DATA_TYPE,
requestedMetadata.stream()
.map(m -> DataTypes.FIELD(m.key, m.dataType))
.collect(Collectors.toList()));
final CanalJsonDeserializationSchema deserializationSchema =
CanalJsonDeserializationSchema.builder(
PHYSICAL_DATA_TYPE,
requestedMetadata,
InternalTypeInfo.of(producedDataType.getLogicalType()))
.setDatabase(database)
.setTable(table)
.setIgnoreParseErrors(false)
.setTimestampFormat(TimestampFormat.ISO_8601)
.build();
final SimpleCollector collector = new SimpleCollector();
deserializationSchema.deserialize(firstLine.getBytes(StandardCharsets.UTF_8), collector);
assertEquals(9, collector.list.size());
testConsumer.accept(collector.list.get(0));
return CanalJsonDeserializationSchema.builder(
PHYSICAL_DATA_TYPE,
requestedMetadata,
InternalTypeInfo.of(producedDataType.getLogicalType()))
.setDatabase(database)
.setTable(table)
.setIgnoreParseErrors(false)
.setTimestampFormat(TimestampFormat.ISO_8601)
.build();
}

// --------------------------------------------------------------------------------------------
Expand Down

0 comments on commit 241185a

Please sign in to comment.