Skip to content

Commit

Permalink
[SPARK-1087] Move python traceback utilities into new traceback_utils…
Browse files Browse the repository at this point in the history
….py file.

Also made some cosmetic cleanups.

Author: Aaron Staple <[email protected]>

Closes apache#2385 from staple/SPARK-1087 and squashes the following commits:

7b3bb13 [Aaron Staple] Address review comments, cosmetic cleanups.
10ba6e1 [Aaron Staple] [SPARK-1087] Move python traceback utilities into new traceback_utils.py file.
  • Loading branch information
staple authored and JoshRosen committed Sep 16, 2014
1 parent da33acb commit 60050f4
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 61 deletions.
8 changes: 2 additions & 6 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
from collections import namedtuple

from pyspark import accumulators
from pyspark.accumulators import Accumulator
Expand All @@ -33,6 +32,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call

from py4j.java_collections import ListConverter

Expand Down Expand Up @@ -99,11 +99,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
...
ValueError:...
"""
if rdd._extract_concise_traceback() is not None:
self._callsite = rdd._extract_concise_traceback()
else:
tempNamedTuple = namedtuple("Callsite", "function file linenum")
self._callsite = tempNamedTuple(function=None, file=None, linenum=None)
self._callsite = first_spark_call() or CallSite(None, None, None)
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
Expand Down
58 changes: 3 additions & 55 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from collections import namedtuple
from itertools import chain, ifilter, imap
import operator
import os
import sys
import shlex
import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
Expand All @@ -45,6 +43,7 @@
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
get_used_memory, ExternalSorter
from pyspark.traceback_utils import SCCallSiteSync

from py4j.java_collections import ListConverter, MapConverter

Expand Down Expand Up @@ -81,57 +80,6 @@ def portable_hash(x):
return hash(x)


def _extract_concise_traceback():
"""
This function returns the traceback info for a callsite, returns a dict
with function name, file name and line number
"""
tb = traceback.extract_stack()
callsite = namedtuple("Callsite", "function file linenum")
if len(tb) == 0:
return None
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return callsite(function=fun, file=file, linenum=line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
return callsite(function=sfun, file=ufile, linenum=uline)

_spark_stack_depth = 0


class _JavaStackTrace(object):

def __init__(self, sc):
tb = _extract_concise_traceback()
if tb is not None:
self._traceback = "%s at %s:%s" % (
tb.function, tb.file, tb.linenum)
else:
self._traceback = "Error! Could not extract traceback info"
self._context = sc

def __enter__(self):
global _spark_stack_depth
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(self._traceback)
_spark_stack_depth += 1

def __exit__(self, type, value, tb):
global _spark_stack_depth
_spark_stack_depth -= 1
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)


class BoundedFloat(float):
"""
Bounded value is generated by approximate job, with confidence and low
Expand Down Expand Up @@ -704,7 +652,7 @@ def collect(self):
"""
Return a list that contains all of the elements in this RDD.
"""
with _JavaStackTrace(self.context) as st:
with SCCallSiteSync(self.context) as css:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))

Expand Down Expand Up @@ -1515,7 +1463,7 @@ def add_shuffle_key(split, iterator):

keyed = self.mapPartitionsWithIndex(add_shuffle_key)
keyed._bypass_serializer = True
with _JavaStackTrace(self.context) as st:
with SCCallSiteSync(self.context) as css:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
Expand Down
78 changes: 78 additions & 0 deletions python/pyspark/traceback_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from collections import namedtuple
import os
import traceback


CallSite = namedtuple("CallSite", "function file linenum")


def first_spark_call():
"""
Return a CallSite representing the first Spark call in the current call stack.
"""
tb = traceback.extract_stack()
if len(tb) == 0:
return None
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return CallSite(function=fun, file=file, linenum=line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
return CallSite(function=sfun, file=ufile, linenum=uline)


class SCCallSiteSync(object):
"""
Helper for setting the spark context call site.
Example usage:
from pyspark.context import SCCallSiteSync
with SCCallSiteSync(<relevant SparkContext>) as css:
<a Spark call>
"""

_spark_stack_depth = 0

def __init__(self, sc):
call_site = first_spark_call()
if call_site is not None:
self._call_site = "%s at %s:%s" % (
call_site.function, call_site.file, call_site.linenum)
else:
self._call_site = "Error! Could not extract traceback info"
self._context = sc

def __enter__(self):
if SCCallSiteSync._spark_stack_depth == 0:
self._context._jsc.setCallSite(self._call_site)
SCCallSiteSync._spark_stack_depth += 1

def __exit__(self, type, value, tb):
SCCallSiteSync._spark_stack_depth -= 1
if SCCallSiteSync._spark_stack_depth == 0:
self._context._jsc.setCallSite(None)

0 comments on commit 60050f4

Please sign in to comment.