diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java new file mode 100644 index 0000000000..09b8ce02bd --- /dev/null +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import scala.Function0; +import scala.Function1; +import scala.Unit; +import scala.collection.JavaConversions; + +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.util.TaskCompletionListener; +import org.apache.spark.util.TaskCompletionListenerException; + +/** +* :: DeveloperApi :: +* Contextual information about a task which can be read or mutated during execution. +*/ +@DeveloperApi +public class TaskContext implements Serializable { + + private int stageId; + private int partitionId; + private long attemptId; + private boolean runningLocally; + private TaskMetrics taskMetrics; + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + * @param taskMetrics performance metrics of the task + */ + @DeveloperApi + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally, + TaskMetrics taskMetrics) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = taskMetrics; + } + + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + */ + @DeveloperApi + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, + Boolean runningLocally) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + } + + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + */ + @DeveloperApi + public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = false; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + } + + private static ThreadLocal taskContext = + new ThreadLocal(); + + /** + * :: Internal API :: + * This is spark internal API, not intended to be called from user programs. + */ + public static void setTaskContext(TaskContext tc) { + taskContext.set(tc); + } + + public static TaskContext get() { + return taskContext.get(); + } + + /** + * :: Internal API :: + */ + public static void remove() { + taskContext.remove(); + } + + // List of callback functions to execute when the task completes. + private transient List onCompleteCallbacks = + new ArrayList(); + + // Whether the corresponding task has been killed. + private volatile Boolean interrupted = false; + + // Whether the task has completed. + private volatile Boolean completed = false; + + /** + * Checks whether the task has completed. + */ + public Boolean isCompleted() { + return completed; + } + + /** + * Checks whether the task has been killed. + */ + public Boolean isInterrupted() { + return interrupted; + } + + /** + * Add a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + *

+ * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { + onCompleteCallbacks.add(listener); + return this; + } + + /** + * Add a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situations - success, failure, or cancellation. + *

+ * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(final Function1 f) { + onCompleteCallbacks.add(new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(context); + } + }); + return this; + } + + /** + * Add a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. + * + * Deprecated: use addTaskCompletionListener + * + * @param f Callback function. + */ + @Deprecated + public void addOnCompleteCallback(final Function0 f) { + onCompleteCallbacks.add(new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(); + } + }); + } + + /** + * ::Internal API:: + * Marks the task as completed and triggers the listeners. + */ + public void markTaskCompleted() throws TaskCompletionListenerException { + completed = true; + List errorMsgs = new ArrayList(2); + // Process complete callbacks in the reverse order of registration + List revlist = + new ArrayList(onCompleteCallbacks); + Collections.reverse(revlist); + for (TaskCompletionListener tcl: revlist) { + try { + tcl.onTaskCompletion(this); + } catch (Throwable e) { + errorMsgs.add(e.getMessage()); + } + } + + if (!errorMsgs.isEmpty()) { + throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); + } + } + + /** + * ::Internal API:: + * Marks the task for interruption, i.e. cancellation. + */ + public void markInterrupted() { + interrupted = true; + } + + @Deprecated + /** Deprecated: use getStageId() */ + public int stageId() { + return stageId; + } + + @Deprecated + /** Deprecated: use getPartitionId() */ + public int partitionId() { + return partitionId; + } + + @Deprecated + /** Deprecated: use getAttemptId() */ + public long attemptId() { + return attemptId; + } + + @Deprecated + /** Deprecated: use getRunningLocally() */ + public boolean runningLocally() { + return runningLocally; + } + + public boolean getRunningLocally() { + return runningLocally; + } + + public int getStageId() { + return stageId; + } + + public int getPartitionId() { + return partitionId; + } + + public long getAttemptId() { + return attemptId; + } + + /** ::Internal API:: */ + public TaskMetrics taskMetrics() { + return taskMetrics; + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala deleted file mode 100644 index 51b3e4d5e0..0000000000 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} - - -/** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - * @param taskMetrics performance metrics of the task - */ -@DeveloperApi -class TaskContext( - val stageId: Int, - val partitionId: Int, - val attemptId: Long, - val runningLocally: Boolean = false, - private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends Serializable with Logging { - - @deprecated("use partitionId", "0.8.1") - def splitId = partitionId - - // List of callback functions to execute when the task completes. - @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] - - // Whether the corresponding task has been killed. - @volatile private var interrupted: Boolean = false - - // Whether the task has completed. - @volatile private var completed: Boolean = false - - /** Checks whether the task has completed. */ - def isCompleted: Boolean = completed - - /** Checks whether the task has been killed. */ - def isInterrupted: Boolean = interrupted - - // TODO: Also track whether the task has completed successfully or with exception. - - /** - * Add a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { - onCompleteCallbacks += listener - this - } - - /** - * Add a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - def addTaskCompletionListener(f: TaskContext => Unit): this.type = { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f(context) - } - this - } - - /** - * Add a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * @param f Callback function. - */ - @deprecated("use addTaskCompletionListener", "1.1.0") - def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f() - } - } - - /** Marks the task as completed and triggers the listeners. */ - private[spark] def markTaskCompleted(): Unit = { - completed = true - val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { listener => - try { - listener.onTaskCompletion(this) - } catch { - case e: Throwable => - errorMsgs += e.getMessage - logError("Error in TaskCompletionListener", e) - } - } - if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs) - } - } - - /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(): Unit = { - interrupted = true - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0e90caa5c9..ba712c9d77 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag]( * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ @DeveloperApi + @deprecated("use TaskContext.get", "1.2.0") def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b2774dfc47..32cf29ed14 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -634,12 +634,14 @@ class DAGScheduler( val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = - new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true) + new TaskContext(job.finalStage.id, job.partitions(0), 0, true) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() + TaskContext.remove() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 6aa0cca068..bf73f6f7bd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -45,7 +45,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + context = new TaskContext(stageId, partitionId, attemptId, false) + TaskContext.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { @@ -92,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + TaskContext.remove() + } } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b8c23d524e..4a07843544 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -776,7 +776,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics()); + TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 90dcadcffd..d735010d7c 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = true) + val context = new TaskContext(0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) }