Skip to content

Commit

Permalink
[client] Use application specific error code to propagate ray errors (r…
Browse files Browse the repository at this point in the history
…ay-project#18278)

* Raise decoded exception if generated by grpc lib

* Switch to missing client_id error to FAILED_PRECONDITION

* switch to ABORTED

* fix comment

* fix decode_exception comment
  • Loading branch information
ckw017 authored Sep 10, 2021
1 parent 3f89f35 commit 6f94d0f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
2 changes: 1 addition & 1 deletion python/ray/util/client/server/proxier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _get_client_id_from_context(context: Any) -> str:
client_id = metadata.get("client_id") or ""
if client_id == "":
logger.error("Client connecting with no client_id")
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
return client_id


Expand Down
5 changes: 4 additions & 1 deletion python/ray/util/client/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,10 @@ def unify_and_track_outputs(self, output, client_id):
def return_exception_in_context(err, context):
if context is not None:
context.set_details(encode_exception(err))
context.set_code(grpc.StatusCode.INTERNAL)
# Note: https://grpc.github.io/grpc/core/md_doc_statuscodes.html
# ABORTED used here since it should never be generated by the
# grpc lib -- this way we know the error was generated by ray logic
context.set_code(grpc.StatusCode.ABORTED)


def encode_exception(exception) -> str:
Expand Down
23 changes: 15 additions & 8 deletions python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def connection_info(self):
try:
data = self.data_client.ConnectionInfo()
except grpc.RpcError as e:
raise decode_exception(e.details())
raise decode_exception(e)
return {
"num_clients": data.num_clients,
"python_version": data.python_version,
Expand Down Expand Up @@ -241,7 +241,7 @@ def _get(self, ref: List[ClientObjectRef], timeout: float):
try:
data = self.data_client.GetObject(req)
except grpc.RpcError as e:
raise decode_exception(e.details())
raise decode_exception(e)
if not data.valid:
try:
err = cloudpickle.loads(data.error)
Expand Down Expand Up @@ -337,7 +337,7 @@ def _call_schedule_for_task(
try:
ticket = self.server.Schedule(task, metadata=self.metadata)
except grpc.RpcError as e:
raise decode_exception(e.details())
raise decode_exception(e)

if not ticket.valid:
try:
Expand Down Expand Up @@ -427,7 +427,7 @@ def terminate_actor(self, actor: ClientActorHandle,
term.client_id = self._client_id
self.server.Terminate(term, metadata=self.metadata)
except grpc.RpcError as e:
raise decode_exception(e.details())
raise decode_exception(e)

def terminate_task(self, obj: ClientObjectRef, force: bool,
recursive: bool) -> None:
Expand All @@ -444,7 +444,7 @@ def terminate_task(self, obj: ClientObjectRef, force: bool,
term.client_id = self._client_id
self.server.Terminate(term, metadata=self.metadata)
except grpc.RpcError as e:
raise decode_exception(e.details())
raise decode_exception(e)

def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum):
req = ray_client_pb2.ClusterInfoRequest()
Expand Down Expand Up @@ -539,7 +539,7 @@ def _server_init(self,
f"Initialization failure from server:\n{response.msg}")

except grpc.RpcError as e:
raise decode_exception(e.details())
raise decode_exception(e)

def _convert_actor(self, actor: "ActorClass") -> str:
"""Register a ClientActorClass for the ActorClass and return a UUID"""
Expand Down Expand Up @@ -592,6 +592,13 @@ def make_client_id() -> str:
return id.hex


def decode_exception(data) -> Exception:
data = base64.standard_b64decode(data)
def decode_exception(e: grpc.RpcError) -> Exception:
if e.code() != grpc.StatusCode.ABORTED:
# The ABORTED status code is used by the server when an application
# error is serialized into the the exception details. If the code
# isn't ABORTED, then raise the original error since there's no
# serialized error to decode.
# See server.py::return_exception_in_context for details
raise
data = base64.standard_b64decode(e.details())
return loads_from_server(data)

0 comments on commit 6f94d0f

Please sign in to comment.