Skip to content

Commit

Permalink
[SPARK-3027] TaskContext: tighten visibility and provide Java friendl…
Browse files Browse the repository at this point in the history
…y callback API

Note this also passes the TaskContext itself to the TaskCompletionListener. In the future we can mark TaskContext with the exception object if exception occurs during task execution.

Author: Reynold Xin <[email protected]>

Closes apache#1938 from rxin/TaskContext and squashes the following commits:

145de43 [Reynold Xin] Added JavaTaskCompletionListenerImpl for Java API friendly guarantee.
f435ea5 [Reynold Xin] Added license header for TaskCompletionListener.
dc4ed27 [Reynold Xin] [SPARK-3027] TaskContext: tighten the visibility and provide Java friendly callback API
  • Loading branch information
rxin committed Aug 15, 2014
1 parent fa5a08e commit 655699f
Show file tree
Hide file tree
Showing 14 changed files with 144 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
if (context.interrupted) {
if (context.isInterrupted) {
throw new TaskKilledException
} else {
delegate.hasNext
Expand Down
63 changes: 56 additions & 7 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.TaskCompletionListener


/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*
* @param stageId stage id
* @param partitionId index of the partition
* @param attemptId the number of attempts to execute this task
* @param runningLocally whether the task is running locally in the driver JVM
* @param taskMetrics performance metrics of the task
*/
@DeveloperApi
class TaskContext(
Expand All @@ -39,27 +47,68 @@ class TaskContext(
def splitId = partitionId

// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]

// Whether the corresponding task has been killed.
@volatile var interrupted: Boolean = false
@volatile private var interrupted: Boolean = false

// Whether the task has completed.
@volatile private var completed: Boolean = false

/** Checks whether the task has completed. */
def isCompleted: Boolean = completed

// Whether the task has completed, before the onCompleteCallbacks are executed.
@volatile var completed: Boolean = false
/** Checks whether the task has been killed. */
def isInterrupted: Boolean = interrupted

// TODO: Also track whether the task has completed successfully or with exception.

/**
* Add a (Java friendly) listener to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
}

/**
* Add a listener in the form of a Scala closure to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
onCompleteCallbacks += new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = f(context)
}
this
}

/**
* Add a callback function to be executed on task completion. An example use
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
* @param f Callback function.
*/
@deprecated("use addTaskCompletionListener", "1.1.0")
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
onCompleteCallbacks += new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = f()
}
}

def executeOnCompleteCallbacks() {
/** Marks the task as completed and triggers the listeners. */
private[spark] def markTaskCompleted(): Unit = {
completed = true
// Process complete callbacks in the reverse order of registration
onCompleteCallbacks.reverse.foreach { _() }
onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) }
}

/** Marks the task for interruption, i.e. cancellation. */
private[spark] def markInterrupted(): Unit = {
interrupted = true
}
}
12 changes: 6 additions & 6 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private[spark] class PythonRDD(
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)

context.addOnCompleteCallback { () =>
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()

// Cleanup the worker socket. This will also cause the Python worker to exit.
Expand Down Expand Up @@ -137,7 +137,7 @@ private[spark] class PythonRDD(
}
} catch {

case e: Exception if context.interrupted =>
case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException

Expand Down Expand Up @@ -176,7 +176,7 @@ private[spark] class PythonRDD(

/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
assert(context.completed)
assert(context.isCompleted)
this.interrupt()
}

Expand Down Expand Up @@ -209,7 +209,7 @@ private[spark] class PythonRDD(
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.flush()
} catch {
case e: Exception if context.completed || context.interrupted =>
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)

case e: Exception =>
Expand All @@ -235,10 +235,10 @@ private[spark] class PythonRDD(
override def run() {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
while (!context.interrupted && !context.completed) {
while (!context.isInterrupted && !context.isCompleted) {
Thread.sleep(2000)
}
if (!context.completed) {
if (!context.isCompleted) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private[spark] object CheckpointRDD extends Logging {
val deserializeStream = serializer.deserializeStream(fileInputStream)

// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => deserializeStream.close())
context.addTaskCompletionListener(context => deserializeStream.close())

deserializeStream.asIterator.asInstanceOf[Iterator[T]]
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class HadoopRDD[K, V](
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)

// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback{ () => closeIfNeeded() }
context.addTaskCompletionListener{ context => closeIfNeeded() }
val key: K = reader.createKey()
val value: V = reader.createValue()

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class JdbcRDD[T: ClassTag](
}

override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
context.addOnCompleteCallback{ () => closeIfNeeded() }
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class NewHadoopRDD[K, V](
context.taskMetrics.inputMetrics = Some(inputMetrics)

// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => close())
context.addTaskCompletionListener(context => close())
var havePair = false
var finished = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ class DAGScheduler(
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.executeOnCompleteCallbacks()
taskContext.markTaskCompleted()
}
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private[spark] class ResultTask[T, U](
try {
func(context, rdd.iterator(partition, context))
} finally {
context.executeOnCompleteCallbacks()
context.markTaskCompleted()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask(
}
throw e
} finally {
context.executeOnCompleteCallbacks()
context.markTaskCompleted()
}
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
def kill(interruptThread: Boolean) {
_killed = true
if (context != null) {
context.interrupted = true
context.markInterrupted()
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import java.util.EventListener

import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi

/**
* :: DeveloperApi ::
*
* Listener providing a callback function to invoke when a task's execution completes.
*/
@DeveloperApi
trait TaskCompletionListener extends EventListener {
def onTaskCompletion(context: TaskContext)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util;

import org.apache.spark.TaskContext;


/**
* A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and
* TaskContext is Java friendly.
*/
public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {

@Override
public void onTaskCompletion(TaskContext context) {
context.isCompleted();
context.isInterrupted();
context.stageId();
context.partitionId();
context.runningLocally();
context.taskMetrics();
context.addTaskCompletionListener(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
val rdd = new RDD[String](sc, List()) {
override def getPartitions = Array[Partition](StubPartition(0))
override def compute(split: Partition, context: TaskContext) = {
context.addOnCompleteCallback(() => TaskContextSuite.completed = true)
context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
sys.error("failed")
}
}
Expand Down

0 comments on commit 655699f

Please sign in to comment.