Skip to content

Commit

Permalink
feat: capture shard failures in the head runtime (jina-ai#5338)
Browse files Browse the repository at this point in the history
  • Loading branch information
girishc13 authored Nov 8, 2022
1 parent bef22c8 commit db1c406
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 104 deletions.
90 changes: 15 additions & 75 deletions jina/serve/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from jina import __default_endpoint__
from jina.enums import PollingType
from jina.excepts import EstablishGrpcConnectionError
from jina.excepts import EstablishGrpcConnectionError, InternalNetworkError
from jina.importer import ImportExtensions
from jina.logging.logger import JinaLogger
from jina.proto import jina_pb2, jina_pb2_grpc
Expand All @@ -26,7 +26,7 @@

from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING: # pragma: no cover
from grpc.aio._interceptor import ClientInterceptor
from opentelemetry.instrumentation.grpc._client import (
OpenTelemetryClientInterceptor,
Expand Down Expand Up @@ -755,42 +755,6 @@ def __init__(
)
self._deployment_address_map = {}

def send_request(
self,
request: Request,
deployment: str,
head: bool = False,
shard_id: Optional[int] = None,
polling_type: PollingType = PollingType.ANY,
endpoint: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
) -> List[asyncio.Task]:
"""Send a single message to target via one or all of the pooled connections, depending on polling_type. Convenience function wrapper around send_request.
:param request: a single request to send
:param deployment: name of the Jina deployment to send the message to
:param head: If True it is send to the head, otherwise to the worker pods
:param shard_id: Send to a specific shard of the deployment, ignored for polling ALL
:param polling_type: defines if the message should be send to any or all pooled connections for the target
:param endpoint: endpoint to target with the request
:param metadata: metadata to send with the request
:param timeout: timeout for sending the requests
:param retries: number of retries per gRPC call. If <0 it defaults to max(3, num_replicas)
:return: list of asyncio.Task items for each send call
"""
return self.send_requests(
requests=[request],
deployment=deployment,
head=head,
shard_id=shard_id,
polling_type=polling_type,
endpoint=endpoint,
metadata=metadata,
timeout=timeout,
retries=retries,
)

def send_requests(
self,
requests: List[Request],
Expand Down Expand Up @@ -872,36 +836,6 @@ def send_discover_endpoint(
)
return None

def send_request_once(
self,
request: Request,
deployment: str,
metadata: Optional[Dict[str, str]] = None,
head: bool = False,
shard_id: Optional[int] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
) -> asyncio.Task:
"""Send msg to target via only one of the pooled connections
:param request: request to send
:param deployment: name of the Jina deployment to send the message to
:param metadata: metadata to send with the request
:param head: If True it is send to the head, otherwise to the worker pods
:param shard_id: Send to a specific shard of the deployment, ignored for polling ALL
:param timeout: timeout for sending the requests
:param retries: number of retries per gRPC call. If <0 it defaults to max(3, num_replicas)
:return: asyncio.Task representing the send call
"""
return self.send_requests_once(
[request],
deployment=deployment,
metadata=metadata,
head=head,
shard_id=shard_id,
timeout=timeout,
retries=retries,
)

def send_requests_once(
self,
requests: List[Request],
Expand All @@ -927,14 +861,15 @@ def send_requests_once(
"""
replicas = self._connections.get_replicas(deployment, head, shard_id)
if replicas:
return self._send_requests(
result = self._send_requests(
requests,
replicas,
endpoint=endpoint,
metadata=metadata,
timeout=timeout,
retries=retries,
)
return result
else:
self._logger.debug(
f'no available connections for deployment {deployment} and shard {shard_id}'
Expand Down Expand Up @@ -1005,7 +940,7 @@ async def _handle_aiorpcerror(
current_address: str = '', # the specific address that was contacted during this attempt
current_deployment: str = '', # the specific deployment that was contacted during this attempt
connection_list: Optional[ReplicaList] = None,
):
) -> 'Optional[Union[AioRpcError, InternalNetworkError]]':
# connection failures, cancelled requests, and timed out requests should be retried
# all other cases should not be retried and will be raised immediately
# connection failures have the code grpc.StatusCode.UNAVAILABLE
Expand All @@ -1018,7 +953,7 @@ async def _handle_aiorpcerror(
and error.code() != grpc.StatusCode.CANCELLED
and error.code() != grpc.StatusCode.DEADLINE_EXCEEDED
):
raise
return error
elif (
error.code() == grpc.StatusCode.UNAVAILABLE
or error.code() == grpc.StatusCode.DEADLINE_EXCEEDED
Expand All @@ -1031,7 +966,7 @@ async def _handle_aiorpcerror(
if connection_list:
await connection_list.reset_connection(current_address, current_deployment)

raise InternalNetworkError(
return InternalNetworkError(
og_exception=error,
request_id=request_id,
dest_addr=tried_addresses,
Expand All @@ -1042,6 +977,7 @@ async def _handle_aiorpcerror(
f'GRPC call failed with code {error.code()}, retry attempt {retry_i + 1}/{total_num_tries - 1}.'
f' Trying next replica, if available.'
)
return None

def _send_requests(
self,
Expand All @@ -1051,7 +987,7 @@ def _send_requests(
metadata: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
) -> asyncio.Task:
) -> 'asyncio.Task[Union[Tuple, AioRpcError, InternalNetworkError]]':
# this wraps the awaitable object from grpc as a coroutine so it can be used as a task
# the grpc call function is not a coroutine but some _AioCall

Expand Down Expand Up @@ -1084,7 +1020,7 @@ async def task_wrapper():
timeout=timeout,
)
except AioRpcError as e:
await self._handle_aiorpcerror(
error = await self._handle_aiorpcerror(
error=e,
retry_i=i,
request_id=requests[0].request_id,
Expand All @@ -1094,6 +1030,8 @@ async def task_wrapper():
current_deployment=current_connection.deployment_name,
connection_list=connections,
)
if error:
return error

return asyncio.create_task(task_wrapper())

Expand Down Expand Up @@ -1128,7 +1066,7 @@ async def task_wrapper():
timeout=timeout,
)
except AioRpcError as e:
await self._handle_aiorpcerror(
error = await self._handle_aiorpcerror(
error=e,
retry_i=i,
tried_addresses=tried_addresses,
Expand All @@ -1137,6 +1075,8 @@ async def task_wrapper():
connection_list=connection_list,
total_num_tries=total_num_tries,
)
if error:
raise error
except AttributeError:
return default_endpoints_proto, None

Expand Down
11 changes: 9 additions & 2 deletions jina/serve/runtimes/gateway/graph/topology_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from typing import Dict, List, Optional, Tuple

import grpc.aio
from grpc.aio import AioRpcError

from jina import __default_endpoint__
from jina.excepts import InternalNetworkError
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.helper import _parse_specific_params
from jina.serve.runtimes.request_handlers.worker_request_handler import WorkerRequestHandler
from jina.serve.runtimes.request_handlers.worker_request_handler import (
WorkerRequestHandler,
)
from jina.types.request.data import DataRequest


Expand Down Expand Up @@ -157,7 +160,7 @@ async def _wait_previous_and_send(
return request, metadata
# otherwise, send to executor and get response
try:
resp, metadata = await connection_pool.send_requests_once(
result = await connection_pool.send_requests_once(
requests=self.parts_to_send,
deployment=self.name,
metadata=self._metadata,
Expand All @@ -166,6 +169,10 @@ async def _wait_previous_and_send(
timeout=self._timeout_send,
retries=self._retries,
)
if issubclass(type(result), BaseException):
raise result
else:
resp, metadata = result
if WorkerRequestHandler._KEY_RESULT in resp.parameters:
# Accumulate results from each Node and then add them to the original
self.result_in_params_returned = resp.parameters[
Expand Down
94 changes: 69 additions & 25 deletions jina/serve/runtimes/head/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import argparse
import asyncio
import contextlib
import json
import os
from abc import ABC
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import grpc
from grpc.aio import AioRpcError
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha import reflection

Expand All @@ -20,7 +20,9 @@
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime
from jina.serve.runtimes.helper import _get_grpc_server_options
from jina.serve.runtimes.request_handlers.worker_request_handler import WorkerRequestHandler
from jina.serve.runtimes.request_handlers.worker_request_handler import (
WorkerRequestHandler,
)
from jina.types.request.data import DataRequest, Response


Expand Down Expand Up @@ -302,6 +304,32 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:

return response

async def _gather_worker_tasks(self, requests, endpoint):
worker_send_tasks = self.connection_pool.send_requests(
requests=requests,
deployment=self._deployment_name,
polling_type=self._polling[endpoint],
timeout=self.timeout_send,
retries=self._retries,
)

all_worker_results = await asyncio.gather(*worker_send_tasks)
worker_results = list(
filter(lambda x: isinstance(x, Tuple), all_worker_results)
)
exceptions = list(
filter(
lambda x: issubclass(type(x), BaseException),
all_worker_results,
)
)
total_shards = len(worker_send_tasks)
failed_shards = len(exceptions)
if failed_shards:
self.logger.warning(f'{failed_shards} shards out of {total_shards} failed.')

return worker_results, exceptions, total_shards, failed_shards

async def _handle_data_request(
self, requests: List[DataRequest], endpoint: Optional[str]
) -> Tuple[DataRequest, Dict]:
Expand All @@ -311,28 +339,29 @@ async def _handle_data_request(

uses_before_metadata = None
if self.uses_before_address:
(
response,
uses_before_metadata,
) = await self.connection_pool.send_requests_once(
result = await self.connection_pool.send_requests_once(
requests,
deployment='uses_before',
timeout=self.timeout_send,
retries=self._retries,
)
requests = [response]

worker_send_tasks = self.connection_pool.send_requests(
requests=requests,
deployment=self._deployment_name,
polling_type=self._polling[endpoint],
timeout=self.timeout_send,
retries=self._retries,
)

worker_results = await asyncio.gather(*worker_send_tasks)
if issubclass(type(result), BaseException):
raise result
else:
response, uses_before_metadata = result
requests = [response]

(
worker_results,
exceptions,
total_shards,
failed_shards,
) = await self._gather_worker_tasks(requests, endpoint)

if len(worker_results) == 0:
if exceptions:
# raise the underlying error first
raise exceptions[0]
raise RuntimeError(
f'Head {self.name} did not receive a response when sending message to worker pods'
)
Expand All @@ -342,31 +371,43 @@ async def _handle_data_request(
response_request = worker_results[0]
uses_after_metadata = None
if self.uses_after_address:
(
response_request,
uses_after_metadata,
) = await self.connection_pool.send_requests_once(
result = await self.connection_pool.send_requests_once(
worker_results,
deployment='uses_after',
timeout=self.timeout_send,
retries=self._retries,
)
if issubclass(type(result), BaseException):
raise result
else:
response_request, uses_after_metadata = result
elif len(worker_results) > 1 and self._reduce:
WorkerRequestHandler.reduce_requests(worker_results)
response_request = WorkerRequestHandler.reduce_requests(worker_results)
elif len(worker_results) > 1 and not self._reduce:
# worker returned multiple responsed, but the head is configured to skip reduction
# worker returned multiple responses, but the head is configured to skip reduction
# just concatenate the docs in this case
response_request.data.docs = WorkerRequestHandler.get_docs_from_request(
requests, field='docs'
)

merged_metadata = self._merge_metadata(
metadata, uses_after_metadata, uses_before_metadata
metadata,
uses_after_metadata,
uses_before_metadata,
total_shards,
failed_shards,
)

return response_request, merged_metadata

def _merge_metadata(self, metadata, uses_after_metadata, uses_before_metadata):
def _merge_metadata(
self,
metadata,
uses_after_metadata,
uses_before_metadata,
total_shards,
failed_shards,
):
merged_metadata = {}
if uses_before_metadata:
for key, value in uses_before_metadata:
Expand All @@ -377,6 +418,9 @@ def _merge_metadata(self, metadata, uses_after_metadata, uses_before_metadata):
if uses_after_metadata:
for key, value in uses_after_metadata:
merged_metadata[key] = value

merged_metadata['total_shards'] = str(total_shards)
merged_metadata['failed_shards'] = str(failed_shards)
return merged_metadata

async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
Expand Down
Loading

0 comments on commit db1c406

Please sign in to comment.