Skip to content

Commit

Permalink
[SPARK-44740][CONNECT] Support specifying session_id in SPARK_REMOT…
Browse files Browse the repository at this point in the history
…E connection string

### What changes were proposed in this pull request?
To support cross-language session sharing in Spark connect, we need to be able to inject the session ID into the connection string because on the server side, the client-provided session ID is used already together with the user id.

```
SparkSession.builder.remote("sc://localhost/;session_id=abcdefg").getOrCreate()
```

### Why are the changes needed?
ease of use

### Does this PR introduce _any_ user-facing change?
Adds a way to configure the Spark Connect connection string with `session_id`

### How was this patch tested?
Added UT for the parameter.

Closes apache#42415 from grundprinzip/SPARK-44740.

Authored-by: Martin Grund <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
grundprinzip authored and HyukjinKwon committed Aug 9, 2023
1 parent 4db378f commit 7af4e35
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private[sql] class SparkConnectClient(
// Generate a unique session ID for this client. This UUID must be unique to allow
// concurrent Spark sessions of the same user. If the channel is closed, creating
// a new client will create a new session ID.
private[sql] val sessionId: String = UUID.randomUUID.toString
private[sql] val sessionId: String = configuration.sessionId.getOrElse(UUID.randomUUID.toString)

private[client] val artifactManager: ArtifactManager = {
new ArtifactManager(configuration, sessionId, bstub, stub)
Expand Down Expand Up @@ -432,6 +432,7 @@ object SparkConnectClient {
val PARAM_USE_SSL = "use_ssl"
val PARAM_TOKEN = "token"
val PARAM_USER_AGENT = "user_agent"
val PARAM_SESSION_ID = "session_id"
}

private def verifyURI(uri: URI): Unit = {
Expand Down Expand Up @@ -463,6 +464,21 @@ object SparkConnectClient {
this
}

def sessionId(value: String): Builder = {
try {
UUID.fromString(value).toString
} catch {
case e: IllegalArgumentException =>
throw new IllegalArgumentException(
"Parameter value 'session_id' must be a valid UUID format.",
e)
}
_configuration = _configuration.copy(sessionId = Some(value))
this
}

def sessionId: Option[String] = _configuration.sessionId

def userAgent: String = _configuration.userAgent

def option(key: String, value: String): Builder = {
Expand Down Expand Up @@ -490,6 +506,7 @@ object SparkConnectClient {
case URIParams.PARAM_TOKEN => token(value)
case URIParams.PARAM_USE_SSL =>
if (java.lang.Boolean.valueOf(value)) enableSsl() else disableSsl()
case URIParams.PARAM_SESSION_ID => sessionId(value)
case _ => option(key, value)
}
}
Expand Down Expand Up @@ -576,7 +593,8 @@ object SparkConnectClient {
userAgent: String = DEFAULT_USER_AGENT,
retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy(),
useReattachableExecute: Boolean = true,
interceptors: List[ClientInterceptor] = List.empty) {
interceptors: List[ClientInterceptor] = List.empty,
sessionId: Option[String] = None) {

def userContext: proto.UserContext = {
val builder = proto.UserContext.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ private[sql] object SparkConnectClientParser {
case "--user_agent" :: tail =>
val (value, remainder) = extract("--user_agent", tail)
parse(remainder, builder.userAgent(value))
case "--session_id" :: tail =>
val (value, remainder) = extract("--session_id", tail)
parse(remainder, builder.sessionId(value))
case "--option" :: tail =>
if (args.isEmpty) {
throw new IllegalArgumentException("--option requires a key-value pair")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.connect.client

import java.util.UUID

import org.apache.spark.sql.connect.client.util.ConnectFunSuite

/**
Expand Down Expand Up @@ -46,6 +48,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite {
argumentTest("user_id", "U1238", _.userId.get)
argumentTest("user_name", "alice", _.userName.get)
argumentTest("user_agent", "MY APP", _.userAgent)
argumentTest("session_id", UUID.randomUUID().toString, _.sessionId.get)

test("Argument - remote") {
val builder =
Expand All @@ -55,6 +58,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite {
assert(builder.token.contains("nahnah"))
assert(builder.userId.contains("x127"))
assert(builder.options === Map(("user_name", "Q"), ("param1", "x")))
assert(builder.sessionId.isEmpty)
}

test("Argument - use_ssl") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
client => {
assert(client.configuration.host == "localhost")
assert(client.configuration.port == 1234)
assert(client.sessionId != null)
// Must be able to parse the UUID
assert(UUID.fromString(client.sessionId) != null)
}),
TestPackURI(
"sc://localhost/;",
Expand Down Expand Up @@ -193,6 +196,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
TestPackURI("sc://host:123/;use_ssl=true", isCorrect = true),
TestPackURI("sc://host:123/;token=mySecretToken", isCorrect = true),
TestPackURI("sc://host:123/;token=", isCorrect = false),
TestPackURI("sc://host:123/;session_id=", isCorrect = false),
TestPackURI("sc://host:123/;session_id=abcdefgh", isCorrect = false),
TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}", isCorrect = true),
TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true),
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true),
TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false),
Expand Down
11 changes: 11 additions & 0 deletions connector/connect/docs/client-connection-string.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ sc://hostname:port/;param1=value;param2=value
<i>Default: </i><pre>_SPARK_CONNECT_PYTHON</pre> in the Python client</td>
<td><pre>user_agent=my_data_query_app</pre></td>
</tr>
<tr>
<td>session_id</td>
<td>String</td>
<td>In addition to the user ID, the cache of Spark Sessions in the Spark Connect
server uses a session ID as the cache key. This option in the connection string
allows to provide this session ID to allow sharing Spark Sessions for the same users
for example across multiple languages. The value must be provided in a valid UUID
string format.<br/>
<i>Default: A UUID generated randomly.</td>
<td><pre>session_id=550e8400-e29b-41d4-a716-446655440000</pre></td>
</tr>
</table>

## Examples
Expand Down
30 changes: 26 additions & 4 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class ChannelBuilder:
PARAM_TOKEN = "token"
PARAM_USER_ID = "user_id"
PARAM_USER_AGENT = "user_agent"
PARAM_SESSION_ID = "session_id"
MAX_MESSAGE_LENGTH = 128 * 1024 * 1024

@staticmethod
Expand Down Expand Up @@ -354,6 +355,22 @@ def get(self, key: str) -> Any:
"""
return self.params[key]

@property
def session_id(self) -> Optional[str]:
"""
Returns
-------
The session_id extracted from the parameters of the connection string or `None` if not
specified.
"""
session_id = self.params.get(ChannelBuilder.PARAM_SESSION_ID, None)
if session_id is not None:
try:
uuid.UUID(session_id, version=4)
except ValueError as ve:
raise ValueError("Parameter value 'session_id' must be a valid UUID format.", ve)
return session_id

def toChannel(self) -> grpc.Channel:
"""
Applies the parameters of the connection string and creates a new
Expand Down Expand Up @@ -628,10 +645,15 @@ def __init__(
if retry_policy:
self._retry_policy.update(retry_policy)

# Generate a unique session ID for this client. This UUID must be unique to allow
# concurrent Spark sessions of the same user. If the channel is closed, creating
# a new client will create a new session ID.
self._session_id = str(uuid.uuid4())
if self._builder.session_id is None:
# Generate a unique session ID for this client. This UUID must be unique to allow
# concurrent Spark sessions of the same user. If the channel is closed, creating
# a new client will create a new session ID.
self._session_id = str(uuid.uuid4())
else:
# Use the pre-defined session ID.
self._session_id = str(self._builder.session_id)

if self._builder.userId is not None:
self._user_id = self._builder.userId
elif user_id is not None:
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import unittest
import uuid
from typing import Optional

from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
Expand Down Expand Up @@ -88,6 +89,12 @@ def test_is_closed(self):
client.close()
self.assertTrue(client.is_closed)

def test_channel_builder_with_session(self):
dummy = str(uuid.uuid4())
chan = ChannelBuilder(f"sc://foo/;session_id={dummy}")
client = SparkConnectClient(chan)
self.assertEqual(client._session_id, chan.session_id)


class MockService:
# Simplest mock of the SparkConnectService.
Expand Down
18 changes: 17 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import shutil
import string
import tempfile
import uuid
from collections import defaultdict

from pyspark.errors import (
Expand Down Expand Up @@ -76,7 +77,7 @@
from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
from pyspark.sql.connect.client.core import Retrying
from pyspark.sql.connect.client.core import Retrying, SparkConnectClient


class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils):
Expand Down Expand Up @@ -3522,6 +3523,21 @@ def test_metadata(self):
md = chan.metadata()
self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md)

def test_metadata(self):
id = str(uuid.uuid4())
chan = ChannelBuilder(f"sc://host/;session_id={id}")
self.assertEqual(id, chan.session_id)

with self.assertRaises(ValueError) as ve:
chan = ChannelBuilder("sc://host/;session_id=abcd")
SparkConnectClient(chan)
self.assertIn(
"Parameter value 'session_id' must be a valid UUID format.", str(ve.exception)
)

chan = ChannelBuilder("sc://host/")
self.assertIsNone(chan.session_id)


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
Expand Down

0 comments on commit 7af4e35

Please sign in to comment.