diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b9aa1f5bc5838..54eb6e761407c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -25,13 +25,13 @@ import scala.collection.JavaConverters._ import com.google.protobuf.ByteString import org.apache.spark.annotation.Evolving +import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.WriteStreamOperationStart import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.connect.common.DataTypeProtoConverter -import org.apache.spark.sql.connect.common.ForeachWriterPacket +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, UdfUtils} import org.apache.spark.sql.execution.streaming.AvailableNowTrigger import org.apache.spark.sql.execution.streaming.ContinuousTrigger import org.apache.spark.sql.execution.streaming.OneTimeTrigger @@ -247,6 +247,24 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } + /** + * :: Experimental :: + * + * (Java-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The + * batchId can be used to deduplicate and transactionally write the output (that is, the + * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the + * same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 3.5.0 + */ + @Evolving + def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { + foreachBatch(UdfUtils.foreachBatchFuncToScalaFunc(function)) + } + /** * Starts the execution of the streaming query, which will continually output results to the * given path as new data arrives. The returned [[StreamingQuery]] object can be used to diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 3fc02d7c397f0..04b162eceec28 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -234,9 +234,6 @@ object CheckConnectJvmClientCompatibility { // DataStreamWriter ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.DataStreamWriter$"), - ProblemFilters.exclude[Problem]( - "org.apache.spark.sql.streaming.DataStreamWriter.foreachBatch" // TODO(SPARK-42944) - ), ProblemFilters.exclude[Problem]( "org.apache.spark.sql.streaming.DataStreamWriter.SOURCE*" // These are constant vals. ), diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index ab92431bc116a..944a999a860b6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -27,6 +27,7 @@ import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ +import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession, SQLHelper} import org.apache.spark.sql.connect.client.util.QueryTest @@ -412,11 +413,13 @@ class EventCollector extends StreamingQueryListener { } } -class ForeachBatchFn(val viewName: String) extends ((DataFrame, Long) => Unit) with Serializable { - override def apply(df: DataFrame, batchId: Long): Unit = { +class ForeachBatchFn(val viewName: String) + extends VoidFunction2[DataFrame, java.lang.Long] + with Serializable { + override def call(df: DataFrame, batchId: java.lang.Long): Unit = { val count = df.count() df.sparkSession - .createDataFrame(Seq((batchId, count))) + .createDataFrame(Seq((batchId.toLong, count))) .createOrReplaceGlobalTempView(viewName) } } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala index ceacc595d15f4..16d5823f4a474 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala @@ -59,6 +59,9 @@ private[sql] object UdfUtils extends Serializable { def foreachPartitionFuncToScalaFunc[T](f: ForeachPartitionFunction[T]): Iterator[T] => Unit = x => f.call(x.asJava) + def foreachBatchFuncToScalaFunc[D](f: VoidFunction2[D, java.lang.Long]): (D, Long) => Unit = + (d, i) => f.call(d, i) + def flatMapFuncToScalaFunc[T, U](f: FlatMapFunction[T, U]): T => TraversableOnce[U] = x => f.call(x).asScala