增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -46,7 +46,7 @@ import sys
from typing import TYPE_CHECKING
__version__ = "1.1.4"
__version__ = "1.2.4"
# Alphabetical order of definitions is ensured in tests
# WARNING: any comment added in this dictionary definition will be lost when
@@ -134,6 +134,7 @@ _SUBMOD_ATTRS = {
"REPO_TYPE_SPACE",
"TF2_WEIGHTS_NAME",
"TF_WEIGHTS_NAME",
"is_offline_mode",
],
"fastai_utils": [
"_save_pretrained_fastai",
@@ -164,6 +165,8 @@ _SUBMOD_ATTRS = {
"HfApi",
"ModelInfo",
"Organization",
"RepoFile",
"RepoFolder",
"RepoUrl",
"SpaceInfo",
"User",
@@ -230,6 +233,7 @@ _SUBMOD_ATTRS = {
"inspect_scheduled_job",
"list_accepted_access_requests",
"list_collections",
"list_daily_papers",
"list_datasets",
"list_inference_catalog",
"list_inference_endpoints",
@@ -296,6 +300,7 @@ _SUBMOD_ATTRS = {
"HfFileSystemFile",
"HfFileSystemResolvedPath",
"HfFileSystemStreamFile",
"hffs",
],
"hub_mixin": [
"ModelHubMixin",
@@ -708,6 +713,8 @@ __all__ = [
"REPO_TYPE_MODEL",
"REPO_TYPE_SPACE",
"RepoCard",
"RepoFile",
"RepoFolder",
"RepoUrl",
"SentenceSimilarityInput",
"SentenceSimilarityInputData",
@@ -883,11 +890,14 @@ __all__ = [
"hf_hub_download",
"hf_hub_url",
"hf_raise_for_status",
"hffs",
"inspect_job",
"inspect_scheduled_job",
"interpreter_login",
"is_offline_mode",
"list_accepted_access_requests",
"list_collections",
"list_daily_papers",
"list_datasets",
"list_inference_catalog",
"list_inference_endpoints",
@@ -1150,6 +1160,7 @@ if TYPE_CHECKING: # pragma: no cover
REPO_TYPE_SPACE, # noqa: F401
TF2_WEIGHTS_NAME, # noqa: F401
TF_WEIGHTS_NAME, # noqa: F401
is_offline_mode, # noqa: F401
)
from .fastai_utils import (
_save_pretrained_fastai, # noqa: F401
@@ -1180,6 +1191,8 @@ if TYPE_CHECKING: # pragma: no cover
HfApi, # noqa: F401
ModelInfo, # noqa: F401
Organization, # noqa: F401
RepoFile, # noqa: F401
RepoFolder, # noqa: F401
RepoUrl, # noqa: F401
SpaceInfo, # noqa: F401
User, # noqa: F401
@@ -1246,6 +1259,7 @@ if TYPE_CHECKING: # pragma: no cover
inspect_scheduled_job, # noqa: F401
list_accepted_access_requests, # noqa: F401
list_collections, # noqa: F401
list_daily_papers, # noqa: F401
list_datasets, # noqa: F401
list_inference_catalog, # noqa: F401
list_inference_endpoints, # noqa: F401
@@ -1312,6 +1326,7 @@ if TYPE_CHECKING: # pragma: no cover
HfFileSystemFile, # noqa: F401
HfFileSystemResolvedPath, # noqa: F401
HfFileSystemStreamFile, # noqa: F401
hffs, # noqa: F401
)
from .hub_mixin import (
ModelHubMixin, # noqa: F401

View File

@@ -27,6 +27,7 @@ from .utils import (
fetch_xet_connection_info_from_repo_info,
get_session,
hf_raise_for_status,
http_backoff,
logging,
sha,
tqdm_stream_file,
@@ -739,7 +740,8 @@ def _fetch_upload_modes(
if gitignore_content is not None:
payload["gitIgnore"] = gitignore_content
resp = get_session().post(
resp = http_backoff(
"POST",
f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}",
json=payload,
headers=headers,

View File

@@ -34,6 +34,11 @@ class InferenceEndpointType(str, Enum):
PRIVATE = "private"
class InferenceEndpointScalingMetric(str, Enum):
PENDING_REQUESTS = "pendingRequests"
HARDWARE_USAGE = "hardwareUsage"
@dataclass
class InferenceEndpoint:
"""

View File

@@ -308,6 +308,7 @@ def read_download_metadata(local_dir: Path, filename: str) -> Optional[LocalDown
paths.metadata_path.unlink()
except Exception as e:
logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}")
return None
try:
# check if the file exists and hasn't been modified since the metadata was saved
@@ -383,6 +384,9 @@ def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetad
except Exception as e:
logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}")
# corrupted metadata => we don't know anything expect its size
return LocalUploadFileMetadata(size=paths.file_path.stat().st_size)
# TODO: can we do better?
if (
metadata.timestamp is not None

View File

@@ -192,7 +192,11 @@ def auth_list() -> None:
tokens = get_stored_tokens()
if not tokens:
logger.info("No access tokens found.")
if _get_token_from_environment():
logger.info("No stored access tokens found.")
logger.warning("Note: Environment variable `HF_TOKEN` is set and is the current active token.")
else:
logger.info("No access tokens found.")
return
# Find current token
current_token = get_token()

View File

@@ -23,7 +23,7 @@ from .utils import tqdm as hf_tqdm
logger = logging.get_logger(__name__)
VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
LARGE_REPO_THRESHOLD = 1000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
@overload
@@ -335,9 +335,7 @@ def snapshot_download(
# In that case, we need to use the `list_repo_tree` method to prevent caching issues.
repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings] if repo_info.siblings is not None else []
unreliable_nb_files = (
repo_info.siblings is None
or len(repo_info.siblings) == 0
or len(repo_info.siblings) > VERY_LARGE_REPO_THRESHOLD
repo_info.siblings is None or len(repo_info.siblings) == 0 or len(repo_info.siblings) > LARGE_REPO_THRESHOLD
)
if unreliable_nb_files:
logger.info(

View File

@@ -191,7 +191,7 @@ def upload_large_folder_internal(
if num_workers is None:
nb_cores = os.cpu_count() or 1
num_workers = max(nb_cores - 2, 2) # Use all but 2 cores, or at least 2 cores
num_workers = max(nb_cores // 2, 1) # Use at most half of cpu cores
# 2. Create repo if missing
repo_url = api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True)

View File

@@ -96,7 +96,7 @@ TokenOpt = Annotated[
]
PrivateOpt = Annotated[
bool,
Optional[bool],
typer.Option(
help="Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already exists.",
),
@@ -144,6 +144,7 @@ def _check_cli_update() -> None:
return
# Touch the file to mark that we did the check now
Path(constants.CHECK_FOR_UPDATE_DONE_PATH).parent.mkdir(parents=True, exist_ok=True)
Path(constants.CHECK_FOR_UPDATE_DONE_PATH).touch()
# Check latest version from PyPI

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from huggingface_hub import constants
from huggingface_hub.cli._cli_utils import check_cli_update, typer_factory
from huggingface_hub.cli.auth import auth_cli
from huggingface_hub.cli.cache import cache_cli
@@ -51,7 +52,8 @@ app.add_typer(ie_cli, name="endpoints")
def main():
logging.set_verbosity_info()
if not constants.HF_DEBUG:
logging.set_verbosity_info()
check_cli_update()
app()

View File

@@ -5,7 +5,7 @@ from typing import Annotated, Optional
import typer
from huggingface_hub._inference_endpoints import InferenceEndpoint
from huggingface_hub._inference_endpoints import InferenceEndpoint, InferenceEndpointScalingMetric
from huggingface_hub.errors import HfHubHTTPError
from ._cli_utils import TokenOpt, get_hf_api, typer_factory
@@ -112,6 +112,36 @@ def deploy(
),
] = None,
token: TokenOpt = None,
min_replica: Annotated[
int,
typer.Option(
help="The minimum number of replicas (instances) to keep running for the Inference Endpoint.",
),
] = 1,
max_replica: Annotated[
int,
typer.Option(
help="The maximum number of replicas (instances) to scale to for the Inference Endpoint.",
),
] = 1,
scale_to_zero_timeout: Annotated[
Optional[int],
typer.Option(
help="The duration in minutes before an inactive endpoint is scaled to zero.",
),
] = None,
scaling_metric: Annotated[
Optional[InferenceEndpointScalingMetric],
typer.Option(
help="The metric reference for scaling.",
),
] = None,
scaling_threshold: Annotated[
Optional[float],
typer.Option(
help="The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.",
),
] = None,
) -> None:
"""Deploy an Inference Endpoint from a Hub repository."""
api = get_hf_api(token=token)
@@ -127,6 +157,11 @@ def deploy(
namespace=namespace,
task=task,
token=token,
min_replica=min_replica,
max_replica=max_replica,
scaling_metric=scaling_metric,
scaling_threshold=scaling_threshold,
scale_to_zero_timeout=scale_to_zero_timeout,
)
_print_endpoint(endpoint)
@@ -262,6 +297,18 @@ def update(
help="The duration in minutes before an inactive endpoint is scaled to zero.",
),
] = None,
scaling_metric: Annotated[
Optional[InferenceEndpointScalingMetric],
typer.Option(
help="The metric reference for scaling.",
),
] = None,
scaling_threshold: Annotated[
Optional[float],
typer.Option(
help="The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.",
),
] = None,
token: TokenOpt = None,
) -> None:
"""Update an existing endpoint."""
@@ -280,6 +327,8 @@ def update(
min_replica=min_replica,
max_replica=max_replica,
scale_to_zero_timeout=scale_to_zero_timeout,
scaling_metric=scaling_metric,
scaling_threshold=scaling_threshold,
token=token,
)
except HfHubHTTPError as error:

View File

@@ -66,7 +66,7 @@ def repo_create(
help="Hugging Face Spaces SDK type. Required when --type is set to 'space'.",
),
] = None,
private: PrivateOpt = False,
private: PrivateOpt = None,
token: TokenOpt = None,
exist_ok: Annotated[
bool,

View File

@@ -80,7 +80,7 @@ def upload(
] = None,
repo_type: RepoTypeOpt = RepoType.model,
revision: RevisionOpt = None,
private: PrivateOpt = False,
private: PrivateOpt = None,
include: Annotated[
Optional[list[str]],
typer.Option(

View File

@@ -38,7 +38,7 @@ def upload_large_folder(
],
repo_type: RepoTypeOpt = RepoType.model,
revision: RevisionOpt = None,
private: PrivateOpt = False,
private: PrivateOpt = None,
include: Annotated[
Optional[list[str]],
typer.Option(

View File

@@ -162,6 +162,26 @@ HF_ASSETS_CACHE = os.path.expandvars(
HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE"))
def is_offline_mode() -> bool:
"""Returns whether we are in offline mode for the Hub.
When offline mode is enabled, all HTTP requests made with `get_session` will raise an `OfflineModeIsEnabled` exception.
Example:
```py
from huggingface_hub import is_offline_mode
def list_files(repo_id: str):
if is_offline_mode():
... # list files from local cache (degraded experience but still functional)
else:
... # list files from Hub (complete experience)
```
"""
return HF_HUB_OFFLINE
# File created to mark that the version check has been done.
# Check is performed once per 24 hours at most.
CHECK_FOR_UPDATE_DONE_PATH = os.path.join(HF_HOME, ".check_for_update_done")

View File

@@ -316,7 +316,7 @@ def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
base, *meta = get_args(value)
if not _is_required_or_notrequired(base):
base = NotRequired[base]
type_hints[key] = Annotated[tuple([base] + list(meta))]
type_hints[key] = Annotated[tuple([base] + list(meta))] # type: ignore
elif not _is_required_or_notrequired(value):
type_hints[key] = NotRequired[value]

View File

@@ -84,9 +84,15 @@ class HfHubHTTPError(HTTPError, OSError):
"""Append additional information to the `HfHubHTTPError` initial message."""
self.args = (self.args[0] + additional_message,) + self.args[1:]
@classmethod
def _reconstruct_hf_hub_http_error(
cls, message: str, response: Response, server_message: Optional[str]
) -> "HfHubHTTPError":
return cls(message, response=response, server_message=server_message)
def __reduce_ex__(self, protocol):
"""Fix pickling of Exception subclass with kwargs. We need to override __reduce_ex__ of the parent class"""
return (self.__class__, (str(self),), {"response": self.response, "server_message": self.server_message})
return (self.__class__._reconstruct_hf_hub_http_error, (str(self), self.response, self.server_message))
# INFERENCE CLIENT ERRORS

View File

@@ -39,7 +39,13 @@ from .utils import (
tqdm,
validate_hf_hub_args,
)
from .utils._http import _adjust_range_header, http_backoff, http_stream_backoff
from .utils._http import (
_DEFAULT_RETRY_ON_EXCEPTIONS,
_DEFAULT_RETRY_ON_STATUS_CODES,
_adjust_range_header,
http_backoff,
http_stream_backoff,
)
from .utils._runtime import is_xet_available
from .utils._typing import HTTP_METHOD_T
from .utils.sha import sha_fileobj
@@ -63,6 +69,9 @@ REGEX_SHA256 = re.compile(r"^[0-9a-f]{64}$")
_are_symlinks_supported_in_dir: dict[str, bool] = {}
# Internal retry timeout for metadata fetch when no local file exists
_ETAG_RETRY_TIMEOUT = 60
def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool:
"""Return whether the symlinks are supported on the machine.
@@ -264,30 +273,38 @@ def hf_hub_url(
return url
def _httpx_follow_relative_redirects(method: HTTP_METHOD_T, url: str, **httpx_kwargs) -> httpx.Response:
def _httpx_follow_relative_redirects(
method: HTTP_METHOD_T, url: str, *, retry_on_errors: bool = False, **httpx_kwargs
) -> httpx.Response:
"""Perform an HTTP request with backoff and follow relative redirects only.
This is useful to follow a redirection to a renamed repository without following redirection to a CDN.
A backoff mechanism retries the HTTP call on 5xx errors and network errors.
A backoff mechanism retries the HTTP call on errors (429, 5xx, timeout, network errors).
Args:
method (`str`):
HTTP method, such as 'GET' or 'HEAD'.
url (`str`):
The URL of the resource to fetch.
retry_on_errors (`bool`, *optional*, defaults to `False`):
Whether to retry on errors. If False, no retry is performed (fast fallback to local cache).
If True, uses default retry behavior (429, 5xx, timeout, network errors).
**httpx_kwargs (`dict`, *optional*):
Params to pass to `httpx.request`.
"""
# if `retry_on_errors=False`, disable all retries for fast fallback to cache
no_retry_kwargs: dict[str, Any] = (
{} if retry_on_errors else {"retry_on_exceptions": (), "retry_on_status_codes": ()}
)
while True:
# Make the request
response = http_backoff(
method=method,
url=url,
**httpx_kwargs,
follow_redirects=False,
retry_on_exceptions=(),
retry_on_status_codes=(429,),
**no_retry_kwargs,
)
hf_raise_for_status(response)
@@ -1131,8 +1148,31 @@ def _hf_hub_download_to_cache_dir(
if not force_download:
return pointer_path
# Otherwise, raise appropriate error
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
if isinstance(head_call_error, _DEFAULT_RETRY_ON_EXCEPTIONS) or (
isinstance(head_call_error, HfHubHTTPError)
and head_call_error.response.status_code in _DEFAULT_RETRY_ON_STATUS_CODES
):
logger.info("No local file found. Retrying..")
(url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = (
_get_metadata_or_catch_error(
repo_id=repo_id,
filename=filename,
repo_type=repo_type,
revision=revision,
endpoint=endpoint,
etag_timeout=_ETAG_RETRY_TIMEOUT,
headers=headers,
token=token,
local_files_only=local_files_only,
storage_folder=storage_folder,
relative_filename=relative_filename,
retry_on_errors=True,
)
)
# If still error, raise
if head_call_error is not None:
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
# From now on, etag, commit_hash, url and size are not None.
assert etag is not None, "etag must have been retrieved from server"
@@ -1300,9 +1340,30 @@ def _hf_hub_download_to_local_dir(
)
if not force_download:
return local_path
elif not force_download:
if isinstance(head_call_error, _DEFAULT_RETRY_ON_EXCEPTIONS) or (
isinstance(head_call_error, HfHubHTTPError)
and head_call_error.response.status_code in _DEFAULT_RETRY_ON_STATUS_CODES
):
logger.info("No local file found. Retrying..")
(url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = (
_get_metadata_or_catch_error(
repo_id=repo_id,
filename=filename,
repo_type=repo_type,
revision=revision,
endpoint=endpoint,
etag_timeout=_ETAG_RETRY_TIMEOUT,
headers=headers,
token=token,
local_files_only=local_files_only,
retry_on_errors=True,
)
)
# Otherwise => raise
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
# If still error, raise
if head_call_error is not None:
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
# From now on, etag, commit_hash, url and size are not None.
assert etag is not None, "etag must have been retrieved from server"
@@ -1501,12 +1562,13 @@ def try_to_load_from_cache(
def get_hf_file_metadata(
url: str,
token: Union[bool, str, None] = None,
timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT,
timeout: Optional[float] = constants.HF_HUB_ETAG_TIMEOUT,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[dict, str, None] = None,
headers: Optional[dict[str, str]] = None,
endpoint: Optional[str] = None,
retry_on_errors: bool = False,
) -> HfFileMetadata:
"""Fetch metadata of a file versioned on the Hub for a given url.
@@ -1531,6 +1593,9 @@ def get_hf_file_metadata(
Additional headers to be sent with the request.
endpoint (`str`, *optional*):
Endpoint of the Hub. Defaults to <https://huggingface.co>.
retry_on_errors (`bool`, *optional*, defaults to `False`):
Whether to retry on errors (429, 5xx, timeout, network errors).
If False, no retry for fast fallback to local cache.
Returns:
A [`HfFileMetadata`] object containing metadata such as location, etag, size and
@@ -1546,7 +1611,9 @@ def get_hf_file_metadata(
hf_headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
# Retrieve metadata
response = _httpx_follow_relative_redirects(method="HEAD", url=url, headers=hf_headers, timeout=timeout)
response = _httpx_follow_relative_redirects(
method="HEAD", url=url, headers=hf_headers, timeout=timeout, retry_on_errors=retry_on_errors
)
hf_raise_for_status(response)
# Return
@@ -1579,6 +1646,7 @@ def _get_metadata_or_catch_error(
local_files_only: bool,
relative_filename: Optional[str] = None, # only used to store `.no_exists` in cache
storage_folder: Optional[str] = None, # only used to store `.no_exists` in cache
retry_on_errors: bool = False,
) -> Union[
# Either an exception is caught and returned
tuple[None, None, None, None, None, Exception],
@@ -1621,7 +1689,12 @@ def _get_metadata_or_catch_error(
try:
try:
metadata = get_hf_file_metadata(
url=url, timeout=etag_timeout, headers=headers, token=token, endpoint=endpoint
url=url,
timeout=etag_timeout,
headers=headers,
token=token,
endpoint=endpoint,
retry_on_errors=retry_on_errors,
)
except RemoteEntryNotFoundError as http_error:
if storage_folder is not None and relative_filename is not None:

View File

@@ -59,7 +59,7 @@ from ._commit_api import (
_upload_files,
_warn_on_overwriting_operations,
)
from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType
from ._inference_endpoints import InferenceEndpoint, InferenceEndpointScalingMetric, InferenceEndpointType
from ._jobs_api import JobInfo, JobSpec, ScheduledJobInfo, _create_job_spec
from ._space_api import SpaceHardware, SpaceRuntime, SpaceStorage, SpaceVariable
from ._upload_large_folder import upload_large_folder_internal
@@ -75,6 +75,7 @@ from .errors import (
BadRequestError,
GatedRepoError,
HfHubHTTPError,
LocalTokenNotFoundError,
RemoteEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
@@ -194,6 +195,7 @@ ExpandSpaceProperty_T = Literal[
USERNAME_PLACEHOLDER = "hf_user"
_REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$")
_REGEX_HTTP_PROTOCOL = re.compile(r"https?://")
_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = (
"\nNote: Creating a commit assumes that the repo already exists on the"
@@ -238,28 +240,62 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
"""
input_hf_id = hf_id
hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else constants.ENDPOINT)
is_hf_url = hub_url in hf_id and "@" not in hf_id
# Get the hub_url (with or without protocol)
full_hub_url = hub_url if hub_url is not None else constants.ENDPOINT
hub_url_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", full_hub_url)
# Check if hf_id is a URL containing the hub_url (check both with and without protocol)
hf_id_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", hf_id)
is_hf_url = hub_url_without_protocol in hf_id_without_protocol and "@" not in hf_id
HFFS_PREFIX = "hf://"
if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists
hf_id = hf_id[len(HFFS_PREFIX) :]
# If it's a URL, strip the endpoint prefix to get the path
if is_hf_url:
# Remove protocol if present
hf_id_normalized = _REGEX_HTTP_PROTOCOL.sub("", hf_id)
# Remove the hub_url prefix to get the relative path
if hf_id_normalized.startswith(hub_url_without_protocol):
# Strip the hub URL and any leading slashes
hf_id = hf_id_normalized[len(hub_url_without_protocol) :].lstrip("/")
url_segments = hf_id.split("/")
is_hf_id = len(url_segments) <= 3
namespace: Optional[str]
if is_hf_url:
namespace, repo_id = url_segments[-2:]
if namespace == hub_url:
namespace = None
if len(url_segments) > 2 and hub_url not in url_segments[-3]:
repo_type = url_segments[-3]
elif namespace in constants.REPO_TYPES_MAPPING:
# Mean canonical dataset or model
repo_type = constants.REPO_TYPES_MAPPING[namespace]
namespace = None
# For URLs, we need to extract repo_type, namespace, repo_id
# Expected format after stripping endpoint: [repo_type]/namespace/repo_id or namespace/repo_id
if len(url_segments) >= 3:
# Check if first segment is a repo type
if url_segments[0] in constants.REPO_TYPES_MAPPING:
repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]]
namespace = url_segments[1]
repo_id = url_segments[2]
else:
# First segment is namespace
namespace = url_segments[0]
repo_id = url_segments[1]
repo_type = None
elif len(url_segments) == 2:
namespace = url_segments[0]
repo_id = url_segments[1]
# Check if namespace is actually a repo type mapping
if namespace in constants.REPO_TYPES_MAPPING:
# Mean canonical dataset or model
repo_type = constants.REPO_TYPES_MAPPING[namespace]
namespace = None
else:
repo_type = None
else:
# Single segment
repo_id = url_segments[0]
namespace = None
repo_type = None
elif is_hf_id:
if len(url_segments) == 3:
@@ -1694,6 +1730,9 @@ class HfApi:
self.headers = headers
self._thread_pool: Optional[ThreadPoolExecutor] = None
# /whoami-v2 is the only endpoint for which we may want to cache results
self._whoami_cache: dict[str, dict] = {}
def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]:
"""
Run a method in the background and return a Future instance.
@@ -1735,22 +1774,51 @@ class HfApi:
return self._thread_pool.submit(fn, *args, **kwargs)
@validate_hf_hub_args
def whoami(self, token: Union[bool, str, None] = None) -> dict:
def whoami(self, token: Union[bool, str, None] = None, *, cache: bool = False) -> dict:
"""
Call HF API to know "whoami".
If passing `cache=True`, the result will be cached for subsequent calls for the duration of the Python process. This is useful if you plan to call
`whoami` multiple times as this endpoint is heavily rate-limited for security reasons.
Args:
token (`bool` or `str`, *optional*):
A valid user access token (string). Defaults to the locally saved
token, which is the recommended method for authentication (see
https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
To disable authentication, pass `False`.
cache (`bool`, *optional*):
Whether to cache the result of the `whoami` call for subsequent calls.
If an error occurs during the first call, it won't be cached.
Defaults to `False`.
"""
# Get the effective token using the helper function get_token
effective_token = token or self.token or get_token() or True
token = self.token if token is None else token
if token is False:
raise ValueError("Cannot use `token=False` with `whoami` method as it requires authentication.")
if token is True or token is None:
token = get_token()
if token is None:
raise LocalTokenNotFoundError(
"Token is required to call the /whoami-v2 endpoint, but no token found. You must provide a token or be logged in to "
"Hugging Face with `hf auth login` or `huggingface_hub.login`. See https://huggingface.co/settings/tokens."
)
if cache and (cached_token := self._whoami_cache.get(token)):
return cached_token
# Call Hub
output = self._inner_whoami(token=token)
# Cache result and return
if cache:
self._whoami_cache[token] = output
return output
def _inner_whoami(self, token: str) -> dict:
r = get_session().get(
f"{self.endpoint}/api/whoami-v2",
headers=self._build_hf_headers(token=effective_token),
headers=self._build_hf_headers(token=token),
)
try:
hf_raise_for_status(r)
@@ -1758,16 +1826,22 @@ class HfApi:
if e.response.status_code == 401:
error_message = "Invalid user token."
# Check which token is the effective one and generate the error message accordingly
if effective_token == _get_token_from_google_colab():
if token == _get_token_from_google_colab():
error_message += " The token from Google Colab vault is invalid. Please update it from the UI."
elif effective_token == _get_token_from_environment():
elif token == _get_token_from_environment():
error_message += (
" The token from HF_TOKEN environment variable is invalid. "
"Note that HF_TOKEN takes precedence over `hf auth login`."
)
elif effective_token == _get_token_from_file():
elif token == _get_token_from_file():
error_message += " The token stored is invalid. Please run `hf auth login` to update it."
raise HfHubHTTPError(error_message, response=e.response) from e
if e.response.status_code == 429:
error_message = (
"You've hit the rate limit for the /whoami-v2 endpoint, which is intentionally strict for security reasons."
" If you're calling it often, consider caching the response with `whoami(..., cache=True)`."
)
raise HfHubHTTPError(error_message, response=e.response) from e
raise
return r.json()
@@ -3727,7 +3801,7 @@ class HfApi:
self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token)
if repo_type is None or repo_type == constants.REPO_TYPE_MODEL:
return RepoUrl(f"{self.endpoint}/{repo_id}")
return RepoUrl(f"{self.endpoint}/{repo_type}/{repo_id}")
return RepoUrl(f"{self.endpoint}/{constants.REPO_TYPES_URL_PREFIXES[repo_type]}{repo_id}")
except HfHubHTTPError:
raise err
else:
@@ -5089,7 +5163,7 @@ class HfApi:
ignore_patterns (`list[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not uploaded.
num_workers (`int`, *optional*):
Number of workers to start. Defaults to `os.cpu_count() - 2` (minimum 2).
Number of workers to start. Defaults to half of CPU cores (minimum 1).
A higher number of workers may speed up the process if your machine allows it. However, on machines with a
slower connection, it is recommended to keep the number of workers low to ensure better resumability.
Indeed, partially uploaded files will have to be completely re-uploaded if the process is interrupted.
@@ -5169,7 +5243,7 @@ class HfApi:
*,
url: str,
token: Union[bool, str, None] = None,
timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT,
timeout: Optional[float] = constants.HF_HUB_ETAG_TIMEOUT,
) -> HfFileMetadata:
"""Fetch metadata of a file versioned on the Hub for a given url.
@@ -7409,6 +7483,8 @@ class HfApi:
account_id: Optional[str] = None,
min_replica: int = 1,
max_replica: int = 1,
scaling_metric: Optional[InferenceEndpointScalingMetric] = None,
scaling_threshold: Optional[float] = None,
scale_to_zero_timeout: Optional[int] = None,
revision: Optional[str] = None,
task: Optional[str] = None,
@@ -7449,6 +7525,12 @@ class HfApi:
scaling to zero, set this value to 0 and adjust `scale_to_zero_timeout` accordingly. Defaults to 1.
max_replica (`int`, *optional*):
The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1.
scaling_metric (`str` or [`InferenceEndpointScalingMetric `], *optional*):
The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided. Defaults to
None (meaning: let the HF Endpoints service specify the metric).
scaling_threshold (`float`, *optional*):
The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.
Defaults to None (meaning: let the HF Endpoints service specify the threshold).
scale_to_zero_timeout (`int`, *optional*):
The duration in minutes before an inactive endpoint is scaled to zero, or no scaling to zero if
set to None and `min_replica` is not 0. Defaults to None.
@@ -7600,6 +7682,8 @@ class HfApi:
},
"type": type,
}
if scaling_metric:
payload["compute"]["scaling"]["measure"] = {scaling_metric: scaling_threshold} # type: ignore
if env:
payload["model"]["env"] = env
if secrets:
@@ -7764,6 +7848,8 @@ class HfApi:
min_replica: Optional[int] = None,
max_replica: Optional[int] = None,
scale_to_zero_timeout: Optional[int] = None,
scaling_metric: Optional[InferenceEndpointScalingMetric] = None,
scaling_threshold: Optional[float] = None,
# Model update
repository: Optional[str] = None,
framework: Optional[str] = None,
@@ -7804,7 +7890,12 @@ class HfApi:
The maximum number of replicas (instances) to scale to for the Inference Endpoint.
scale_to_zero_timeout (`int`, *optional*):
The duration in minutes before an inactive endpoint is scaled to zero.
scaling_metric (`str` or [`InferenceEndpointScalingMetric `], *optional*):
The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided.
Defaults to None.
scaling_threshold (`float`, *optional*):
The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.
Defaults to None.
repository (`str`, *optional*):
The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
framework (`str`, *optional*):
@@ -7858,6 +7949,8 @@ class HfApi:
payload["compute"]["scaling"]["minReplica"] = min_replica
if scale_to_zero_timeout is not None:
payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout
if scaling_metric:
payload["compute"]["scaling"]["measure"] = {scaling_metric: scaling_threshold}
if repository is not None:
payload["model"]["repository"] = repository
if framework is not None:
@@ -8333,10 +8426,10 @@ class HfApi:
collection_slug (`str`):
Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`.
item_id (`str`):
ID of the item to add to the collection. It can be the ID of a repo on the Hub (e.g. `"facebook/bart-large-mnli"`)
or a paper id (e.g. `"2307.09288"`).
Id of the item to add to the collection. Use the repo_id for repos/spaces/datasets,
the paper id for papers, or the slug of another collection (e.g. `"moonshotai/kimi-k2"`).
item_type (`str`):
Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`.
Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"`, `"paper"` or `"collection"`.
note (`str`, *optional*):
A note to attach to the item in the collection. The maximum size for a note is 500 characters.
exists_ok (`bool`, *optional*):
@@ -9792,6 +9885,75 @@ class HfApi:
hf_raise_for_status(r)
return PaperInfo(**r.json())
def list_daily_papers(
self,
*,
date: Optional[str] = None,
token: Union[bool, str, None] = None,
week: Optional[str] = None,
month: Optional[str] = None,
submitter: Optional[str] = None,
sort: Optional[Literal["publishedAt", "trending"]] = None,
p: Optional[int] = None,
limit: Optional[int] = None,
) -> Iterable[PaperInfo]:
"""
List the daily papers published on a given date on the Hugging Face Hub.
Args:
date (`str`, *optional*):
Date in ISO format (YYYY-MM-DD) for which to fetch daily papers.
Defaults to most recent ones.
token (Union[bool, str, None], *optional*):
A valid user access token (string). Defaults to the locally saved
token. To disable authentication, pass `False`.
week (`str`, *optional*):
Week in ISO format (YYYY-Www) for which to fetch daily papers. Example, `2025-W09`.
month (`str`, *optional*):
Month in ISO format (YYYY-MM) for which to fetch daily papers. Example, `2025-02`.
submitter (`str`, *optional*):
Username of the submitter to filter daily papers.
sort (`Literal["publishedAt", "trending"]`, *optional*):
Sort order for the daily papers. Can be either by `publishedAt` or by `trending`.
Defaults to `"publishedAt"`
p (`int`, *optional*):
Page number for pagination. Defaults to 0.
limit (`int`, *optional*):
Limit of papers to fetch. Defaults to 50.
Returns:
`Iterable[PaperInfo]`: an iterable of [`huggingface_hub.hf_api.PaperInfo`] objects.
Example:
```python
>>> from huggingface_hub import HfApi
>>> api = HfApi()
>>> list(api.list_daily_papers(date="2025-10-29"))
```
"""
path = f"{self.endpoint}/api/daily_papers"
params = {
k: v
for k, v in {
"p": p,
"limit": limit,
"sort": sort,
"date": date,
"week": week,
"month": month,
"submitter": submitter,
}.items()
if v is not None
}
r = get_session().get(path, params=params, headers=self._build_hf_headers(token=token))
hf_raise_for_status(r)
for paper in r.json():
yield PaperInfo(**paper)
def auth_check(
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
) -> None:
@@ -9931,7 +10093,7 @@ class HfApi:
timeout=timeout,
)
response = get_session().post(
f"https://huggingface.co/api/jobs/{namespace}",
f"{self.endpoint}/api/jobs/{namespace}",
json=job_spec,
headers=self._build_hf_headers(token=token),
)
@@ -9994,7 +10156,7 @@ class HfApi:
try:
with get_session().stream(
"GET",
f"https://huggingface.co/api/jobs/{namespace}/{job_id}/logs",
f"{self.endpoint}/api/jobs/{namespace}/{job_id}/logs",
headers=self._build_hf_headers(token=token),
timeout=120,
) as response:
@@ -10022,7 +10184,7 @@ class HfApi:
job_status = (
get_session()
.get(
f"https://huggingface.co/api/jobs/{namespace}/{job_id}",
f"{self.endpoint}/api/jobs/{namespace}/{job_id}",
headers=self._build_hf_headers(token=token),
)
.json()
@@ -10362,7 +10524,7 @@ class HfApi:
if suspend is not None:
input_json["suspend"] = suspend
response = get_session().post(
f"https://huggingface.co/api/scheduled-jobs/{namespace}",
f"{self.endpoint}/api/scheduled-jobs/{namespace}",
json=input_json,
headers=self._build_hf_headers(token=token),
)
@@ -10813,6 +10975,7 @@ space_info = api.space_info
list_papers = api.list_papers
paper_info = api.paper_info
list_daily_papers = api.list_daily_papers
repo_exists = api.repo_exists
revision_exists = api.revision_exists

View File

@@ -123,32 +123,43 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
> layer. For better performance and reliability, it's recommended to use `HfApi` methods when possible.
Args:
token (`str` or `bool`, *optional*):
endpoint (`str`, *optional*):
Endpoint of the Hub. Defaults to <https://huggingface.co>.
token (`bool` or `str`, *optional*):
A valid user access token (string). Defaults to the locally saved
token, which is the recommended method for authentication (see
https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
To disable authentication, pass `False`.
endpoint (`str`, *optional*):
Endpoint of the Hub. Defaults to <https://huggingface.co>.
block_size (`int`, *optional*):
Block size for reading and writing files.
expand_info (`bool`, *optional*):
Whether to expand the information of the files.
**storage_options (`dict`, *optional*):
Additional options for the filesystem. See [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.__init__).
Usage:
```python
>>> from huggingface_hub import HfFileSystem
>>> fs = HfFileSystem()
>>> from huggingface_hub import hffs
>>> # List files
>>> fs.glob("my-username/my-model/*.bin")
>>> hffs.glob("my-username/my-model/*.bin")
['my-username/my-model/pytorch_model.bin']
>>> fs.ls("datasets/my-username/my-dataset", detail=False)
>>> hffs.ls("datasets/my-username/my-dataset", detail=False)
['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json']
>>> # Read/write files
>>> with fs.open("my-username/my-model/pytorch_model.bin") as f:
>>> with hffs.open("my-username/my-model/pytorch_model.bin") as f:
... data = f.read()
>>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f:
>>> with hffs.open("my-username/my-model/pytorch_model.bin", "wb") as f:
... f.write(data)
```
Specify a token for authentication:
```python
>>> from huggingface_hub import HfFileSystem
>>> hffs = HfFileSystem(token=token)
```
"""
root_marker = ""
@@ -160,6 +171,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
endpoint: Optional[str] = None,
token: Union[bool, str, None] = None,
block_size: Optional[int] = None,
expand_info: Optional[bool] = None,
**storage_options,
):
super().__init__(*args, **storage_options)
@@ -167,6 +179,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
self.token = token
self._api = HfApi(endpoint=endpoint, token=token)
self.block_size = block_size
self.expand_info = expand_info
# Maps (repo_type, repo_id, revision) to a 2-tuple with:
# * the 1st element indicating whether the repositoy and the revision exist
# * the 2nd element being the exception raised if the repository or revision doesn't exist
@@ -330,14 +343,14 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
(resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision), None
)
def _open(
def _open( # type: ignore[override]
self,
path: str,
mode: str = "rb",
revision: Optional[str] = None,
block_size: Optional[int] = None,
revision: Optional[str] = None,
**kwargs,
) -> "HfFileSystemFile":
) -> Union["HfFileSystemFile", "HfFileSystemStreamFile"]:
block_size = block_size if block_size is not None else self.block_size
if block_size is not None:
kwargs["block_size"] = block_size
@@ -453,9 +466,12 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
recursive: bool = False,
refresh: bool = False,
revision: Optional[str] = None,
expand_info: bool = False,
expand_info: Optional[bool] = None,
maxdepth: Optional[int] = None,
):
expand_info = (
expand_info if expand_info is not None else (self.expand_info if self.expand_info is not None else False)
)
resolved_path = self.resolve_path(path, revision=revision)
path = resolved_path.unresolve()
root_path = HfFileSystemResolvedPath(
@@ -581,7 +597,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
yield from super().walk(path, *args, **kwargs)
def glob(self, path: str, **kwargs) -> list[str]:
def glob(self, path: str, maxdepth: Optional[int] = None, **kwargs) -> list[str]:
"""
Find files by glob-matching.
@@ -595,7 +611,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
`list[str]`: List of paths matching the pattern.
"""
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
return super().glob(path, **kwargs)
return super().glob(path, maxdepth=maxdepth, **kwargs)
def find(
self,
@@ -756,7 +772,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
resolved_path = self.resolve_path(path, revision=revision)
path = resolved_path.unresolve()
expand_info = kwargs.get(
"expand_info", False
"expand_info", self.expand_info if self.expand_info is not None else False
) # don't expose it as a parameter in the public API to follow the spec
if not resolved_path.path_in_repo:
# Path is the root directory
@@ -1189,7 +1205,6 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
"GET",
url,
headers=headers,
retry_on_status_codes=(500, 502, 503, 504),
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
)
)
@@ -1255,3 +1270,6 @@ def make_instance(cls, args, kwargs, instance_state):
for attr, state_value in instance_state.items():
setattr(fs, attr, state_value)
return fs
hffs = HfFileSystem()

View File

@@ -367,7 +367,7 @@ class ModelHubMixin:
if is_simple_optional_type(expected_type):
if value is None:
return None
expected_type = unwrap_simple_optional_type(expected_type)
expected_type = unwrap_simple_optional_type(expected_type) # type: ignore[assignment]
# Dataclass => handle it
if is_dataclass(expected_type):
return _load_dataclass(expected_type, value) # type: ignore[return-value]

View File

@@ -135,7 +135,7 @@ class InferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
@@ -452,6 +452,7 @@ class InferenceClient:
api_key=self.token,
)
response = self._inner_post(request_parameters)
response = provider_helper.get_response(response, request_params=request_parameters)
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
@overload
@@ -1028,7 +1029,7 @@ class InferenceClient:
normalize: Optional[bool] = None,
prompt_name: Optional[str] = None,
truncate: Optional[bool] = None,
truncation_direction: Optional[Literal["Left", "Right"]] = None,
truncation_direction: Optional[Literal["left", "right"]] = None,
model: Optional[str] = None,
) -> "np.ndarray":
"""
@@ -1053,7 +1054,7 @@ class InferenceClient:
truncate (`bool`, *optional*):
Whether to truncate the embeddings or not.
Only available on server powered by Text-Embedding-Inference.
truncation_direction (`Literal["Left", "Right"]`, *optional*):
truncation_direction (`Literal["left", "right"]`, *optional*):
Which side of the input should be truncated when `truncate=True` is passed.
Returns:
@@ -3195,10 +3196,7 @@ class InferenceClient:
)
response = self._inner_post(request_parameters)
output = _bytes_to_dict(response)
return [
ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
for label, score in zip(output["labels"], output["scores"])
]
return ZeroShotClassificationOutputElement.parse_obj_as_list(output)
def zero_shot_image_classification(
self,

View File

@@ -144,7 +144,7 @@ def _open_as_mime_bytes(content: Optional[ContentT]) -> Optional[MimeBytes]:
if hasattr(content, "read"): # duck-typing instead of isinstance(content, BinaryIO)
logger.debug("Reading content from BinaryIO")
data = content.read()
mime_type = mimetypes.guess_type(content.name)[0] if hasattr(content, "name") else None
mime_type = mimetypes.guess_type(str(content.name))[0] if hasattr(content, "name") else None
if isinstance(data, str):
raise TypeError("Expected binary stream (bytes), but got text stream")
return MimeBytes(data, mime_type=mime_type)

View File

@@ -126,7 +126,7 @@ class AsyncInferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
@@ -472,6 +472,7 @@ class AsyncInferenceClient:
api_key=self.token,
)
response = await self._inner_post(request_parameters)
response = provider_helper.get_response(response, request_params=request_parameters)
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
@overload
@@ -1055,7 +1056,7 @@ class AsyncInferenceClient:
normalize: Optional[bool] = None,
prompt_name: Optional[str] = None,
truncate: Optional[bool] = None,
truncation_direction: Optional[Literal["Left", "Right"]] = None,
truncation_direction: Optional[Literal["left", "right"]] = None,
model: Optional[str] = None,
) -> "np.ndarray":
"""
@@ -1080,7 +1081,7 @@ class AsyncInferenceClient:
truncate (`bool`, *optional*):
Whether to truncate the embeddings or not.
Only available on server powered by Text-Embedding-Inference.
truncation_direction (`Literal["Left", "Right"]`, *optional*):
truncation_direction (`Literal["left", "right"]`, *optional*):
Which side of the input should be truncated when `truncate=True` is passed.
Returns:
@@ -3245,10 +3246,7 @@ class AsyncInferenceClient:
)
response = await self._inner_post(request_parameters)
output = _bytes_to_dict(response)
return [
ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
for label, score in zip(output["labels"], output["scores"])
]
return ZeroShotClassificationOutputElement.parse_obj_as_list(output)
async def zero_shot_image_classification(
self,

View File

@@ -19,6 +19,8 @@ import types
from dataclasses import asdict, dataclass
from typing import Any, TypeVar, Union, get_args
from typing_extensions import dataclass_transform
T = TypeVar("T", bound="BaseInferenceType")
@@ -29,6 +31,7 @@ def _repr_with_extra(self):
return f"{self.__class__.__name__}({', '.join(f'{k}={self.__dict__[k]!r}' for k in fields + other_fields)})"
@dataclass_transform()
def dataclass_with_extra(cls: type[T]) -> type[T]:
"""Decorator to add a custom __repr__ method to a dataclass, showing all fields, including extra ones.

View File

@@ -8,7 +8,7 @@ from typing import Literal, Optional, Union
from .base import BaseInferenceType, dataclass_with_extra
FeatureExtractionInputTruncationDirection = Literal["Left", "Right"]
FeatureExtractionInputTruncationDirection = Literal["left", "right"]
@dataclass_with_extra

View File

@@ -5,8 +5,8 @@ import traceback
from typing import Optional
import typer
from rich import print
from ...utils import ANSI
from ._cli_hacks import _async_prompt, _patch_anyio_open_process
from .agent import Agent
from .utils import _load_agent_config
@@ -55,10 +55,10 @@ async def run_agent(
if first_sigint:
first_sigint = False
abort_event.set()
print("\n[red]Interrupted. Press Ctrl+C again to quit.[/red]", flush=True)
print(ANSI.red("\nInterrupted. Press Ctrl+C again to quit."), flush=True)
return
print("\n[red]Exiting...[/red]", flush=True)
print(ANSI.red("\nExiting..."), flush=True)
exit_event.set()
try:
@@ -75,8 +75,12 @@ async def run_agent(
if len(inputs) > 0:
print(
"[bold blue]Some initial inputs are required by the agent. "
"Please provide a value or leave empty to load from env.[/bold blue]"
ANSI.bold(
ANSI.blue(
"Some initial inputs are required by the agent. "
"Please provide a value or leave empty to load from env."
)
)
)
for input_item in inputs:
input_id = input_item["id"]
@@ -98,15 +102,17 @@ async def run_agent(
if not input_usages:
print(
f"[yellow]Input '{input_id}' defined in config but not used by any server or as an API key."
" Skipping.[/yellow]"
ANSI.yellow(
f"Input '{input_id}' defined in config but not used by any server or as an API key."
" Skipping."
)
)
continue
# Prompt user for input
env_variable_key = input_id.replace("-", "_").upper()
print(
f"[blue]{input_id}[/blue]: {description}. (default: load from {env_variable_key}).",
ANSI.blue(f"{input_id}") + f": {description}. (default: load from {env_variable_key}).",
end=" ",
)
user_input = (await _async_prompt(exit_event=exit_event)).strip()
@@ -118,10 +124,12 @@ async def run_agent(
if not final_value:
final_value = os.getenv(env_variable_key, "")
if final_value:
print(f"[green]Value successfully loaded from '{env_variable_key}'[/green]")
print(ANSI.green(f"Value successfully loaded from '{env_variable_key}'"))
else:
print(
f"[yellow]No value found for '{env_variable_key}' in environment variables. Continuing.[/yellow]"
ANSI.yellow(
f"No value found for '{env_variable_key}' in environment variables. Continuing."
)
)
resolved_inputs[input_id] = final_value
@@ -150,9 +158,9 @@ async def run_agent(
prompt=prompt,
) as agent:
await agent.load_tools()
print(f"[bold blue]Agent loaded with {len(agent.available_tools)} tools:[/bold blue]")
print(ANSI.bold(ANSI.blue("Agent loaded with {} tools:".format(len(agent.available_tools)))))
for t in agent.available_tools:
print(f"[blue]{t.function.name}[/blue]")
print(ANSI.blue(f"{t.function.name}"))
while True:
abort_event.clear()
@@ -165,13 +173,13 @@ async def run_agent(
user_input = await _async_prompt(exit_event=exit_event)
first_sigint = True
except EOFError:
print("\n[red]EOF received, exiting.[/red]", flush=True)
print(ANSI.red("\nEOF received, exiting."), flush=True)
break
except KeyboardInterrupt:
if not first_sigint and abort_event.is_set():
continue
else:
print("\n[red]Keyboard interrupt during input processing.[/red]", flush=True)
print(ANSI.red("\nKeyboard interrupt during input processing."), flush=True)
break
try:
@@ -195,7 +203,7 @@ async def run_agent(
print(f"{call.function.arguments}", end="")
else:
print(
f"\n\n[green]Tool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}[/green]\n",
ANSI.green(f"\n\nTool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}\n"),
flush=True,
)
@@ -203,12 +211,12 @@ async def run_agent(
except Exception as e:
tb_str = traceback.format_exc()
print(f"\n[bold red]Error during agent run: {e}\n{tb_str}[/bold red]", flush=True)
print(ANSI.red(f"\nError during agent run: {e}\n{tb_str}"), flush=True)
first_sigint = True # Allow graceful interrupt for the next command
except Exception as e:
tb_str = traceback.format_exc()
print(f"\n[bold red]An unexpected error occurred: {e}\n{tb_str}[/bold red]", flush=True)
print(ANSI.red(f"\nAn unexpected error occurred: {e}\n{tb_str}"), flush=True)
raise e
finally:
@@ -236,10 +244,10 @@ def run(
try:
asyncio.run(run_agent(path))
except KeyboardInterrupt:
print("\n[red]Application terminated by KeyboardInterrupt.[/red]", flush=True)
print(ANSI.red("\nApplication terminated by KeyboardInterrupt."), flush=True)
raise typer.Exit(code=130)
except Exception as e:
print(f"\n[bold red]An unexpected error occurred: {e}[/bold red]", flush=True)
print(ANSI.red(f"\nAn unexpected error occurred: {e}"), flush=True)
raise e

View File

@@ -57,10 +57,10 @@ def format_result(result: "mcp_types.CallToolResult") -> str:
elif item.type == "resource":
resource = item.resource
if hasattr(resource, "text"):
if hasattr(resource, "text") and isinstance(resource.text, str):
formatted_parts.append(resource.text)
elif hasattr(resource, "blob"):
elif hasattr(resource, "blob") and isinstance(resource.blob, str):
formatted_parts.append(
f"[Binary Content ({resource.uri}): {resource.mimeType}, {_get_base64_size(resource.blob)} bytes]\n"
f"The task is complete and the content accessible to the User"

View File

@@ -38,8 +38,15 @@ from .nebius import (
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
from .openai import OpenAIConversationalTask
from .ovhcloud import OVHcloudConversationalTask
from .publicai import PublicAIConversationalTask
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
from .replicate import (
ReplicateAutomaticSpeechRecognitionTask,
ReplicateImageToImageTask,
ReplicateTask,
ReplicateTextToImageTask,
ReplicateTextToSpeechTask,
)
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
@@ -70,6 +77,7 @@ PROVIDER_T = Literal[
"novita",
"nscale",
"openai",
"ovhcloud",
"publicai",
"replicate",
"sambanova",
@@ -166,10 +174,14 @@ PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = {
"openai": {
"conversational": OpenAIConversationalTask(),
},
"ovhcloud": {
"conversational": OVHcloudConversationalTask(),
},
"publicai": {
"conversational": PublicAIConversationalTask(),
},
"replicate": {
"automatic-speech-recognition": ReplicateAutomaticSpeechRecognitionTask(),
"image-to-image": ReplicateImageToImageTask(),
"text-to-image": ReplicateTextToImageTask(),
"text-to-speech": ReplicateTextToSpeechTask(),

View File

@@ -32,6 +32,7 @@ HARDCODED_MODEL_INFERENCE_MAPPING: dict[str, dict[str, InferenceProviderMapping]
"hyperbolic": {},
"nebius": {},
"nscale": {},
"ovhcloud": {},
"replicate": {},
"sambanova": {},
"scaleway": {},

View File

@@ -112,7 +112,7 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask):
text = _as_dict(response)["text"]
if not isinstance(text, str):
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
return text
return {"text": text}
class FalAITextToImageTask(FalAITask):

View File

@@ -72,6 +72,67 @@ class ReplicateTextToSpeechTask(ReplicateTask):
return payload
class ReplicateAutomaticSpeechRecognitionTask(ReplicateTask):
def __init__(self) -> None:
super().__init__("automatic-speech-recognition")
def _prepare_payload_as_dict(
self,
inputs: Any,
parameters: dict,
provider_mapping_info: InferenceProviderMapping,
) -> Optional[dict]:
mapped_model = provider_mapping_info.provider_id
audio_url = _as_url(inputs, default_mime_type="audio/wav")
payload: dict[str, Any] = {
"input": {
**{"audio": audio_url},
**filter_none(parameters),
}
}
if ":" in mapped_model:
payload["version"] = mapped_model.split(":", 1)[1]
return payload
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
response_dict = _as_dict(response)
output = response_dict.get("output")
if isinstance(output, str):
return {"text": output}
if isinstance(output, list) and output:
first_item = output[0]
if isinstance(first_item, str):
return {"text": first_item}
if isinstance(first_item, dict):
output = first_item
text: Optional[str] = None
if isinstance(output, dict):
transcription = output.get("transcription")
if isinstance(transcription, str):
text = transcription
translation = output.get("translation")
if isinstance(translation, str):
text = translation
txt_file = output.get("txt_file")
if isinstance(txt_file, str):
text_response = get_session().get(txt_file)
text_response.raise_for_status()
text = text_response.text
if text is not None:
return {"text": text}
raise ValueError("Received malformed response from Replicate automatic-speech-recognition API")
class ReplicateImageToImageTask(ReplicateTask):
def __init__(self):
super().__init__("image-to-image")

View File

@@ -102,7 +102,7 @@ def split_state_dict_into_shards_factory(
continue
# If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block`
storage_id = get_storage_id(tensor)
storage_id = get_storage_id(tensor) # type: ignore[invalid-argument-type]
if storage_id is not None:
if storage_id in storage_id_to_tensors:
# We skip this tensor for now and will reassign to correct shard later
@@ -114,7 +114,7 @@ def split_state_dict_into_shards_factory(
storage_id_to_tensors[storage_id] = [key]
# Compute tensor size
tensor_size = get_storage_size(tensor)
tensor_size = get_storage_size(tensor) # type: ignore[invalid-argument-type]
# If this tensor is bigger than the maximal size, we put it in its own shard
if tensor_size > max_shard_size:

View File

@@ -54,6 +54,7 @@ from ._headers import build_hf_headers, get_token_to_send
from ._http import (
ASYNC_CLIENT_FACTORY_T,
CLIENT_FACTORY_T,
RateLimitInfo,
close_session,
fix_hf_endpoint_in_url,
get_async_session,
@@ -61,6 +62,7 @@ from ._http import (
hf_raise_for_status,
http_backoff,
http_stream_backoff,
parse_ratelimit_headers,
set_async_client_factory,
set_client_factory,
)

View File

@@ -24,7 +24,7 @@ def cached_assets_path(
subfolder: str = "default",
*,
assets_dir: Union[str, Path, None] = None,
):
) -> Path:
"""Return a folder path to cache arbitrary files.
`huggingface_hub` provides a canonical folder path to store assets. This is the

View File

@@ -23,9 +23,9 @@ import threading
import time
import uuid
from contextlib import contextmanager
from http import HTTPStatus
from dataclasses import dataclass
from shlex import quote
from typing import Any, Callable, Generator, Optional, Union
from typing import Any, Callable, Generator, Mapping, Optional, Union
import httpx
@@ -48,6 +48,94 @@ from ._typing import HTTP_METHOD_T
logger = logging.get_logger(__name__)
@dataclass(frozen=True)
class RateLimitInfo:
"""
Parsed rate limit information from HTTP response headers.
Attributes:
resource_type (`str`): The type of resource being rate limited.
remaining (`int`): The number of requests remaining in the current window.
reset_in_seconds (`int`): The number of seconds until the rate limit resets.
limit (`int`, *optional*): The maximum number of requests allowed in the current window.
window_seconds (`int`, *optional*): The number of seconds in the current window.
"""
resource_type: str
remaining: int
reset_in_seconds: int
limit: Optional[int] = None
window_seconds: Optional[int] = None
# Regex patterns for parsing rate limit headers
# e.g.: "api";r=0;t=55 --> resource_type="api", r=0, t=55
_RATELIMIT_REGEX = re.compile(r"\"(?P<resource_type>\w+)\"\s*;\s*r\s*=\s*(?P<r>\d+)\s*;\s*t\s*=\s*(?P<t>\d+)")
# e.g.: "fixed window";"api";q=500;w=300 --> q=500, w=300
_RATELIMIT_POLICY_REGEX = re.compile(r"q\s*=\s*(?P<q>\d+).*?w\s*=\s*(?P<w>\d+)")
def parse_ratelimit_headers(headers: Mapping[str, str]) -> Optional[RateLimitInfo]:
"""Parse rate limit information from HTTP response headers.
Follows IETF draft: https://www.ietf.org/archive/id/draft-ietf-httpapi-ratelimit-headers-09.html
Only a subset is implemented.
Example:
```python
>>> from huggingface_hub.utils import parse_ratelimit_headers
>>> headers = {
... "ratelimit": '"api";r=0;t=55',
... "ratelimit-policy": '"fixed window";"api";q=500;w=300',
... }
>>> info = parse_ratelimit_headers(headers)
>>> info.remaining
0
>>> info.reset_in_seconds
55
```
"""
ratelimit: Optional[str] = None
policy: Optional[str] = None
for key in headers:
lower_key = key.lower()
if lower_key == "ratelimit":
ratelimit = headers[key]
elif lower_key == "ratelimit-policy":
policy = headers[key]
if not ratelimit:
return None
match = _RATELIMIT_REGEX.search(ratelimit)
if not match:
return None
resource_type = match.group("resource_type")
remaining = int(match.group("r"))
reset_in_seconds = int(match.group("t"))
limit: Optional[int] = None
window_seconds: Optional[int] = None
if policy:
policy_match = _RATELIMIT_POLICY_REGEX.search(policy)
if policy_match:
limit = int(policy_match.group("q"))
window_seconds = int(policy_match.group("w"))
return RateLimitInfo(
resource_type=resource_type,
remaining=remaining,
reset_in_seconds=reset_in_seconds,
limit=limit,
window_seconds=window_seconds,
)
# Both headers are used by the Hub to debug failed requests.
# `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB.
# If `X_AMZN_TRACE_ID` is set, the Hub will use it as well.
@@ -79,7 +167,7 @@ def hf_request_event_hook(request: httpx.Request) -> None:
- Add a request ID to the request headers
- Log the request if debug mode is enabled
"""
if constants.HF_HUB_OFFLINE:
if constants.is_offline_mode():
raise OfflineModeIsEnabled(
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
)
@@ -249,6 +337,10 @@ if hasattr(os, "register_at_fork"):
os.register_at_fork(after_in_child=close_session)
_DEFAULT_RETRY_ON_EXCEPTIONS: tuple[type[Exception], ...] = (httpx.TimeoutException, httpx.NetworkError)
_DEFAULT_RETRY_ON_STATUS_CODES: tuple[int, ...] = (429, 500, 502, 503, 504)
def _http_backoff_base(
method: HTTP_METHOD_T,
url: str,
@@ -256,11 +348,8 @@ def _http_backoff_base(
max_retries: int = 5,
base_wait_time: float = 1,
max_wait_time: float = 8,
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
httpx.TimeoutException,
httpx.NetworkError,
),
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = _DEFAULT_RETRY_ON_EXCEPTIONS,
retry_on_status_codes: Union[int, tuple[int, ...]] = _DEFAULT_RETRY_ON_STATUS_CODES,
stream: bool = False,
**kwargs,
) -> Generator[httpx.Response, None, None]:
@@ -273,6 +362,7 @@ def _http_backoff_base(
nb_tries = 0
sleep_time = base_wait_time
ratelimit_reset: Optional[int] = None # seconds to wait for rate limit reset if 429 response
# If `data` is used and is a file object (or any IO), it will be consumed on the
# first HTTP request. We need to save the initial position so that the full content
@@ -284,6 +374,7 @@ def _http_backoff_base(
client = get_session()
while True:
nb_tries += 1
ratelimit_reset = None
try:
# If `data` is used and is a file object (or any IO), set back cursor to
# initial position.
@@ -293,6 +384,8 @@ def _http_backoff_base(
# Perform request and handle response
def _should_retry(response: httpx.Response) -> bool:
"""Handle response and return True if should retry, False if should return/yield."""
nonlocal ratelimit_reset
if response.status_code not in retry_on_status_codes:
return False # Success, don't retry
@@ -304,6 +397,12 @@ def _http_backoff_base(
# user ask for retry on a status code that doesn't raise_for_status.
return False # Don't retry, return/yield response
# get rate limit reset time from headers if 429 response
if response.status_code == 429:
ratelimit_info = parse_ratelimit_headers(response.headers)
if ratelimit_info is not None:
ratelimit_reset = ratelimit_info.reset_in_seconds
return True # Should retry
if stream:
@@ -326,9 +425,14 @@ def _http_backoff_base(
if nb_tries > max_retries:
raise err
# Sleep for X seconds
logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].")
time.sleep(sleep_time)
if ratelimit_reset is not None:
actual_sleep = float(ratelimit_reset) + 1 # +1s to avoid rounding issues
logger.warning(f"Rate limited. Waiting {actual_sleep}s before retry [Retry {nb_tries}/{max_retries}].")
else:
actual_sleep = sleep_time
logger.warning(f"Retrying in {actual_sleep}s [Retry {nb_tries}/{max_retries}].")
time.sleep(actual_sleep)
# Update sleep time for next retry
sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff
@@ -341,11 +445,8 @@ def http_backoff(
max_retries: int = 5,
base_wait_time: float = 1,
max_wait_time: float = 8,
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
httpx.TimeoutException,
httpx.NetworkError,
),
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = _DEFAULT_RETRY_ON_EXCEPTIONS,
retry_on_status_codes: Union[int, tuple[int, ...]] = _DEFAULT_RETRY_ON_STATUS_CODES,
**kwargs,
) -> httpx.Response:
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
@@ -374,9 +475,9 @@ def http_backoff(
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`.
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
Define on which status codes the request must be retried. By default, only
HTTP 503 Service Unavailable is retried.
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `(429, 500, 502, 503, 504)`):
Define on which status codes the request must be retried. By default, retries
on rate limit (429) and server errors (5xx).
**kwargs (`dict`, *optional*):
kwargs to pass to `httpx.request`.
@@ -425,11 +526,8 @@ def http_stream_backoff(
max_retries: int = 5,
base_wait_time: float = 1,
max_wait_time: float = 8,
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
httpx.TimeoutException,
httpx.NetworkError,
),
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = _DEFAULT_RETRY_ON_EXCEPTIONS,
retry_on_status_codes: Union[int, tuple[int, ...]] = _DEFAULT_RETRY_ON_STATUS_CODES,
**kwargs,
) -> Generator[httpx.Response, None, None]:
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
@@ -457,10 +555,10 @@ def http_stream_backoff(
Maximum duration (in seconds) to wait before retrying.
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
By default, retry on `httpx.Timeout` and `httpx.NetworkError`.
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
Define on which status codes the request must be retried. By default, only
HTTP 503 Service Unavailable is retried.
By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`.
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `(429, 500, 502, 503, 504)`):
Define on which status codes the request must be retried. By default, retries
on rate limit (429) and server errors (5xx).
**kwargs (`dict`, *optional*):
kwargs to pass to `httpx.request`.
@@ -549,6 +647,12 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] =
> - [`~utils.HfHubHTTPError`]
> If request failed for a reason not listed above.
"""
try:
_warn_on_warning_headers(response)
except Exception:
# Never raise on warning parsing
logger.debug("Failed to parse warning headers", exc_info=True)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
@@ -619,6 +723,25 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] =
)
raise _format(HfHubHTTPError, message, response) from e
elif response.status_code == 429:
ratelimit_info = parse_ratelimit_headers(response.headers)
if ratelimit_info is not None:
message = (
f"\n\n429 Too Many Requests: you have reached your '{ratelimit_info.resource_type}' rate limit."
)
message += f"\nRetry after {ratelimit_info.reset_in_seconds} seconds"
if ratelimit_info.limit is not None and ratelimit_info.window_seconds is not None:
message += (
f" ({ratelimit_info.remaining}/{ratelimit_info.limit} requests remaining"
f" in current {ratelimit_info.window_seconds}s window)."
)
else:
message += "."
message += f"\nUrl: {response.url}."
else:
message = f"\n\n429 Too Many Requests for url: {response.url}."
raise _format(HfHubHTTPError, message, response) from e
elif response.status_code == 416:
range_header = response.request.headers.get("Range")
message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}."
@@ -629,6 +752,33 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] =
raise _format(HfHubHTTPError, str(e), response) from e
_WARNED_TOPICS = set()
def _warn_on_warning_headers(response: httpx.Response) -> None:
"""
Emit warnings if warning headers are present in the HTTP response.
Expected header format: 'X-HF-Warning: topic; message'
Only the first warning for each topic will be shown. Topic is optional and can be empty. Note that several warning
headers can be present in a single response.
Args:
response (`httpx.Response`):
The HTTP response to check for warning headers.
"""
server_warnings = response.headers.get_list("X-HF-Warning")
for server_warning in server_warnings:
topic, message = server_warning.split(";", 1) if ";" in server_warning else ("", server_warning)
topic = topic.strip()
if topic not in _WARNED_TOPICS:
message = message.strip()
if message:
_WARNED_TOPICS.add(topic)
logger.warning(message)
def _format(error_type: type[HfHubHTTPError], custom_message: str, response: httpx.Response) -> HfHubHTTPError:
server_errors = []

View File

@@ -42,7 +42,7 @@ def paginate(path: str, params: dict, headers: dict) -> Iterable:
next_page = _get_next_page(r)
while next_page is not None:
logger.debug(f"Pagination detected. Requesting next page: {next_page}")
r = http_backoff("GET", next_page, max_retries=20, retry_on_status_codes=429, headers=headers)
r = http_backoff("GET", next_page, headers=headers)
hf_raise_for_status(r)
yield from r.json()
next_page = _get_next_page(r)

View File

@@ -64,7 +64,7 @@ def send_telemetry(
... )
```
"""
if constants.HF_HUB_OFFLINE or constants.HF_HUB_DISABLE_TELEMETRY:
if constants.is_offline_mode() or constants.HF_HUB_DISABLE_TELEMETRY:
return
_start_telemetry_thread() # starts thread only if doesn't exist yet

View File

@@ -22,12 +22,18 @@ class ANSI:
Helper for en.wikipedia.org/wiki/ANSI_escape_code
"""
_blue = "\u001b[34m"
_bold = "\u001b[1m"
_gray = "\u001b[90m"
_green = "\u001b[32m"
_red = "\u001b[31m"
_reset = "\u001b[0m"
_yellow = "\u001b[33m"
@classmethod
def blue(cls, s: str) -> str:
return cls._format(s, cls._blue)
@classmethod
def bold(cls, s: str) -> str:
return cls._format(s, cls._bold)
@@ -36,6 +42,10 @@ class ANSI:
def gray(cls, s: str) -> str:
return cls._format(s, cls._gray)
@classmethod
def green(cls, s: str) -> str:
return cls._format(s, cls._green)
@classmethod
def red(cls, s: str) -> str:
return cls._format(s, cls._bold + cls._red)