Skip to content

Commit

Permalink
[fix][function] Use the schema set by the Function when it returns a …
Browse files Browse the repository at this point in the history
…Record (apache#17142)
  • Loading branch information
cbornet authored Aug 19, 2022
1 parent 23c2efd commit ee00400
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ protected AbstractSinkRecord(Record<?> sourceRecord) {

public abstract boolean shouldAlwaysSetMessageProperties();

public abstract boolean shouldSetSchema();

public Record<?> getSourceRecord() {
return sourceRecord;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@EqualsAndHashCode(callSuper = true)
@ToString
class OutputRecordSinkRecord<T> extends AbstractSinkRecord<T> {
public class OutputRecordSinkRecord<T> extends AbstractSinkRecord<T> {

private final Record<T> sinkRecord;

Expand Down Expand Up @@ -91,4 +91,9 @@ public Optional<Message<T>> getMessage() {
public boolean shouldAlwaysSetMessageProperties() {
return true;
}

@Override
public boolean shouldSetSchema() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.pulsar.client.api.Message;
import org.apache.pulsar.client.api.Schema;
import org.apache.pulsar.functions.api.Record;
import org.apache.pulsar.functions.source.PulsarRecord;

@EqualsAndHashCode(callSuper = true)
@ToString
Expand Down Expand Up @@ -92,4 +93,9 @@ public Optional<Message<T>> getMessage() {
public boolean shouldAlwaysSetMessageProperties() {
return false;
}

@Override
public boolean shouldSetSchema() {
return !(sourceRecord instanceof PulsarRecord);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ public PulsarSinkAtMostOnceProcessor(Schema schema, Crypto crypto) {
@Override
public TypedMessageBuilder<T> newMessage(AbstractSinkRecord<T> record) {
Schema<T> schemaToWrite = record.getSchema();
if (record.getSourceRecord() instanceof PulsarRecord) {
if (!record.shouldSetSchema()) {
// we are receiving data directly from another Pulsar topic
// and the Function return type is not a Record
// we must use the destination topic schema
schemaToWrite = schema;
}
Expand Down Expand Up @@ -304,8 +305,9 @@ public TypedMessageBuilder<T> newMessage(AbstractSinkRecord<T> record) {
"PartitionId needs to be specified for every record while in Effectively-once mode");
}
Schema<T> schemaToWrite = record.getSchema();
if (record.getSourceRecord() instanceof PulsarRecord) {
if (!record.shouldSetSchema()) {
// we are receiving data directly from another Pulsar topic
// and the Function return type is not a Record
// we must use the destination topic schema
schemaToWrite = schema;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pulsar.tests.integration.functions;

import java.io.ByteArrayOutputStream;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.io.BinaryEncoder;
import org.apache.avro.io.EncoderFactory;
import org.apache.pulsar.client.api.Schema;
import org.apache.pulsar.client.api.schema.GenericObject;
import org.apache.pulsar.client.api.schema.GenericRecord;
import org.apache.pulsar.client.api.schema.KeyValueSchema;
import org.apache.pulsar.common.schema.KeyValue;
import org.apache.pulsar.common.schema.SchemaType;
import org.apache.pulsar.functions.api.Context;
import org.apache.pulsar.functions.api.Function;
import org.apache.pulsar.functions.api.Record;

/**
* This function removes a "field" from a AVRO message.
*/
@Slf4j
public class RemoveAvroFieldRecordFunction implements Function<GenericObject, Record<GenericObject>> {

private static final String FIELD_TO_REMOVE = "age";

@Override
public Record<GenericObject> process(GenericObject genericObject, Context context) throws Exception {
Record<?> currentRecord = context.getCurrentRecord();
log.info("apply to {} {}", genericObject, genericObject.getNativeObject());
log.info("record with schema {} version {} {}", currentRecord.getSchema(),
currentRecord.getMessage().get().getSchemaVersion(),
currentRecord);
Object nativeObject = genericObject.getNativeObject();
Schema<?> schema = currentRecord.getSchema();

Schema outputSchema = schema;
Object outputObject = genericObject.getNativeObject();
boolean someThingDone = false;
if (schema instanceof KeyValueSchema && nativeObject instanceof KeyValue) {
KeyValueSchema kvSchema = (KeyValueSchema) schema;

Schema keySchema = kvSchema.getKeySchema();
Schema valueSchema = kvSchema.getValueSchema();
// remove a column "age" from the "valueSchema"
if (valueSchema.getSchemaInfo().getType() == SchemaType.AVRO) {

org.apache.avro.Schema avroSchema = (org.apache.avro.Schema) valueSchema.getNativeSchema().get();
if (avroSchema.getField(FIELD_TO_REMOVE) != null) {
org.apache.avro.Schema.Parser parser = new org.apache.avro.Schema.Parser();
org.apache.avro.Schema originalAvroSchema = parser.parse(avroSchema.toString(false));
org.apache.avro.Schema modified = org.apache.avro.Schema.createRecord(
originalAvroSchema.getName(), originalAvroSchema.getDoc(), originalAvroSchema.getNamespace(),
originalAvroSchema.isError(),
originalAvroSchema.getFields().
stream()
.filter(f -> !f.name().equals(FIELD_TO_REMOVE))
.map(f -> new org.apache.avro.Schema.Field(f.name(), f.schema(), f.doc(), f.defaultVal(),
f.order()))
.collect(Collectors.toList()));

KeyValue originalObject = (KeyValue) nativeObject;

GenericRecord value = (GenericRecord) originalObject.getValue();
org.apache.avro.generic.GenericRecord genericRecord =
(org.apache.avro.generic.GenericRecord) value.getNativeObject();

org.apache.avro.generic.GenericRecord newRecord = new GenericData.Record(modified);
for (org.apache.avro.Schema.Field field : modified.getFields()) {
newRecord.put(field.name(), genericRecord.get(field.name()));
}
GenericDatumWriter writer = new GenericDatumWriter(modified);
ByteArrayOutputStream oo = new ByteArrayOutputStream();
BinaryEncoder encoder = EncoderFactory.get().directBinaryEncoder(oo, null);
writer.write(newRecord, encoder);
Object newValue = oo.toByteArray();

Schema newValueSchema = Schema.NATIVE_AVRO(modified);
outputSchema = Schema.KeyValue(keySchema, newValueSchema, kvSchema.getKeyValueEncodingType());
outputObject = new KeyValue(originalObject.getKey(), newValue);
someThingDone = true;
}
}
} else if (schema.getSchemaInfo().getType() == SchemaType.AVRO) {
org.apache.avro.Schema avroSchema = (org.apache.avro.Schema) schema.getNativeSchema().get();
if (avroSchema.getField(FIELD_TO_REMOVE) != null) {
org.apache.avro.Schema.Parser parser = new org.apache.avro.Schema.Parser();
org.apache.avro.Schema originalAvroSchema = parser.parse(avroSchema.toString(false));
org.apache.avro.Schema modified = org.apache.avro.Schema.createRecord(
originalAvroSchema.getName(), originalAvroSchema.getDoc(), originalAvroSchema.getNamespace(),
originalAvroSchema.isError(),
originalAvroSchema.getFields().
stream()
.filter(f -> !f.name().equals(FIELD_TO_REMOVE))
.map(f -> new org.apache.avro.Schema.Field(f.name(), f.schema(), f.doc(), f.defaultVal(),
f.order()))
.collect(Collectors.toList()));

org.apache.avro.generic.GenericRecord genericRecord =
(org.apache.avro.generic.GenericRecord) nativeObject;
org.apache.avro.generic.GenericRecord newRecord = new GenericData.Record(modified);
for (org.apache.avro.Schema.Field field : modified.getFields()) {
newRecord.put(field.name(), genericRecord.get(field.name()));
}
GenericDatumWriter writer = new GenericDatumWriter(modified);
ByteArrayOutputStream oo = new ByteArrayOutputStream();
BinaryEncoder encoder = EncoderFactory.get().directBinaryEncoder(oo, null);
writer.write(newRecord, encoder);

Schema newValueSchema = Schema.NATIVE_AVRO(modified);
outputSchema = newValueSchema;
outputObject = oo.toByteArray();
someThingDone = true;
}
}

if (!someThingDone) {
// do some processing...
final boolean isStruct;
switch (currentRecord.getSchema().getSchemaInfo().getType()) {
case AVRO:
case JSON:
case PROTOBUF_NATIVE:
isStruct = true;
break;
default:
isStruct = false;
break;
}
if (isStruct) {
// GenericRecord must stay wrapped
outputObject = currentRecord.getValue();
} else {
// primitives and KeyValue must be unwrapped
outputObject = nativeObject;
}
}
log.info("output {} schema {}", outputObject, outputSchema);

return context.newOutputRecordBuilder()
.schema(outputSchema)
.value(outputObject)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ public abstract class PulsarFunctionsTestBase extends PulsarTestSuite {
public static final String REMOVE_AVRO_FIELD_FUNCTION_JAVA_CLASS =
"org.apache.pulsar.tests.integration.functions.RemoveAvroFieldFunction";

public static final String REMOVE_AVRO_FIELD_RECORD_FUNCTION_JAVA_CLASS =
"org.apache.pulsar.tests.integration.functions.RemoveAvroFieldRecordFunction";

public static final String SERDE_JAVA_CLASS =
"org.apache.pulsar.functions.api.examples.CustomBaseToBaseFunction";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,30 @@ public void testGenericObjectFunction() throws Exception {
}

@Test(groups = {"java_function", "function"})
public void testGenericObjectRemoveFiledFunction() throws Exception {
public void testGenericObjectRemoveFieldFunction() throws Exception {
testGenericObjectFunction(REMOVE_AVRO_FIELD_FUNCTION_JAVA_CLASS, true, false);
}

@Test(groups = {"java_function", "function"})
public void testGenericObjectRemoveFieldRecordFunction() throws Exception {
testGenericObjectFunction(REMOVE_AVRO_FIELD_RECORD_FUNCTION_JAVA_CLASS, true, false);
}

@Test(groups = {"java_function", "function"})
public void testGenericObjectFunctionKeyValue() throws Exception {
testGenericObjectFunction(GENERIC_OBJECT_FUNCTION_JAVA_CLASS, false, true);
}

@Test(groups = {"java_function", "function"})
public void testGenericObjectRemoveFiledFunctionKeyValue() throws Exception {
public void testGenericObjectRemoveFieldFunctionKeyValue() throws Exception {
testGenericObjectFunction(REMOVE_AVRO_FIELD_FUNCTION_JAVA_CLASS, true, true);
}

@Test(groups = {"java_function", "function"})
public void testGenericObjectRemoveFieldRecordFunctionKeyValue() throws Exception {
testGenericObjectFunction(REMOVE_AVRO_FIELD_RECORD_FUNCTION_JAVA_CLASS, true, true);
}

@Test(groups = {"java_function", "function"})
public void testRecordFunctionTest() throws Exception {
testRecordFunction();
Expand Down

0 comments on commit ee00400

Please sign in to comment.