Skip to content

Commit

Permalink
[SPARK-3634] [PySpark] User's module should take precedence over syst…
Browse files Browse the repository at this point in the history
…em modules

Python modules added through addPyFile should take precedence over system modules.

This patch put the path for user added module in the front of sys.path (just after '').

Author: Davies Liu <[email protected]>

Closes apache#2492 from davies/path and squashes the following commits:

4a2af78 [Davies Liu] fix tests
f7ff4da [Davies Liu] ad license header
6b0002f [Davies Liu] add tests
c16c392 [Davies Liu] put addPyFile in front of sys.path
  • Loading branch information
davies authored and JoshRosen committed Sep 24, 2014
1 parent 50f8633 commit c854b9f
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 8 deletions.
11 changes: 5 additions & 6 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,

SparkFiles._sc = self
root_dir = SparkFiles.getRootDirectory()
sys.path.append(root_dir)
sys.path.insert(1, root_dir)

# Deploy any code dependencies specified in the constructor
self._python_includes = list()
Expand All @@ -183,10 +183,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
self._python_includes.append(filename)
sys.path.append(path)
if dirname not in sys.path:
sys.path.append(dirname)
if filename.lower().endswith("zip") or filename.lower().endswith("egg"):
self._python_includes.append(filename)
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))

# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
Expand Down Expand Up @@ -667,7 +666,7 @@ def addPyFile(self, path):
if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
self._python_includes.append(filename)
# for tests in local mode
sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))

def setCheckpointDir(self, dirName):
"""
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,18 @@ def func():
from userlib import UserClass
self.assertEqual("Hello World from inside a package!", UserClass().hello())

def test_overwrite_system_module(self):
self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py"))

import SimpleHTTPServer
self.assertEqual("My Server", SimpleHTTPServer.__name__)

def func(x):
import SimpleHTTPServer
return SimpleHTTPServer.__name__

self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())


class TestRDDFunctions(PySparkTestCase):

Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def report_times(outfile, boot, init, finish):
write_long(1000 * finish, outfile)


def add_path(path):
# worker can be used, so donot add path multiple times
if path not in sys.path:
# overwrite system packages
sys.path.insert(1, path)


def main(infile, outfile):
try:
boot_time = time.time()
Expand All @@ -61,11 +68,11 @@ def main(infile, outfile):
SparkFiles._is_running_on_worker = True

# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
add_path(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
filename = utf8_deserializer.loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename))
add_path(os.path.join(spark_files_dir, filename))

# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
Expand Down
22 changes: 22 additions & 0 deletions python/test_support/SimpleHTTPServer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Used to test override standard SimpleHTTPServer module.
"""

__name__ = "My Server"

0 comments on commit c854b9f

Please sign in to comment.