Skip to content

Commit

Permalink
Revert "Revert "[serve][xlang]Support deploying Python deployment fro…
Browse files Browse the repository at this point in the history
…m Java. …" (ray-project#27945)
  • Loading branch information
yaxife authored Aug 19, 2022
1 parent a6b7189 commit af488e1
Show file tree
Hide file tree
Showing 20 changed files with 354 additions and 95 deletions.
1 change: 1 addition & 0 deletions java/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ define_java_module(
"@maven//:com_google_code_gson_gson",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:commons_io_commons_io",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_apache_httpcomponents_client5_httpclient5",
"@maven//:org_apache_httpcomponents_client5_httpclient5_fluent",
Expand Down
4 changes: 4 additions & 0 deletions java/serve/src/main/java/io/ray/serve/api/Serve.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.ray.serve.deployment.DeploymentRoute;
import io.ray.serve.exception.RayServeException;
import io.ray.serve.generated.ActorNameList;
import io.ray.serve.poll.LongPollClientFactory;
import io.ray.serve.replica.ReplicaContext;
import io.ray.serve.util.CollectionUtil;
import io.ray.serve.util.CommonUtil;
Expand Down Expand Up @@ -143,7 +144,10 @@ public static void shutdown() {
}

client.shutdown();
LongPollClientFactory.stop();
LongPollClientFactory.clearAllCache();
setGlobalClient(null);
setInternalReplicaContext(null);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,15 @@ public byte[] toProtoBytes() {
io.ray.serve.generated.DeploymentConfig.newBuilder()
.setNumReplicas(numReplicas)
.setMaxConcurrentQueries(maxConcurrentQueries)
.setUserConfig(
ByteString.copyFrom(
MessagePackSerializer.encode(userConfig).getKey())) // TODO-xlang
.setGracefulShutdownWaitLoopS(gracefulShutdownWaitLoopS)
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS)
.setIsCrossLanguage(isCrossLanguage)
.setDeploymentLanguage(deploymentLanguage);
if (null != userConfig) {
builder.setUserConfig(ByteString.copyFrom(MessagePackSerializer.encode(userConfig).getKey()));
}
if (null != autoscalingConfig) {
builder.setAutoscalingConfig(autoscalingConfig.toProto());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,14 @@ public static ReplicaConfig fromProto(io.ray.serve.generated.ReplicaConfig proto
if (proto == null) {
return null;
}
Object[] initArgs = null;
if (0 != proto.getInitArgs().toByteArray().length) {
initArgs = MessagePackSerializer.decode(proto.getInitArgs().toByteArray(), null);
}
ReplicaConfig replicaConfig =
new ReplicaConfig(
proto.getDeploymentDefName(),
MessagePackSerializer.decode(proto.getInitArgs().toByteArray(), null), // TODO-xlang
initArgs,
gson.fromJson(proto.getRayActorOptions(), Map.class));
return replicaConfig;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ public DeploymentCreator options() {
.setGracefulShutdownWaitLoopS(this.config.getGracefulShutdownWaitLoopS())
.setGracefulShutdownTimeoutS(this.config.getGracefulShutdownTimeoutS())
.setHealthCheckPeriodS(this.config.getHealthCheckPeriodS())
.setHealthCheckTimeoutS(this.config.getHealthCheckTimeoutS());
.setHealthCheckTimeoutS(this.config.getHealthCheckTimeoutS())
.setLanguage(this.config.getDeploymentLanguage());
}

public String getDeploymentDef() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public class DeploymentCreator {

private boolean routed;

private DeploymentLanguage deploymentLanguage;
private DeploymentLanguage language;

public Deployment create() {

Expand All @@ -97,7 +97,7 @@ public Deployment create() {
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS)
.setDeploymentLanguage(deploymentLanguage);
.setDeploymentLanguage(language);

return new Deployment(
deploymentDef,
Expand Down Expand Up @@ -246,11 +246,12 @@ public DeploymentCreator setHealthCheckTimeoutS(Double healthCheckTimeoutS) {
return this;
}

public DeploymentLanguage getDeploymentLanguage() {
return deploymentLanguage;
public DeploymentLanguage getLanguage() {
return language;
}

public void setDeploymentLanguage(DeploymentLanguage deploymentLanguage) {
this.deploymentLanguage = deploymentLanguage;
public DeploymentCreator setLanguage(DeploymentLanguage language) {
this.language = language;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ public static synchronized void stop() {
scheduledExecutorService.shutdown();
}
inited = false;
LOGGER.info("LongPollClient was shopped.");
LOGGER.info("LongPollClient was stopped.");
}

public static boolean isInitialized() {
Expand Down
87 changes: 60 additions & 27 deletions java/serve/src/main/java/io/ray/serve/router/ReplicaSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,28 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.function.PyActorMethod;
import io.ray.runtime.metric.Gauge;
import io.ray.runtime.metric.Metrics;
import io.ray.runtime.metric.TagKey;
import io.ray.serve.api.Serve;
import io.ray.serve.common.Constants;
import io.ray.serve.deployment.Deployment;
import io.ray.serve.exception.RayServeException;
import io.ray.serve.generated.ActorNameList;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.metrics.RayServeMetrics;
import io.ray.serve.replica.RayServeWrappedReplica;
import io.ray.serve.util.CollectionUtil;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
Expand All @@ -31,7 +38,14 @@ public class ReplicaSet {

private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class);

private final Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> inFlightQueries;
// The key is the name of the actor, and the value is a set of all flight queries objectrefs of
// the actor.
private final Map<String, Set<ObjectRef<Object>>> inFlightQueries;

// Map the actor name to the handle of the actor.
private final Map<String, BaseActorHandle> allActorHandles;

private DeploymentLanguage language;

private AtomicInteger numQueuedQueries = new AtomicInteger();

Expand All @@ -41,6 +55,16 @@ public class ReplicaSet {

public ReplicaSet(String deploymentName) {
this.inFlightQueries = new ConcurrentHashMap<>();
this.allActorHandles = new ConcurrentHashMap<>();
try {
Deployment deployment = Serve.getDeployment(deploymentName);
this.language = deployment.getConfig().getDeploymentLanguage();
} catch (Exception e) {
LOGGER.warn(
"Failed to get language from controller. Set it to Java as default value. The exception is ",
e);
this.language = DeploymentLanguage.JAVA;
}
RayServeMetrics.execute(
() ->
this.numQueuedQueriesGauge =
Expand All @@ -54,26 +78,26 @@ public ReplicaSet(String deploymentName) {

@SuppressWarnings("unchecked")
public synchronized void updateWorkerReplicas(Object actorSet) {
List<String> actorNames = ((ActorNameList) actorSet).getNamesList();
Set<ActorHandle<RayServeWrappedReplica>> workerReplicas = new HashSet<>();
if (!CollectionUtil.isEmpty(actorNames)) {
actorNames.forEach(
name ->
workerReplicas.add(
(ActorHandle<RayServeWrappedReplica>)
Ray.getActor(name, Constants.SERVE_NAMESPACE).get()));
}

Set<ActorHandle<RayServeWrappedReplica>> added =
new HashSet<>(Sets.difference(workerReplicas, inFlightQueries.keySet()));
Set<ActorHandle<RayServeWrappedReplica>> removed =
new HashSet<>(Sets.difference(inFlightQueries.keySet(), workerReplicas));

added.forEach(actorHandle -> inFlightQueries.put(actorHandle, Sets.newConcurrentHashSet()));
removed.forEach(inFlightQueries::remove);

if (added.size() > 0 || removed.size() > 0) {
LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size());
if (null != actorSet) {
Set<String> actorNameSet = new HashSet<>(((ActorNameList) actorSet).getNamesList());
Set<String> added = new HashSet<>(Sets.difference(actorNameSet, inFlightQueries.keySet()));
Set<String> removed = new HashSet<>(Sets.difference(inFlightQueries.keySet(), actorNameSet));
added.forEach(
name -> {
Optional<BaseActorHandle> handleOptional =
Ray.getActor(name, Constants.SERVE_NAMESPACE);
if (handleOptional.isPresent()) {
allActorHandles.put(name, handleOptional.get());
inFlightQueries.put(name, Sets.newConcurrentHashSet());
} else {
LOGGER.warn("Can not get actor handle. actor name is {}", name);
}
});
removed.forEach(inFlightQueries::remove);
removed.forEach(allActorHandles::remove);
if (added.size() > 0 || removed.size() > 0) {
LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size());
}
}
hasPullReplica = true;
}
Expand Down Expand Up @@ -121,20 +145,29 @@ private ObjectRef<Object> tryAssignReplica(Query query) {
}
loopCount++;
}
List<ActorHandle<RayServeWrappedReplica>> handles = new ArrayList<>(inFlightQueries.keySet());
List<BaseActorHandle> handles = new ArrayList<>(allActorHandles.values());
if (CollectionUtil.isEmpty(handles)) {
throw new RayServeException("ReplicaSet found no replica.");
}
int randomIndex = RandomUtils.nextInt(0, handles.size());
ActorHandle<RayServeWrappedReplica> replica =
BaseActorHandle replica =
handles.get(randomIndex); // TODO controll concurrency using maxConcurrentQueries
LOGGER.debug("Assigned query {} to replica {}.", query.getMetadata().getRequestId(), replica);
return replica
.task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs())
.remote();
if (language == DeploymentLanguage.PYTHON) {
return ((PyActorHandle) replica)
.task(
PyActorMethod.of("handle_request_from_java"),
query.getMetadata().toByteArray(),
query.getArgs())
.remote();
} else {
return ((ActorHandle<RayServeWrappedReplica>) replica)
.task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs())
.remote();
}
}

public Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> getInFlightQueries() {
public Map<String, Set<ObjectRef<Object>>> getInFlightQueries() {
return inFlightQueries;
}
}
20 changes: 20 additions & 0 deletions java/serve/src/main/resources/test_python_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This file is used by CrossLanguageDeploymentTest.java to test cross-language
# invocation.
from ray import serve


def echo_server(v):
return v


@serve.deployment
class Counter(object):
def __init__(self, value):
self.value = int(value)

def increase(self, delta):
self.value += int(delta)
return str(self.value)

def reconfigure(self, value_str):
self.value = int(value_str)
4 changes: 0 additions & 4 deletions java/serve/src/test/java/io/ray/serve/BaseServeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import io.ray.serve.api.Serve;
import io.ray.serve.api.ServeControllerClient;
import io.ray.serve.config.RayServeConfig;
import io.ray.serve.poll.LongPollClientFactory;
import java.lang.reflect.Method;
import java.util.Map;
import org.slf4j.Logger;
Expand Down Expand Up @@ -38,8 +37,5 @@ public void tearDownBase() {
} catch (Exception e) {
LOGGER.error("ray shutdown error", e);
}
LongPollClientFactory.stop();
LongPollClientFactory.clearAllCache();
Serve.setInternalReplicaContext(null);
}
}
4 changes: 4 additions & 0 deletions java/serve/src/test/java/io/ray/serve/BaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.ray.api.Ray;
import io.ray.serve.api.Serve;
import io.ray.serve.common.Constants;
import io.ray.serve.poll.LongPollClientFactory;

public class BaseTest {

Expand All @@ -18,6 +19,9 @@ protected void init() {
}

protected void shutdown() {
LongPollClientFactory.stop();
LongPollClientFactory.clearAllCache();
Serve.setInternalReplicaContext(null);
if (!previousInited) {
Ray.shutdown();
}
Expand Down
Loading

0 comments on commit af488e1

Please sign in to comment.