diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java index 95fce96e71cbd..ccd0febe06659 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java @@ -29,7 +29,7 @@ */ public class TaskEventHandler { - // Listeners for each event type + /** Listeners for each event type */ private final Multimap, EventListener> listeners = HashMultimap.create(); public void subscribe(EventListener listener, Class eventType) { @@ -45,7 +45,7 @@ public void unsubscribe(EventListener listener, Class> operatorState; @@ -290,11 +298,11 @@ public Task(TaskDeploymentDescriptor tdd, this.inputGates[i] = gate; inputGatesById.put(gate.getConsumedResultId(), gate); } + + invokableHasBeenCanceled = new AtomicBoolean(false); // finally, create the executing thread, but do not start it executingThread = new Thread(TASK_THREADS_GROUP, this, taskNameWithSubtask); - - invokableHasBeenCanceled = new AtomicBoolean(false); } // ------------------------------------------------------------------------ @@ -646,9 +654,17 @@ else if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) { try { LOG.info("Freeing task resources for " + taskNameWithSubtask); + // stop the async dispatcher. + // copy dispatcher reference to stack, against concurrent release + ExecutorService dispatcher = this.asyncCallDispatcher; + if (dispatcher != null && !dispatcher.isShutdown()) { + dispatcher.shutdownNow(); + } + // free the network resources network.unregisterTask(this); + // free memory resources if (invokable != null) { memoryManager.releaseAll(invokable); } @@ -797,6 +813,7 @@ else if (current == ExecutionState.RUNNING) { Runnable canceler = new TaskCanceler(LOG, invokable, executingThread, taskNameWithSubtask); Thread cancelThread = new Thread(executingThread.getThreadGroup(), canceler, "Canceler for " + taskNameWithSubtask); + cancelThread.setDaemon(true); cancelThread.start(); } return; @@ -955,11 +972,49 @@ else if (partitionState == ExecutionState.CANCELED LOG.debug("Ignoring partition state notification for not running task."); } } - + + /** + * Utility method to dispatch an asynchronous call on the invokable. + * + * @param runnable The async call runnable. + * @param callName The name of the call, for logging purposes. + */ private void executeAsyncCallRunnable(Runnable runnable, String callName) { - Thread thread = new Thread(runnable, callName); - thread.setDaemon(true); - thread.start(); + // make sure the executor is initialized. lock against concurrent calls to this function + synchronized (this) { + if (isCanceledOrFailed()) { + return; + } + + // get ourselves a reference on the stack that cannot be concurrently modified + ExecutorService executor = this.asyncCallDispatcher; + if (executor == null) { + // first time use, initialize + executor = Executors.newSingleThreadExecutor( + new DispatherThreadFactory(TASK_THREADS_GROUP, "Async calls on " + taskNameWithSubtask)); + this.asyncCallDispatcher = executor; + + // double-check for execution state, and make sure we clean up after ourselves + // if we created the dispatcher while the task was concurrently canceled + if (isCanceledOrFailed()) { + executor.shutdown(); + asyncCallDispatcher = null; + return; + } + } + + LOG.debug("Invoking async call {} on task {}", callName, taskNameWithSubtask); + + try { + executor.submit(runnable); + } + catch (RejectedExecutionException e) { + // may be that we are concurrently canceled. if not, report that something is fishy + if (!isCanceledOrFailed()) { + throw new RuntimeException("Async call was rejected, even though the task was not canceled.", e); + } + } + } } // ------------------------------------------------------------------------ @@ -1051,7 +1106,7 @@ public void run() { executer.interrupt(); try { - executer.join(5000); + executer.join(10000); } catch (InterruptedException e) { // we can ignore this diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java new file mode 100644 index 0000000000000..618c01fcbb5db --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.taskmanager; + +import akka.actor.ActorSystem; +import akka.actor.Props; +import akka.actor.UntypedActor; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.akka.AkkaUtils; +import org.apache.flink.runtime.blob.BlobKey; +import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; +import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; +import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.filecache.FileCache; +import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; +import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.jobgraph.tasks.CheckpointCommittingOperator; +import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator; +import org.apache.flink.runtime.memorymanager.MemoryManager; + +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.util.SerializedValue; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import scala.concurrent.duration.FiniteDuration; + +import java.util.Collections; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TaskAsyncCallTest { + + private static final int NUM_CALLS = 1000; + + private static ActorSystem actorSystem; + + private static OneShotLatch awaitLatch; + private static OneShotLatch triggerLatch; + + // ------------------------------------------------------------------------ + // Init & Shutdown + // ------------------------------------------------------------------------ + + @BeforeClass + public static void startActorSystem() { + actorSystem = AkkaUtils.createLocalActorSystem(new Configuration()); + } + + @AfterClass + public static void shutdown() { + actorSystem.shutdown(); + actorSystem.awaitTermination(); + } + + @Before + public void createQueuesAndActors() { + awaitLatch = new OneShotLatch(); + triggerLatch = new OneShotLatch(); + } + + + // ------------------------------------------------------------------------ + // Tests + // ------------------------------------------------------------------------ + + @Test + public void testCheckpointCallsInOrder() { + try { + Task task = createTask(); + task.startTaskThread(); + + awaitLatch.await(); + + for (int i = 1; i <= NUM_CALLS; i++) { + task.triggerCheckpointBarrier(i, 156865867234L); + } + + triggerLatch.await(); + + assertFalse(task.isCanceledOrFailed()); + assertEquals(ExecutionState.RUNNING, task.getExecutionState()); + + task.cancelExecution(); + task.getExecutingThread().join(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMixedAsyncCallsInOrder() { + try { + Task task = createTask(); + task.startTaskThread(); + + awaitLatch.await(); + + for (int i = 1; i <= NUM_CALLS; i++) { + task.triggerCheckpointBarrier(i, 156865867234L); + task.confirmCheckpoint(i, null); + } + + triggerLatch.await(); + + assertFalse(task.isCanceledOrFailed()); + assertEquals(ExecutionState.RUNNING, task.getExecutionState()); + + task.cancelExecution(); + task.getExecutingThread().join(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + private static Task createTask() { + + LibraryCacheManager libCache = mock(LibraryCacheManager.class); + when(libCache.getClassLoader(any(JobID.class))).thenReturn(ClassLoader.getSystemClassLoader()); + + ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); + ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); + NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); + when(networkEnvironment.getPartitionManager()).thenReturn(partitionManager); + when(networkEnvironment.getPartitionConsumableNotifier()).thenReturn(consumableNotifier); + when(networkEnvironment.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + + TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor( + new JobID(), new JobVertexID(), new ExecutionAttemptID(), + "Test Task", 0, 1, + new Configuration(), new Configuration(), + CheckpointsInOrderInvokable.class.getName(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + 0); + + return new Task(tdd, + mock(MemoryManager.class), + mock(IOManager.class), + networkEnvironment, + mock(BroadcastVariableManager.class), + actorSystem.actorOf(Props.create(BlackHoleActor.class)), + actorSystem.actorOf(Props.create(BlackHoleActor.class)), + new FiniteDuration(60, TimeUnit.SECONDS), + libCache, + mock(FileCache.class)); + } + + public static class CheckpointsInOrderInvokable extends AbstractInvokable + implements CheckpointedOperator, CheckpointCommittingOperator { + + private volatile long lastCheckpointId = 0; + + private volatile Exception error; + + @Override + public void registerInputOutput() {} + + @Override + public void invoke() throws Exception { + awaitLatch.trigger(); + + // wait forever (until canceled) + synchronized (this) { + while (error == null && lastCheckpointId < NUM_CALLS) { + wait(); + } + } + + triggerLatch.trigger(); + if (error != null) { + throw error; + } + } + + @Override + public void triggerCheckpoint(long checkpointId, long timestamp) throws Exception { + lastCheckpointId++; + if (checkpointId == lastCheckpointId) { + if (lastCheckpointId == NUM_CALLS) { + triggerLatch.trigger(); + } + } + else if (this.error == null) { + this.error = new Exception("calls out of order"); + synchronized (this) { + notifyAll(); + } + } + } + + @Override + public void confirmCheckpoint(long checkpointId, SerializedValue> state) throws Exception { + if (checkpointId != lastCheckpointId && this.error == null) { + this.error = new Exception("calls out of order"); + synchronized (this) { + notifyAll(); + } + } + } + } + + public static class BlackHoleActor extends UntypedActor { + + @Override + public void onReceive(Object message) {} + } +}