|
|
|
|
@@ -23,9 +23,9 @@ import threading
|
|
|
|
|
import time
|
|
|
|
|
import uuid
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from http import HTTPStatus
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from shlex import quote
|
|
|
|
|
from typing import Any, Callable, Generator, Optional, Union
|
|
|
|
|
from typing import Any, Callable, Generator, Mapping, Optional, Union
|
|
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
|
|
|
|
|
|
@@ -48,6 +48,94 @@ from ._typing import HTTP_METHOD_T
|
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class RateLimitInfo:
|
|
|
|
|
"""
|
|
|
|
|
Parsed rate limit information from HTTP response headers.
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
resource_type (`str`): The type of resource being rate limited.
|
|
|
|
|
remaining (`int`): The number of requests remaining in the current window.
|
|
|
|
|
reset_in_seconds (`int`): The number of seconds until the rate limit resets.
|
|
|
|
|
limit (`int`, *optional*): The maximum number of requests allowed in the current window.
|
|
|
|
|
window_seconds (`int`, *optional*): The number of seconds in the current window.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
resource_type: str
|
|
|
|
|
remaining: int
|
|
|
|
|
reset_in_seconds: int
|
|
|
|
|
limit: Optional[int] = None
|
|
|
|
|
window_seconds: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Regex patterns for parsing rate limit headers
|
|
|
|
|
# e.g.: "api";r=0;t=55 --> resource_type="api", r=0, t=55
|
|
|
|
|
_RATELIMIT_REGEX = re.compile(r"\"(?P<resource_type>\w+)\"\s*;\s*r\s*=\s*(?P<r>\d+)\s*;\s*t\s*=\s*(?P<t>\d+)")
|
|
|
|
|
# e.g.: "fixed window";"api";q=500;w=300 --> q=500, w=300
|
|
|
|
|
_RATELIMIT_POLICY_REGEX = re.compile(r"q\s*=\s*(?P<q>\d+).*?w\s*=\s*(?P<w>\d+)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_ratelimit_headers(headers: Mapping[str, str]) -> Optional[RateLimitInfo]:
|
|
|
|
|
"""Parse rate limit information from HTTP response headers.
|
|
|
|
|
|
|
|
|
|
Follows IETF draft: https://www.ietf.org/archive/id/draft-ietf-httpapi-ratelimit-headers-09.html
|
|
|
|
|
Only a subset is implemented.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
```python
|
|
|
|
|
>>> from huggingface_hub.utils import parse_ratelimit_headers
|
|
|
|
|
>>> headers = {
|
|
|
|
|
... "ratelimit": '"api";r=0;t=55',
|
|
|
|
|
... "ratelimit-policy": '"fixed window";"api";q=500;w=300',
|
|
|
|
|
... }
|
|
|
|
|
>>> info = parse_ratelimit_headers(headers)
|
|
|
|
|
>>> info.remaining
|
|
|
|
|
0
|
|
|
|
|
>>> info.reset_in_seconds
|
|
|
|
|
55
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
ratelimit: Optional[str] = None
|
|
|
|
|
policy: Optional[str] = None
|
|
|
|
|
for key in headers:
|
|
|
|
|
lower_key = key.lower()
|
|
|
|
|
if lower_key == "ratelimit":
|
|
|
|
|
ratelimit = headers[key]
|
|
|
|
|
elif lower_key == "ratelimit-policy":
|
|
|
|
|
policy = headers[key]
|
|
|
|
|
|
|
|
|
|
if not ratelimit:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
match = _RATELIMIT_REGEX.search(ratelimit)
|
|
|
|
|
if not match:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
resource_type = match.group("resource_type")
|
|
|
|
|
remaining = int(match.group("r"))
|
|
|
|
|
reset_in_seconds = int(match.group("t"))
|
|
|
|
|
|
|
|
|
|
limit: Optional[int] = None
|
|
|
|
|
window_seconds: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
if policy:
|
|
|
|
|
policy_match = _RATELIMIT_POLICY_REGEX.search(policy)
|
|
|
|
|
if policy_match:
|
|
|
|
|
limit = int(policy_match.group("q"))
|
|
|
|
|
window_seconds = int(policy_match.group("w"))
|
|
|
|
|
|
|
|
|
|
return RateLimitInfo(
|
|
|
|
|
resource_type=resource_type,
|
|
|
|
|
remaining=remaining,
|
|
|
|
|
reset_in_seconds=reset_in_seconds,
|
|
|
|
|
limit=limit,
|
|
|
|
|
window_seconds=window_seconds,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Both headers are used by the Hub to debug failed requests.
|
|
|
|
|
# `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB.
|
|
|
|
|
# If `X_AMZN_TRACE_ID` is set, the Hub will use it as well.
|
|
|
|
|
@@ -79,7 +167,7 @@ def hf_request_event_hook(request: httpx.Request) -> None:
|
|
|
|
|
- Add a request ID to the request headers
|
|
|
|
|
- Log the request if debug mode is enabled
|
|
|
|
|
"""
|
|
|
|
|
if constants.HF_HUB_OFFLINE:
|
|
|
|
|
if constants.is_offline_mode():
|
|
|
|
|
raise OfflineModeIsEnabled(
|
|
|
|
|
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
|
|
|
|
|
)
|
|
|
|
|
@@ -249,6 +337,10 @@ if hasattr(os, "register_at_fork"):
|
|
|
|
|
os.register_at_fork(after_in_child=close_session)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_DEFAULT_RETRY_ON_EXCEPTIONS: tuple[type[Exception], ...] = (httpx.TimeoutException, httpx.NetworkError)
|
|
|
|
|
_DEFAULT_RETRY_ON_STATUS_CODES: tuple[int, ...] = (429, 500, 502, 503, 504)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _http_backoff_base(
|
|
|
|
|
method: HTTP_METHOD_T,
|
|
|
|
|
url: str,
|
|
|
|
|
@@ -256,11 +348,8 @@ def _http_backoff_base(
|
|
|
|
|
max_retries: int = 5,
|
|
|
|
|
base_wait_time: float = 1,
|
|
|
|
|
max_wait_time: float = 8,
|
|
|
|
|
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
|
|
|
|
|
httpx.TimeoutException,
|
|
|
|
|
httpx.NetworkError,
|
|
|
|
|
),
|
|
|
|
|
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
|
|
|
|
|
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = _DEFAULT_RETRY_ON_EXCEPTIONS,
|
|
|
|
|
retry_on_status_codes: Union[int, tuple[int, ...]] = _DEFAULT_RETRY_ON_STATUS_CODES,
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Generator[httpx.Response, None, None]:
|
|
|
|
|
@@ -273,6 +362,7 @@ def _http_backoff_base(
|
|
|
|
|
|
|
|
|
|
nb_tries = 0
|
|
|
|
|
sleep_time = base_wait_time
|
|
|
|
|
ratelimit_reset: Optional[int] = None # seconds to wait for rate limit reset if 429 response
|
|
|
|
|
|
|
|
|
|
# If `data` is used and is a file object (or any IO), it will be consumed on the
|
|
|
|
|
# first HTTP request. We need to save the initial position so that the full content
|
|
|
|
|
@@ -284,6 +374,7 @@ def _http_backoff_base(
|
|
|
|
|
client = get_session()
|
|
|
|
|
while True:
|
|
|
|
|
nb_tries += 1
|
|
|
|
|
ratelimit_reset = None
|
|
|
|
|
try:
|
|
|
|
|
# If `data` is used and is a file object (or any IO), set back cursor to
|
|
|
|
|
# initial position.
|
|
|
|
|
@@ -293,6 +384,8 @@ def _http_backoff_base(
|
|
|
|
|
# Perform request and handle response
|
|
|
|
|
def _should_retry(response: httpx.Response) -> bool:
|
|
|
|
|
"""Handle response and return True if should retry, False if should return/yield."""
|
|
|
|
|
nonlocal ratelimit_reset
|
|
|
|
|
|
|
|
|
|
if response.status_code not in retry_on_status_codes:
|
|
|
|
|
return False # Success, don't retry
|
|
|
|
|
|
|
|
|
|
@@ -304,6 +397,12 @@ def _http_backoff_base(
|
|
|
|
|
# user ask for retry on a status code that doesn't raise_for_status.
|
|
|
|
|
return False # Don't retry, return/yield response
|
|
|
|
|
|
|
|
|
|
# get rate limit reset time from headers if 429 response
|
|
|
|
|
if response.status_code == 429:
|
|
|
|
|
ratelimit_info = parse_ratelimit_headers(response.headers)
|
|
|
|
|
if ratelimit_info is not None:
|
|
|
|
|
ratelimit_reset = ratelimit_info.reset_in_seconds
|
|
|
|
|
|
|
|
|
|
return True # Should retry
|
|
|
|
|
|
|
|
|
|
if stream:
|
|
|
|
|
@@ -326,9 +425,14 @@ def _http_backoff_base(
|
|
|
|
|
if nb_tries > max_retries:
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
# Sleep for X seconds
|
|
|
|
|
logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].")
|
|
|
|
|
time.sleep(sleep_time)
|
|
|
|
|
if ratelimit_reset is not None:
|
|
|
|
|
actual_sleep = float(ratelimit_reset) + 1 # +1s to avoid rounding issues
|
|
|
|
|
logger.warning(f"Rate limited. Waiting {actual_sleep}s before retry [Retry {nb_tries}/{max_retries}].")
|
|
|
|
|
else:
|
|
|
|
|
actual_sleep = sleep_time
|
|
|
|
|
logger.warning(f"Retrying in {actual_sleep}s [Retry {nb_tries}/{max_retries}].")
|
|
|
|
|
|
|
|
|
|
time.sleep(actual_sleep)
|
|
|
|
|
|
|
|
|
|
# Update sleep time for next retry
|
|
|
|
|
sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff
|
|
|
|
|
@@ -341,11 +445,8 @@ def http_backoff(
|
|
|
|
|
max_retries: int = 5,
|
|
|
|
|
base_wait_time: float = 1,
|
|
|
|
|
max_wait_time: float = 8,
|
|
|
|
|
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
|
|
|
|
|
httpx.TimeoutException,
|
|
|
|
|
httpx.NetworkError,
|
|
|
|
|
),
|
|
|
|
|
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
|
|
|
|
|
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = _DEFAULT_RETRY_ON_EXCEPTIONS,
|
|
|
|
|
retry_on_status_codes: Union[int, tuple[int, ...]] = _DEFAULT_RETRY_ON_STATUS_CODES,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> httpx.Response:
|
|
|
|
|
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
|
|
|
|
|
@@ -374,9 +475,9 @@ def http_backoff(
|
|
|
|
|
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
|
|
|
|
|
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
|
|
|
|
|
By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`.
|
|
|
|
|
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
|
|
|
|
|
Define on which status codes the request must be retried. By default, only
|
|
|
|
|
HTTP 503 Service Unavailable is retried.
|
|
|
|
|
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `(429, 500, 502, 503, 504)`):
|
|
|
|
|
Define on which status codes the request must be retried. By default, retries
|
|
|
|
|
on rate limit (429) and server errors (5xx).
|
|
|
|
|
**kwargs (`dict`, *optional*):
|
|
|
|
|
kwargs to pass to `httpx.request`.
|
|
|
|
|
|
|
|
|
|
@@ -425,11 +526,8 @@ def http_stream_backoff(
|
|
|
|
|
max_retries: int = 5,
|
|
|
|
|
base_wait_time: float = 1,
|
|
|
|
|
max_wait_time: float = 8,
|
|
|
|
|
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
|
|
|
|
|
httpx.TimeoutException,
|
|
|
|
|
httpx.NetworkError,
|
|
|
|
|
),
|
|
|
|
|
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
|
|
|
|
|
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = _DEFAULT_RETRY_ON_EXCEPTIONS,
|
|
|
|
|
retry_on_status_codes: Union[int, tuple[int, ...]] = _DEFAULT_RETRY_ON_STATUS_CODES,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Generator[httpx.Response, None, None]:
|
|
|
|
|
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
|
|
|
|
|
@@ -457,10 +555,10 @@ def http_stream_backoff(
|
|
|
|
|
Maximum duration (in seconds) to wait before retrying.
|
|
|
|
|
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
|
|
|
|
|
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
|
|
|
|
|
By default, retry on `httpx.Timeout` and `httpx.NetworkError`.
|
|
|
|
|
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
|
|
|
|
|
Define on which status codes the request must be retried. By default, only
|
|
|
|
|
HTTP 503 Service Unavailable is retried.
|
|
|
|
|
By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`.
|
|
|
|
|
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `(429, 500, 502, 503, 504)`):
|
|
|
|
|
Define on which status codes the request must be retried. By default, retries
|
|
|
|
|
on rate limit (429) and server errors (5xx).
|
|
|
|
|
**kwargs (`dict`, *optional*):
|
|
|
|
|
kwargs to pass to `httpx.request`.
|
|
|
|
|
|
|
|
|
|
@@ -549,6 +647,12 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] =
|
|
|
|
|
> - [`~utils.HfHubHTTPError`]
|
|
|
|
|
> If request failed for a reason not listed above.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
_warn_on_warning_headers(response)
|
|
|
|
|
except Exception:
|
|
|
|
|
# Never raise on warning parsing
|
|
|
|
|
logger.debug("Failed to parse warning headers", exc_info=True)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
|
|
|
@@ -619,6 +723,25 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] =
|
|
|
|
|
)
|
|
|
|
|
raise _format(HfHubHTTPError, message, response) from e
|
|
|
|
|
|
|
|
|
|
elif response.status_code == 429:
|
|
|
|
|
ratelimit_info = parse_ratelimit_headers(response.headers)
|
|
|
|
|
if ratelimit_info is not None:
|
|
|
|
|
message = (
|
|
|
|
|
f"\n\n429 Too Many Requests: you have reached your '{ratelimit_info.resource_type}' rate limit."
|
|
|
|
|
)
|
|
|
|
|
message += f"\nRetry after {ratelimit_info.reset_in_seconds} seconds"
|
|
|
|
|
if ratelimit_info.limit is not None and ratelimit_info.window_seconds is not None:
|
|
|
|
|
message += (
|
|
|
|
|
f" ({ratelimit_info.remaining}/{ratelimit_info.limit} requests remaining"
|
|
|
|
|
f" in current {ratelimit_info.window_seconds}s window)."
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
message += "."
|
|
|
|
|
message += f"\nUrl: {response.url}."
|
|
|
|
|
else:
|
|
|
|
|
message = f"\n\n429 Too Many Requests for url: {response.url}."
|
|
|
|
|
raise _format(HfHubHTTPError, message, response) from e
|
|
|
|
|
|
|
|
|
|
elif response.status_code == 416:
|
|
|
|
|
range_header = response.request.headers.get("Range")
|
|
|
|
|
message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}."
|
|
|
|
|
@@ -629,6 +752,33 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] =
|
|
|
|
|
raise _format(HfHubHTTPError, str(e), response) from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_WARNED_TOPICS = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _warn_on_warning_headers(response: httpx.Response) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Emit warnings if warning headers are present in the HTTP response.
|
|
|
|
|
|
|
|
|
|
Expected header format: 'X-HF-Warning: topic; message'
|
|
|
|
|
|
|
|
|
|
Only the first warning for each topic will be shown. Topic is optional and can be empty. Note that several warning
|
|
|
|
|
headers can be present in a single response.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
response (`httpx.Response`):
|
|
|
|
|
The HTTP response to check for warning headers.
|
|
|
|
|
"""
|
|
|
|
|
server_warnings = response.headers.get_list("X-HF-Warning")
|
|
|
|
|
for server_warning in server_warnings:
|
|
|
|
|
topic, message = server_warning.split(";", 1) if ";" in server_warning else ("", server_warning)
|
|
|
|
|
topic = topic.strip()
|
|
|
|
|
if topic not in _WARNED_TOPICS:
|
|
|
|
|
message = message.strip()
|
|
|
|
|
if message:
|
|
|
|
|
_WARNED_TOPICS.add(topic)
|
|
|
|
|
logger.warning(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _format(error_type: type[HfHubHTTPError], custom_message: str, response: httpx.Response) -> HfHubHTTPError:
|
|
|
|
|
server_errors = []
|
|
|
|
|
|
|
|
|
|
|