Skip to content

Commit

Permalink
[FLINK-24550][rpc] Use ContextClassLoader for message deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
zentol authored Nov 4, 2021
1 parent 4d9de6c commit 9b1529c
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ private Object invokeRpc(Method method, Object[] args) throws Exception {
final CompletableFuture<?> resultFuture =
ask(rpcInvocation, futureTimeout)
.thenApply(
resultValue -> deserializeValueIfNeeded(resultValue, method));
resultValue ->
deserializeValueIfNeeded(
resultValue, method, flinkClassLoader));

final CompletableFuture<Object> completableFuture = new CompletableFuture<>();
resultFuture.whenComplete(
Expand Down Expand Up @@ -414,11 +416,11 @@ public CompletableFuture<Void> getTerminationFuture() {
return terminationFuture;
}

static Object deserializeValueIfNeeded(Object o, Method method) {
private static Object deserializeValueIfNeeded(
Object o, Method method, ClassLoader flinkClassLoader) {
if (o instanceof AkkaRpcSerializedValue) {
try {
return ((AkkaRpcSerializedValue) o)
.deserializeValue(AkkaInvocationHandler.class.getClassLoader());
return ((AkkaRpcSerializedValue) o).deserializeValue(flinkClassLoader);
} catch (IOException | ClassNotFoundException e) {
throw new CompletionException(
new RpcException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package org.apache.flink.runtime.rpc.akka;

import org.apache.flink.api.common.time.Time;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.concurrent.akka.AkkaFutureUtils;
import org.apache.flink.runtime.rpc.RpcEndpoint;
import org.apache.flink.runtime.rpc.RpcGateway;
import org.apache.flink.runtime.rpc.RpcService;
import org.apache.flink.runtime.rpc.RpcUtils;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.concurrent.FutureUtils;

Expand All @@ -31,6 +33,11 @@
import org.junit.Before;
import org.junit.Test;

import javax.annotation.Nullable;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Arrays;
Expand All @@ -39,6 +46,7 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;

import static org.apache.flink.runtime.concurrent.akka.ClassLoadingUtils.runWithContextClassLoader;
import static org.hamcrest.CoreMatchers.either;
Expand Down Expand Up @@ -82,6 +90,8 @@ public void setup() {
actorSystem,
AkkaRpcServiceConfiguration.defaultConfiguration(),
pretendFlinkClassLoader);

PickyObject.classLoaderAssertion = this::assertIsFlinkClassLoader;
}

@After
Expand Down Expand Up @@ -280,6 +290,50 @@ public void testAkkaRpcInvocationHandler_RPCFutureCompletedWithFlinkContextClass
}
}

@Test
public void testAkkaRpcInvocationHandler_ContextClassLoaderUsedForDeserialization()
throws Exception {
// setup 2 actor systems and rpc services that support remote connections (for which RPCs go
// through serialization)
final AkkaRpcService serverAkkaRpcService =
new AkkaRpcService(
AkkaUtils.createActorSystem(
"serverActorSystem",
AkkaUtils.getAkkaConfig(
new Configuration(), new HostAndPort("localhost", 0))),
AkkaRpcServiceConfiguration.defaultConfiguration());

final AkkaRpcService clientAkkaRpcService =
new AkkaRpcService(
AkkaUtils.createActorSystem(
"clientActorSystem",
AkkaUtils.getAkkaConfig(
new Configuration(), new HostAndPort("localhost", 0))),
AkkaRpcServiceConfiguration.defaultConfiguration(),
pretendFlinkClassLoader);

try {
final TestEndpoint rpcEndpoint =
new TestEndpoint(serverAkkaRpcService, new PickyObject());
rpcEndpoint.start();

final TestEndpointGateway rpcGateway =
rpcEndpoint.getSelfGateway(TestEndpointGateway.class);

final TestEndpointGateway connect =
clientAkkaRpcService
.connect(rpcGateway.getAddress(), TestEndpointGateway.class)
.get();

// if the wrong classloader is used the deserialization fails and get() throws an
// exception
connect.getPickyObject().get();
} finally {
RpcUtils.terminateRpcService(clientAkkaRpcService, TIMEOUT);
RpcUtils.terminateRpcService(serverAkkaRpcService, TIMEOUT);
}
}

@Test
public void testSupervisorActor_TerminationFutureCompletedWithFlinkContextClassLoader()
throws Exception {
Expand Down Expand Up @@ -315,6 +369,20 @@ private interface TestEndpointGateway extends RpcGateway {
CompletableFuture<ClassLoader> doRunAsync();

void doSomethingWithoutReturningAnything();

CompletableFuture<PickyObject> getPickyObject();
}

/**
* An object that only allows deserialiation if its favorite ContextClassLoader is doing it.
*/
private static class PickyObject implements Serializable {
static Consumer<ClassLoader> classLoaderAssertion = null;

private void readObject(ObjectInputStream aInputStream)
throws ClassNotFoundException, IOException {
classLoaderAssertion.accept(Thread.currentThread().getContextClassLoader());
}
}

private static class TestEndpoint extends RpcEndpoint implements TestEndpointGateway {
Expand All @@ -325,8 +393,15 @@ private static class TestEndpoint extends RpcEndpoint implements TestEndpointGat
new CompletableFuture<>();
private final CompletableFuture<Void> rpcResponseFuture = new CompletableFuture<>();

@Nullable private final PickyObject pickyObject;

protected TestEndpoint(RpcService rpcService) {
this(rpcService, null);
}

protected TestEndpoint(RpcService rpcService, @Nullable PickyObject pickyObject) {
super(rpcService);
this.pickyObject = pickyObject;
}

@Override
Expand Down Expand Up @@ -368,6 +443,11 @@ public void doSomethingWithoutReturningAnything() {
voidOperationClassLoader.complete(Thread.currentThread().getContextClassLoader());
}

@Override
public CompletableFuture<PickyObject> getPickyObject() {
return CompletableFuture.completedFuture(pickyObject);
}

public void completeRPCFuture() {
rpcResponseFuture.complete(null);
}
Expand Down

0 comments on commit 9b1529c

Please sign in to comment.