Skip to content

Commit

Permalink
[SPARK-50578][PYTHON][SS] Add support for new version of state metada…
Browse files Browse the repository at this point in the history
…ta for TransformWithStateInPandas

### What changes were proposed in this pull request?

Enable TransformWithStateInPandas operator to write new versions of state metadata and state schema. This will enable state metadata source and state data source reader. And will also support future schema evolution changes.

To achieve this purpose, in this PR, we add a new implementation of driver side Python runner. This is because spark will need to get the state schema on the driver during planning inside `IncrementalExecution`. We will also need to start another state server in the new driver side Python runner to handle API calls in init().

### Why are the changes needed?

This is to match with the new versions of state metadata and state schema version implemented in Scala side of TransformWithState.

### Does this PR introduce _any_ user-facing change?

No.
But now users will be able to get results from state metadata source reader and state data source reader using the same API as Scala. E.g. for state metadata source reader, we can now read out state metadata as follows:
```
metadata_df = spark.read.format("state-metadata").load(checkpoint_path)
```
And we can read out state rows by using state data source as follows:
```
list_state_df = spark.read.format("statestore")\
  .option("path", checkpoint_path)\
  .option("stateVarName", "listState")\
  .load()
```

### How was this patch tested?

Add unit tests in `python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py`. Test state metadata and state schema files are written correctly by using state metadata source reader and state data source reader.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#49156 from jingz-db/python-metadata.

Lead-authored-by: jingz-db <[email protected]>
Co-authored-by: Jing Zhan <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
2 people authored and HeartSaVioR committed Dec 26, 2024
1 parent aac494e commit c920210
Show file tree
Hide file tree
Showing 16 changed files with 887 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ private[spark] class StreamingPythonRunner(
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)

private val envVars: java.util.Map[String, String] = func.envVars
private val pythonExec: String = func.pythonExec
private var pythonWorker: Option[PythonWorker] = None
private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
protected val envVars: java.util.Map[String, String] = func.envVars
protected val pythonExec: String = func.pythonExec
protected var pythonWorker: Option[PythonWorker] = None
protected var pythonWorkerFactory: Option[PythonWorkerFactory] = None
protected val pythonVer: String = func.pythonVer

/**
Expand All @@ -68,7 +68,9 @@ private[spark] class StreamingPythonRunner(

envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)
if (!connectUrl.isEmpty) {
envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)
}

val workerFactory =
new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap, false)
Expand All @@ -83,7 +85,9 @@ private[spark] class StreamingPythonRunner(
PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)

// Send sessionId
PythonRDD.writeUTF(sessionId, dataOut)
if (!sessionId.isEmpty) {
PythonRDD.writeUTF(sessionId, dataOut)
}

// Send the user function to python process
PythonWorkerUtils.writePythonFunction(func, dataOut)
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,24 @@ def transformWithStateInPandas(
if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))

def handle_pre_init(
statefulProcessorApiClient: StatefulProcessorApiClient,
) -> Iterator["PandasDataFrameLike"]:
# Driver handle is different from the handle used on executors;
# On JVM side, we will use `DriverStatefulProcessorHandleImpl` for driver handle which
# will only be used for handling init() and get the state schema on the driver.
driver_handle = StatefulProcessorHandle(statefulProcessorApiClient)
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.PRE_INIT)
statefulProcessor.init(driver_handle)

# This method is used for the driver-side stateful processor after we have collected
# all the necessary schemas. This instance of the DriverStatefulProcessorHandleImpl
# won't be used again on JVM.
statefulProcessor.close()

# return a dummy results, no return value is needed for pre init
return iter([])

def handle_data_rows(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
Expand Down Expand Up @@ -560,6 +578,9 @@ def transformWithStateUDF(
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
) -> Iterator["PandasDataFrameLike"]:
if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
return handle_pre_init(statefulProcessorApiClient)

handle = StatefulProcessorHandle(statefulProcessorApiClient)

if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED:
Expand Down Expand Up @@ -606,6 +627,9 @@ def transformWithStateWithInitStateUDF(
- `initialStates` is None, while `inputRows` is not empty. This is not first batch.
`initialStates` is initialized to the positional value as None.
"""
if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
return handle_pre_init(statefulProcessorApiClient)

handle = StatefulProcessorHandle(statefulProcessorApiClient)

if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED:
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/streaming/proto/StateMessage_pb2.py

Large diffs are not rendered by default.

22 changes: 12 additions & 10 deletions python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,21 @@ class _HandleStateEnumTypeWrapper(
builtins.type,
): # noqa: F821
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
CREATED: _HandleState.ValueType # 0
INITIALIZED: _HandleState.ValueType # 1
DATA_PROCESSED: _HandleState.ValueType # 2
TIMER_PROCESSED: _HandleState.ValueType # 3
CLOSED: _HandleState.ValueType # 4
PRE_INIT: _HandleState.ValueType # 0
CREATED: _HandleState.ValueType # 1
INITIALIZED: _HandleState.ValueType # 2
DATA_PROCESSED: _HandleState.ValueType # 3
TIMER_PROCESSED: _HandleState.ValueType # 4
CLOSED: _HandleState.ValueType # 5

class HandleState(_HandleState, metaclass=_HandleStateEnumTypeWrapper): ...

CREATED: HandleState.ValueType # 0
INITIALIZED: HandleState.ValueType # 1
DATA_PROCESSED: HandleState.ValueType # 2
TIMER_PROCESSED: HandleState.ValueType # 3
CLOSED: HandleState.ValueType # 4
PRE_INIT: HandleState.ValueType # 0
CREATED: HandleState.ValueType # 1
INITIALIZED: HandleState.ValueType # 2
DATA_PROCESSED: HandleState.ValueType # 3
TIMER_PROCESSED: HandleState.ValueType # 4
CLOSED: HandleState.ValueType # 5
global___HandleState = HandleState

class StateRequest(google.protobuf.message.Message):
Expand Down
14 changes: 11 additions & 3 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@


class StatefulProcessorHandleState(Enum):
PRE_INIT = 0
CREATED = 1
INITIALIZED = 2
DATA_PROCESSED = 3
Expand All @@ -48,14 +49,19 @@ class StatefulProcessorHandleState(Enum):


class StatefulProcessorApiClient:
def __init__(self, state_server_port: int, key_schema: StructType) -> None:
def __init__(
self, state_server_port: int, key_schema: StructType, is_driver: bool = False
) -> None:
self.key_schema = key_schema
self._client_socket = socket.socket()
self._client_socket.connect(("localhost", state_server_port))
self.sockfile = self._client_socket.makefile(
"rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
)
self.handle_state = StatefulProcessorHandleState.CREATED
if is_driver:
self.handle_state = StatefulProcessorHandleState.PRE_INIT
else:
self.handle_state = StatefulProcessorHandleState.CREATED
self.utf8_deserializer = UTF8Deserializer()
self.pickleSer = CPickleSerializer()
self.serializer = ArrowStreamSerializer()
Expand All @@ -70,7 +76,9 @@ def __init__(self, state_server_port: int, key_schema: StructType) -> None:
def set_handle_state(self, state: StatefulProcessorHandleState) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

if state == StatefulProcessorHandleState.CREATED:
if state == StatefulProcessorHandleState.PRE_INIT:
proto_state = stateMessage.PRE_INIT
elif state == StatefulProcessorHandleState.CREATED:
proto_state = stateMessage.CREATED
elif state == StatefulProcessorHandleState.INITIALIZED:
proto_state = stateMessage.INITIALIZED
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/streaming/stateful_processor_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ class TransformWithStateInPandasFuncMode(Enum):
PROCESS_DATA = 1
PROCESS_TIMER = 2
COMPLETE = 3
PRE_INIT = 4
102 changes: 102 additions & 0 deletions python/pyspark/sql/streaming/transform_with_state_driver_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import json
from typing import Any, Iterator, TYPE_CHECKING

from pyspark.util import local_connect_and_auth
from pyspark.serializers import (
write_int,
read_int,
UTF8Deserializer,
CPickleSerializer,
)
from pyspark import worker
from pyspark.util import handle_worker_exception
from typing import IO
from pyspark.worker_util import check_python_version
from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
from pyspark.sql.types import StructType

if TYPE_CHECKING:
from pyspark.sql.pandas._typing import (
DataFrameLike as PandasDataFrameLike,
)

pickle_ser = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()


def main(infile: IO, outfile: IO) -> None:
check_python_version(infile)

log_name = "Streaming TransformWithStateInPandas Python worker"
print(f"Starting {log_name}.\n")

def process(
processor: StatefulProcessorApiClient,
mode: TransformWithStateInPandasFuncMode,
key: Any,
input: Iterator["PandasDataFrameLike"],
) -> None:
print(f"{log_name} Starting execution of UDF: {func}.\n")
func(processor, mode, key, input)
print(f"{log_name} Completed execution of UDF: {func}.\n")

try:
func, return_type = worker.read_command(pickle_ser, infile)
print(
f"{log_name} finish init stage of Python runner. Received UDF from JVM: {func}, "
f"received return type of UDF: {return_type}.\n"
)
# send signal for getting args
write_int(0, outfile)
outfile.flush()

# This driver runner will only be used on the first batch of a query,
# and the following code block should be only run once for each query run
state_server_port = read_int(infile)
key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
print(
f"{log_name} received parameters for UDF. State server port: {state_server_port}, "
f"key schema: {key_schema}.\n"
)

stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
process(
stateful_processor_api_client,
TransformWithStateInPandasFuncMode.PRE_INIT,
None,
iter([]),
)
write_int(0, outfile)
outfile.flush()
except Exception as e:
handle_worker_exception(e, outfile)
outfile.flush()


if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Loading

0 comments on commit c920210

Please sign in to comment.