Skip to content

Commit

Permalink
[FLINK-2543] [core] Fix user object deserialization for file-based st…
Browse files Browse the repository at this point in the history
…ate handles.

Send exceptions from JM --> JC in serialized form.
Exceptions send from the JobManager to the JobClient were relying on
Akka's JavaSerialization, which does not have access to the user code classloader.

This closes apache#1048
  • Loading branch information
rmetzger authored and StephanEwen committed Aug 30, 2015
1 parent 554b77b commit bf8c8e5
Show file tree
Hide file tree
Showing 41 changed files with 519 additions and 198 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.flink.api.common.Program;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.client.SerializedJobExecutionResult;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.optimizer.DataStatistics;
Expand Down Expand Up @@ -176,8 +175,7 @@ public JobExecutionResult executePlan(Plan plan) throws Exception {
JobGraph jobGraph = jgg.compileJobGraph(op);

boolean sysoutPrint = isPrintingStatusDuringExecution();
SerializedJobExecutionResult result = flink.submitJobAndWait(jobGraph,sysoutPrint);
return result.toJobExecutionResult(ClassLoader.getSystemClassLoader());
return flink.submitJobAndWait(jobGraph, sysoutPrint);
}
finally {
if (shutDownAtEnd) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.client.JobClient;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.client.SerializedJobExecutionResult;
import org.apache.flink.runtime.instance.ActorGateway;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobmanager.JobManager;
Expand Down Expand Up @@ -425,15 +424,8 @@ public JobSubmissionResult run(JobGraph jobGraph, boolean wait) throws ProgramIn

try{
if (wait) {
SerializedJobExecutionResult result = JobClient.submitJobAndWait(actorSystem,
jobManagerGateway, jobGraph, timeout, printStatusDuringExecution);
try {
return result.toJobExecutionResult(this.userCodeClassLoader);
}
catch (Exception e) {
throw new ProgramInvocationException(
"Failed to deserialize the accumulator result after the job execution", e);
}
return JobClient.submitJobAndWait(actorSystem,
jobManagerGateway, jobGraph, timeout, printStatusDuringExecution, this.userCodeClassLoader);
}
else {
JobClient.submitJobDetached(jobManagerGateway, jobGraph, timeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class InstantiationUtil {
* user-code ClassLoader.
*
*/
private static class ClassLoaderObjectInputStream extends ObjectInputStream {
public static class ClassLoaderObjectInputStream extends ObjectInputStream {
private ClassLoader classLoader;

private static final HashMap<String, Class<?>> primitiveClasses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import akka.actor.Address;
import akka.actor.PoisonPill;
import akka.actor.Props;
import akka.actor.Status;
import akka.pattern.Patterns;
import akka.util.Timeout;
import com.google.common.base.Preconditions;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
Expand All @@ -36,6 +36,7 @@
import org.apache.flink.runtime.messages.JobClientMessages;
import org.apache.flink.runtime.messages.JobManagerMessages;

import org.apache.flink.runtime.util.SerializedThrowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -64,8 +65,7 @@ public class JobClient {
public static ActorSystem startJobClientActorSystem(Configuration config)
throws IOException {
LOG.info("Starting JobClient actor system");
Option<Tuple2<String, Object>> remoting =
new Some<Tuple2<String, Object>>(new Tuple2<String, Object>("", 0));
Option<Tuple2<String, Object>> remoting = new Some<>(new Tuple2<String, Object>("", 0));

// start a remote actor system to listen on an arbitrary port
ActorSystem system = AkkaUtils.createActorSystem(config, remoting);
Expand Down Expand Up @@ -123,12 +123,13 @@ public static InetSocketAddress getJobManagerAddress(Configuration config) throw
* @throws org.apache.flink.runtime.client.JobExecutionException Thrown if the job
* execution fails.
*/
public static SerializedJobExecutionResult submitJobAndWait(
public static JobExecutionResult submitJobAndWait(
ActorSystem actorSystem,
ActorGateway jobManagerGateway,
JobGraph jobGraph,
FiniteDuration timeout,
boolean sysoutLogUpdates) throws JobExecutionException {
boolean sysoutLogUpdates,
ClassLoader userCodeClassloader) throws JobExecutionException {

Preconditions.checkNotNull(actorSystem, "The actorSystem must not be null.");
Preconditions.checkNotNull(jobManagerGateway, "The jobManagerGateway must not be null.");
Expand Down Expand Up @@ -160,26 +161,30 @@ public static SerializedJobExecutionResult submitJobAndWait(

SerializedJobExecutionResult result = ((JobManagerMessages.JobResultSuccess) answer).result();
if (result != null) {
return result;
return result.toJobExecutionResult(userCodeClassloader);
} else {
throw new Exception("Job was successfully executed but result contained a null JobExecutionResult.");
}
} else if (answer instanceof Status.Failure) {
throw ((Status.Failure) answer).cause();
} else {
throw new Exception("Unknown answer after submitting the job: " + answer);
}
}
catch (JobExecutionException e) {
throw e;
if(e.getCause() instanceof SerializedThrowable) {
SerializedThrowable serializedThrowable = (SerializedThrowable)e.getCause();
Throwable deserialized = serializedThrowable.deserializeError(userCodeClassloader);
throw new JobExecutionException(jobGraph.getJobID(), "Job execution failed " + deserialized.getMessage(), deserialized);
} else {
throw e;
}
}
catch (TimeoutException e) {
throw new JobTimeoutException(jobGraph.getJobID(), "Timeout while waiting for JobManager answer. " +
"Job time exceeded " + AkkaUtils.INF_TIMEOUT(), e);
}
catch (Throwable t) {
catch (Throwable throwable) {
throw new JobExecutionException(jobGraph.getJobID(),
"Communication with JobManager failed: " + t.getMessage(), t);
"Communication with JobManager failed: " + throwable.getMessage(), throwable);
}
finally {
// failsafe shutdown of the client actor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.flink.runtime.messages.ExecutionGraphMessages;
import org.apache.flink.runtime.taskmanager.TaskExecutionState;
import org.apache.flink.runtime.util.SerializableObject;
import org.apache.flink.runtime.util.SerializedThrowable;
import org.apache.flink.util.SerializedValue;
import org.apache.flink.util.ExceptionUtils;

Expand Down Expand Up @@ -1028,8 +1029,12 @@ public void registerExecutionListener(ActorGateway listener) {

private void notifyJobStatusChange(JobStatus newState, Throwable error) {
if (jobStatusListenerActors.size() > 0) {
SerializedThrowable serializedThrowable = null;
if(error != null) {
serializedThrowable = new SerializedThrowable(error);
}
ExecutionGraphMessages.JobStatusChanged message =
new ExecutionGraphMessages.JobStatusChanged(jobID, newState, System.currentTimeMillis(), error);
new ExecutionGraphMessages.JobStatusChanged(jobID, newState, System.currentTimeMillis(), serializedThrowable);

for (ActorGateway listener: jobStatusListenerActors) {
listener.tell(message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ public interface OperatorStateCarrier<T extends StateHandle<?>> {
*
* @param stateHandle The handle to the state.
*/
public void setInitialState(T stateHandle) throws Exception;
void setInitialState(T stateHandle) throws Exception;

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.apache.flink.runtime.state;

import org.apache.flink.util.InstantiationUtil;

import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
Expand Down Expand Up @@ -56,9 +58,9 @@ public ByteStreamStateHandle(Serializable state) {
protected abstract InputStream getInputStream() throws Exception;

@Override
public Serializable getState() throws Exception {
public Serializable getState(ClassLoader userCodeClassLoader) throws Exception {
if (!stateFetched()) {
ObjectInputStream stream = new ObjectInputStream(getInputStream());
ObjectInputStream stream = new InstantiationUtil.ClassLoaderObjectInputStream(getInputStream(), userCodeClassLoader);
try {
state = (Serializable) stream.readObject();
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public LocalStateHandle(T state) {
}

@Override
public T getState() {
public T getState(ClassLoader userCodeClassLoader) {
// The object has been deserialized correctly before
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
/**
* Wrapper for storing the handles for each state in a partitioned form. It can
* be used to repartition the state before re-injecting to the tasks.
*
*
* TODO: This class needs testing!
*/
public class PartitionedStateHandle implements
StateHandle<Map<Serializable, StateHandle<Serializable>>> {
Expand All @@ -38,7 +39,7 @@ public PartitionedStateHandle(Map<Serializable, StateHandle<Serializable>> handl
}

@Override
public Map<Serializable, StateHandle<Serializable>> getState() throws Exception {
public Map<Serializable, StateHandle<Serializable>> getState(ClassLoader userCodeClassLoader) throws Exception {
return handles;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
public interface StateHandle<T> extends Serializable {

/**
* This retrieves and return the state represented by the handle.
*
* This retrieves and return the state represented by the handle.
*
* @param userCodeClassLoader Class loader for deserializing user code specific classes
*
* @return The state represented by the handle.
* @throws java.lang.Exception Thrown, if the state cannot be fetched.
*/
T getState() throws Exception;
T getState(ClassLoader userCodeClassLoader) throws Exception;

/**
* Discards the state referred to by this handle, to free up resources in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@

package org.apache.flink.runtime.taskmanager;

import java.util.Arrays;

import org.apache.flink.runtime.accumulators.AccumulatorSnapshot;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.api.common.JobID;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.runtime.util.SerializedThrowable;

/**
* This class represents an update about a task's execution state.
Expand All @@ -47,11 +44,7 @@ public class TaskExecutionState implements java.io.Serializable {

private final ExecutionState executionState;

private final byte[] serializedError;

// The exception must not be (de)serialized with the class, as its
// class may not be part of the system class loader.
private transient Throwable cachedError;
private final SerializedThrowable throwable;

/** Serialized flink and user-defined accumulators */
private final AccumulatorSnapshot accumulators;
Expand Down Expand Up @@ -104,49 +97,19 @@ public TaskExecutionState(JobID jobID, ExecutionAttemptID executionId,
ExecutionState executionState, Throwable error,
AccumulatorSnapshot accumulators) {


if (jobID == null || executionId == null || executionState == null) {
if (jobID == null || executionId == null || executionState == null) {
throw new NullPointerException();
}

this.jobID = jobID;
this.executionId = executionId;
this.executionState = executionState;
this.cachedError = error;
this.accumulators = accumulators;

if (error != null) {
byte[] serializedError;
try {
serializedError = InstantiationUtil.serializeObject(error);
}
catch (Throwable t) {
// could not serialize exception. send the stringified version instead
try {
this.cachedError = new Exception(ExceptionUtils.stringifyException(error));
serializedError = InstantiationUtil.serializeObject(this.cachedError);
}
catch (Throwable tt) {
// seems like we cannot do much to report the actual exception
// report a placeholder instead
try {
this.cachedError = new Exception("Cause is a '" + error.getClass().getName()
+ "' (failed to serialize or stringify)");
serializedError = InstantiationUtil.serializeObject(this.cachedError);
}
catch (Throwable ttt) {
// this should never happen unless the JVM is fubar.
// we just report the state without the error
this.cachedError = null;
serializedError = null;
}
}
}
this.serializedError = serializedError;
}
else {
this.serializedError = null;
if(error != null) {
this.throwable = new SerializedThrowable(error);
} else {
this.throwable = null;
}
this.accumulators = accumulators;
}

// --------------------------------------------------------------------------------------------
Expand All @@ -160,19 +123,11 @@ public TaskExecutionState(JobID jobID, ExecutionAttemptID executionId,
* job this update refers to.
*/
public Throwable getError(ClassLoader usercodeClassloader) {
if (this.serializedError == null) {
if (this.throwable == null) {
return null;
} else {
return throwable.deserializeError(usercodeClassloader);
}

if (this.cachedError == null) {
try {
cachedError = (Throwable) InstantiationUtil.deserializeObject(this.serializedError, usercodeClassloader);
}
catch (Exception e) {
throw new RuntimeException("Error while deserializing the attached exception", e);
}
}
return this.cachedError;
}

/**
Expand Down Expand Up @@ -218,8 +173,8 @@ public boolean equals(Object obj) {
return other.jobID.equals(this.jobID) &&
other.executionId.equals(this.executionId) &&
other.executionState == this.executionState &&
(other.serializedError == null ? this.serializedError == null :
(this.serializedError != null && Arrays.equals(this.serializedError, other.serializedError)));
(other.throwable == null ? this.throwable == null :
(this.throwable != null && throwable.equals(other.throwable) ));
}
else {
return false;
Expand All @@ -235,7 +190,6 @@ public int hashCode() {
public String toString() {
return String.format("TaskState jobId=%s, executionId=%s, state=%s, error=%s",
jobID, executionId, executionState,
cachedError == null ? (serializedError == null ? "(null)" : "(serialized)")
: (cachedError.getClass().getName() + ": " + cachedError.getMessage()));
throwable == null ? "(null)" : throwable.toString());
}
}
Loading

0 comments on commit bf8c8e5

Please sign in to comment.