Skip to content

Commit

Permalink
[SPARK-50661][CONNECT][SS] Fix Spark Connect Scala foreachBatch impl.…
Browse files Browse the repository at this point in the history
… to support Dataset[T].

### What changes were proposed in this pull request?
This PR fixes incorrect implementation of Scala Streaming foreachBatch when the input dataset is not a DataFrame (but a Dataset[T]) in spark connect mode.

**Note** that this only affects `Scala`.

In `DataStreamWriter`:
- serialize foreachBatch function together with the dataset's encoder.
- reuse ForeachWriterPacket for foreachBatch as both are sink operations and only require a function/writer object and the encoder of the input. Optionally, we could rename `ForeachWriterPacket` to something more general for both cases.

In `SparkConnectPlanner` / `StreamingForeachBatchHelper`
- Use the encoder passed from the client to recover the Dataset[T] object to properly call the foreachBatch function.

### Why are the changes needed?
Without the fix, Scala foreachBatch will fail or give wrong results when the input dataset is not a DataFrame.

Below is a simple reproduction:

```
import org.apache.spark.sql._
spark.range(10).write.format("parquet").mode("overwrite").save("/tmp/test")

val q = spark.readStream.format("parquet").schema("id LONG").load("/tmp/test").as[java.lang.Long].writeStream.foreachBatch((ds: Dataset[java.lang.Long], batchId: Long) => println(ds.collect().map(_.asInstanceOf[Long]).sum)).start()

Thread.sleep(1000)
q.stop()
```

The code above should output 45 in the foreachBatch function. Without the fix, the code will fail because the foreachBatch function will be called with a DataFrame object instead of Dataset[java.lang.Long].

### Does this PR introduce _any_ user-facing change?
Yes, this PR includes fixes to the Spark Connect client (we add the encoder to the foreachBatch function during serialization) around the foreachBatch API.

### How was this patch tested?
1. Run end-to-end test with spark-shell (with spark connect server and client running in connect mode).
2. New / updated unit tests that would have failed without the fix.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#49323 from haiyangsun-db/SPARK-50661.

Authored-by: Haiyang Sun <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
haiyangsun-db authored and HyukjinKwon committed Dec 28, 2024
1 parent af53ee4 commit 51b011f
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataSt
/** @inheritdoc */
@Evolving
def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = {
val serializedFn = SparkSerDeUtils.serialize(function)
// SPARK-50661: the client should send the encoder for the input dataset together with the
// function to the server.
val serializedFn =
SparkSerDeUtils.serialize(ForeachWriterPacket(function, ds.agnosticEncoder))
sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder
.setPayload(ByteString.copyFrom(serializedFn))
.setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ import org.scalatest.concurrent.Futures.timeout
import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkException
import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row, SparkSession}
import org.apache.spark.sql.functions.{col, lit, udf, window}
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.sql.test.{IntegrationTestUtils, QueryTest, RemoteSparkSession}
Expand Down Expand Up @@ -567,7 +566,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
}
}

test("foreachBatch") {
test("foreachBatch with DataFrame") {
// Starts a streaming query with a foreachBatch function, which writes batchId and row count
// to a temp view. The test verifies that the view is populated with data.

Expand All @@ -581,7 +580,12 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
.option("numPartitions", "1")
.load()
.writeStream
.foreachBatch(new ForeachBatchFn(viewName))
.foreachBatch((df: DataFrame, batchId: Long) => {
val count = df.collect().map(row => row.getLong(1)).sum
df.sparkSession
.createDataFrame(Seq((batchId, count)))
.createOrReplaceGlobalTempView(viewName)
})
.start()

eventually(timeout(30.seconds)) { // Wait for first progress.
Expand All @@ -596,13 +600,83 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
.collect()
.toSeq
assert(rows.size > 0)
assert(rows.map(_.getLong(1)).sum > 0)
logInfo(s"Rows in $tableName: $rows")
}

q.stop()
}
}

test("foreachBatch with Dataset[java.lang.Long]") {
val viewName = "test_view"
val tableName = s"global_temp.$viewName"

withTable(tableName) {
val session = spark
import session.implicits._
val q = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
.select($"value")
.as[java.lang.Long]
.writeStream
.foreachBatch((ds: Dataset[java.lang.Long], batchId: Long) => {
val count = ds.collect().map(v => v.asInstanceOf[Long]).sum
ds.sparkSession
.createDataFrame(Seq((batchId, count)))
.createOrReplaceGlobalTempView(viewName)
})
.start()

eventually(timeout(30.seconds)) { // Wait for first progress.
assert(q.lastProgress != null, "Failed to make progress")
assert(q.lastProgress.numInputRows > 0)
}

eventually(timeout(30.seconds)) {
// There should be row(s) in temporary view created by foreachBatch.
val rows = spark
.sql(s"select * from $tableName")
.collect()
.toSeq
assert(rows.size > 0)
assert(rows.map(_.getLong(1)).sum > 0)
logInfo(s"Rows in $tableName: $rows")
}

q.stop()
}
}

test("foreachBatch with Dataset[TestClass]") {
val session: SparkSession = spark
import session.implicits._
val viewName = "test_view"
val tableName = s"global_temp.$viewName"

val df = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.load()

val q = df
.selectExpr("CAST(value AS INT)")
.as[TestClass]
.writeStream
.foreachBatch((ds: Dataset[TestClass], batchId: Long) => {
val count = ds.collect().map(_.value).sum
})
.start()
eventually(timeout(30.seconds)) {
assert(q.isActive)
assert(q.exception.isEmpty)
}
q.stop()
}

abstract class EventCollector extends StreamingQueryListener {
protected def tablePostfix: String

Expand Down Expand Up @@ -700,14 +774,3 @@ class TestForeachWriter[T] extends ForeachWriter[T] {
case class TestClass(value: Int) {
override def toString: String = value.toString
}

class ForeachBatchFn(val viewName: String)
extends VoidFunction2[DataFrame, java.lang.Long]
with Serializable {
override def call(df: DataFrame, batchId: java.lang.Long): Unit = {
val count = df.count()
df.sparkSession
.createDataFrame(Seq((batchId.toLong, count)))
.createOrReplaceGlobalTempView(viewName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2957,10 +2957,9 @@ class SparkConnectPlanner(
fn

case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION =>
val scalaFn = Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType](
StreamingForeachBatchHelper.scalaForeachBatchWrapper(
writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder)
sessionHolder)

case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ import org.apache.spark.SparkException
import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, QUERY_ID, RUN_ID_STRING, SESSION_ID}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders}
import org.apache.spark.sql.connect.common.ForeachWriterPacket
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.Utils

/**
* A helper class for handling ForeachBatch related functionality in Spark Connect servers
Expand Down Expand Up @@ -88,13 +91,31 @@ object StreamingForeachBatchHelper extends Logging {
* DataFrame, so the user code actually runs with legacy DataFrame and session..
*/
def scalaForeachBatchWrapper(
fn: ForeachBatchFnType,
payloadBytes: Array[Byte],
sessionHolder: SessionHolder): ForeachBatchFnType = {
val foreachBatchPkt =
Utils.deserialize[ForeachWriterPacket](payloadBytes, Utils.getContextOrSparkClassLoader)
val fn = foreachBatchPkt.foreachWriter.asInstanceOf[(Dataset[Any], Long) => Unit]
val encoder = foreachBatchPkt.datasetEncoder.asInstanceOf[AgnosticEncoder[Any]]
// TODO(SPARK-44462): Set up Spark Connect session.
// Do we actually need this for the first version?
dataFrameCachingWrapper(
(args: FnArgsWithId) => {
fn(args.df, args.batchId) // dfId is not used, see hack comment above.
// dfId is not used, see hack comment above.
try {
val ds = if (AgnosticEncoders.UnboundRowEncoder == encoder) {
// When the dataset is a DataFrame (Dataset[Row).
args.df.asInstanceOf[Dataset[Any]]
} else {
// Recover the Dataset from the DataFrame using the encoder.
Dataset.apply(args.df.sparkSession, args.df.logicalPlan)(encoder)
}
fn(ds, args.batchId)
} catch {
case t: Throwable =>
logError(s"Calling foreachBatch fn failed", t)
throw t
}
},
sessionHolder)
}
Expand Down

0 comments on commit 51b011f

Please sign in to comment.