69 lines
2.5 KiB
Python
69 lines
2.5 KiB
Python
from typing import Optional, Set
|
|
import grpc
|
|
from tenacity import retry, stop_after_attempt, wait_exponential_jitter, retry_if_result
|
|
from opentelemetry.trace import Span
|
|
|
|
|
|
class RetryOnRpcErrorClientInterceptor(
|
|
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor
|
|
):
|
|
"""
|
|
A gRPC client interceptor that retries RPCs on specific status codes. By default, it retries on UNAVAILABLE and UNKNOWN status codes.
|
|
|
|
This interceptor should be placed after the OpenTelemetry interceptor in the interceptor list.
|
|
"""
|
|
|
|
max_attempts: int
|
|
retryable_status_codes: Set[grpc.StatusCode]
|
|
|
|
def __init__(
|
|
self,
|
|
max_attempts: int = 5,
|
|
retryable_status_codes: Set[grpc.StatusCode] = set(
|
|
[grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]
|
|
),
|
|
) -> None:
|
|
self.max_attempts = max_attempts
|
|
self.retryable_status_codes = retryable_status_codes
|
|
|
|
def _intercept_call(self, continuation, client_call_details, request_or_iterator):
|
|
sleep_span: Optional[Span] = None
|
|
|
|
def before_sleep(_):
|
|
from chromadb.telemetry.opentelemetry import tracer
|
|
|
|
nonlocal sleep_span
|
|
if tracer is not None:
|
|
sleep_span = tracer.start_span("Waiting to retry RPC")
|
|
|
|
@retry(
|
|
wait=wait_exponential_jitter(0.1, jitter=0.1),
|
|
stop=stop_after_attempt(self.max_attempts),
|
|
retry=retry_if_result(lambda x: x.code() in self.retryable_status_codes),
|
|
before_sleep=before_sleep,
|
|
)
|
|
def wrapped(*args, **kwargs):
|
|
nonlocal sleep_span
|
|
if sleep_span is not None:
|
|
sleep_span.end()
|
|
sleep_span = None
|
|
return continuation(*args, **kwargs)
|
|
|
|
return wrapped(client_call_details, request_or_iterator)
|
|
|
|
def intercept_unary_unary(self, continuation, client_call_details, request):
|
|
return self._intercept_call(continuation, client_call_details, request)
|
|
|
|
def intercept_unary_stream(self, continuation, client_call_details, request):
|
|
return self._intercept_call(continuation, client_call_details, request)
|
|
|
|
def intercept_stream_unary(
|
|
self, continuation, client_call_details, request_iterator
|
|
):
|
|
return self._intercept_call(continuation, client_call_details, request_iterator)
|
|
|
|
def intercept_stream_stream(
|
|
self, continuation, client_call_details, request_iterator
|
|
):
|
|
return self._intercept_call(continuation, client_call_details, request_iterator)
|