Skip to content

Commit

Permalink
[SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with S…
Browse files Browse the repository at this point in the history
…park Connect

### What changes were proposed in this pull request?
Make Torch Distributor support Spark Connect

### Why are the changes needed?
functionality parity.

**Note**, `local_mode` with `use_gpu` is not supported for now since `sc.resources` is missing in Connect

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
reused UT

Closes apache#40607 from zhengruifeng/connect_torch.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Apr 7, 2023
1 parent 10fd918 commit ad013d3
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,18 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
session.withActive {

// Add debug information to the query execution so that the jobs are traceable.
val debugString = v.toString
session.sparkContext.setLocalProperty(
"callSite.short",
s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
session.sparkContext.setLocalProperty(
"callSite.long",
StringUtils.abbreviate(debugString, 2048))
try {
val debugString = v.toString
session.sparkContext.setLocalProperty(
"callSite.short",
s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
session.sparkContext.setLocalProperty(
"callSite.long",
StringUtils.abbreviate(debugString, 2048))
} catch {
case e: Throwable =>
logWarning("Fail to extract or attach the debug information", e)
}

v.getPlan.getOpTypeCase match {
case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ def __hash__(self):
"pyspark.ml.connect.functions",
# ml unittests
"pyspark.ml.tests.connect.test_connect_function",
"pyspark.ml.tests.connect.test_parity_torch_distributor",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
134 changes: 134 additions & 0 deletions python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import shutil
import tempfile
import unittest

have_torch = True
try:
import torch # noqa: F401
except ImportError:
have_torch = False

from pyspark.sql import SparkSession

from pyspark.ml.torch.distributor import TorchDistributor

from pyspark.ml.torch.tests.test_distributor import (
TorchDistributorBaselineUnitTestsMixin,
TorchDistributorLocalUnitTestsMixin,
TorchDistributorDistributedUnitTestsMixin,
TorchWrapperUnitTestsMixin,
)


@unittest.skipIf(not have_torch, "torch is required")
class TorchDistributorBaselineUnitTestsOnConnect(
TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote("local[4]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()

def test_get_num_tasks_fails(self) -> None:
inputs = [1, 5, 4]

# This is when the conf isn't set and we request GPUs
for num_processes in inputs:
with self.subTest():
# TODO(SPARK-42994): Support sc.resources
# with self.assertRaisesRegex(RuntimeError, "driver"):
# TorchDistributor(num_processes, True, True)
with self.assertRaisesRegex(RuntimeError, "unset"):
TorchDistributor(num_processes, False, True)


@unittest.skipIf(not have_torch, "torch is required")
class TorchDistributorLocalUnitTestsOnConnect(
TorchDistributorLocalUnitTestsMixin, unittest.TestCase
):
def setUp(self) -> None:
class_name = self.__class__.__name__
conf = self._get_spark_conf()
builder = SparkSession.builder.appName(class_name)
for k, v in conf.getAll():
if k not in ["spark.master", "spark.remote", "spark.app.name"]:
builder = builder.config(k, v)
self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()
self.mnist_dir_path = tempfile.mkdtemp()

def tearDown(self) -> None:
shutil.rmtree(self.mnist_dir_path)
os.unlink(self.gpu_discovery_script_file.name)
self.spark.stop()

# TODO(SPARK-42994): Support sc.resources
@unittest.skip("need to support sc.resources")
def test_get_num_tasks_locally(self):
super().test_get_num_tasks_locally()

# TODO(SPARK-42994): Support sc.resources
@unittest.skip("need to support sc.resources")
def test_get_gpus_owned_local(self):
super().test_get_gpus_owned_local()

# TODO(SPARK-42994): Support sc.resources
@unittest.skip("need to support sc.resources")
def test_local_training_succeeds(self):
super().test_local_training_succeeds()


@unittest.skipIf(not have_torch, "torch is required")
class TorchDistributorDistributedUnitTestsOnConnect(
TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
):
def setUp(self) -> None:
class_name = self.__class__.__name__
conf = self._get_spark_conf()
builder = SparkSession.builder.appName(class_name)
for k, v in conf.getAll():
if k not in ["spark.master", "spark.remote", "spark.app.name"]:
builder = builder.config(k, v)

self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()
self.mnist_dir_path = tempfile.mkdtemp()

def tearDown(self) -> None:
shutil.rmtree(self.mnist_dir_path)
os.unlink(self.gpu_discovery_script_file.name)
self.spark.stop()


@unittest.skipIf(not have_torch, "torch is required")
class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin, unittest.TestCase):
pass


if __name__ == "__main__":
from pyspark.ml.tests.connect.test_parity_torch_distributor import * # noqa: F401,F403

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading

0 comments on commit ad013d3

Please sign in to comment.