增加环绕侦察场景适配
This commit is contained in:
@@ -46,7 +46,7 @@ import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
__version__ = "1.1.4"
|
||||
__version__ = "1.2.4"
|
||||
|
||||
# Alphabetical order of definitions is ensured in tests
|
||||
# WARNING: any comment added in this dictionary definition will be lost when
|
||||
@@ -134,6 +134,7 @@ _SUBMOD_ATTRS = {
|
||||
"REPO_TYPE_SPACE",
|
||||
"TF2_WEIGHTS_NAME",
|
||||
"TF_WEIGHTS_NAME",
|
||||
"is_offline_mode",
|
||||
],
|
||||
"fastai_utils": [
|
||||
"_save_pretrained_fastai",
|
||||
@@ -164,6 +165,8 @@ _SUBMOD_ATTRS = {
|
||||
"HfApi",
|
||||
"ModelInfo",
|
||||
"Organization",
|
||||
"RepoFile",
|
||||
"RepoFolder",
|
||||
"RepoUrl",
|
||||
"SpaceInfo",
|
||||
"User",
|
||||
@@ -230,6 +233,7 @@ _SUBMOD_ATTRS = {
|
||||
"inspect_scheduled_job",
|
||||
"list_accepted_access_requests",
|
||||
"list_collections",
|
||||
"list_daily_papers",
|
||||
"list_datasets",
|
||||
"list_inference_catalog",
|
||||
"list_inference_endpoints",
|
||||
@@ -296,6 +300,7 @@ _SUBMOD_ATTRS = {
|
||||
"HfFileSystemFile",
|
||||
"HfFileSystemResolvedPath",
|
||||
"HfFileSystemStreamFile",
|
||||
"hffs",
|
||||
],
|
||||
"hub_mixin": [
|
||||
"ModelHubMixin",
|
||||
@@ -708,6 +713,8 @@ __all__ = [
|
||||
"REPO_TYPE_MODEL",
|
||||
"REPO_TYPE_SPACE",
|
||||
"RepoCard",
|
||||
"RepoFile",
|
||||
"RepoFolder",
|
||||
"RepoUrl",
|
||||
"SentenceSimilarityInput",
|
||||
"SentenceSimilarityInputData",
|
||||
@@ -883,11 +890,14 @@ __all__ = [
|
||||
"hf_hub_download",
|
||||
"hf_hub_url",
|
||||
"hf_raise_for_status",
|
||||
"hffs",
|
||||
"inspect_job",
|
||||
"inspect_scheduled_job",
|
||||
"interpreter_login",
|
||||
"is_offline_mode",
|
||||
"list_accepted_access_requests",
|
||||
"list_collections",
|
||||
"list_daily_papers",
|
||||
"list_datasets",
|
||||
"list_inference_catalog",
|
||||
"list_inference_endpoints",
|
||||
@@ -1150,6 +1160,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
REPO_TYPE_SPACE, # noqa: F401
|
||||
TF2_WEIGHTS_NAME, # noqa: F401
|
||||
TF_WEIGHTS_NAME, # noqa: F401
|
||||
is_offline_mode, # noqa: F401
|
||||
)
|
||||
from .fastai_utils import (
|
||||
_save_pretrained_fastai, # noqa: F401
|
||||
@@ -1180,6 +1191,8 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
HfApi, # noqa: F401
|
||||
ModelInfo, # noqa: F401
|
||||
Organization, # noqa: F401
|
||||
RepoFile, # noqa: F401
|
||||
RepoFolder, # noqa: F401
|
||||
RepoUrl, # noqa: F401
|
||||
SpaceInfo, # noqa: F401
|
||||
User, # noqa: F401
|
||||
@@ -1246,6 +1259,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
inspect_scheduled_job, # noqa: F401
|
||||
list_accepted_access_requests, # noqa: F401
|
||||
list_collections, # noqa: F401
|
||||
list_daily_papers, # noqa: F401
|
||||
list_datasets, # noqa: F401
|
||||
list_inference_catalog, # noqa: F401
|
||||
list_inference_endpoints, # noqa: F401
|
||||
@@ -1312,6 +1326,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
HfFileSystemFile, # noqa: F401
|
||||
HfFileSystemResolvedPath, # noqa: F401
|
||||
HfFileSystemStreamFile, # noqa: F401
|
||||
hffs, # noqa: F401
|
||||
)
|
||||
from .hub_mixin import (
|
||||
ModelHubMixin, # noqa: F401
|
||||
|
||||
@@ -27,6 +27,7 @@ from .utils import (
|
||||
fetch_xet_connection_info_from_repo_info,
|
||||
get_session,
|
||||
hf_raise_for_status,
|
||||
http_backoff,
|
||||
logging,
|
||||
sha,
|
||||
tqdm_stream_file,
|
||||
@@ -739,7 +740,8 @@ def _fetch_upload_modes(
|
||||
if gitignore_content is not None:
|
||||
payload["gitIgnore"] = gitignore_content
|
||||
|
||||
resp = get_session().post(
|
||||
resp = http_backoff(
|
||||
"POST",
|
||||
f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
|
||||
@@ -34,6 +34,11 @@ class InferenceEndpointType(str, Enum):
|
||||
PRIVATE = "private"
|
||||
|
||||
|
||||
class InferenceEndpointScalingMetric(str, Enum):
|
||||
PENDING_REQUESTS = "pendingRequests"
|
||||
HARDWARE_USAGE = "hardwareUsage"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceEndpoint:
|
||||
"""
|
||||
|
||||
@@ -308,6 +308,7 @@ def read_download_metadata(local_dir: Path, filename: str) -> Optional[LocalDown
|
||||
paths.metadata_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# check if the file exists and hasn't been modified since the metadata was saved
|
||||
@@ -383,6 +384,9 @@ def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetad
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}")
|
||||
|
||||
# corrupted metadata => we don't know anything expect its size
|
||||
return LocalUploadFileMetadata(size=paths.file_path.stat().st_size)
|
||||
|
||||
# TODO: can we do better?
|
||||
if (
|
||||
metadata.timestamp is not None
|
||||
|
||||
@@ -192,7 +192,11 @@ def auth_list() -> None:
|
||||
tokens = get_stored_tokens()
|
||||
|
||||
if not tokens:
|
||||
logger.info("No access tokens found.")
|
||||
if _get_token_from_environment():
|
||||
logger.info("No stored access tokens found.")
|
||||
logger.warning("Note: Environment variable `HF_TOKEN` is set and is the current active token.")
|
||||
else:
|
||||
logger.info("No access tokens found.")
|
||||
return
|
||||
# Find current token
|
||||
current_token = get_token()
|
||||
|
||||
@@ -23,7 +23,7 @@ from .utils import tqdm as hf_tqdm
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
|
||||
LARGE_REPO_THRESHOLD = 1000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
|
||||
|
||||
|
||||
@overload
|
||||
@@ -335,9 +335,7 @@ def snapshot_download(
|
||||
# In that case, we need to use the `list_repo_tree` method to prevent caching issues.
|
||||
repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings] if repo_info.siblings is not None else []
|
||||
unreliable_nb_files = (
|
||||
repo_info.siblings is None
|
||||
or len(repo_info.siblings) == 0
|
||||
or len(repo_info.siblings) > VERY_LARGE_REPO_THRESHOLD
|
||||
repo_info.siblings is None or len(repo_info.siblings) == 0 or len(repo_info.siblings) > LARGE_REPO_THRESHOLD
|
||||
)
|
||||
if unreliable_nb_files:
|
||||
logger.info(
|
||||
|
||||
@@ -191,7 +191,7 @@ def upload_large_folder_internal(
|
||||
|
||||
if num_workers is None:
|
||||
nb_cores = os.cpu_count() or 1
|
||||
num_workers = max(nb_cores - 2, 2) # Use all but 2 cores, or at least 2 cores
|
||||
num_workers = max(nb_cores // 2, 1) # Use at most half of cpu cores
|
||||
|
||||
# 2. Create repo if missing
|
||||
repo_url = api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True)
|
||||
|
||||
@@ -96,7 +96,7 @@ TokenOpt = Annotated[
|
||||
]
|
||||
|
||||
PrivateOpt = Annotated[
|
||||
bool,
|
||||
Optional[bool],
|
||||
typer.Option(
|
||||
help="Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already exists.",
|
||||
),
|
||||
@@ -144,6 +144,7 @@ def _check_cli_update() -> None:
|
||||
return
|
||||
|
||||
# Touch the file to mark that we did the check now
|
||||
Path(constants.CHECK_FOR_UPDATE_DONE_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(constants.CHECK_FOR_UPDATE_DONE_PATH).touch()
|
||||
|
||||
# Check latest version from PyPI
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from huggingface_hub import constants
|
||||
from huggingface_hub.cli._cli_utils import check_cli_update, typer_factory
|
||||
from huggingface_hub.cli.auth import auth_cli
|
||||
from huggingface_hub.cli.cache import cache_cli
|
||||
@@ -51,7 +52,8 @@ app.add_typer(ie_cli, name="endpoints")
|
||||
|
||||
|
||||
def main():
|
||||
logging.set_verbosity_info()
|
||||
if not constants.HF_DEBUG:
|
||||
logging.set_verbosity_info()
|
||||
check_cli_update()
|
||||
app()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Annotated, Optional
|
||||
|
||||
import typer
|
||||
|
||||
from huggingface_hub._inference_endpoints import InferenceEndpoint
|
||||
from huggingface_hub._inference_endpoints import InferenceEndpoint, InferenceEndpointScalingMetric
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from ._cli_utils import TokenOpt, get_hf_api, typer_factory
|
||||
@@ -112,6 +112,36 @@ def deploy(
|
||||
),
|
||||
] = None,
|
||||
token: TokenOpt = None,
|
||||
min_replica: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
help="The minimum number of replicas (instances) to keep running for the Inference Endpoint.",
|
||||
),
|
||||
] = 1,
|
||||
max_replica: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
help="The maximum number of replicas (instances) to scale to for the Inference Endpoint.",
|
||||
),
|
||||
] = 1,
|
||||
scale_to_zero_timeout: Annotated[
|
||||
Optional[int],
|
||||
typer.Option(
|
||||
help="The duration in minutes before an inactive endpoint is scaled to zero.",
|
||||
),
|
||||
] = None,
|
||||
scaling_metric: Annotated[
|
||||
Optional[InferenceEndpointScalingMetric],
|
||||
typer.Option(
|
||||
help="The metric reference for scaling.",
|
||||
),
|
||||
] = None,
|
||||
scaling_threshold: Annotated[
|
||||
Optional[float],
|
||||
typer.Option(
|
||||
help="The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.",
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Deploy an Inference Endpoint from a Hub repository."""
|
||||
api = get_hf_api(token=token)
|
||||
@@ -127,6 +157,11 @@ def deploy(
|
||||
namespace=namespace,
|
||||
task=task,
|
||||
token=token,
|
||||
min_replica=min_replica,
|
||||
max_replica=max_replica,
|
||||
scaling_metric=scaling_metric,
|
||||
scaling_threshold=scaling_threshold,
|
||||
scale_to_zero_timeout=scale_to_zero_timeout,
|
||||
)
|
||||
|
||||
_print_endpoint(endpoint)
|
||||
@@ -262,6 +297,18 @@ def update(
|
||||
help="The duration in minutes before an inactive endpoint is scaled to zero.",
|
||||
),
|
||||
] = None,
|
||||
scaling_metric: Annotated[
|
||||
Optional[InferenceEndpointScalingMetric],
|
||||
typer.Option(
|
||||
help="The metric reference for scaling.",
|
||||
),
|
||||
] = None,
|
||||
scaling_threshold: Annotated[
|
||||
Optional[float],
|
||||
typer.Option(
|
||||
help="The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.",
|
||||
),
|
||||
] = None,
|
||||
token: TokenOpt = None,
|
||||
) -> None:
|
||||
"""Update an existing endpoint."""
|
||||
@@ -280,6 +327,8 @@ def update(
|
||||
min_replica=min_replica,
|
||||
max_replica=max_replica,
|
||||
scale_to_zero_timeout=scale_to_zero_timeout,
|
||||
scaling_metric=scaling_metric,
|
||||
scaling_threshold=scaling_threshold,
|
||||
token=token,
|
||||
)
|
||||
except HfHubHTTPError as error:
|
||||
|
||||
@@ -66,7 +66,7 @@ def repo_create(
|
||||
help="Hugging Face Spaces SDK type. Required when --type is set to 'space'.",
|
||||
),
|
||||
] = None,
|
||||
private: PrivateOpt = False,
|
||||
private: PrivateOpt = None,
|
||||
token: TokenOpt = None,
|
||||
exist_ok: Annotated[
|
||||
bool,
|
||||
|
||||
@@ -80,7 +80,7 @@ def upload(
|
||||
] = None,
|
||||
repo_type: RepoTypeOpt = RepoType.model,
|
||||
revision: RevisionOpt = None,
|
||||
private: PrivateOpt = False,
|
||||
private: PrivateOpt = None,
|
||||
include: Annotated[
|
||||
Optional[list[str]],
|
||||
typer.Option(
|
||||
|
||||
@@ -38,7 +38,7 @@ def upload_large_folder(
|
||||
],
|
||||
repo_type: RepoTypeOpt = RepoType.model,
|
||||
revision: RevisionOpt = None,
|
||||
private: PrivateOpt = False,
|
||||
private: PrivateOpt = None,
|
||||
include: Annotated[
|
||||
Optional[list[str]],
|
||||
typer.Option(
|
||||
|
||||
@@ -162,6 +162,26 @@ HF_ASSETS_CACHE = os.path.expandvars(
|
||||
|
||||
HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE"))
|
||||
|
||||
|
||||
def is_offline_mode() -> bool:
|
||||
"""Returns whether we are in offline mode for the Hub.
|
||||
|
||||
When offline mode is enabled, all HTTP requests made with `get_session` will raise an `OfflineModeIsEnabled` exception.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from huggingface_hub import is_offline_mode
|
||||
|
||||
def list_files(repo_id: str):
|
||||
if is_offline_mode():
|
||||
... # list files from local cache (degraded experience but still functional)
|
||||
else:
|
||||
... # list files from Hub (complete experience)
|
||||
```
|
||||
"""
|
||||
return HF_HUB_OFFLINE
|
||||
|
||||
|
||||
# File created to mark that the version check has been done.
|
||||
# Check is performed once per 24 hours at most.
|
||||
CHECK_FOR_UPDATE_DONE_PATH = os.path.join(HF_HOME, ".check_for_update_done")
|
||||
|
||||
@@ -316,7 +316,7 @@ def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
|
||||
base, *meta = get_args(value)
|
||||
if not _is_required_or_notrequired(base):
|
||||
base = NotRequired[base]
|
||||
type_hints[key] = Annotated[tuple([base] + list(meta))]
|
||||
type_hints[key] = Annotated[tuple([base] + list(meta))] # type: ignore
|
||||
elif not _is_required_or_notrequired(value):
|
||||
type_hints[key] = NotRequired[value]
|
||||
|
||||
|
||||
@@ -84,9 +84,15 @@ class HfHubHTTPError(HTTPError, OSError):
|
||||
"""Append additional information to the `HfHubHTTPError` initial message."""
|
||||
self.args = (self.args[0] + additional_message,) + self.args[1:]
|
||||
|
||||
@classmethod
|
||||
def _reconstruct_hf_hub_http_error(
|
||||
cls, message: str, response: Response, server_message: Optional[str]
|
||||
) -> "HfHubHTTPError":
|
||||
return cls(message, response=response, server_message=server_message)
|
||||
|
||||
def __reduce_ex__(self, protocol):
|
||||
"""Fix pickling of Exception subclass with kwargs. We need to override __reduce_ex__ of the parent class"""
|
||||
return (self.__class__, (str(self),), {"response": self.response, "server_message": self.server_message})
|
||||
return (self.__class__._reconstruct_hf_hub_http_error, (str(self), self.response, self.server_message))
|
||||
|
||||
|
||||
# INFERENCE CLIENT ERRORS
|
||||
|
||||
@@ -39,7 +39,13 @@ from .utils import (
|
||||
tqdm,
|
||||
validate_hf_hub_args,
|
||||
)
|
||||
from .utils._http import _adjust_range_header, http_backoff, http_stream_backoff
|
||||
from .utils._http import (
|
||||
_DEFAULT_RETRY_ON_EXCEPTIONS,
|
||||
_DEFAULT_RETRY_ON_STATUS_CODES,
|
||||
_adjust_range_header,
|
||||
http_backoff,
|
||||
http_stream_backoff,
|
||||
)
|
||||
from .utils._runtime import is_xet_available
|
||||
from .utils._typing import HTTP_METHOD_T
|
||||
from .utils.sha import sha_fileobj
|
||||
@@ -63,6 +69,9 @@ REGEX_SHA256 = re.compile(r"^[0-9a-f]{64}$")
|
||||
|
||||
_are_symlinks_supported_in_dir: dict[str, bool] = {}
|
||||
|
||||
# Internal retry timeout for metadata fetch when no local file exists
|
||||
_ETAG_RETRY_TIMEOUT = 60
|
||||
|
||||
|
||||
def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool:
|
||||
"""Return whether the symlinks are supported on the machine.
|
||||
@@ -264,30 +273,38 @@ def hf_hub_url(
|
||||
return url
|
||||
|
||||
|
||||
def _httpx_follow_relative_redirects(method: HTTP_METHOD_T, url: str, **httpx_kwargs) -> httpx.Response:
|
||||
def _httpx_follow_relative_redirects(
|
||||
method: HTTP_METHOD_T, url: str, *, retry_on_errors: bool = False, **httpx_kwargs
|
||||
) -> httpx.Response:
|
||||
"""Perform an HTTP request with backoff and follow relative redirects only.
|
||||
|
||||
This is useful to follow a redirection to a renamed repository without following redirection to a CDN.
|
||||
|
||||
A backoff mechanism retries the HTTP call on 5xx errors and network errors.
|
||||
A backoff mechanism retries the HTTP call on errors (429, 5xx, timeout, network errors).
|
||||
|
||||
Args:
|
||||
method (`str`):
|
||||
HTTP method, such as 'GET' or 'HEAD'.
|
||||
url (`str`):
|
||||
The URL of the resource to fetch.
|
||||
retry_on_errors (`bool`, *optional*, defaults to `False`):
|
||||
Whether to retry on errors. If False, no retry is performed (fast fallback to local cache).
|
||||
If True, uses default retry behavior (429, 5xx, timeout, network errors).
|
||||
**httpx_kwargs (`dict`, *optional*):
|
||||
Params to pass to `httpx.request`.
|
||||
"""
|
||||
# if `retry_on_errors=False`, disable all retries for fast fallback to cache
|
||||
no_retry_kwargs: dict[str, Any] = (
|
||||
{} if retry_on_errors else {"retry_on_exceptions": (), "retry_on_status_codes": ()}
|
||||
)
|
||||
|
||||
while True:
|
||||
# Make the request
|
||||
response = http_backoff(
|
||||
method=method,
|
||||
url=url,
|
||||
**httpx_kwargs,
|
||||
follow_redirects=False,
|
||||
retry_on_exceptions=(),
|
||||
retry_on_status_codes=(429,),
|
||||
**no_retry_kwargs,
|
||||
)
|
||||
hf_raise_for_status(response)
|
||||
|
||||
@@ -1131,8 +1148,31 @@ def _hf_hub_download_to_cache_dir(
|
||||
if not force_download:
|
||||
return pointer_path
|
||||
|
||||
# Otherwise, raise appropriate error
|
||||
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
|
||||
if isinstance(head_call_error, _DEFAULT_RETRY_ON_EXCEPTIONS) or (
|
||||
isinstance(head_call_error, HfHubHTTPError)
|
||||
and head_call_error.response.status_code in _DEFAULT_RETRY_ON_STATUS_CODES
|
||||
):
|
||||
logger.info("No local file found. Retrying..")
|
||||
(url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = (
|
||||
_get_metadata_or_catch_error(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
endpoint=endpoint,
|
||||
etag_timeout=_ETAG_RETRY_TIMEOUT,
|
||||
headers=headers,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
storage_folder=storage_folder,
|
||||
relative_filename=relative_filename,
|
||||
retry_on_errors=True,
|
||||
)
|
||||
)
|
||||
|
||||
# If still error, raise
|
||||
if head_call_error is not None:
|
||||
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
|
||||
|
||||
# From now on, etag, commit_hash, url and size are not None.
|
||||
assert etag is not None, "etag must have been retrieved from server"
|
||||
@@ -1300,9 +1340,30 @@ def _hf_hub_download_to_local_dir(
|
||||
)
|
||||
if not force_download:
|
||||
return local_path
|
||||
elif not force_download:
|
||||
if isinstance(head_call_error, _DEFAULT_RETRY_ON_EXCEPTIONS) or (
|
||||
isinstance(head_call_error, HfHubHTTPError)
|
||||
and head_call_error.response.status_code in _DEFAULT_RETRY_ON_STATUS_CODES
|
||||
):
|
||||
logger.info("No local file found. Retrying..")
|
||||
(url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = (
|
||||
_get_metadata_or_catch_error(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
endpoint=endpoint,
|
||||
etag_timeout=_ETAG_RETRY_TIMEOUT,
|
||||
headers=headers,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
retry_on_errors=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Otherwise => raise
|
||||
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
|
||||
# If still error, raise
|
||||
if head_call_error is not None:
|
||||
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
|
||||
|
||||
# From now on, etag, commit_hash, url and size are not None.
|
||||
assert etag is not None, "etag must have been retrieved from server"
|
||||
@@ -1501,12 +1562,13 @@ def try_to_load_from_cache(
|
||||
def get_hf_file_metadata(
|
||||
url: str,
|
||||
token: Union[bool, str, None] = None,
|
||||
timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT,
|
||||
timeout: Optional[float] = constants.HF_HUB_ETAG_TIMEOUT,
|
||||
library_name: Optional[str] = None,
|
||||
library_version: Optional[str] = None,
|
||||
user_agent: Union[dict, str, None] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
retry_on_errors: bool = False,
|
||||
) -> HfFileMetadata:
|
||||
"""Fetch metadata of a file versioned on the Hub for a given url.
|
||||
|
||||
@@ -1531,6 +1593,9 @@ def get_hf_file_metadata(
|
||||
Additional headers to be sent with the request.
|
||||
endpoint (`str`, *optional*):
|
||||
Endpoint of the Hub. Defaults to <https://huggingface.co>.
|
||||
retry_on_errors (`bool`, *optional*, defaults to `False`):
|
||||
Whether to retry on errors (429, 5xx, timeout, network errors).
|
||||
If False, no retry for fast fallback to local cache.
|
||||
|
||||
Returns:
|
||||
A [`HfFileMetadata`] object containing metadata such as location, etag, size and
|
||||
@@ -1546,7 +1611,9 @@ def get_hf_file_metadata(
|
||||
hf_headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
|
||||
|
||||
# Retrieve metadata
|
||||
response = _httpx_follow_relative_redirects(method="HEAD", url=url, headers=hf_headers, timeout=timeout)
|
||||
response = _httpx_follow_relative_redirects(
|
||||
method="HEAD", url=url, headers=hf_headers, timeout=timeout, retry_on_errors=retry_on_errors
|
||||
)
|
||||
hf_raise_for_status(response)
|
||||
|
||||
# Return
|
||||
@@ -1579,6 +1646,7 @@ def _get_metadata_or_catch_error(
|
||||
local_files_only: bool,
|
||||
relative_filename: Optional[str] = None, # only used to store `.no_exists` in cache
|
||||
storage_folder: Optional[str] = None, # only used to store `.no_exists` in cache
|
||||
retry_on_errors: bool = False,
|
||||
) -> Union[
|
||||
# Either an exception is caught and returned
|
||||
tuple[None, None, None, None, None, Exception],
|
||||
@@ -1621,7 +1689,12 @@ def _get_metadata_or_catch_error(
|
||||
try:
|
||||
try:
|
||||
metadata = get_hf_file_metadata(
|
||||
url=url, timeout=etag_timeout, headers=headers, token=token, endpoint=endpoint
|
||||
url=url,
|
||||
timeout=etag_timeout,
|
||||
headers=headers,
|
||||
token=token,
|
||||
endpoint=endpoint,
|
||||
retry_on_errors=retry_on_errors,
|
||||
)
|
||||
except RemoteEntryNotFoundError as http_error:
|
||||
if storage_folder is not None and relative_filename is not None:
|
||||
|
||||
@@ -59,7 +59,7 @@ from ._commit_api import (
|
||||
_upload_files,
|
||||
_warn_on_overwriting_operations,
|
||||
)
|
||||
from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType
|
||||
from ._inference_endpoints import InferenceEndpoint, InferenceEndpointScalingMetric, InferenceEndpointType
|
||||
from ._jobs_api import JobInfo, JobSpec, ScheduledJobInfo, _create_job_spec
|
||||
from ._space_api import SpaceHardware, SpaceRuntime, SpaceStorage, SpaceVariable
|
||||
from ._upload_large_folder import upload_large_folder_internal
|
||||
@@ -75,6 +75,7 @@ from .errors import (
|
||||
BadRequestError,
|
||||
GatedRepoError,
|
||||
HfHubHTTPError,
|
||||
LocalTokenNotFoundError,
|
||||
RemoteEntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
@@ -194,6 +195,7 @@ ExpandSpaceProperty_T = Literal[
|
||||
|
||||
USERNAME_PLACEHOLDER = "hf_user"
|
||||
_REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$")
|
||||
_REGEX_HTTP_PROTOCOL = re.compile(r"https?://")
|
||||
|
||||
_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = (
|
||||
"\nNote: Creating a commit assumes that the repo already exists on the"
|
||||
@@ -238,28 +240,62 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
|
||||
"""
|
||||
input_hf_id = hf_id
|
||||
|
||||
hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else constants.ENDPOINT)
|
||||
is_hf_url = hub_url in hf_id and "@" not in hf_id
|
||||
# Get the hub_url (with or without protocol)
|
||||
full_hub_url = hub_url if hub_url is not None else constants.ENDPOINT
|
||||
hub_url_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", full_hub_url)
|
||||
|
||||
# Check if hf_id is a URL containing the hub_url (check both with and without protocol)
|
||||
hf_id_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", hf_id)
|
||||
is_hf_url = hub_url_without_protocol in hf_id_without_protocol and "@" not in hf_id
|
||||
|
||||
HFFS_PREFIX = "hf://"
|
||||
if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists
|
||||
hf_id = hf_id[len(HFFS_PREFIX) :]
|
||||
|
||||
# If it's a URL, strip the endpoint prefix to get the path
|
||||
if is_hf_url:
|
||||
# Remove protocol if present
|
||||
hf_id_normalized = _REGEX_HTTP_PROTOCOL.sub("", hf_id)
|
||||
|
||||
# Remove the hub_url prefix to get the relative path
|
||||
if hf_id_normalized.startswith(hub_url_without_protocol):
|
||||
# Strip the hub URL and any leading slashes
|
||||
hf_id = hf_id_normalized[len(hub_url_without_protocol) :].lstrip("/")
|
||||
|
||||
url_segments = hf_id.split("/")
|
||||
is_hf_id = len(url_segments) <= 3
|
||||
|
||||
namespace: Optional[str]
|
||||
if is_hf_url:
|
||||
namespace, repo_id = url_segments[-2:]
|
||||
if namespace == hub_url:
|
||||
namespace = None
|
||||
if len(url_segments) > 2 and hub_url not in url_segments[-3]:
|
||||
repo_type = url_segments[-3]
|
||||
elif namespace in constants.REPO_TYPES_MAPPING:
|
||||
# Mean canonical dataset or model
|
||||
repo_type = constants.REPO_TYPES_MAPPING[namespace]
|
||||
namespace = None
|
||||
# For URLs, we need to extract repo_type, namespace, repo_id
|
||||
# Expected format after stripping endpoint: [repo_type]/namespace/repo_id or namespace/repo_id
|
||||
|
||||
if len(url_segments) >= 3:
|
||||
# Check if first segment is a repo type
|
||||
if url_segments[0] in constants.REPO_TYPES_MAPPING:
|
||||
repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]]
|
||||
namespace = url_segments[1]
|
||||
repo_id = url_segments[2]
|
||||
else:
|
||||
# First segment is namespace
|
||||
namespace = url_segments[0]
|
||||
repo_id = url_segments[1]
|
||||
repo_type = None
|
||||
elif len(url_segments) == 2:
|
||||
namespace = url_segments[0]
|
||||
repo_id = url_segments[1]
|
||||
|
||||
# Check if namespace is actually a repo type mapping
|
||||
if namespace in constants.REPO_TYPES_MAPPING:
|
||||
# Mean canonical dataset or model
|
||||
repo_type = constants.REPO_TYPES_MAPPING[namespace]
|
||||
namespace = None
|
||||
else:
|
||||
repo_type = None
|
||||
else:
|
||||
# Single segment
|
||||
repo_id = url_segments[0]
|
||||
namespace = None
|
||||
repo_type = None
|
||||
elif is_hf_id:
|
||||
if len(url_segments) == 3:
|
||||
@@ -1694,6 +1730,9 @@ class HfApi:
|
||||
self.headers = headers
|
||||
self._thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
# /whoami-v2 is the only endpoint for which we may want to cache results
|
||||
self._whoami_cache: dict[str, dict] = {}
|
||||
|
||||
def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]:
|
||||
"""
|
||||
Run a method in the background and return a Future instance.
|
||||
@@ -1735,22 +1774,51 @@ class HfApi:
|
||||
return self._thread_pool.submit(fn, *args, **kwargs)
|
||||
|
||||
@validate_hf_hub_args
|
||||
def whoami(self, token: Union[bool, str, None] = None) -> dict:
|
||||
def whoami(self, token: Union[bool, str, None] = None, *, cache: bool = False) -> dict:
|
||||
"""
|
||||
Call HF API to know "whoami".
|
||||
|
||||
If passing `cache=True`, the result will be cached for subsequent calls for the duration of the Python process. This is useful if you plan to call
|
||||
`whoami` multiple times as this endpoint is heavily rate-limited for security reasons.
|
||||
|
||||
Args:
|
||||
token (`bool` or `str`, *optional*):
|
||||
A valid user access token (string). Defaults to the locally saved
|
||||
token, which is the recommended method for authentication (see
|
||||
https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
|
||||
To disable authentication, pass `False`.
|
||||
cache (`bool`, *optional*):
|
||||
Whether to cache the result of the `whoami` call for subsequent calls.
|
||||
If an error occurs during the first call, it won't be cached.
|
||||
Defaults to `False`.
|
||||
"""
|
||||
# Get the effective token using the helper function get_token
|
||||
effective_token = token or self.token or get_token() or True
|
||||
token = self.token if token is None else token
|
||||
if token is False:
|
||||
raise ValueError("Cannot use `token=False` with `whoami` method as it requires authentication.")
|
||||
if token is True or token is None:
|
||||
token = get_token()
|
||||
if token is None:
|
||||
raise LocalTokenNotFoundError(
|
||||
"Token is required to call the /whoami-v2 endpoint, but no token found. You must provide a token or be logged in to "
|
||||
"Hugging Face with `hf auth login` or `huggingface_hub.login`. See https://huggingface.co/settings/tokens."
|
||||
)
|
||||
|
||||
if cache and (cached_token := self._whoami_cache.get(token)):
|
||||
return cached_token
|
||||
|
||||
# Call Hub
|
||||
output = self._inner_whoami(token=token)
|
||||
|
||||
# Cache result and return
|
||||
if cache:
|
||||
self._whoami_cache[token] = output
|
||||
return output
|
||||
|
||||
def _inner_whoami(self, token: str) -> dict:
|
||||
r = get_session().get(
|
||||
f"{self.endpoint}/api/whoami-v2",
|
||||
headers=self._build_hf_headers(token=effective_token),
|
||||
headers=self._build_hf_headers(token=token),
|
||||
)
|
||||
try:
|
||||
hf_raise_for_status(r)
|
||||
@@ -1758,16 +1826,22 @@ class HfApi:
|
||||
if e.response.status_code == 401:
|
||||
error_message = "Invalid user token."
|
||||
# Check which token is the effective one and generate the error message accordingly
|
||||
if effective_token == _get_token_from_google_colab():
|
||||
if token == _get_token_from_google_colab():
|
||||
error_message += " The token from Google Colab vault is invalid. Please update it from the UI."
|
||||
elif effective_token == _get_token_from_environment():
|
||||
elif token == _get_token_from_environment():
|
||||
error_message += (
|
||||
" The token from HF_TOKEN environment variable is invalid. "
|
||||
"Note that HF_TOKEN takes precedence over `hf auth login`."
|
||||
)
|
||||
elif effective_token == _get_token_from_file():
|
||||
elif token == _get_token_from_file():
|
||||
error_message += " The token stored is invalid. Please run `hf auth login` to update it."
|
||||
raise HfHubHTTPError(error_message, response=e.response) from e
|
||||
if e.response.status_code == 429:
|
||||
error_message = (
|
||||
"You've hit the rate limit for the /whoami-v2 endpoint, which is intentionally strict for security reasons."
|
||||
" If you're calling it often, consider caching the response with `whoami(..., cache=True)`."
|
||||
)
|
||||
raise HfHubHTTPError(error_message, response=e.response) from e
|
||||
raise
|
||||
return r.json()
|
||||
|
||||
@@ -3727,7 +3801,7 @@ class HfApi:
|
||||
self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token)
|
||||
if repo_type is None or repo_type == constants.REPO_TYPE_MODEL:
|
||||
return RepoUrl(f"{self.endpoint}/{repo_id}")
|
||||
return RepoUrl(f"{self.endpoint}/{repo_type}/{repo_id}")
|
||||
return RepoUrl(f"{self.endpoint}/{constants.REPO_TYPES_URL_PREFIXES[repo_type]}{repo_id}")
|
||||
except HfHubHTTPError:
|
||||
raise err
|
||||
else:
|
||||
@@ -5089,7 +5163,7 @@ class HfApi:
|
||||
ignore_patterns (`list[str]` or `str`, *optional*):
|
||||
If provided, files matching any of the patterns are not uploaded.
|
||||
num_workers (`int`, *optional*):
|
||||
Number of workers to start. Defaults to `os.cpu_count() - 2` (minimum 2).
|
||||
Number of workers to start. Defaults to half of CPU cores (minimum 1).
|
||||
A higher number of workers may speed up the process if your machine allows it. However, on machines with a
|
||||
slower connection, it is recommended to keep the number of workers low to ensure better resumability.
|
||||
Indeed, partially uploaded files will have to be completely re-uploaded if the process is interrupted.
|
||||
@@ -5169,7 +5243,7 @@ class HfApi:
|
||||
*,
|
||||
url: str,
|
||||
token: Union[bool, str, None] = None,
|
||||
timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT,
|
||||
timeout: Optional[float] = constants.HF_HUB_ETAG_TIMEOUT,
|
||||
) -> HfFileMetadata:
|
||||
"""Fetch metadata of a file versioned on the Hub for a given url.
|
||||
|
||||
@@ -7409,6 +7483,8 @@ class HfApi:
|
||||
account_id: Optional[str] = None,
|
||||
min_replica: int = 1,
|
||||
max_replica: int = 1,
|
||||
scaling_metric: Optional[InferenceEndpointScalingMetric] = None,
|
||||
scaling_threshold: Optional[float] = None,
|
||||
scale_to_zero_timeout: Optional[int] = None,
|
||||
revision: Optional[str] = None,
|
||||
task: Optional[str] = None,
|
||||
@@ -7449,6 +7525,12 @@ class HfApi:
|
||||
scaling to zero, set this value to 0 and adjust `scale_to_zero_timeout` accordingly. Defaults to 1.
|
||||
max_replica (`int`, *optional*):
|
||||
The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1.
|
||||
scaling_metric (`str` or [`InferenceEndpointScalingMetric `], *optional*):
|
||||
The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided. Defaults to
|
||||
None (meaning: let the HF Endpoints service specify the metric).
|
||||
scaling_threshold (`float`, *optional*):
|
||||
The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.
|
||||
Defaults to None (meaning: let the HF Endpoints service specify the threshold).
|
||||
scale_to_zero_timeout (`int`, *optional*):
|
||||
The duration in minutes before an inactive endpoint is scaled to zero, or no scaling to zero if
|
||||
set to None and `min_replica` is not 0. Defaults to None.
|
||||
@@ -7600,6 +7682,8 @@ class HfApi:
|
||||
},
|
||||
"type": type,
|
||||
}
|
||||
if scaling_metric:
|
||||
payload["compute"]["scaling"]["measure"] = {scaling_metric: scaling_threshold} # type: ignore
|
||||
if env:
|
||||
payload["model"]["env"] = env
|
||||
if secrets:
|
||||
@@ -7764,6 +7848,8 @@ class HfApi:
|
||||
min_replica: Optional[int] = None,
|
||||
max_replica: Optional[int] = None,
|
||||
scale_to_zero_timeout: Optional[int] = None,
|
||||
scaling_metric: Optional[InferenceEndpointScalingMetric] = None,
|
||||
scaling_threshold: Optional[float] = None,
|
||||
# Model update
|
||||
repository: Optional[str] = None,
|
||||
framework: Optional[str] = None,
|
||||
@@ -7804,7 +7890,12 @@ class HfApi:
|
||||
The maximum number of replicas (instances) to scale to for the Inference Endpoint.
|
||||
scale_to_zero_timeout (`int`, *optional*):
|
||||
The duration in minutes before an inactive endpoint is scaled to zero.
|
||||
|
||||
scaling_metric (`str` or [`InferenceEndpointScalingMetric `], *optional*):
|
||||
The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided.
|
||||
Defaults to None.
|
||||
scaling_threshold (`float`, *optional*):
|
||||
The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.
|
||||
Defaults to None.
|
||||
repository (`str`, *optional*):
|
||||
The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
|
||||
framework (`str`, *optional*):
|
||||
@@ -7858,6 +7949,8 @@ class HfApi:
|
||||
payload["compute"]["scaling"]["minReplica"] = min_replica
|
||||
if scale_to_zero_timeout is not None:
|
||||
payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout
|
||||
if scaling_metric:
|
||||
payload["compute"]["scaling"]["measure"] = {scaling_metric: scaling_threshold}
|
||||
if repository is not None:
|
||||
payload["model"]["repository"] = repository
|
||||
if framework is not None:
|
||||
@@ -8333,10 +8426,10 @@ class HfApi:
|
||||
collection_slug (`str`):
|
||||
Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`.
|
||||
item_id (`str`):
|
||||
ID of the item to add to the collection. It can be the ID of a repo on the Hub (e.g. `"facebook/bart-large-mnli"`)
|
||||
or a paper id (e.g. `"2307.09288"`).
|
||||
Id of the item to add to the collection. Use the repo_id for repos/spaces/datasets,
|
||||
the paper id for papers, or the slug of another collection (e.g. `"moonshotai/kimi-k2"`).
|
||||
item_type (`str`):
|
||||
Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`.
|
||||
Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"`, `"paper"` or `"collection"`.
|
||||
note (`str`, *optional*):
|
||||
A note to attach to the item in the collection. The maximum size for a note is 500 characters.
|
||||
exists_ok (`bool`, *optional*):
|
||||
@@ -9792,6 +9885,75 @@ class HfApi:
|
||||
hf_raise_for_status(r)
|
||||
return PaperInfo(**r.json())
|
||||
|
||||
def list_daily_papers(
|
||||
self,
|
||||
*,
|
||||
date: Optional[str] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
week: Optional[str] = None,
|
||||
month: Optional[str] = None,
|
||||
submitter: Optional[str] = None,
|
||||
sort: Optional[Literal["publishedAt", "trending"]] = None,
|
||||
p: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Iterable[PaperInfo]:
|
||||
"""
|
||||
List the daily papers published on a given date on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
date (`str`, *optional*):
|
||||
Date in ISO format (YYYY-MM-DD) for which to fetch daily papers.
|
||||
Defaults to most recent ones.
|
||||
token (Union[bool, str, None], *optional*):
|
||||
A valid user access token (string). Defaults to the locally saved
|
||||
token. To disable authentication, pass `False`.
|
||||
week (`str`, *optional*):
|
||||
Week in ISO format (YYYY-Www) for which to fetch daily papers. Example, `2025-W09`.
|
||||
month (`str`, *optional*):
|
||||
Month in ISO format (YYYY-MM) for which to fetch daily papers. Example, `2025-02`.
|
||||
submitter (`str`, *optional*):
|
||||
Username of the submitter to filter daily papers.
|
||||
sort (`Literal["publishedAt", "trending"]`, *optional*):
|
||||
Sort order for the daily papers. Can be either by `publishedAt` or by `trending`.
|
||||
Defaults to `"publishedAt"`
|
||||
p (`int`, *optional*):
|
||||
Page number for pagination. Defaults to 0.
|
||||
limit (`int`, *optional*):
|
||||
Limit of papers to fetch. Defaults to 50.
|
||||
|
||||
Returns:
|
||||
`Iterable[PaperInfo]`: an iterable of [`huggingface_hub.hf_api.PaperInfo`] objects.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from huggingface_hub import HfApi
|
||||
|
||||
>>> api = HfApi()
|
||||
>>> list(api.list_daily_papers(date="2025-10-29"))
|
||||
```
|
||||
"""
|
||||
path = f"{self.endpoint}/api/daily_papers"
|
||||
|
||||
params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"p": p,
|
||||
"limit": limit,
|
||||
"sort": sort,
|
||||
"date": date,
|
||||
"week": week,
|
||||
"month": month,
|
||||
"submitter": submitter,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
r = get_session().get(path, params=params, headers=self._build_hf_headers(token=token))
|
||||
hf_raise_for_status(r)
|
||||
for paper in r.json():
|
||||
yield PaperInfo(**paper)
|
||||
|
||||
def auth_check(
|
||||
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
|
||||
) -> None:
|
||||
@@ -9931,7 +10093,7 @@ class HfApi:
|
||||
timeout=timeout,
|
||||
)
|
||||
response = get_session().post(
|
||||
f"https://huggingface.co/api/jobs/{namespace}",
|
||||
f"{self.endpoint}/api/jobs/{namespace}",
|
||||
json=job_spec,
|
||||
headers=self._build_hf_headers(token=token),
|
||||
)
|
||||
@@ -9994,7 +10156,7 @@ class HfApi:
|
||||
try:
|
||||
with get_session().stream(
|
||||
"GET",
|
||||
f"https://huggingface.co/api/jobs/{namespace}/{job_id}/logs",
|
||||
f"{self.endpoint}/api/jobs/{namespace}/{job_id}/logs",
|
||||
headers=self._build_hf_headers(token=token),
|
||||
timeout=120,
|
||||
) as response:
|
||||
@@ -10022,7 +10184,7 @@ class HfApi:
|
||||
job_status = (
|
||||
get_session()
|
||||
.get(
|
||||
f"https://huggingface.co/api/jobs/{namespace}/{job_id}",
|
||||
f"{self.endpoint}/api/jobs/{namespace}/{job_id}",
|
||||
headers=self._build_hf_headers(token=token),
|
||||
)
|
||||
.json()
|
||||
@@ -10362,7 +10524,7 @@ class HfApi:
|
||||
if suspend is not None:
|
||||
input_json["suspend"] = suspend
|
||||
response = get_session().post(
|
||||
f"https://huggingface.co/api/scheduled-jobs/{namespace}",
|
||||
f"{self.endpoint}/api/scheduled-jobs/{namespace}",
|
||||
json=input_json,
|
||||
headers=self._build_hf_headers(token=token),
|
||||
)
|
||||
@@ -10813,6 +10975,7 @@ space_info = api.space_info
|
||||
|
||||
list_papers = api.list_papers
|
||||
paper_info = api.paper_info
|
||||
list_daily_papers = api.list_daily_papers
|
||||
|
||||
repo_exists = api.repo_exists
|
||||
revision_exists = api.revision_exists
|
||||
|
||||
@@ -123,32 +123,43 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
|
||||
> layer. For better performance and reliability, it's recommended to use `HfApi` methods when possible.
|
||||
|
||||
Args:
|
||||
token (`str` or `bool`, *optional*):
|
||||
endpoint (`str`, *optional*):
|
||||
Endpoint of the Hub. Defaults to <https://huggingface.co>.
|
||||
token (`bool` or `str`, *optional*):
|
||||
A valid user access token (string). Defaults to the locally saved
|
||||
token, which is the recommended method for authentication (see
|
||||
https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
|
||||
To disable authentication, pass `False`.
|
||||
endpoint (`str`, *optional*):
|
||||
Endpoint of the Hub. Defaults to <https://huggingface.co>.
|
||||
block_size (`int`, *optional*):
|
||||
Block size for reading and writing files.
|
||||
expand_info (`bool`, *optional*):
|
||||
Whether to expand the information of the files.
|
||||
**storage_options (`dict`, *optional*):
|
||||
Additional options for the filesystem. See [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.__init__).
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
>>> from huggingface_hub import HfFileSystem
|
||||
|
||||
>>> fs = HfFileSystem()
|
||||
>>> from huggingface_hub import hffs
|
||||
|
||||
>>> # List files
|
||||
>>> fs.glob("my-username/my-model/*.bin")
|
||||
>>> hffs.glob("my-username/my-model/*.bin")
|
||||
['my-username/my-model/pytorch_model.bin']
|
||||
>>> fs.ls("datasets/my-username/my-dataset", detail=False)
|
||||
>>> hffs.ls("datasets/my-username/my-dataset", detail=False)
|
||||
['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json']
|
||||
|
||||
>>> # Read/write files
|
||||
>>> with fs.open("my-username/my-model/pytorch_model.bin") as f:
|
||||
>>> with hffs.open("my-username/my-model/pytorch_model.bin") as f:
|
||||
... data = f.read()
|
||||
>>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f:
|
||||
>>> with hffs.open("my-username/my-model/pytorch_model.bin", "wb") as f:
|
||||
... f.write(data)
|
||||
```
|
||||
|
||||
Specify a token for authentication:
|
||||
```python
|
||||
>>> from huggingface_hub import HfFileSystem
|
||||
>>> hffs = HfFileSystem(token=token)
|
||||
```
|
||||
"""
|
||||
|
||||
root_marker = ""
|
||||
@@ -160,6 +171,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
|
||||
endpoint: Optional[str] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
block_size: Optional[int] = None,
|
||||
expand_info: Optional[bool] = None,
|
||||
**storage_options,
|
||||
):
|
||||
super().__init__(*args, **storage_options)
|
||||
@@ -167,6 +179,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
|
||||
self.token = token
|
||||
self._api = HfApi(endpoint=endpoint, token=token)
|
||||
self.block_size = block_size
|
||||
self.expand_info = expand_info
|
||||
# Maps (repo_type, repo_id, revision) to a 2-tuple with:
|
||||
# * the 1st element indicating whether the repositoy and the revision exist
|
||||
# * the 2nd element being the exception raised if the repository or revision doesn't exist
|
||||
@@ -330,14 +343,14 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
|
||||
(resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision), None
|
||||
)
|
||||
|
||||
def _open(
|
||||
def _open( # type: ignore[override]
|
||||
self,
|
||||
path: str,
|
||||
mode: str = "rb",
|
||||
revision: Optional[str] = None,
|
||||
block_size: Optional[int] = None,
|
||||
revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> "HfFileSystemFile":
|
||||
) -> Union["HfFileSystemFile", "HfFileSystemStreamFile"]:
|
||||
block_size = block_size if block_size is not None else self.block_size
|
||||
if block_size is not None:
|
||||
kwargs["block_size"] = block_size
|
||||
@@ -453,9 +466,12 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
|
||||
recursive: bool = False,
|
||||
refresh: bool = False,
|
||||
revision: Optional[str] = None,
|
||||
expand_info: bool = False,
|
||||
expand_info: Optional[bool] = None,
|
||||
maxdepth: Optional[int] = None,
|
||||
):
|
||||
expand_info = (
|
||||
expand_info if expand_info is not None else (self.expand_info if self.expand_info is not None else False)
|
||||
)
|
||||
resolved_path = self.resolve_path(path, revision=revision)
|
||||
path = resolved_path.unresolve()
|
||||
root_path = HfFileSystemResolvedPath(
|
||||
@@ -581,7 +597,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
|
||||
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
|
||||
yield from super().walk(path, *args, **kwargs)
|
||||
|
||||
def glob(self, path: str, **kwargs) -> list[str]:
|
||||
def glob(self, path: str, maxdepth: Optional[int] = None, **kwargs) -> list[str]:
|
||||
"""
|
||||
Find files by glob-matching.
|
||||
|
||||
@@ -595,7 +611,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
|
||||
`list[str]`: List of paths matching the pattern.
|
||||
"""
|
||||
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
|
||||
return super().glob(path, **kwargs)
|
||||
return super().glob(path, maxdepth=maxdepth, **kwargs)
|
||||
|
||||
def find(
|
||||
self,
|
||||
@@ -756,7 +772,7 @@ class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
|
||||
resolved_path = self.resolve_path(path, revision=revision)
|
||||
path = resolved_path.unresolve()
|
||||
expand_info = kwargs.get(
|
||||
"expand_info", False
|
||||
"expand_info", self.expand_info if self.expand_info is not None else False
|
||||
) # don't expose it as a parameter in the public API to follow the spec
|
||||
if not resolved_path.path_in_repo:
|
||||
# Path is the root directory
|
||||
@@ -1189,7 +1205,6 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
||||
"GET",
|
||||
url,
|
||||
headers=headers,
|
||||
retry_on_status_codes=(500, 502, 503, 504),
|
||||
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
||||
)
|
||||
)
|
||||
@@ -1255,3 +1270,6 @@ def make_instance(cls, args, kwargs, instance_state):
|
||||
for attr, state_value in instance_state.items():
|
||||
setattr(fs, attr, state_value)
|
||||
return fs
|
||||
|
||||
|
||||
hffs = HfFileSystem()
|
||||
|
||||
@@ -367,7 +367,7 @@ class ModelHubMixin:
|
||||
if is_simple_optional_type(expected_type):
|
||||
if value is None:
|
||||
return None
|
||||
expected_type = unwrap_simple_optional_type(expected_type)
|
||||
expected_type = unwrap_simple_optional_type(expected_type) # type: ignore[assignment]
|
||||
# Dataclass => handle it
|
||||
if is_dataclass(expected_type):
|
||||
return _load_dataclass(expected_type, value) # type: ignore[return-value]
|
||||
|
||||
@@ -135,7 +135,7 @@ class InferenceClient:
|
||||
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
||||
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
|
||||
provider (`str`, *optional*):
|
||||
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
|
||||
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
|
||||
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
||||
If model is a URL or `base_url` is passed, then `provider` is not used.
|
||||
token (`str`, *optional*):
|
||||
@@ -452,6 +452,7 @@ class InferenceClient:
|
||||
api_key=self.token,
|
||||
)
|
||||
response = self._inner_post(request_parameters)
|
||||
response = provider_helper.get_response(response, request_params=request_parameters)
|
||||
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
|
||||
|
||||
@overload
|
||||
@@ -1028,7 +1029,7 @@ class InferenceClient:
|
||||
normalize: Optional[bool] = None,
|
||||
prompt_name: Optional[str] = None,
|
||||
truncate: Optional[bool] = None,
|
||||
truncation_direction: Optional[Literal["Left", "Right"]] = None,
|
||||
truncation_direction: Optional[Literal["left", "right"]] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> "np.ndarray":
|
||||
"""
|
||||
@@ -1053,7 +1054,7 @@ class InferenceClient:
|
||||
truncate (`bool`, *optional*):
|
||||
Whether to truncate the embeddings or not.
|
||||
Only available on server powered by Text-Embedding-Inference.
|
||||
truncation_direction (`Literal["Left", "Right"]`, *optional*):
|
||||
truncation_direction (`Literal["left", "right"]`, *optional*):
|
||||
Which side of the input should be truncated when `truncate=True` is passed.
|
||||
|
||||
Returns:
|
||||
@@ -3195,10 +3196,7 @@ class InferenceClient:
|
||||
)
|
||||
response = self._inner_post(request_parameters)
|
||||
output = _bytes_to_dict(response)
|
||||
return [
|
||||
ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
|
||||
for label, score in zip(output["labels"], output["scores"])
|
||||
]
|
||||
return ZeroShotClassificationOutputElement.parse_obj_as_list(output)
|
||||
|
||||
def zero_shot_image_classification(
|
||||
self,
|
||||
|
||||
@@ -144,7 +144,7 @@ def _open_as_mime_bytes(content: Optional[ContentT]) -> Optional[MimeBytes]:
|
||||
if hasattr(content, "read"): # duck-typing instead of isinstance(content, BinaryIO)
|
||||
logger.debug("Reading content from BinaryIO")
|
||||
data = content.read()
|
||||
mime_type = mimetypes.guess_type(content.name)[0] if hasattr(content, "name") else None
|
||||
mime_type = mimetypes.guess_type(str(content.name))[0] if hasattr(content, "name") else None
|
||||
if isinstance(data, str):
|
||||
raise TypeError("Expected binary stream (bytes), but got text stream")
|
||||
return MimeBytes(data, mime_type=mime_type)
|
||||
|
||||
@@ -126,7 +126,7 @@ class AsyncInferenceClient:
|
||||
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
||||
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
|
||||
provider (`str`, *optional*):
|
||||
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
|
||||
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
|
||||
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
||||
If model is a URL or `base_url` is passed, then `provider` is not used.
|
||||
token (`str`, *optional*):
|
||||
@@ -472,6 +472,7 @@ class AsyncInferenceClient:
|
||||
api_key=self.token,
|
||||
)
|
||||
response = await self._inner_post(request_parameters)
|
||||
response = provider_helper.get_response(response, request_params=request_parameters)
|
||||
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
|
||||
|
||||
@overload
|
||||
@@ -1055,7 +1056,7 @@ class AsyncInferenceClient:
|
||||
normalize: Optional[bool] = None,
|
||||
prompt_name: Optional[str] = None,
|
||||
truncate: Optional[bool] = None,
|
||||
truncation_direction: Optional[Literal["Left", "Right"]] = None,
|
||||
truncation_direction: Optional[Literal["left", "right"]] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> "np.ndarray":
|
||||
"""
|
||||
@@ -1080,7 +1081,7 @@ class AsyncInferenceClient:
|
||||
truncate (`bool`, *optional*):
|
||||
Whether to truncate the embeddings or not.
|
||||
Only available on server powered by Text-Embedding-Inference.
|
||||
truncation_direction (`Literal["Left", "Right"]`, *optional*):
|
||||
truncation_direction (`Literal["left", "right"]`, *optional*):
|
||||
Which side of the input should be truncated when `truncate=True` is passed.
|
||||
|
||||
Returns:
|
||||
@@ -3245,10 +3246,7 @@ class AsyncInferenceClient:
|
||||
)
|
||||
response = await self._inner_post(request_parameters)
|
||||
output = _bytes_to_dict(response)
|
||||
return [
|
||||
ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score})
|
||||
for label, score in zip(output["labels"], output["scores"])
|
||||
]
|
||||
return ZeroShotClassificationOutputElement.parse_obj_as_list(output)
|
||||
|
||||
async def zero_shot_image_classification(
|
||||
self,
|
||||
|
||||
@@ -19,6 +19,8 @@ import types
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, TypeVar, Union, get_args
|
||||
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseInferenceType")
|
||||
|
||||
@@ -29,6 +31,7 @@ def _repr_with_extra(self):
|
||||
return f"{self.__class__.__name__}({', '.join(f'{k}={self.__dict__[k]!r}' for k in fields + other_fields)})"
|
||||
|
||||
|
||||
@dataclass_transform()
|
||||
def dataclass_with_extra(cls: type[T]) -> type[T]:
|
||||
"""Decorator to add a custom __repr__ method to a dataclass, showing all fields, including extra ones.
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Literal, Optional, Union
|
||||
from .base import BaseInferenceType, dataclass_with_extra
|
||||
|
||||
|
||||
FeatureExtractionInputTruncationDirection = Literal["Left", "Right"]
|
||||
FeatureExtractionInputTruncationDirection = Literal["left", "right"]
|
||||
|
||||
|
||||
@dataclass_with_extra
|
||||
|
||||
@@ -5,8 +5,8 @@ import traceback
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from rich import print
|
||||
|
||||
from ...utils import ANSI
|
||||
from ._cli_hacks import _async_prompt, _patch_anyio_open_process
|
||||
from .agent import Agent
|
||||
from .utils import _load_agent_config
|
||||
@@ -55,10 +55,10 @@ async def run_agent(
|
||||
if first_sigint:
|
||||
first_sigint = False
|
||||
abort_event.set()
|
||||
print("\n[red]Interrupted. Press Ctrl+C again to quit.[/red]", flush=True)
|
||||
print(ANSI.red("\nInterrupted. Press Ctrl+C again to quit."), flush=True)
|
||||
return
|
||||
|
||||
print("\n[red]Exiting...[/red]", flush=True)
|
||||
print(ANSI.red("\nExiting..."), flush=True)
|
||||
exit_event.set()
|
||||
|
||||
try:
|
||||
@@ -75,8 +75,12 @@ async def run_agent(
|
||||
|
||||
if len(inputs) > 0:
|
||||
print(
|
||||
"[bold blue]Some initial inputs are required by the agent. "
|
||||
"Please provide a value or leave empty to load from env.[/bold blue]"
|
||||
ANSI.bold(
|
||||
ANSI.blue(
|
||||
"Some initial inputs are required by the agent. "
|
||||
"Please provide a value or leave empty to load from env."
|
||||
)
|
||||
)
|
||||
)
|
||||
for input_item in inputs:
|
||||
input_id = input_item["id"]
|
||||
@@ -98,15 +102,17 @@ async def run_agent(
|
||||
|
||||
if not input_usages:
|
||||
print(
|
||||
f"[yellow]Input '{input_id}' defined in config but not used by any server or as an API key."
|
||||
" Skipping.[/yellow]"
|
||||
ANSI.yellow(
|
||||
f"Input '{input_id}' defined in config but not used by any server or as an API key."
|
||||
" Skipping."
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Prompt user for input
|
||||
env_variable_key = input_id.replace("-", "_").upper()
|
||||
print(
|
||||
f"[blue] • {input_id}[/blue]: {description}. (default: load from {env_variable_key}).",
|
||||
ANSI.blue(f" • {input_id}") + f": {description}. (default: load from {env_variable_key}).",
|
||||
end=" ",
|
||||
)
|
||||
user_input = (await _async_prompt(exit_event=exit_event)).strip()
|
||||
@@ -118,10 +124,12 @@ async def run_agent(
|
||||
if not final_value:
|
||||
final_value = os.getenv(env_variable_key, "")
|
||||
if final_value:
|
||||
print(f"[green]Value successfully loaded from '{env_variable_key}'[/green]")
|
||||
print(ANSI.green(f"Value successfully loaded from '{env_variable_key}'"))
|
||||
else:
|
||||
print(
|
||||
f"[yellow]No value found for '{env_variable_key}' in environment variables. Continuing.[/yellow]"
|
||||
ANSI.yellow(
|
||||
f"No value found for '{env_variable_key}' in environment variables. Continuing."
|
||||
)
|
||||
)
|
||||
resolved_inputs[input_id] = final_value
|
||||
|
||||
@@ -150,9 +158,9 @@ async def run_agent(
|
||||
prompt=prompt,
|
||||
) as agent:
|
||||
await agent.load_tools()
|
||||
print(f"[bold blue]Agent loaded with {len(agent.available_tools)} tools:[/bold blue]")
|
||||
print(ANSI.bold(ANSI.blue("Agent loaded with {} tools:".format(len(agent.available_tools)))))
|
||||
for t in agent.available_tools:
|
||||
print(f"[blue] • {t.function.name}[/blue]")
|
||||
print(ANSI.blue(f" • {t.function.name}"))
|
||||
|
||||
while True:
|
||||
abort_event.clear()
|
||||
@@ -165,13 +173,13 @@ async def run_agent(
|
||||
user_input = await _async_prompt(exit_event=exit_event)
|
||||
first_sigint = True
|
||||
except EOFError:
|
||||
print("\n[red]EOF received, exiting.[/red]", flush=True)
|
||||
print(ANSI.red("\nEOF received, exiting."), flush=True)
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
if not first_sigint and abort_event.is_set():
|
||||
continue
|
||||
else:
|
||||
print("\n[red]Keyboard interrupt during input processing.[/red]", flush=True)
|
||||
print(ANSI.red("\nKeyboard interrupt during input processing."), flush=True)
|
||||
break
|
||||
|
||||
try:
|
||||
@@ -195,7 +203,7 @@ async def run_agent(
|
||||
print(f"{call.function.arguments}", end="")
|
||||
else:
|
||||
print(
|
||||
f"\n\n[green]Tool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}[/green]\n",
|
||||
ANSI.green(f"\n\nTool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}\n"),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
@@ -203,12 +211,12 @@ async def run_agent(
|
||||
|
||||
except Exception as e:
|
||||
tb_str = traceback.format_exc()
|
||||
print(f"\n[bold red]Error during agent run: {e}\n{tb_str}[/bold red]", flush=True)
|
||||
print(ANSI.red(f"\nError during agent run: {e}\n{tb_str}"), flush=True)
|
||||
first_sigint = True # Allow graceful interrupt for the next command
|
||||
|
||||
except Exception as e:
|
||||
tb_str = traceback.format_exc()
|
||||
print(f"\n[bold red]An unexpected error occurred: {e}\n{tb_str}[/bold red]", flush=True)
|
||||
print(ANSI.red(f"\nAn unexpected error occurred: {e}\n{tb_str}"), flush=True)
|
||||
raise e
|
||||
|
||||
finally:
|
||||
@@ -236,10 +244,10 @@ def run(
|
||||
try:
|
||||
asyncio.run(run_agent(path))
|
||||
except KeyboardInterrupt:
|
||||
print("\n[red]Application terminated by KeyboardInterrupt.[/red]", flush=True)
|
||||
print(ANSI.red("\nApplication terminated by KeyboardInterrupt."), flush=True)
|
||||
raise typer.Exit(code=130)
|
||||
except Exception as e:
|
||||
print(f"\n[bold red]An unexpected error occurred: {e}[/bold red]", flush=True)
|
||||
print(ANSI.red(f"\nAn unexpected error occurred: {e}"), flush=True)
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
@@ -57,10 +57,10 @@ def format_result(result: "mcp_types.CallToolResult") -> str:
|
||||
elif item.type == "resource":
|
||||
resource = item.resource
|
||||
|
||||
if hasattr(resource, "text"):
|
||||
if hasattr(resource, "text") and isinstance(resource.text, str):
|
||||
formatted_parts.append(resource.text)
|
||||
|
||||
elif hasattr(resource, "blob"):
|
||||
elif hasattr(resource, "blob") and isinstance(resource.blob, str):
|
||||
formatted_parts.append(
|
||||
f"[Binary Content ({resource.uri}): {resource.mimeType}, {_get_base64_size(resource.blob)} bytes]\n"
|
||||
f"The task is complete and the content accessible to the User"
|
||||
|
||||
@@ -38,8 +38,15 @@ from .nebius import (
|
||||
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
|
||||
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
|
||||
from .openai import OpenAIConversationalTask
|
||||
from .ovhcloud import OVHcloudConversationalTask
|
||||
from .publicai import PublicAIConversationalTask
|
||||
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
|
||||
from .replicate import (
|
||||
ReplicateAutomaticSpeechRecognitionTask,
|
||||
ReplicateImageToImageTask,
|
||||
ReplicateTask,
|
||||
ReplicateTextToImageTask,
|
||||
ReplicateTextToSpeechTask,
|
||||
)
|
||||
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
|
||||
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
|
||||
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
|
||||
@@ -70,6 +77,7 @@ PROVIDER_T = Literal[
|
||||
"novita",
|
||||
"nscale",
|
||||
"openai",
|
||||
"ovhcloud",
|
||||
"publicai",
|
||||
"replicate",
|
||||
"sambanova",
|
||||
@@ -166,10 +174,14 @@ PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = {
|
||||
"openai": {
|
||||
"conversational": OpenAIConversationalTask(),
|
||||
},
|
||||
"ovhcloud": {
|
||||
"conversational": OVHcloudConversationalTask(),
|
||||
},
|
||||
"publicai": {
|
||||
"conversational": PublicAIConversationalTask(),
|
||||
},
|
||||
"replicate": {
|
||||
"automatic-speech-recognition": ReplicateAutomaticSpeechRecognitionTask(),
|
||||
"image-to-image": ReplicateImageToImageTask(),
|
||||
"text-to-image": ReplicateTextToImageTask(),
|
||||
"text-to-speech": ReplicateTextToSpeechTask(),
|
||||
|
||||
@@ -32,6 +32,7 @@ HARDCODED_MODEL_INFERENCE_MAPPING: dict[str, dict[str, InferenceProviderMapping]
|
||||
"hyperbolic": {},
|
||||
"nebius": {},
|
||||
"nscale": {},
|
||||
"ovhcloud": {},
|
||||
"replicate": {},
|
||||
"sambanova": {},
|
||||
"scaleway": {},
|
||||
|
||||
@@ -112,7 +112,7 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask):
|
||||
text = _as_dict(response)["text"]
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
|
||||
return text
|
||||
return {"text": text}
|
||||
|
||||
|
||||
class FalAITextToImageTask(FalAITask):
|
||||
|
||||
@@ -72,6 +72,67 @@ class ReplicateTextToSpeechTask(ReplicateTask):
|
||||
return payload
|
||||
|
||||
|
||||
class ReplicateAutomaticSpeechRecognitionTask(ReplicateTask):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("automatic-speech-recognition")
|
||||
|
||||
def _prepare_payload_as_dict(
|
||||
self,
|
||||
inputs: Any,
|
||||
parameters: dict,
|
||||
provider_mapping_info: InferenceProviderMapping,
|
||||
) -> Optional[dict]:
|
||||
mapped_model = provider_mapping_info.provider_id
|
||||
audio_url = _as_url(inputs, default_mime_type="audio/wav")
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"input": {
|
||||
**{"audio": audio_url},
|
||||
**filter_none(parameters),
|
||||
}
|
||||
}
|
||||
|
||||
if ":" in mapped_model:
|
||||
payload["version"] = mapped_model.split(":", 1)[1]
|
||||
|
||||
return payload
|
||||
|
||||
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
|
||||
response_dict = _as_dict(response)
|
||||
output = response_dict.get("output")
|
||||
|
||||
if isinstance(output, str):
|
||||
return {"text": output}
|
||||
|
||||
if isinstance(output, list) and output:
|
||||
first_item = output[0]
|
||||
if isinstance(first_item, str):
|
||||
return {"text": first_item}
|
||||
if isinstance(first_item, dict):
|
||||
output = first_item
|
||||
|
||||
text: Optional[str] = None
|
||||
if isinstance(output, dict):
|
||||
transcription = output.get("transcription")
|
||||
if isinstance(transcription, str):
|
||||
text = transcription
|
||||
|
||||
translation = output.get("translation")
|
||||
if isinstance(translation, str):
|
||||
text = translation
|
||||
|
||||
txt_file = output.get("txt_file")
|
||||
if isinstance(txt_file, str):
|
||||
text_response = get_session().get(txt_file)
|
||||
text_response.raise_for_status()
|
||||
text = text_response.text
|
||||
|
||||
if text is not None:
|
||||
return {"text": text}
|
||||
|
||||
raise ValueError("Received malformed response from Replicate automatic-speech-recognition API")
|
||||
|
||||
|
||||
class ReplicateImageToImageTask(ReplicateTask):
|
||||
def __init__(self):
|
||||
super().__init__("image-to-image")
|
||||
|
||||
@@ -102,7 +102,7 @@ def split_state_dict_into_shards_factory(
|
||||
continue
|
||||
|
||||
# If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block`
|
||||
storage_id = get_storage_id(tensor)
|
||||
storage_id = get_storage_id(tensor) # type: ignore[invalid-argument-type]
|
||||
if storage_id is not None:
|
||||
if storage_id in storage_id_to_tensors:
|
||||
# We skip this tensor for now and will reassign to correct shard later
|
||||
@@ -114,7 +114,7 @@ def split_state_dict_into_shards_factory(
|
||||
storage_id_to_tensors[storage_id] = [key]
|
||||
|
||||
# Compute tensor size
|
||||
tensor_size = get_storage_size(tensor)
|
||||
tensor_size = get_storage_size(tensor) # type: ignore[invalid-argument-type]
|
||||
|
||||
# If this tensor is bigger than the maximal size, we put it in its own shard
|
||||
if tensor_size > max_shard_size:
|
||||
|
||||
@@ -54,6 +54,7 @@ from ._headers import build_hf_headers, get_token_to_send
|
||||
from ._http import (
|
||||
ASYNC_CLIENT_FACTORY_T,
|
||||
CLIENT_FACTORY_T,
|
||||
RateLimitInfo,
|
||||
close_session,
|
||||
fix_hf_endpoint_in_url,
|
||||
get_async_session,
|
||||
@@ -61,6 +62,7 @@ from ._http import (
|
||||
hf_raise_for_status,
|
||||
http_backoff,
|
||||
http_stream_backoff,
|
||||
parse_ratelimit_headers,
|
||||
set_async_client_factory,
|
||||
set_client_factory,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ def cached_assets_path(
|
||||
subfolder: str = "default",
|
||||
*,
|
||||
assets_dir: Union[str, Path, None] = None,
|
||||
):
|
||||
) -> Path:
|
||||
"""Return a folder path to cache arbitrary files.
|
||||
|
||||
`huggingface_hub` provides a canonical folder path to store assets. This is the
|
||||
|
||||
@@ -23,9 +23,9 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from http import HTTPStatus
|
||||
from dataclasses import dataclass
|
||||
from shlex import quote
|
||||
from typing import Any, Callable, Generator, Optional, Union
|
||||
from typing import Any, Callable, Generator, Mapping, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -48,6 +48,94 @@ from ._typing import HTTP_METHOD_T
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitInfo:
|
||||
"""
|
||||
Parsed rate limit information from HTTP response headers.
|
||||
|
||||
Attributes:
|
||||
resource_type (`str`): The type of resource being rate limited.
|
||||
remaining (`int`): The number of requests remaining in the current window.
|
||||
reset_in_seconds (`int`): The number of seconds until the rate limit resets.
|
||||
limit (`int`, *optional*): The maximum number of requests allowed in the current window.
|
||||
window_seconds (`int`, *optional*): The number of seconds in the current window.
|
||||
|
||||
"""
|
||||
|
||||
resource_type: str
|
||||
remaining: int
|
||||
reset_in_seconds: int
|
||||
limit: Optional[int] = None
|
||||
window_seconds: Optional[int] = None
|
||||
|
||||
|
||||
# Regex patterns for parsing rate limit headers
|
||||
# e.g.: "api";r=0;t=55 --> resource_type="api", r=0, t=55
|
||||
_RATELIMIT_REGEX = re.compile(r"\"(?P<resource_type>\w+)\"\s*;\s*r\s*=\s*(?P<r>\d+)\s*;\s*t\s*=\s*(?P<t>\d+)")
|
||||
# e.g.: "fixed window";"api";q=500;w=300 --> q=500, w=300
|
||||
_RATELIMIT_POLICY_REGEX = re.compile(r"q\s*=\s*(?P<q>\d+).*?w\s*=\s*(?P<w>\d+)")
|
||||
|
||||
|
||||
def parse_ratelimit_headers(headers: Mapping[str, str]) -> Optional[RateLimitInfo]:
|
||||
"""Parse rate limit information from HTTP response headers.
|
||||
|
||||
Follows IETF draft: https://www.ietf.org/archive/id/draft-ietf-httpapi-ratelimit-headers-09.html
|
||||
Only a subset is implemented.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from huggingface_hub.utils import parse_ratelimit_headers
|
||||
>>> headers = {
|
||||
... "ratelimit": '"api";r=0;t=55',
|
||||
... "ratelimit-policy": '"fixed window";"api";q=500;w=300',
|
||||
... }
|
||||
>>> info = parse_ratelimit_headers(headers)
|
||||
>>> info.remaining
|
||||
0
|
||||
>>> info.reset_in_seconds
|
||||
55
|
||||
```
|
||||
"""
|
||||
|
||||
ratelimit: Optional[str] = None
|
||||
policy: Optional[str] = None
|
||||
for key in headers:
|
||||
lower_key = key.lower()
|
||||
if lower_key == "ratelimit":
|
||||
ratelimit = headers[key]
|
||||
elif lower_key == "ratelimit-policy":
|
||||
policy = headers[key]
|
||||
|
||||
if not ratelimit:
|
||||
return None
|
||||
|
||||
match = _RATELIMIT_REGEX.search(ratelimit)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
resource_type = match.group("resource_type")
|
||||
remaining = int(match.group("r"))
|
||||
reset_in_seconds = int(match.group("t"))
|
||||
|
||||
limit: Optional[int] = None
|
||||
window_seconds: Optional[int] = None
|
||||
|
||||
if policy:
|
||||
policy_match = _RATELIMIT_POLICY_REGEX.search(policy)
|
||||
if policy_match:
|
||||
limit = int(policy_match.group("q"))
|
||||
window_seconds = int(policy_match.group("w"))
|
||||
|
||||
return RateLimitInfo(
|
||||
resource_type=resource_type,
|
||||
remaining=remaining,
|
||||
reset_in_seconds=reset_in_seconds,
|
||||
limit=limit,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
|
||||
# Both headers are used by the Hub to debug failed requests.
|
||||
# `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB.
|
||||
# If `X_AMZN_TRACE_ID` is set, the Hub will use it as well.
|
||||
@@ -79,7 +167,7 @@ def hf_request_event_hook(request: httpx.Request) -> None:
|
||||
- Add a request ID to the request headers
|
||||
- Log the request if debug mode is enabled
|
||||
"""
|
||||
if constants.HF_HUB_OFFLINE:
|
||||
if constants.is_offline_mode():
|
||||
raise OfflineModeIsEnabled(
|
||||
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
|
||||
)
|
||||
@@ -249,6 +337,10 @@ if hasattr(os, "register_at_fork"):
|
||||
os.register_at_fork(after_in_child=close_session)
|
||||
|
||||
|
||||
_DEFAULT_RETRY_ON_EXCEPTIONS: tuple[type[Exception], ...] = (httpx.TimeoutException, httpx.NetworkError)
|
||||
_DEFAULT_RETRY_ON_STATUS_CODES: tuple[int, ...] = (429, 500, 502, 503, 504)
|
||||
|
||||
|
||||
def _http_backoff_base(
|
||||
method: HTTP_METHOD_T,
|
||||
url: str,
|
||||
@@ -256,11 +348,8 @@ def _http_backoff_base(
|
||||
max_retries: int = 5,
|
||||
base_wait_time: float = 1,
|
||||
max_wait_time: float = 8,
|
||||
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
),
|
||||
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = _DEFAULT_RETRY_ON_EXCEPTIONS,
|
||||
retry_on_status_codes: Union[int, tuple[int, ...]] = _DEFAULT_RETRY_ON_STATUS_CODES,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Generator[httpx.Response, None, None]:
|
||||
@@ -273,6 +362,7 @@ def _http_backoff_base(
|
||||
|
||||
nb_tries = 0
|
||||
sleep_time = base_wait_time
|
||||
ratelimit_reset: Optional[int] = None # seconds to wait for rate limit reset if 429 response
|
||||
|
||||
# If `data` is used and is a file object (or any IO), it will be consumed on the
|
||||
# first HTTP request. We need to save the initial position so that the full content
|
||||
@@ -284,6 +374,7 @@ def _http_backoff_base(
|
||||
client = get_session()
|
||||
while True:
|
||||
nb_tries += 1
|
||||
ratelimit_reset = None
|
||||
try:
|
||||
# If `data` is used and is a file object (or any IO), set back cursor to
|
||||
# initial position.
|
||||
@@ -293,6 +384,8 @@ def _http_backoff_base(
|
||||
# Perform request and handle response
|
||||
def _should_retry(response: httpx.Response) -> bool:
|
||||
"""Handle response and return True if should retry, False if should return/yield."""
|
||||
nonlocal ratelimit_reset
|
||||
|
||||
if response.status_code not in retry_on_status_codes:
|
||||
return False # Success, don't retry
|
||||
|
||||
@@ -304,6 +397,12 @@ def _http_backoff_base(
|
||||
# user ask for retry on a status code that doesn't raise_for_status.
|
||||
return False # Don't retry, return/yield response
|
||||
|
||||
# get rate limit reset time from headers if 429 response
|
||||
if response.status_code == 429:
|
||||
ratelimit_info = parse_ratelimit_headers(response.headers)
|
||||
if ratelimit_info is not None:
|
||||
ratelimit_reset = ratelimit_info.reset_in_seconds
|
||||
|
||||
return True # Should retry
|
||||
|
||||
if stream:
|
||||
@@ -326,9 +425,14 @@ def _http_backoff_base(
|
||||
if nb_tries > max_retries:
|
||||
raise err
|
||||
|
||||
# Sleep for X seconds
|
||||
logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].")
|
||||
time.sleep(sleep_time)
|
||||
if ratelimit_reset is not None:
|
||||
actual_sleep = float(ratelimit_reset) + 1 # +1s to avoid rounding issues
|
||||
logger.warning(f"Rate limited. Waiting {actual_sleep}s before retry [Retry {nb_tries}/{max_retries}].")
|
||||
else:
|
||||
actual_sleep = sleep_time
|
||||
logger.warning(f"Retrying in {actual_sleep}s [Retry {nb_tries}/{max_retries}].")
|
||||
|
||||
time.sleep(actual_sleep)
|
||||
|
||||
# Update sleep time for next retry
|
||||
sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff
|
||||
@@ -341,11 +445,8 @@ def http_backoff(
|
||||
max_retries: int = 5,
|
||||
base_wait_time: float = 1,
|
||||
max_wait_time: float = 8,
|
||||
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
),
|
||||
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = _DEFAULT_RETRY_ON_EXCEPTIONS,
|
||||
retry_on_status_codes: Union[int, tuple[int, ...]] = _DEFAULT_RETRY_ON_STATUS_CODES,
|
||||
**kwargs,
|
||||
) -> httpx.Response:
|
||||
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
|
||||
@@ -374,9 +475,9 @@ def http_backoff(
|
||||
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
|
||||
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
|
||||
By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`.
|
||||
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
|
||||
Define on which status codes the request must be retried. By default, only
|
||||
HTTP 503 Service Unavailable is retried.
|
||||
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `(429, 500, 502, 503, 504)`):
|
||||
Define on which status codes the request must be retried. By default, retries
|
||||
on rate limit (429) and server errors (5xx).
|
||||
**kwargs (`dict`, *optional*):
|
||||
kwargs to pass to `httpx.request`.
|
||||
|
||||
@@ -425,11 +526,8 @@ def http_stream_backoff(
|
||||
max_retries: int = 5,
|
||||
base_wait_time: float = 1,
|
||||
max_wait_time: float = 8,
|
||||
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
),
|
||||
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = _DEFAULT_RETRY_ON_EXCEPTIONS,
|
||||
retry_on_status_codes: Union[int, tuple[int, ...]] = _DEFAULT_RETRY_ON_STATUS_CODES,
|
||||
**kwargs,
|
||||
) -> Generator[httpx.Response, None, None]:
|
||||
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
|
||||
@@ -457,10 +555,10 @@ def http_stream_backoff(
|
||||
Maximum duration (in seconds) to wait before retrying.
|
||||
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
|
||||
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
|
||||
By default, retry on `httpx.Timeout` and `httpx.NetworkError`.
|
||||
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
|
||||
Define on which status codes the request must be retried. By default, only
|
||||
HTTP 503 Service Unavailable is retried.
|
||||
By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`.
|
||||
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `(429, 500, 502, 503, 504)`):
|
||||
Define on which status codes the request must be retried. By default, retries
|
||||
on rate limit (429) and server errors (5xx).
|
||||
**kwargs (`dict`, *optional*):
|
||||
kwargs to pass to `httpx.request`.
|
||||
|
||||
@@ -549,6 +647,12 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] =
|
||||
> - [`~utils.HfHubHTTPError`]
|
||||
> If request failed for a reason not listed above.
|
||||
"""
|
||||
try:
|
||||
_warn_on_warning_headers(response)
|
||||
except Exception:
|
||||
# Never raise on warning parsing
|
||||
logger.debug("Failed to parse warning headers", exc_info=True)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -619,6 +723,25 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] =
|
||||
)
|
||||
raise _format(HfHubHTTPError, message, response) from e
|
||||
|
||||
elif response.status_code == 429:
|
||||
ratelimit_info = parse_ratelimit_headers(response.headers)
|
||||
if ratelimit_info is not None:
|
||||
message = (
|
||||
f"\n\n429 Too Many Requests: you have reached your '{ratelimit_info.resource_type}' rate limit."
|
||||
)
|
||||
message += f"\nRetry after {ratelimit_info.reset_in_seconds} seconds"
|
||||
if ratelimit_info.limit is not None and ratelimit_info.window_seconds is not None:
|
||||
message += (
|
||||
f" ({ratelimit_info.remaining}/{ratelimit_info.limit} requests remaining"
|
||||
f" in current {ratelimit_info.window_seconds}s window)."
|
||||
)
|
||||
else:
|
||||
message += "."
|
||||
message += f"\nUrl: {response.url}."
|
||||
else:
|
||||
message = f"\n\n429 Too Many Requests for url: {response.url}."
|
||||
raise _format(HfHubHTTPError, message, response) from e
|
||||
|
||||
elif response.status_code == 416:
|
||||
range_header = response.request.headers.get("Range")
|
||||
message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}."
|
||||
@@ -629,6 +752,33 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] =
|
||||
raise _format(HfHubHTTPError, str(e), response) from e
|
||||
|
||||
|
||||
_WARNED_TOPICS = set()
|
||||
|
||||
|
||||
def _warn_on_warning_headers(response: httpx.Response) -> None:
|
||||
"""
|
||||
Emit warnings if warning headers are present in the HTTP response.
|
||||
|
||||
Expected header format: 'X-HF-Warning: topic; message'
|
||||
|
||||
Only the first warning for each topic will be shown. Topic is optional and can be empty. Note that several warning
|
||||
headers can be present in a single response.
|
||||
|
||||
Args:
|
||||
response (`httpx.Response`):
|
||||
The HTTP response to check for warning headers.
|
||||
"""
|
||||
server_warnings = response.headers.get_list("X-HF-Warning")
|
||||
for server_warning in server_warnings:
|
||||
topic, message = server_warning.split(";", 1) if ";" in server_warning else ("", server_warning)
|
||||
topic = topic.strip()
|
||||
if topic not in _WARNED_TOPICS:
|
||||
message = message.strip()
|
||||
if message:
|
||||
_WARNED_TOPICS.add(topic)
|
||||
logger.warning(message)
|
||||
|
||||
|
||||
def _format(error_type: type[HfHubHTTPError], custom_message: str, response: httpx.Response) -> HfHubHTTPError:
|
||||
server_errors = []
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ def paginate(path: str, params: dict, headers: dict) -> Iterable:
|
||||
next_page = _get_next_page(r)
|
||||
while next_page is not None:
|
||||
logger.debug(f"Pagination detected. Requesting next page: {next_page}")
|
||||
r = http_backoff("GET", next_page, max_retries=20, retry_on_status_codes=429, headers=headers)
|
||||
r = http_backoff("GET", next_page, headers=headers)
|
||||
hf_raise_for_status(r)
|
||||
yield from r.json()
|
||||
next_page = _get_next_page(r)
|
||||
|
||||
@@ -64,7 +64,7 @@ def send_telemetry(
|
||||
... )
|
||||
```
|
||||
"""
|
||||
if constants.HF_HUB_OFFLINE or constants.HF_HUB_DISABLE_TELEMETRY:
|
||||
if constants.is_offline_mode() or constants.HF_HUB_DISABLE_TELEMETRY:
|
||||
return
|
||||
|
||||
_start_telemetry_thread() # starts thread only if doesn't exist yet
|
||||
|
||||
@@ -22,12 +22,18 @@ class ANSI:
|
||||
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
||||
"""
|
||||
|
||||
_blue = "\u001b[34m"
|
||||
_bold = "\u001b[1m"
|
||||
_gray = "\u001b[90m"
|
||||
_green = "\u001b[32m"
|
||||
_red = "\u001b[31m"
|
||||
_reset = "\u001b[0m"
|
||||
_yellow = "\u001b[33m"
|
||||
|
||||
@classmethod
|
||||
def blue(cls, s: str) -> str:
|
||||
return cls._format(s, cls._blue)
|
||||
|
||||
@classmethod
|
||||
def bold(cls, s: str) -> str:
|
||||
return cls._format(s, cls._bold)
|
||||
@@ -36,6 +42,10 @@ class ANSI:
|
||||
def gray(cls, s: str) -> str:
|
||||
return cls._format(s, cls._gray)
|
||||
|
||||
@classmethod
|
||||
def green(cls, s: str) -> str:
|
||||
return cls._format(s, cls._green)
|
||||
|
||||
@classmethod
|
||||
def red(cls, s: str) -> str:
|
||||
return cls._format(s, cls._bold + cls._red)
|
||||
|
||||
Reference in New Issue
Block a user