diff --git a/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/instance/SinkRecord.java b/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/instance/SinkRecord.java index 61e9d5378aef4..b922b98858134 100644 --- a/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/instance/SinkRecord.java +++ b/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/instance/SinkRecord.java @@ -29,6 +29,7 @@ import org.apache.pulsar.client.impl.schema.KeyValueSchemaImpl; import org.apache.pulsar.functions.api.KVRecord; import org.apache.pulsar.functions.api.Record; +import org.apache.pulsar.functions.source.PulsarRecord; @Slf4j @Data @@ -82,6 +83,30 @@ public void ack() { sourceRecord.ack(); } + /** + * Some sink sometimes wants to control the ack type. + */ + public void cumulativeAck() { + if (sourceRecord instanceof PulsarRecord) { + PulsarRecord pulsarRecord = (PulsarRecord) sourceRecord; + pulsarRecord.cumulativeAck(); + } else { + throw new RuntimeException("SourceRecord class type must be PulsarRecord"); + } + } + + /** + * Some sink sometimes wants to control the ack type. + */ + public void individualAck() { + if (sourceRecord instanceof PulsarRecord) { + PulsarRecord pulsarRecord = (PulsarRecord) sourceRecord; + pulsarRecord.individualAck(); + } else { + throw new RuntimeException("SourceRecord class type must be PulsarRecord"); + } + } + @Override public void fail() { sourceRecord.fail(); diff --git a/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/source/PulsarRecord.java b/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/source/PulsarRecord.java index e53e7fe9ace80..cd8405e407d8d 100644 --- a/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/source/PulsarRecord.java +++ b/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/source/PulsarRecord.java @@ -20,6 +20,7 @@ import java.util.Map; import java.util.Optional; +import java.util.function.Consumer; import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -44,6 +45,7 @@ public class PulsarRecord implements RecordWithEncryptionContext { private final Runnable failFunction; private final Runnable ackFunction; + private final Consumer customAckFunction; @Override public Optional getKey() { @@ -93,6 +95,20 @@ public Optional getEventTime() { } } + /** + * Some sink sometimes wants to control the ack type. + */ + public void cumulativeAck() { + this.customAckFunction.accept(true); + } + + /** + * Some sink sometimes wants to control the ack type. + */ + public void individualAck() { + this.customAckFunction.accept(false); + } + @Override public Optional getEncryptionCtx() { return message.getEncryptionCtx(); @@ -121,4 +137,5 @@ public void fail() { public Optional> getMessage() { return Optional.of(message); } + } diff --git a/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/source/PulsarSource.java b/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/source/PulsarSource.java index 652c682bbedb0..1fb76459e6024 100644 --- a/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/source/PulsarSource.java +++ b/pulsar-functions/instance/src/main/java/org/apache/pulsar/functions/source/PulsarSource.java @@ -132,18 +132,21 @@ protected Record buildRecord(Consumer consumer, Message message) { .message(message) .schema(schema) .topicName(message.getTopicName()) + .customAckFunction(cumulative -> { + if (cumulative) { + consumer.acknowledgeCumulativeAsync(message) + .whenComplete((unused, throwable) -> message.release()); + } else { + consumer.acknowledgeAsync(message).whenComplete((unused, throwable) -> message.release()); + } + }) .ackFunction(() -> { - try { - if (pulsarSourceConfig - .getProcessingGuarantees() == FunctionConfig.ProcessingGuarantees.EFFECTIVELY_ONCE) { - consumer.acknowledgeCumulativeAsync(message); - } else { - consumer.acknowledgeAsync(message); - } - } finally { - // don't need to check if message pooling is set - // client will automatically check - message.release(); + if (pulsarSourceConfig + .getProcessingGuarantees() == FunctionConfig.ProcessingGuarantees.EFFECTIVELY_ONCE) { + consumer.acknowledgeCumulativeAsync(message) + .whenComplete((unused, throwable) -> message.release()); + } else { + consumer.acknowledgeAsync(message).whenComplete((unused, throwable) -> message.release()); } }).failFunction(() -> { try { diff --git a/pulsar-functions/instance/src/test/java/org/apache/pulsar/functions/instance/SinkRecordTest.java b/pulsar-functions/instance/src/test/java/org/apache/pulsar/functions/instance/SinkRecordTest.java new file mode 100644 index 0000000000000..56624581824c2 --- /dev/null +++ b/pulsar-functions/instance/src/test/java/org/apache/pulsar/functions/instance/SinkRecordTest.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pulsar.functions.instance; + +import org.apache.pulsar.functions.api.Record; +import org.apache.pulsar.functions.source.PulsarRecord; +import org.junit.Assert; +import org.mockito.Mockito; +import org.testng.annotations.Test; + +public class SinkRecordTest { + + @Test + public void testCustomAck() { + + PulsarRecord pulsarRecord = Mockito.mock(PulsarRecord.class); + SinkRecord sinkRecord = new SinkRecord<>(pulsarRecord, new Object()); + + sinkRecord.cumulativeAck(); + Mockito.verify(pulsarRecord, Mockito.times(1)).cumulativeAck(); + + sinkRecord = new SinkRecord(Mockito.mock(Record.class), new Object()); + try { + sinkRecord.individualAck(); + Assert.fail("Should throw runtime exception"); + } catch (Exception e) { + Assert.assertTrue(e instanceof RuntimeException); + Assert.assertEquals(e.getMessage(), "SourceRecord class type must be PulsarRecord"); + } + } +} \ No newline at end of file diff --git a/pulsar-functions/instance/src/test/java/org/apache/pulsar/functions/source/PulsarSourceTest.java b/pulsar-functions/instance/src/test/java/org/apache/pulsar/functions/source/PulsarSourceTest.java index 80c4001d36f0e..d7c0a8d818a11 100644 --- a/pulsar-functions/instance/src/test/java/org/apache/pulsar/functions/source/PulsarSourceTest.java +++ b/pulsar-functions/instance/src/test/java/org/apache/pulsar/functions/source/PulsarSourceTest.java @@ -38,6 +38,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.pulsar.client.api.Consumer; import org.apache.pulsar.client.api.ConsumerBuilder; +import org.apache.pulsar.client.api.Message; import org.apache.pulsar.client.api.PulsarClientException; import org.apache.pulsar.client.api.Schema; import org.apache.pulsar.client.api.SubscriptionInitialPosition; @@ -360,4 +361,24 @@ public void testInputConsumersGetter(PulsarSourceConfig pulsarSourceConfig) thro fail("Unknown config type"); } + + + @Test(dataProvider = "sourceImpls") + public void testPulsarRecordCustomAck(PulsarSourceConfig pulsarSourceConfig) throws Exception { + + PulsarSource pulsarSource = getPulsarSource(pulsarSourceConfig); + Message message = Mockito.mock(Message.class); + Consumer consumer = Mockito.mock(Consumer.class); + Mockito.when(consumer.acknowledgeAsync(message)).thenReturn(CompletableFuture.completedFuture(null)); + Mockito.when(consumer.acknowledgeCumulativeAsync(message)).thenReturn(CompletableFuture.completedFuture(null)); + + PulsarRecord record = (PulsarRecord) pulsarSource.buildRecord(consumer, message); + + record.cumulativeAck(); + Mockito.verify(consumer, Mockito.times(1)).acknowledgeCumulativeAsync(message); + + record.individualAck(); + Mockito.verify(consumer, Mockito.times(1)).acknowledgeAsync(message); + } + }