chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
# ruff: noqa: F401
from huggingface_hub.errors import (
BadRequestError,
CacheNotFound,
CorruptedCacheException,
DisabledRepoError,
EntryNotFoundError,
FileMetadataError,
GatedRepoError,
HfHubHTTPError,
HFValidationError,
LocalEntryNotFoundError,
LocalTokenNotFoundError,
NotASafetensorsRepoError,
OfflineModeIsEnabled,
RepositoryNotFoundError,
RevisionNotFoundError,
SafetensorsParsingError,
)
from . import tqdm as _tqdm # _tqdm is the module
from ._auth import get_stored_tokens, get_token
from ._cache_assets import cached_assets_path
from ._cache_manager import (
CachedFileInfo,
CachedRepoInfo,
CachedRevisionInfo,
DeleteCacheStrategy,
HFCacheInfo,
_format_size,
scan_cache_dir,
)
from ._chunk_utils import chunk_iterable
from ._datetime import parse_datetime
from ._experimental import experimental
from ._fixes import SoftTemporaryDirectory, WeakFileLock, yaml_dump
from ._git_credential import list_credential_helpers, set_git_credential, unset_git_credential
from ._headers import build_hf_headers, get_token_to_send
from ._http import (
ASYNC_CLIENT_FACTORY_T,
CLIENT_FACTORY_T,
close_session,
fix_hf_endpoint_in_url,
get_async_session,
get_session,
hf_raise_for_status,
http_backoff,
http_stream_backoff,
set_async_client_factory,
set_client_factory,
)
from ._pagination import paginate
from ._paths import DEFAULT_IGNORE_PATTERNS, FORBIDDEN_FOLDERS, filter_repo_objects
from ._runtime import (
dump_environment_info,
get_aiohttp_version,
get_fastai_version,
get_fastapi_version,
get_fastcore_version,
get_gradio_version,
get_graphviz_version,
get_hf_hub_version,
get_jinja_version,
get_numpy_version,
get_pillow_version,
get_pydantic_version,
get_pydot_version,
get_python_version,
get_tensorboard_version,
get_tf_version,
get_torch_version,
installation_method,
is_aiohttp_available,
is_colab_enterprise,
is_fastai_available,
is_fastapi_available,
is_fastcore_available,
is_google_colab,
is_gradio_available,
is_graphviz_available,
is_jinja_available,
is_notebook,
is_numpy_available,
is_package_available,
is_pillow_available,
is_pydantic_available,
is_pydot_available,
is_safetensors_available,
is_tensorboard_available,
is_tf_available,
is_torch_available,
)
from ._safetensors import SafetensorsFileMetadata, SafetensorsRepoMetadata, TensorInfo
from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess
from ._telemetry import send_telemetry
from ._terminal import ANSI, tabulate
from ._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type
from ._validators import validate_hf_hub_args, validate_repo_id
from ._xet import (
XetConnectionInfo,
XetFileData,
XetTokenType,
fetch_xet_connection_info_from_repo_info,
parse_xet_file_data_from_response,
refresh_xet_connection_info,
)
from .tqdm import (
are_progress_bars_disabled,
disable_progress_bars,
enable_progress_bars,
is_tqdm_disabled,
tqdm,
tqdm_stream_file,
)

View File

@@ -0,0 +1,214 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a helper to get the token from machine (env variable, secret or config file)."""
import configparser
import logging
import os
import warnings
from pathlib import Path
from threading import Lock
from typing import Optional
from .. import constants
from ._runtime import is_colab_enterprise, is_google_colab
_IS_GOOGLE_COLAB_CHECKED = False
_GOOGLE_COLAB_SECRET_LOCK = Lock()
_GOOGLE_COLAB_SECRET: Optional[str] = None
logger = logging.getLogger(__name__)
def get_token() -> Optional[str]:
"""
Get token if user is logged in.
Note: in most cases, you should use [`huggingface_hub.utils.build_hf_headers`] instead. This method is only useful
if you want to retrieve the token for other purposes than sending an HTTP request.
Token is retrieved in priority from the `HF_TOKEN` environment variable. Otherwise, we read the token file located
in the Hugging Face home folder. Returns None if user is not logged in. To log in, use [`login`] or
`hf auth login`.
Returns:
`str` or `None`: The token, `None` if it doesn't exist.
"""
return _get_token_from_google_colab() or _get_token_from_environment() or _get_token_from_file()
def _get_token_from_google_colab() -> Optional[str]:
"""Get token from Google Colab secrets vault using `google.colab.userdata.get(...)`.
Token is read from the vault only once per session and then stored in a global variable to avoid re-requesting
access to the vault.
"""
# If it's not a Google Colab or it's Colab Enterprise, fallback to environment variable or token file authentication
if not is_google_colab() or is_colab_enterprise():
return None
# `google.colab.userdata` is not thread-safe
# This can lead to a deadlock if multiple threads try to access it at the same time
# (typically when using `snapshot_download`)
# => use a lock
# See https://github.com/huggingface/huggingface_hub/issues/1952 for more details.
with _GOOGLE_COLAB_SECRET_LOCK:
global _GOOGLE_COLAB_SECRET
global _IS_GOOGLE_COLAB_CHECKED
if _IS_GOOGLE_COLAB_CHECKED: # request access only once
return _GOOGLE_COLAB_SECRET
try:
from google.colab import userdata # type: ignore
from google.colab.errors import Error as ColabError # type: ignore
except ImportError:
return None
try:
token = userdata.get("HF_TOKEN")
_GOOGLE_COLAB_SECRET = _clean_token(token)
except userdata.NotebookAccessError:
# Means the user has a secret call `HF_TOKEN` and got a popup "please grand access to HF_TOKEN" and refused it
# => warn user but ignore error => do not re-request access to user
warnings.warn(
"\nAccess to the secret `HF_TOKEN` has not been granted on this notebook."
"\nYou will not be requested again."
"\nPlease restart the session if you want to be prompted again."
)
_GOOGLE_COLAB_SECRET = None
except userdata.SecretNotFoundError:
# Means the user did not define a `HF_TOKEN` secret => warn
warnings.warn(
"\nThe secret `HF_TOKEN` does not exist in your Colab secrets."
"\nTo authenticate with the Hugging Face Hub, create a token in your settings tab "
"(https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session."
"\nYou will be able to reuse this secret in all of your notebooks."
"\nPlease note that authentication is recommended but still optional to access public models or datasets."
)
_GOOGLE_COLAB_SECRET = None
except ColabError as e:
# Something happen but we don't know what => recommend to open a GitHub issue
warnings.warn(
f"\nError while fetching `HF_TOKEN` secret value from your vault: '{str(e)}'."
"\nYou are not authenticated with the Hugging Face Hub in this notebook."
"\nIf the error persists, please let us know by opening an issue on GitHub "
"(https://github.com/huggingface/huggingface_hub/issues/new)."
)
_GOOGLE_COLAB_SECRET = None
_IS_GOOGLE_COLAB_CHECKED = True
return _GOOGLE_COLAB_SECRET
def _get_token_from_environment() -> Optional[str]:
# `HF_TOKEN` has priority (keep `HUGGING_FACE_HUB_TOKEN` for backward compatibility)
return _clean_token(os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN"))
def _get_token_from_file() -> Optional[str]:
try:
return _clean_token(Path(constants.HF_TOKEN_PATH).read_text())
except FileNotFoundError:
return None
def get_stored_tokens() -> dict[str, str]:
"""
Returns the parsed INI file containing the access tokens.
The file is located at `HF_STORED_TOKENS_PATH`, defaulting to `~/.cache/huggingface/stored_tokens`.
If the file does not exist, an empty dictionary is returned.
Returns: `dict[str, str]`
Key is the token name and value is the token.
"""
tokens_path = Path(constants.HF_STORED_TOKENS_PATH)
if not tokens_path.exists():
stored_tokens = {}
config = configparser.ConfigParser()
try:
config.read(tokens_path)
stored_tokens = {token_name: config.get(token_name, "hf_token") for token_name in config.sections()}
except configparser.Error as e:
logger.error(f"Error parsing stored tokens file: {e}")
stored_tokens = {}
return stored_tokens
def _save_stored_tokens(stored_tokens: dict[str, str]) -> None:
"""
Saves the given configuration to the stored tokens file.
Args:
stored_tokens (`dict[str, str]`):
The stored tokens to save. Key is the token name and value is the token.
"""
stored_tokens_path = Path(constants.HF_STORED_TOKENS_PATH)
# Write the stored tokens into an INI file
config = configparser.ConfigParser()
for token_name in sorted(stored_tokens.keys()):
config.add_section(token_name)
config.set(token_name, "hf_token", stored_tokens[token_name])
stored_tokens_path.parent.mkdir(parents=True, exist_ok=True)
with stored_tokens_path.open("w") as config_file:
config.write(config_file)
def _get_token_by_name(token_name: str) -> Optional[str]:
"""
Get the token by name.
Args:
token_name (`str`):
The name of the token to get.
Returns:
`str` or `None`: The token, `None` if it doesn't exist.
"""
stored_tokens = get_stored_tokens()
if token_name not in stored_tokens:
return None
return _clean_token(stored_tokens[token_name])
def _save_token(token: str, token_name: str) -> None:
"""
Save the given token.
If the stored tokens file does not exist, it will be created.
Args:
token (`str`):
The token to save.
token_name (`str`):
The name of the token.
"""
tokens_path = Path(constants.HF_STORED_TOKENS_PATH)
stored_tokens = get_stored_tokens()
stored_tokens[token_name] = token
_save_stored_tokens(stored_tokens)
logger.info(f"The token `{token_name}` has been saved to {tokens_path}")
def _clean_token(token: Optional[str]) -> Optional[str]:
"""Clean token by removing trailing and leading spaces and newlines.
If token is an empty string, return None.
"""
if token is None:
return None
return token.replace("\r", "").replace("\n", "").strip() or None

View File

@@ -0,0 +1,135 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Union
from ..constants import HF_ASSETS_CACHE
def cached_assets_path(
library_name: str,
namespace: str = "default",
subfolder: str = "default",
*,
assets_dir: Union[str, Path, None] = None,
):
"""Return a folder path to cache arbitrary files.
`huggingface_hub` provides a canonical folder path to store assets. This is the
recommended way to integrate cache in a downstream library as it will benefit from
the builtins tools to scan and delete the cache properly.
The distinction is made between files cached from the Hub and assets. Files from the
Hub are cached in a git-aware manner and entirely managed by `huggingface_hub`. See
[related documentation](https://huggingface.co/docs/huggingface_hub/how-to-cache).
All other files that a downstream library caches are considered to be "assets"
(files downloaded from external sources, extracted from a .tar archive, preprocessed
for training,...).
Once the folder path is generated, it is guaranteed to exist and to be a directory.
The path is based on 3 levels of depth: the library name, a namespace and a
subfolder. Those 3 levels grants flexibility while allowing `huggingface_hub` to
expect folders when scanning/deleting parts of the assets cache. Within a library,
it is expected that all namespaces share the same subset of subfolder names but this
is not a mandatory rule. The downstream library has then full control on which file
structure to adopt within its cache. Namespace and subfolder are optional (would
default to a `"default/"` subfolder) but library name is mandatory as we want every
downstream library to manage its own cache.
Expected tree:
```text
assets/
└── datasets/
│ ├── SQuAD/
│ │ ├── downloaded/
│ │ ├── extracted/
│ │ └── processed/
│ ├── Helsinki-NLP--tatoeba_mt/
│ ├── downloaded/
│ ├── extracted/
│ └── processed/
└── transformers/
├── default/
│ ├── something/
├── bert-base-cased/
│ ├── default/
│ └── training/
hub/
└── models--julien-c--EsperBERTo-small/
├── blobs/
│ ├── (...)
│ ├── (...)
├── refs/
│ └── (...)
└── [ 128] snapshots/
├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/
│ ├── (...)
└── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/
└── (...)
```
Args:
library_name (`str`):
Name of the library that will manage the cache folder. Example: `"dataset"`.
namespace (`str`, *optional*, defaults to "default"):
Namespace to which the data belongs. Example: `"SQuAD"`.
subfolder (`str`, *optional*, defaults to "default"):
Subfolder in which the data will be stored. Example: `extracted`.
assets_dir (`str`, `Path`, *optional*):
Path to the folder where assets are cached. This must not be the same folder
where Hub files are cached. Defaults to `HF_HOME / "assets"` if not provided.
Can also be set with `HF_ASSETS_CACHE` environment variable.
Returns:
Path to the cache folder (`Path`).
Example:
```py
>>> from huggingface_hub import cached_assets_path
>>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="download")
PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/download')
>>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="extracted")
PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/extracted')
>>> cached_assets_path(library_name="datasets", namespace="Helsinki-NLP/tatoeba_mt")
PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/Helsinki-NLP--tatoeba_mt/default')
>>> cached_assets_path(library_name="datasets", assets_dir="/tmp/tmp123456")
PosixPath('/tmp/tmp123456/datasets/default/default')
```
"""
# Resolve assets_dir
if assets_dir is None:
assets_dir = HF_ASSETS_CACHE
assets_dir = Path(assets_dir).expanduser().resolve()
# Avoid names that could create path issues
for part in (" ", "/", "\\"):
library_name = library_name.replace(part, "--")
namespace = namespace.replace(part, "--")
subfolder = subfolder.replace(part, "--")
# Path to subfolder is created
path = assets_dir / library_name / namespace / subfolder
try:
path.mkdir(exist_ok=True, parents=True)
except (FileExistsError, NotADirectoryError):
raise ValueError(f"Corrupted assets folder: cannot create directory because of an existing file ({path}).")
# Return
return path

View File

@@ -0,0 +1,841 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to manage the HF cache directory."""
import os
import shutil
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, Union
from huggingface_hub.errors import CacheNotFound, CorruptedCacheException
from ..constants import HF_HUB_CACHE
from . import logging
from ._parsing import format_timesince
from ._terminal import tabulate
logger = logging.get_logger(__name__)
REPO_TYPE_T = Literal["model", "dataset", "space"]
# List of OS-created helper files that need to be ignored
FILES_TO_IGNORE = [".DS_Store"]
@dataclass(frozen=True)
class CachedFileInfo:
"""Frozen data structure holding information about a single cached file.
Args:
file_name (`str`):
Name of the file. Example: `config.json`.
file_path (`Path`):
Path of the file in the `snapshots` directory. The file path is a symlink
referring to a blob in the `blobs` folder.
blob_path (`Path`):
Path of the blob file. This is equivalent to `file_path.resolve()`.
size_on_disk (`int`):
Size of the blob file in bytes.
blob_last_accessed (`float`):
Timestamp of the last time the blob file has been accessed (from any
revision).
blob_last_modified (`float`):
Timestamp of the last time the blob file has been modified/created.
> [!WARNING]
> `blob_last_accessed` and `blob_last_modified` reliability can depend on the OS you
> are using. See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result)
> for more details.
"""
file_name: str
file_path: Path
blob_path: Path
size_on_disk: int
blob_last_accessed: float
blob_last_modified: float
@property
def blob_last_accessed_str(self) -> str:
"""
(property) Timestamp of the last time the blob file has been accessed (from any
revision), returned as a human-readable string.
Example: "2 weeks ago".
"""
return format_timesince(self.blob_last_accessed)
@property
def blob_last_modified_str(self) -> str:
"""
(property) Timestamp of the last time the blob file has been modified, returned
as a human-readable string.
Example: "2 weeks ago".
"""
return format_timesince(self.blob_last_modified)
@property
def size_on_disk_str(self) -> str:
"""
(property) Size of the blob file as a human-readable string.
Example: "42.2K".
"""
return _format_size(self.size_on_disk)
@dataclass(frozen=True)
class CachedRevisionInfo:
"""Frozen data structure holding information about a revision.
A revision correspond to a folder in the `snapshots` folder and is populated with
the exact tree structure as the repo on the Hub but contains only symlinks. A
revision can be either referenced by 1 or more `refs` or be "detached" (no refs).
Args:
commit_hash (`str`):
Hash of the revision (unique).
Example: `"9338f7b671827df886678df2bdd7cc7b4f36dffd"`.
snapshot_path (`Path`):
Path to the revision directory in the `snapshots` folder. It contains the
exact tree structure as the repo on the Hub.
files: (`frozenset[CachedFileInfo]`):
Set of [`~CachedFileInfo`] describing all files contained in the snapshot.
refs (`frozenset[str]`):
Set of `refs` pointing to this revision. If the revision has no `refs`, it
is considered detached.
Example: `{"main", "2.4.0"}` or `{"refs/pr/1"}`.
size_on_disk (`int`):
Sum of the blob file sizes that are symlink-ed by the revision.
last_modified (`float`):
Timestamp of the last time the revision has been created/modified.
> [!WARNING]
> `last_accessed` cannot be determined correctly on a single revision as blob files
> are shared across revisions.
> [!WARNING]
> `size_on_disk` is not necessarily the sum of all file sizes because of possible
> duplicated files. Besides, only blobs are taken into account, not the (negligible)
> size of folders and symlinks.
"""
commit_hash: str
snapshot_path: Path
size_on_disk: int
files: frozenset[CachedFileInfo]
refs: frozenset[str]
last_modified: float
@property
def last_modified_str(self) -> str:
"""
(property) Timestamp of the last time the revision has been modified, returned
as a human-readable string.
Example: "2 weeks ago".
"""
return format_timesince(self.last_modified)
@property
def size_on_disk_str(self) -> str:
"""
(property) Sum of the blob file sizes as a human-readable string.
Example: "42.2K".
"""
return _format_size(self.size_on_disk)
@property
def nb_files(self) -> int:
"""
(property) Total number of files in the revision.
"""
return len(self.files)
@dataclass(frozen=True)
class CachedRepoInfo:
"""Frozen data structure holding information about a cached repository.
Args:
repo_id (`str`):
Repo id of the repo on the Hub. Example: `"google/fleurs"`.
repo_type (`Literal["dataset", "model", "space"]`):
Type of the cached repo.
repo_path (`Path`):
Local path to the cached repo.
size_on_disk (`int`):
Sum of the blob file sizes in the cached repo.
nb_files (`int`):
Total number of blob files in the cached repo.
revisions (`frozenset[CachedRevisionInfo]`):
Set of [`~CachedRevisionInfo`] describing all revisions cached in the repo.
last_accessed (`float`):
Timestamp of the last time a blob file of the repo has been accessed.
last_modified (`float`):
Timestamp of the last time a blob file of the repo has been modified/created.
> [!WARNING]
> `size_on_disk` is not necessarily the sum of all revisions sizes because of
> duplicated files. Besides, only blobs are taken into account, not the (negligible)
> size of folders and symlinks.
> [!WARNING]
> `last_accessed` and `last_modified` reliability can depend on the OS you are using.
> See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result)
> for more details.
"""
repo_id: str
repo_type: REPO_TYPE_T
repo_path: Path
size_on_disk: int
nb_files: int
revisions: frozenset[CachedRevisionInfo]
last_accessed: float
last_modified: float
@property
def last_accessed_str(self) -> str:
"""
(property) Last time a blob file of the repo has been accessed, returned as a
human-readable string.
Example: "2 weeks ago".
"""
return format_timesince(self.last_accessed)
@property
def last_modified_str(self) -> str:
"""
(property) Last time a blob file of the repo has been modified, returned as a
human-readable string.
Example: "2 weeks ago".
"""
return format_timesince(self.last_modified)
@property
def size_on_disk_str(self) -> str:
"""
(property) Sum of the blob file sizes as a human-readable string.
Example: "42.2K".
"""
return _format_size(self.size_on_disk)
@property
def cache_id(self) -> str:
"""Canonical `type/id` identifier used across cache tooling."""
return f"{self.repo_type}/{self.repo_id}"
@property
def refs(self) -> dict[str, CachedRevisionInfo]:
"""
(property) Mapping between `refs` and revision data structures.
"""
return {ref: revision for revision in self.revisions for ref in revision.refs}
@dataclass(frozen=True)
class DeleteCacheStrategy:
"""Frozen data structure holding the strategy to delete cached revisions.
This object is not meant to be instantiated programmatically but to be returned by
[`~utils.HFCacheInfo.delete_revisions`]. See documentation for usage example.
Args:
expected_freed_size (`float`):
Expected freed size once strategy is executed.
blobs (`frozenset[Path]`):
Set of blob file paths to be deleted.
refs (`frozenset[Path]`):
Set of reference file paths to be deleted.
repos (`frozenset[Path]`):
Set of entire repo paths to be deleted.
snapshots (`frozenset[Path]`):
Set of snapshots to be deleted (directory of symlinks).
"""
expected_freed_size: int
blobs: frozenset[Path]
refs: frozenset[Path]
repos: frozenset[Path]
snapshots: frozenset[Path]
@property
def expected_freed_size_str(self) -> str:
"""
(property) Expected size that will be freed as a human-readable string.
Example: "42.2K".
"""
return _format_size(self.expected_freed_size)
def execute(self) -> None:
"""Execute the defined strategy.
> [!WARNING]
> If this method is interrupted, the cache might get corrupted. Deletion order is
> implemented so that references and symlinks are deleted before the actual blob
> files.
> [!WARNING]
> This method is irreversible. If executed, cached files are erased and must be
> downloaded again.
"""
# Deletion order matters. Blobs are deleted in last so that the user can't end
# up in a state where a `ref`` refers to a missing snapshot or a snapshot
# symlink refers to a deleted blob.
# Delete entire repos
for path in self.repos:
_try_delete_path(path, path_type="repo")
# Delete snapshot directories
for path in self.snapshots:
_try_delete_path(path, path_type="snapshot")
# Delete refs files
for path in self.refs:
_try_delete_path(path, path_type="ref")
# Delete blob files
for path in self.blobs:
_try_delete_path(path, path_type="blob")
logger.info(f"Cache deletion done. Saved {self.expected_freed_size_str}.")
@dataclass(frozen=True)
class HFCacheInfo:
"""Frozen data structure holding information about the entire cache-system.
This data structure is returned by [`scan_cache_dir`] and is immutable.
Args:
size_on_disk (`int`):
Sum of all valid repo sizes in the cache-system.
repos (`frozenset[CachedRepoInfo]`):
Set of [`~CachedRepoInfo`] describing all valid cached repos found on the
cache-system while scanning.
warnings (`list[CorruptedCacheException]`):
List of [`~CorruptedCacheException`] that occurred while scanning the cache.
Those exceptions are captured so that the scan can continue. Corrupted repos
are skipped from the scan.
> [!WARNING]
> Here `size_on_disk` is equal to the sum of all repo sizes (only blobs). However if
> some cached repos are corrupted, their sizes are not taken into account.
"""
size_on_disk: int
repos: frozenset[CachedRepoInfo]
warnings: list[CorruptedCacheException]
@property
def size_on_disk_str(self) -> str:
"""
(property) Sum of all valid repo sizes in the cache-system as a human-readable
string.
Example: "42.2K".
"""
return _format_size(self.size_on_disk)
def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy:
"""Prepare the strategy to delete one or more revisions cached locally.
Input revisions can be any revision hash. If a revision hash is not found in the
local cache, a warning is thrown but no error is raised. Revisions can be from
different cached repos since hashes are unique across repos,
Examples:
```py
>>> from huggingface_hub import scan_cache_dir
>>> cache_info = scan_cache_dir()
>>> delete_strategy = cache_info.delete_revisions(
... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa"
... )
>>> print(f"Will free {delete_strategy.expected_freed_size_str}.")
Will free 7.9K.
>>> delete_strategy.execute()
Cache deletion done. Saved 7.9K.
```
```py
>>> from huggingface_hub import scan_cache_dir
>>> scan_cache_dir().delete_revisions(
... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa",
... "e2983b237dccf3ab4937c97fa717319a9ca1a96d",
... "6c0e6080953db56375760c0471a8c5f2929baf11",
... ).execute()
Cache deletion done. Saved 8.6G.
```
> [!WARNING]
> `delete_revisions` returns a [`~utils.DeleteCacheStrategy`] object that needs to
> be executed. The [`~utils.DeleteCacheStrategy`] is not meant to be modified but
> allows having a dry run before actually executing the deletion.
"""
hashes_to_delete: set[str] = set(revisions)
repos_with_revisions: dict[CachedRepoInfo, set[CachedRevisionInfo]] = defaultdict(set)
for repo in self.repos:
for revision in repo.revisions:
if revision.commit_hash in hashes_to_delete:
repos_with_revisions[repo].add(revision)
hashes_to_delete.remove(revision.commit_hash)
if len(hashes_to_delete) > 0:
logger.warning(f"Revision(s) not found - cannot delete them: {', '.join(hashes_to_delete)}")
delete_strategy_blobs: set[Path] = set()
delete_strategy_refs: set[Path] = set()
delete_strategy_repos: set[Path] = set()
delete_strategy_snapshots: set[Path] = set()
delete_strategy_expected_freed_size = 0
for affected_repo, revisions_to_delete in repos_with_revisions.items():
other_revisions = affected_repo.revisions - revisions_to_delete
# If no other revisions, it means all revisions are deleted
# -> delete the entire cached repo
if len(other_revisions) == 0:
delete_strategy_repos.add(affected_repo.repo_path)
delete_strategy_expected_freed_size += affected_repo.size_on_disk
continue
# Some revisions of the repo will be deleted but not all. We need to filter
# which blob files will not be linked anymore.
for revision_to_delete in revisions_to_delete:
# Snapshot dir
delete_strategy_snapshots.add(revision_to_delete.snapshot_path)
# Refs dir
for ref in revision_to_delete.refs:
delete_strategy_refs.add(affected_repo.repo_path / "refs" / ref)
# Blobs dir
for file in revision_to_delete.files:
if file.blob_path not in delete_strategy_blobs:
is_file_alone = True
for revision in other_revisions:
for rev_file in revision.files:
if file.blob_path == rev_file.blob_path:
is_file_alone = False
break
if not is_file_alone:
break
# Blob file not referenced by remaining revisions -> delete
if is_file_alone:
delete_strategy_blobs.add(file.blob_path)
delete_strategy_expected_freed_size += file.size_on_disk
# Return the strategy instead of executing it.
return DeleteCacheStrategy(
blobs=frozenset(delete_strategy_blobs),
refs=frozenset(delete_strategy_refs),
repos=frozenset(delete_strategy_repos),
snapshots=frozenset(delete_strategy_snapshots),
expected_freed_size=delete_strategy_expected_freed_size,
)
def export_as_table(self, *, verbosity: int = 0) -> str:
"""Generate a table from the [`HFCacheInfo`] object.
Pass `verbosity=0` to get a table with a single row per repo, with columns
"repo_id", "repo_type", "size_on_disk", "nb_files", "last_accessed", "last_modified", "refs", "local_path".
Pass `verbosity=1` to get a table with a row per repo and revision (thus multiple rows can appear for a single repo), with columns
"repo_id", "repo_type", "revision", "size_on_disk", "nb_files", "last_modified", "refs", "local_path".
Example:
```py
>>> from huggingface_hub.utils import scan_cache_dir
>>> hf_cache_info = scan_cache_dir()
HFCacheInfo(...)
>>> print(hf_cache_info.export_as_table())
REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH
--------------------------------------------------- --------- ------------ -------- ------------- ------------- ---- --------------------------------------------------------------------------------------------------
roberta-base model 2.7M 5 1 day ago 1 week ago main ~/.cache/huggingface/hub/models--roberta-base
suno/bark model 8.8K 1 1 week ago 1 week ago main ~/.cache/huggingface/hub/models--suno--bark
t5-base model 893.8M 4 4 days ago 7 months ago main ~/.cache/huggingface/hub/models--t5-base
t5-large model 3.0G 4 5 weeks ago 5 months ago main ~/.cache/huggingface/hub/models--t5-large
>>> print(hf_cache_info.export_as_table(verbosity=1))
REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH
--------------------------------------------------- --------- ---------------------------------------- ------------ -------- ------------- ---- -----------------------------------------------------------------------------------------------------------------------------------------------------
roberta-base model e2da8e2f811d1448a5b465c236feacd80ffbac7b 2.7M 5 1 week ago main ~/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b
suno/bark model 70a8a7d34168586dc5d028fa9666aceade177992 8.8K 1 1 week ago main ~/.cache/huggingface/hub/models--suno--bark/snapshots/70a8a7d34168586dc5d028fa9666aceade177992
t5-base model a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 893.8M 4 7 months ago main ~/.cache/huggingface/hub/models--t5-base/snapshots/a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1
t5-large model 150ebc2c4b72291e770f58e6057481c8d2ed331a 3.0G 4 5 months ago main ~/.cache/huggingface/hub/models--t5-large/snapshots/150ebc2c4b72291e770f58e6057481c8d2ed331a
```
Args:
verbosity (`int`, *optional*):
The verbosity level. Defaults to 0.
Returns:
`str`: The table as a string.
"""
if verbosity == 0:
return tabulate(
rows=[
[
repo.repo_id,
repo.repo_type,
"{:>12}".format(repo.size_on_disk_str),
repo.nb_files,
repo.last_accessed_str,
repo.last_modified_str,
", ".join(sorted(repo.refs)),
str(repo.repo_path),
]
for repo in sorted(self.repos, key=lambda repo: repo.repo_path)
],
headers=[
"REPO ID",
"REPO TYPE",
"SIZE ON DISK",
"NB FILES",
"LAST_ACCESSED",
"LAST_MODIFIED",
"REFS",
"LOCAL PATH",
],
)
else:
return tabulate(
rows=[
[
repo.repo_id,
repo.repo_type,
revision.commit_hash,
"{:>12}".format(revision.size_on_disk_str),
revision.nb_files,
revision.last_modified_str,
", ".join(sorted(revision.refs)),
str(revision.snapshot_path),
]
for repo in sorted(self.repos, key=lambda repo: repo.repo_path)
for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash)
],
headers=[
"REPO ID",
"REPO TYPE",
"REVISION",
"SIZE ON DISK",
"NB FILES",
"LAST_MODIFIED",
"REFS",
"LOCAL PATH",
],
)
def scan_cache_dir(cache_dir: Optional[Union[str, Path]] = None) -> HFCacheInfo:
"""Scan the entire HF cache-system and return a [`~HFCacheInfo`] structure.
Use `scan_cache_dir` in order to programmatically scan your cache-system. The cache
will be scanned repo by repo. If a repo is corrupted, a [`~CorruptedCacheException`]
will be thrown internally but captured and returned in the [`~HFCacheInfo`]
structure. Only valid repos get a proper report.
```py
>>> from huggingface_hub import scan_cache_dir
>>> hf_cache_info = scan_cache_dir()
HFCacheInfo(
size_on_disk=3398085269,
repos=frozenset({
CachedRepoInfo(
repo_id='t5-small',
repo_type='model',
repo_path=PosixPath(...),
size_on_disk=970726914,
nb_files=11,
revisions=frozenset({
CachedRevisionInfo(
commit_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5',
size_on_disk=970726339,
snapshot_path=PosixPath(...),
files=frozenset({
CachedFileInfo(
file_name='config.json',
size_on_disk=1197
file_path=PosixPath(...),
blob_path=PosixPath(...),
),
CachedFileInfo(...),
...
}),
),
CachedRevisionInfo(...),
...
}),
),
CachedRepoInfo(...),
...
}),
warnings=[
CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."),
CorruptedCacheException(...),
...
],
)
```
You can also print a detailed report directly from the `hf` command line using:
```text
> hf cache ls
ID SIZE LAST_ACCESSED LAST_MODIFIED REFS
--------------------------- -------- ------------- ------------- -----------
dataset/nyu-mll/glue 157.4M 2 days ago 2 days ago main script
model/LiquidAI/LFM2-VL-1.6B 3.2G 4 days ago 4 days ago main
model/microsoft/UserLM-8b 32.1G 4 days ago 4 days ago main
Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G.
Got 1 warning(s) while scanning. Use -vvv to print details.
```
Args:
cache_dir (`str` or `Path`, `optional`):
Cache directory to cache. Defaults to the default HF cache directory.
> [!WARNING]
> Raises:
>
> `CacheNotFound`
> If the cache directory does not exist.
>
> [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
> If the cache directory is a file, instead of a directory.
Returns: a [`~HFCacheInfo`] object.
"""
if cache_dir is None:
cache_dir = HF_HUB_CACHE
cache_dir = Path(cache_dir).expanduser().resolve()
if not cache_dir.exists():
raise CacheNotFound(
f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.",
cache_dir=cache_dir,
)
if cache_dir.is_file():
raise ValueError(
f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable."
)
repos: set[CachedRepoInfo] = set()
warnings: list[CorruptedCacheException] = []
for repo_path in cache_dir.iterdir():
if repo_path.name == ".locks": # skip './.locks/' folder
continue
try:
repos.add(_scan_cached_repo(repo_path))
except CorruptedCacheException as e:
warnings.append(e)
return HFCacheInfo(
repos=frozenset(repos),
size_on_disk=sum(repo.size_on_disk for repo in repos),
warnings=warnings,
)
def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo:
"""Scan a single cache repo and return information about it.
Any unexpected behavior will raise a [`~CorruptedCacheException`].
"""
if not repo_path.is_dir():
raise CorruptedCacheException(f"Repo path is not a directory: {repo_path}")
if "--" not in repo_path.name:
raise CorruptedCacheException(f"Repo path is not a valid HuggingFace cache directory: {repo_path}")
repo_type, repo_id = repo_path.name.split("--", maxsplit=1)
repo_type = repo_type[:-1] # "models" -> "model"
repo_id = repo_id.replace("--", "/") # google/fleurs -> "google/fleurs"
if repo_type not in {"dataset", "model", "space"}:
raise CorruptedCacheException(
f"Repo type must be `dataset`, `model` or `space`, found `{repo_type}` ({repo_path})."
)
blob_stats: dict[Path, os.stat_result] = {} # Key is blob_path, value is blob stats
snapshots_path = repo_path / "snapshots"
refs_path = repo_path / "refs"
if not snapshots_path.exists() or not snapshots_path.is_dir():
raise CorruptedCacheException(f"Snapshots dir doesn't exist in cached repo: {snapshots_path}")
# Scan over `refs` directory
# key is revision hash, value is set of refs
refs_by_hash: dict[str, set[str]] = defaultdict(set)
if refs_path.exists():
# Example of `refs` directory
# ── refs
# ├── main
# └── refs
# └── pr
# └── 1
if refs_path.is_file():
raise CorruptedCacheException(f"Refs directory cannot be a file: {refs_path}")
for ref_path in refs_path.glob("**/*"):
# glob("**/*") iterates over all files and directories -> skip directories
if ref_path.is_dir() or ref_path.name in FILES_TO_IGNORE:
continue
ref_name = str(ref_path.relative_to(refs_path))
with ref_path.open() as f:
commit_hash = f.read()
refs_by_hash[commit_hash].add(ref_name)
# Scan snapshots directory
cached_revisions: set[CachedRevisionInfo] = set()
for revision_path in snapshots_path.iterdir():
# Ignore OS-created helper files
if revision_path.name in FILES_TO_IGNORE:
continue
if revision_path.is_file():
raise CorruptedCacheException(f"Snapshots folder corrupted. Found a file: {revision_path}")
cached_files = set()
for file_path in revision_path.glob("**/*"):
# glob("**/*") iterates over all files and directories -> skip directories
if file_path.is_dir():
continue
blob_path = Path(file_path).resolve()
if not blob_path.exists():
raise CorruptedCacheException(f"Blob missing (broken symlink): {blob_path}")
if blob_path not in blob_stats:
blob_stats[blob_path] = blob_path.stat()
cached_files.add(
CachedFileInfo(
file_name=file_path.name,
file_path=file_path,
size_on_disk=blob_stats[blob_path].st_size,
blob_path=blob_path,
blob_last_accessed=blob_stats[blob_path].st_atime,
blob_last_modified=blob_stats[blob_path].st_mtime,
)
)
# Last modified is either the last modified blob file or the revision folder
# itself if it is empty
if len(cached_files) > 0:
revision_last_modified = max(blob_stats[file.blob_path].st_mtime for file in cached_files)
else:
revision_last_modified = revision_path.stat().st_mtime
cached_revisions.add(
CachedRevisionInfo(
commit_hash=revision_path.name,
files=frozenset(cached_files),
refs=frozenset(refs_by_hash.pop(revision_path.name, set())),
size_on_disk=sum(
blob_stats[blob_path].st_size for blob_path in set(file.blob_path for file in cached_files)
),
snapshot_path=revision_path,
last_modified=revision_last_modified,
)
)
# Check that all refs referred to an existing revision
if len(refs_by_hash) > 0:
raise CorruptedCacheException(
f"Reference(s) refer to missing commit hashes: {dict(refs_by_hash)} ({repo_path})."
)
# Last modified is either the last modified blob file or the repo folder itself if
# no blob files has been found. Same for last accessed.
if len(blob_stats) > 0:
repo_last_accessed = max(stat.st_atime for stat in blob_stats.values())
repo_last_modified = max(stat.st_mtime for stat in blob_stats.values())
else:
repo_stats = repo_path.stat()
repo_last_accessed = repo_stats.st_atime
repo_last_modified = repo_stats.st_mtime
# Build and return frozen structure
return CachedRepoInfo(
nb_files=len(blob_stats),
repo_id=repo_id,
repo_path=repo_path,
repo_type=repo_type, # type: ignore
revisions=frozenset(cached_revisions),
size_on_disk=sum(stat.st_size for stat in blob_stats.values()),
last_accessed=repo_last_accessed,
last_modified=repo_last_modified,
)
def _format_size(num: int) -> str:
"""Format size in bytes into a human-readable string.
Taken from https://stackoverflow.com/a/1094933
"""
num_f = float(num)
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
if abs(num_f) < 1000.0:
return f"{num_f:3.1f}{unit}"
num_f /= 1000.0
return f"{num_f:.1f}Y"
def _try_delete_path(path: Path, path_type: str) -> None:
"""Try to delete a local file or folder.
If the path does not exist, error is logged as a warning and then ignored.
Args:
path (`Path`)
Path to delete. Can be a file or a folder.
path_type (`str`)
What path are we deleting ? Only for logging purposes. Example: "snapshot".
"""
logger.info(f"Delete {path_type}: {path}")
try:
if path.is_file():
os.remove(path)
else:
shutil.rmtree(path)
except FileNotFoundError:
logger.warning(f"Couldn't delete {path_type}: file not found ({path})", exc_info=True)
except PermissionError:
logger.warning(f"Couldn't delete {path_type}: permission denied ({path})", exc_info=True)

View File

@@ -0,0 +1,64 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a utility to iterate by chunks over an iterator."""
import itertools
from typing import Iterable, TypeVar
T = TypeVar("T")
def chunk_iterable(iterable: Iterable[T], chunk_size: int) -> Iterable[Iterable[T]]:
"""Iterates over an iterator chunk by chunk.
Taken from https://stackoverflow.com/a/8998040.
See also https://github.com/huggingface/huggingface_hub/pull/920#discussion_r938793088.
Args:
iterable (`Iterable`):
The iterable on which we want to iterate.
chunk_size (`int`):
Size of the chunks. Must be a strictly positive integer (e.g. >0).
Example:
```python
>>> from huggingface_hub.utils import chunk_iterable
>>> for items in chunk_iterable(range(17), chunk_size=8):
... print(items)
# [0, 1, 2, 3, 4, 5, 6, 7]
# [8, 9, 10, 11, 12, 13, 14, 15]
# [16] # smaller last chunk
```
Raises:
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If `chunk_size` <= 0.
> [!WARNING]
> The last chunk can be smaller than `chunk_size`.
"""
if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("`chunk_size` must be a strictly positive integer (>0).")
iterator = iter(iterable)
while True:
try:
next_item = next(iterator)
except StopIteration:
return
yield itertools.chain((next_item,), itertools.islice(iterator, chunk_size - 1))

View File

@@ -0,0 +1,67 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle datetimes in Huggingface Hub."""
from datetime import datetime, timezone
def parse_datetime(date_string: str) -> datetime:
"""
Parses a date_string returned from the server to a datetime object.
This parser is a weak-parser is the sense that it handles only a single format of
date_string. It is expected that the server format will never change. The
implementation depends only on the standard lib to avoid an external dependency
(python-dateutil). See full discussion about this decision on PR:
https://github.com/huggingface/huggingface_hub/pull/999.
Example:
```py
> parse_datetime('2022-08-19T07:19:38.123Z')
datetime.datetime(2022, 8, 19, 7, 19, 38, 123000, tzinfo=timezone.utc)
```
Args:
date_string (`str`):
A string representing a datetime returned by the Hub server.
String is expected to follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern.
Returns:
A python datetime object.
Raises:
:class:`ValueError`:
If `date_string` cannot be parsed.
"""
try:
# Normalize the string to always have 6 digits of fractional seconds
if date_string.endswith("Z"):
# Case 1: No decimal point (e.g., "2024-11-16T00:27:02Z")
if "." not in date_string:
# No fractional seconds - insert .000000
date_string = date_string[:-1] + ".000000Z"
# Case 2: Has decimal point (e.g., "2022-08-19T07:19:38.123456789Z")
else:
# Get the fractional and base parts
base, fraction = date_string[:-1].split(".")
# fraction[:6] takes first 6 digits and :0<6 pads with zeros if less than 6 digits
date_string = f"{base}.{fraction[:6]:0<6}Z"
return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc)
except ValueError as e:
raise ValueError(
f"Cannot parse '{date_string}' as a datetime. Date string is expected to"
" follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern."
) from e

View File

@@ -0,0 +1,136 @@
import warnings
from functools import wraps
from inspect import Parameter, signature
from typing import Iterable, Optional
def _deprecate_positional_args(*, version: str):
"""Decorator for methods that issues warnings for positional arguments.
Using the keyword-only argument syntax in pep 3102, arguments after the
* will issue a warning when passed as a positional argument.
Args:
version (`str`):
The version when positional arguments will result in error.
"""
def _inner_deprecate_positional_args(f):
sig = signature(f)
kwonly_args = []
all_args = []
for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)
@wraps(f)
def inner_f(*args, **kwargs):
extra_args = len(args) - len(all_args)
if extra_args <= 0:
return f(*args, **kwargs)
# extra_args > 0
args_msg = [
f"{name}='{arg}'" if isinstance(arg, str) else f"{name}={arg}"
for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:])
]
args_msg = ", ".join(args_msg)
warnings.warn(
f"Deprecated positional argument(s) used in '{f.__name__}': pass"
f" {args_msg} as keyword args. From version {version} passing these"
" as positional arguments will result in an error,",
FutureWarning,
)
kwargs.update(zip(sig.parameters, args))
return f(**kwargs)
return inner_f
return _inner_deprecate_positional_args
def _deprecate_arguments(
*,
version: str,
deprecated_args: Iterable[str],
custom_message: Optional[str] = None,
):
"""Decorator to issue warnings when using deprecated arguments.
TODO: could be useful to be able to set a custom error message.
Args:
version (`str`):
The version when deprecated arguments will result in error.
deprecated_args (`list[str]`):
List of the arguments to be deprecated.
custom_message (`str`, *optional*):
Warning message that is raised. If not passed, a default warning message
will be created.
"""
def _inner_deprecate_positional_args(f):
sig = signature(f)
@wraps(f)
def inner_f(*args, **kwargs):
# Check for used deprecated arguments
used_deprecated_args = []
for _, parameter in zip(args, sig.parameters.values()):
if parameter.name in deprecated_args:
used_deprecated_args.append(parameter.name)
for kwarg_name, kwarg_value in kwargs.items():
if (
# If argument is deprecated but still used
kwarg_name in deprecated_args
# And then the value is not the default value
and kwarg_value != sig.parameters[kwarg_name].default
):
used_deprecated_args.append(kwarg_name)
# Warn and proceed
if len(used_deprecated_args) > 0:
message = (
f"Deprecated argument(s) used in '{f.__name__}':"
f" {', '.join(used_deprecated_args)}. Will not be supported from"
f" version '{version}'."
)
if custom_message is not None:
message += "\n\n" + custom_message
warnings.warn(message, FutureWarning)
return f(*args, **kwargs)
return inner_f
return _inner_deprecate_positional_args
def _deprecate_method(*, version: str, message: Optional[str] = None):
"""Decorator to issue warnings when using a deprecated method.
Args:
version (`str`):
The version when deprecated arguments will result in error.
message (`str`, *optional*):
Warning message that is raised. If not passed, a default warning message
will be created.
"""
def _inner_deprecate_method(f):
name = f.__name__
if name == "__init__":
name = f.__qualname__.split(".")[0] # class name instead of method name
@wraps(f)
def inner_f(*args, **kwargs):
warning_message = (
f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'."
)
if message is not None:
warning_message += " " + message
warnings.warn(warning_message, FutureWarning)
return f(*args, **kwargs)
return inner_f
return _inner_deprecate_method

View File

@@ -0,0 +1,55 @@
# AI-generated module (ChatGPT)
import re
from typing import Optional
def load_dotenv(dotenv_str: str, environ: Optional[dict[str, str]] = None) -> dict[str, str]:
"""
Parse a DOTENV-format string and return a dictionary of key-value pairs.
Handles quoted values, comments, export keyword, and blank lines.
"""
env: dict[str, str] = {}
line_pattern = re.compile(
r"""
^\s*
(?:export[^\S\n]+)? # optional export
([A-Za-z_][A-Za-z0-9_]*) # key
[^\S\n]*(=)?[^\S\n]*
( # value group
(?:
'(?:\\'|[^'])*' # single-quoted value
| \"(?:\\\"|[^\"])*\" # double-quoted value
| [^#\n\r]+? # unquoted value
)
)?
[^\S\n]*(?:\#.*)?$ # optional inline comment
""",
re.VERBOSE,
)
for line in dotenv_str.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue # Skip comments and empty lines
match = line_pattern.match(line)
if match:
key = match.group(1)
val = None
if match.group(2): # if there is '='
raw_val = match.group(3) or ""
val = raw_val.strip()
# Remove surrounding quotes if quoted
if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")):
val = val[1:-1]
val = val.replace(r"\n", "\n").replace(r"\t", "\t").replace(r"\"", '"').replace(r"\\", "\\")
if raw_val.startswith('"'):
val = val.replace(r"\$", "$") # only in double quotes
elif environ is not None:
# Get it from the current environment
val = environ.get(key)
if val is not None:
env[key] = val
return env

View File

@@ -0,0 +1,68 @@
# coding=utf-8
# Copyright 2023-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to flag a feature as "experimental" in Huggingface Hub."""
import warnings
from functools import wraps
from typing import Callable
from .. import constants
def experimental(fn: Callable) -> Callable:
"""Decorator to flag a feature as experimental.
An experimental feature triggers a warning when used as it might be subject to breaking changes without prior notice
in the future.
Warnings can be disabled by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable.
Args:
fn (`Callable`):
The function to flag as experimental.
Returns:
`Callable`: The decorated function.
Example:
```python
>>> from huggingface_hub.utils import experimental
>>> @experimental
... def my_function():
... print("Hello world!")
>>> my_function()
UserWarning: 'my_function' is experimental and might be subject to breaking changes in the future without prior
notice. You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable.
Hello world!
```
"""
# For classes, put the "experimental" around the "__new__" method => __new__ will be removed in warning message
name = fn.__qualname__[: -len(".__new__")] if fn.__qualname__.endswith(".__new__") else fn.__qualname__
@wraps(fn)
def _inner_fn(*args, **kwargs):
if not constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING:
warnings.warn(
f"'{name}' is experimental and might be subject to breaking changes in the future without prior notice."
" You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment"
" variable.",
UserWarning,
)
return fn(*args, **kwargs)
return _inner_fn

View File

@@ -0,0 +1,123 @@
import contextlib
import os
import shutil
import stat
import tempfile
import time
from functools import partial
from pathlib import Path
from typing import Callable, Generator, Optional, Union
import yaml
from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout
from .. import constants
from . import logging
logger = logging.get_logger(__name__)
# Wrap `yaml.dump` to set `allow_unicode=True` by default.
#
# Example:
# ```py
# >>> yaml.dump({"emoji": "👀", "some unicode": "日本か"})
# 'emoji: "\\U0001F440"\nsome unicode: "\\u65E5\\u672C\\u304B"\n'
#
# >>> yaml_dump({"emoji": "👀", "some unicode": "日本か"})
# 'emoji: "👀"\nsome unicode: "日本か"\n'
# ```
yaml_dump: Callable[..., str] = partial(yaml.dump, stream=None, allow_unicode=True) # type: ignore
@contextlib.contextmanager
def SoftTemporaryDirectory(
suffix: Optional[str] = None,
prefix: Optional[str] = None,
dir: Optional[Union[Path, str]] = None,
**kwargs,
) -> Generator[Path, None, None]:
"""
Context manager to create a temporary directory and safely delete it.
If tmp directory cannot be deleted normally, we set the WRITE permission and retry.
If cleanup still fails, we give up but don't raise an exception. This is equivalent
to `tempfile.TemporaryDirectory(..., ignore_cleanup_errors=True)` introduced in
Python 3.10.
See https://www.scivision.dev/python-tempfile-permission-error-windows/.
"""
tmpdir = tempfile.TemporaryDirectory(prefix=prefix, suffix=suffix, dir=dir, **kwargs)
yield Path(tmpdir.name).resolve()
try:
# First once with normal cleanup
shutil.rmtree(tmpdir.name)
except Exception:
# If failed, try to set write permission and retry
try:
shutil.rmtree(tmpdir.name, onerror=_set_write_permission_and_retry)
except Exception:
pass
# And finally, cleanup the tmpdir.
# If it fails again, give up but do not throw error
try:
tmpdir.cleanup()
except Exception:
pass
def _set_write_permission_and_retry(func, path, excinfo):
os.chmod(path, stat.S_IWRITE)
func(path)
@contextlib.contextmanager
def WeakFileLock(
lock_file: Union[str, Path], *, timeout: Optional[float] = None
) -> Generator[BaseFileLock, None, None]:
"""A filelock with some custom logic.
This filelock is weaker than the default filelock in that:
1. It won't raise an exception if release fails.
2. It will default to a SoftFileLock if the filesystem does not support flock.
An INFO log message is emitted every 10 seconds if the lock is not acquired immediately.
If a timeout is provided, a `filelock.Timeout` exception is raised if the lock is not acquired within the timeout.
"""
log_interval = constants.FILELOCK_LOG_EVERY_SECONDS
lock = FileLock(lock_file, timeout=log_interval)
start_time = time.time()
while True:
elapsed_time = time.time() - start_time
if timeout is not None and elapsed_time >= timeout:
raise Timeout(str(lock_file))
try:
lock.acquire(timeout=min(log_interval, timeout - elapsed_time) if timeout else log_interval)
except Timeout:
logger.info(
f"Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)"
)
except NotImplementedError as e:
if "use SoftFileLock instead" in str(e):
logger.warning(
"FileSystem does not appear to support flock. Falling back to SoftFileLock for %s", lock_file
)
lock = SoftFileLock(lock_file, timeout=log_interval)
continue
else:
break
try:
yield lock
finally:
try:
lock.release()
except OSError:
try:
Path(lock_file).unlink()
except OSError:
pass

View File

@@ -0,0 +1,121 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to manage Git credentials."""
import re
import subprocess
from typing import Optional
from ..constants import ENDPOINT
from ._subprocess import run_interactive_subprocess, run_subprocess
GIT_CREDENTIAL_REGEX = re.compile(
r"""
^\s* # start of line
credential\.helper # credential.helper value
\s*=\s* # separator
([\w\-\/]+) # the helper name or absolute path (group 1)
(\s|$) # whitespace or end of line
""",
flags=re.MULTILINE | re.IGNORECASE | re.VERBOSE,
)
def list_credential_helpers(folder: Optional[str] = None) -> list[str]:
"""Return the list of git credential helpers configured.
See https://git-scm.com/docs/gitcredentials.
Credentials are saved in all configured helpers (store, cache, macOS keychain,...).
Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential.
Args:
folder (`str`, *optional*):
The folder in which to check the configured helpers.
"""
try:
output = run_subprocess("git config --list", folder=folder).stdout
parsed = _parse_credential_output(output)
return parsed
except subprocess.CalledProcessError as exc:
raise EnvironmentError(exc.stderr)
def set_git_credential(token: str, username: str = "hf_user", folder: Optional[str] = None) -> None:
"""Save a username/token pair in git credential for HF Hub registry.
Credentials are saved in all configured helpers (store, cache, macOS keychain,...).
Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential.
Args:
username (`str`, defaults to `"hf_user"`):
A git username. Defaults to `"hf_user"`, the default user used in the Hub.
token (`str`, defaults to `"hf_user"`):
A git password. In practice, the User Access Token for the Hub.
See https://huggingface.co/settings/tokens.
folder (`str`, *optional*):
The folder in which to check the configured helpers.
"""
with run_interactive_subprocess("git credential approve", folder=folder) as (
stdin,
_,
):
stdin.write(f"url={ENDPOINT}\nusername={username.lower()}\npassword={token}\n\n")
stdin.flush()
def unset_git_credential(username: str = "hf_user", folder: Optional[str] = None) -> None:
"""Erase credentials from git credential for HF Hub registry.
Credentials are erased from the configured helpers (store, cache, macOS
keychain,...), if any. If `username` is not provided, any credential configured for
HF Hub endpoint is erased.
Calls "`git credential erase`" internally. See https://git-scm.com/docs/git-credential.
Args:
username (`str`, defaults to `"hf_user"`):
A git username. Defaults to `"hf_user"`, the default user used in the Hub.
folder (`str`, *optional*):
The folder in which to check the configured helpers.
"""
with run_interactive_subprocess("git credential reject", folder=folder) as (
stdin,
_,
):
standard_input = f"url={ENDPOINT}\n"
if username is not None:
standard_input += f"username={username.lower()}\n"
standard_input += "\n"
stdin.write(standard_input)
stdin.flush()
def _parse_credential_output(output: str) -> list[str]:
"""Parse the output of `git credential fill` to extract the password.
Args:
output (`str`):
The output of `git credential fill`.
"""
# NOTE: If user has set a helper for a custom URL, it will not be caught here.
# Example: `credential.https://huggingface.co.helper=store`
# See: https://github.com/huggingface/huggingface_hub/pull/1138#discussion_r1013324508
return sorted( # Sort for nice printing
set( # Might have some duplicates
match[0] for match in GIT_CREDENTIAL_REGEX.findall(output)
)
)

View File

@@ -0,0 +1,206 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle headers to send in calls to Huggingface Hub."""
from typing import Optional, Union
from huggingface_hub.errors import LocalTokenNotFoundError
from .. import constants
from ._auth import get_token
from ._runtime import (
get_hf_hub_version,
get_python_version,
get_torch_version,
is_torch_available,
)
from ._validators import validate_hf_hub_args
@validate_hf_hub_args
def build_hf_headers(
*,
token: Optional[Union[bool, str]] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[dict, str, None] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, str]:
"""
Build headers dictionary to send in a HF Hub call.
By default, authorization token is always provided either from argument (explicit
use) or retrieved from the cache (implicit use). To explicitly avoid sending the
token to the Hub, set `token=False` or set the `HF_HUB_DISABLE_IMPLICIT_TOKEN`
environment variable.
In case of an API call that requires write access, an error is thrown if token is
`None` or token is an organization token (starting with `"api_org***"`).
In addition to the auth header, a user-agent is added to provide information about
the installed packages (versions of python, huggingface_hub, torch).
Args:
token (`str`, `bool`, *optional*):
The token to be sent in authorization header for the Hub call:
- if a string, it is used as the Hugging Face token
- if `True`, the token is read from the machine (cache or env variable)
- if `False`, authorization header is not set
- if `None`, the token is read from the machine only except if
`HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set.
library_name (`str`, *optional*):
The name of the library that is making the HTTP request. Will be added to
the user-agent header.
library_version (`str`, *optional*):
The version of the library that is making the HTTP request. Will be added
to the user-agent header.
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string. It will
be completed with information about the installed packages.
headers (`dict`, *optional*):
Additional headers to include in the request. Those headers take precedence
over the ones generated by this function.
Returns:
A `dict` of headers to pass in your API call.
Example:
```py
>>> build_hf_headers(token="hf_***") # explicit token
{"authorization": "Bearer hf_***", "user-agent": ""}
>>> build_hf_headers(token=True) # explicitly use cached token
{"authorization": "Bearer hf_***",...}
>>> build_hf_headers(token=False) # explicitly don't use cached token
{"user-agent": ...}
>>> build_hf_headers() # implicit use of the cached token
{"authorization": "Bearer hf_***",...}
# HF_HUB_DISABLE_IMPLICIT_TOKEN=True # to set as env variable
>>> build_hf_headers() # token is not sent
{"user-agent": ...}
>>> build_hf_headers(library_name="transformers", library_version="1.2.3")
{"authorization": ..., "user-agent": "transformers/1.2.3; hf_hub/0.10.2; python/3.10.4; tensorflow/1.55"}
```
Raises:
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If organization token is passed and "write" access is required.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If "write" access is required but token is not passed and not saved locally.
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
If `token=True` but token is not saved locally.
"""
# Get auth token to send
token_to_send = get_token_to_send(token)
# Combine headers
hf_headers = {
"user-agent": _http_user_agent(
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
)
}
if token_to_send is not None:
hf_headers["authorization"] = f"Bearer {token_to_send}"
if headers is not None:
hf_headers.update(headers)
return hf_headers
def get_token_to_send(token: Optional[Union[bool, str]]) -> Optional[str]:
"""Select the token to send from either `token` or the cache."""
# Case token is explicitly provided
if isinstance(token, str):
return token
# Case token is explicitly forbidden
if token is False:
return None
# Token is not provided: we get it from local cache
cached_token = get_token()
# Case token is explicitly required
if token is True:
if cached_token is None:
raise LocalTokenNotFoundError(
"Token is required (`token=True`), but no token found. You"
" need to provide a token or be logged in to Hugging Face with"
" `hf auth login` or `huggingface_hub.login`. See"
" https://huggingface.co/settings/tokens."
)
return cached_token
# Case implicit use of the token is forbidden by env variable
if constants.HF_HUB_DISABLE_IMPLICIT_TOKEN:
return None
# Otherwise: we use the cached token as the user has not explicitly forbidden it
return cached_token
def _http_user_agent(
*,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[dict, str, None] = None,
) -> str:
"""Format a user-agent string containing information about the installed packages.
Args:
library_name (`str`, *optional*):
The name of the library that is making the HTTP request.
library_version (`str`, *optional*):
The version of the library that is making the HTTP request.
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string.
Returns:
The formatted user-agent string.
"""
if library_name is not None:
ua = f"{library_name}/{library_version}"
else:
ua = "unknown/None"
ua += f"; hf_hub/{get_hf_hub_version()}"
ua += f"; python/{get_python_version()}"
if not constants.HF_HUB_DISABLE_TELEMETRY:
if is_torch_available():
ua += f"; torch/{get_torch_version()}"
if isinstance(user_agent, dict):
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
# Retrieve user-agent origin headers from environment variable
origin = constants.HF_HUB_USER_AGENT_ORIGIN
if origin is not None:
ua += "; origin/" + origin
return _deduplicate_user_agent(ua)
def _deduplicate_user_agent(user_agent: str) -> str:
"""Deduplicate redundant information in the generated user-agent."""
# Split around ";" > Strip whitespaces > Store as dict keys (ensure unicity) > format back as string
# Order is implicitly preserved by dictionary structure (see https://stackoverflow.com/a/53657523).
return "; ".join({key.strip(): None for key in user_agent.split(";")}.keys())

View File

@@ -0,0 +1,796 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle HTTP requests in huggingface_hub."""
import atexit
import io
import json
import os
import re
import threading
import time
import uuid
from contextlib import contextmanager
from http import HTTPStatus
from shlex import quote
from typing import Any, Callable, Generator, Optional, Union
import httpx
from huggingface_hub.errors import OfflineModeIsEnabled
from .. import constants
from ..errors import (
BadRequestError,
DisabledRepoError,
GatedRepoError,
HfHubHTTPError,
RemoteEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from . import logging
from ._lfs import SliceFileObj
from ._typing import HTTP_METHOD_T
logger = logging.get_logger(__name__)
# Both headers are used by the Hub to debug failed requests.
# `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB.
# If `X_AMZN_TRACE_ID` is set, the Hub will use it as well.
X_AMZN_TRACE_ID = "X-Amzn-Trace-Id"
X_REQUEST_ID = "x-request-id"
REPO_API_REGEX = re.compile(
r"""
# staging or production endpoint
^https://[^/]+
(
# on /api/repo_type/repo_id
/api/(models|datasets|spaces)/(.+)
|
# or /repo_id/resolve/revision/...
/(.+)/resolve/(.+)
)
""",
flags=re.VERBOSE,
)
def hf_request_event_hook(request: httpx.Request) -> None:
"""
Event hook that will be used to make HTTP requests to the Hugging Face Hub.
What it does:
- Block requests if offline mode is enabled
- Add a request ID to the request headers
- Log the request if debug mode is enabled
"""
if constants.HF_HUB_OFFLINE:
raise OfflineModeIsEnabled(
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
)
# Add random request ID => easier for server-side debugging
if X_AMZN_TRACE_ID not in request.headers:
request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
request_id = request.headers.get(X_AMZN_TRACE_ID)
# Debug log
logger.debug(
"Request %s: %s %s (authenticated: %s)",
request_id,
request.method,
request.url,
request.headers.get("authorization") is not None,
)
if constants.HF_DEBUG:
logger.debug("Send: %s", _curlify(request))
return request_id
async def async_hf_request_event_hook(request: httpx.Request) -> None:
"""
Async version of `hf_request_event_hook`.
"""
return hf_request_event_hook(request)
async def async_hf_response_event_hook(response: httpx.Response) -> None:
if response.status_code >= 400:
# If response will raise, read content from stream to have it available when raising the exception
# If content-length is not set or is too large, skip reading the content to avoid OOM
if "Content-length" in response.headers:
try:
length = int(response.headers["Content-length"])
except ValueError:
return
if length < 1_000_000:
await response.aread()
def default_client_factory() -> httpx.Client:
"""
Factory function to create a `httpx.Client` with the default transport.
"""
return httpx.Client(
event_hooks={"request": [hf_request_event_hook]},
follow_redirects=True,
timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0),
)
def default_async_client_factory() -> httpx.AsyncClient:
"""
Factory function to create a `httpx.AsyncClient` with the default transport.
"""
return httpx.AsyncClient(
event_hooks={"request": [async_hf_request_event_hook], "response": [async_hf_response_event_hook]},
follow_redirects=True,
timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0),
)
CLIENT_FACTORY_T = Callable[[], httpx.Client]
ASYNC_CLIENT_FACTORY_T = Callable[[], httpx.AsyncClient]
_CLIENT_LOCK = threading.Lock()
_GLOBAL_CLIENT_FACTORY: CLIENT_FACTORY_T = default_client_factory
_GLOBAL_ASYNC_CLIENT_FACTORY: ASYNC_CLIENT_FACTORY_T = default_async_client_factory
_GLOBAL_CLIENT: Optional[httpx.Client] = None
def set_client_factory(client_factory: CLIENT_FACTORY_T) -> None:
"""
Set the HTTP client factory to be used by `huggingface_hub`.
The client factory is a method that returns a `httpx.Client` object. On the first call to [`get_client`] the client factory
will be used to create a new `httpx.Client` object that will be shared between all calls made by `huggingface_hub`.
This can be useful if you are running your scripts in a specific environment requiring custom configuration (e.g. custom proxy or certifications).
Use [`get_client`] to get a correctly configured `httpx.Client`.
"""
global _GLOBAL_CLIENT_FACTORY
with _CLIENT_LOCK:
close_session()
_GLOBAL_CLIENT_FACTORY = client_factory
def set_async_client_factory(async_client_factory: ASYNC_CLIENT_FACTORY_T) -> None:
"""
Set the HTTP async client factory to be used by `huggingface_hub`.
The async client factory is a method that returns a `httpx.AsyncClient` object.
This can be useful if you are running your scripts in a specific environment requiring custom configuration (e.g. custom proxy or certifications).
Use [`get_async_client`] to get a correctly configured `httpx.AsyncClient`.
<Tip warning={true}>
Contrary to the `httpx.Client` that is shared between all calls made by `huggingface_hub`, the `httpx.AsyncClient` is not shared.
It is recommended to use an async context manager to ensure the client is properly closed when the context is exited.
</Tip>
"""
global _GLOBAL_ASYNC_CLIENT_FACTORY
_GLOBAL_ASYNC_CLIENT_FACTORY = async_client_factory
def get_session() -> httpx.Client:
"""
Get a `httpx.Client` object, using the transport factory from the user.
This client is shared between all calls made by `huggingface_hub`. Therefore you should not close it manually.
Use [`set_client_factory`] to customize the `httpx.Client`.
"""
global _GLOBAL_CLIENT
if _GLOBAL_CLIENT is None:
with _CLIENT_LOCK:
_GLOBAL_CLIENT = _GLOBAL_CLIENT_FACTORY()
return _GLOBAL_CLIENT
def get_async_session() -> httpx.AsyncClient:
"""
Return a `httpx.AsyncClient` object, using the transport factory from the user.
Use [`set_async_client_factory`] to customize the `httpx.AsyncClient`.
<Tip warning={true}>
Contrary to the `httpx.Client` that is shared between all calls made by `huggingface_hub`, the `httpx.AsyncClient` is not shared.
It is recommended to use an async context manager to ensure the client is properly closed when the context is exited.
</Tip>
"""
return _GLOBAL_ASYNC_CLIENT_FACTORY()
def close_session() -> None:
"""
Close the global `httpx.Client` used by `huggingface_hub`.
If a Client is closed, it will be recreated on the next call to [`get_session`].
Can be useful if e.g. an SSL certificate has been updated.
"""
global _GLOBAL_CLIENT
client = _GLOBAL_CLIENT
# First, set global client to None
_GLOBAL_CLIENT = None
# Then, close the clients
if client is not None:
try:
client.close()
except Exception as e:
logger.warning(f"Error closing client: {e}")
atexit.register(close_session)
if hasattr(os, "register_at_fork"):
os.register_at_fork(after_in_child=close_session)
def _http_backoff_base(
method: HTTP_METHOD_T,
url: str,
*,
max_retries: int = 5,
base_wait_time: float = 1,
max_wait_time: float = 8,
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
httpx.TimeoutException,
httpx.NetworkError,
),
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
stream: bool = False,
**kwargs,
) -> Generator[httpx.Response, None, None]:
"""Internal implementation of HTTP backoff logic shared between `http_backoff` and `http_stream_backoff`."""
if isinstance(retry_on_exceptions, type): # Tuple from single exception type
retry_on_exceptions = (retry_on_exceptions,)
if isinstance(retry_on_status_codes, int): # Tuple from single status code
retry_on_status_codes = (retry_on_status_codes,)
nb_tries = 0
sleep_time = base_wait_time
# If `data` is used and is a file object (or any IO), it will be consumed on the
# first HTTP request. We need to save the initial position so that the full content
# of the file is re-sent on http backoff. See warning tip in docstring.
io_obj_initial_pos = None
if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)):
io_obj_initial_pos = kwargs["data"].tell()
client = get_session()
while True:
nb_tries += 1
try:
# If `data` is used and is a file object (or any IO), set back cursor to
# initial position.
if io_obj_initial_pos is not None:
kwargs["data"].seek(io_obj_initial_pos)
# Perform request and handle response
def _should_retry(response: httpx.Response) -> bool:
"""Handle response and return True if should retry, False if should return/yield."""
if response.status_code not in retry_on_status_codes:
return False # Success, don't retry
# Wrong status code returned (HTTP 503 for instance)
logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}")
if nb_tries > max_retries:
hf_raise_for_status(response) # Will raise uncaught exception
# Return/yield response to avoid infinite loop in the corner case where the
# user ask for retry on a status code that doesn't raise_for_status.
return False # Don't retry, return/yield response
return True # Should retry
if stream:
with client.stream(method=method, url=url, **kwargs) as response:
if not _should_retry(response):
yield response
return
else:
response = client.request(method=method, url=url, **kwargs)
if not _should_retry(response):
yield response
return
except retry_on_exceptions as err:
logger.warning(f"'{err}' thrown while requesting {method} {url}")
if isinstance(err, httpx.ConnectError):
close_session() # In case of SSLError it's best to close the shared httpx.Client objects
if nb_tries > max_retries:
raise err
# Sleep for X seconds
logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].")
time.sleep(sleep_time)
# Update sleep time for next retry
sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff
def http_backoff(
method: HTTP_METHOD_T,
url: str,
*,
max_retries: int = 5,
base_wait_time: float = 1,
max_wait_time: float = 8,
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
httpx.TimeoutException,
httpx.NetworkError,
),
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
**kwargs,
) -> httpx.Response:
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...)
and/or on specific status codes (ex: service unavailable). If the call failed more
than `max_retries`, the exception is thrown or `raise_for_status` is called on the
response object.
Re-implement mechanisms from the `backoff` library to avoid adding an external
dependencies to `hugging_face_hub`. See https://github.com/litl/backoff.
Args:
method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`):
HTTP method to perform.
url (`str`):
The URL of the resource to fetch.
max_retries (`int`, *optional*, defaults to `5`):
Maximum number of retries, defaults to 5 (no retries).
base_wait_time (`float`, *optional*, defaults to `1`):
Duration (in seconds) to wait before retrying the first time.
Wait time between retries then grows exponentially, capped by
`max_wait_time`.
max_wait_time (`float`, *optional*, defaults to `8`):
Maximum duration (in seconds) to wait before retrying.
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`.
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
Define on which status codes the request must be retried. By default, only
HTTP 503 Service Unavailable is retried.
**kwargs (`dict`, *optional*):
kwargs to pass to `httpx.request`.
Example:
```
>>> from huggingface_hub.utils import http_backoff
# Same usage as "httpx.request".
>>> response = http_backoff("GET", "https://www.google.com")
>>> response.raise_for_status()
# If you expect a Gateway Timeout from time to time
>>> http_backoff("PUT", upload_url, data=data, retry_on_status_codes=504)
>>> response.raise_for_status()
```
> [!WARNING]
> When using `requests` it is possible to stream data by passing an iterator to the
> `data` argument. On http backoff this is a problem as the iterator is not reset
> after a failed call. This issue is mitigated for file objects or any IO streams
> by saving the initial position of the cursor (with `data.tell()`) and resetting the
> cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff
> will fail. If this is a hard constraint for you, please let us know by opening an
> issue on [Github](https://github.com/huggingface/huggingface_hub).
"""
return next(
_http_backoff_base(
method=method,
url=url,
max_retries=max_retries,
base_wait_time=base_wait_time,
max_wait_time=max_wait_time,
retry_on_exceptions=retry_on_exceptions,
retry_on_status_codes=retry_on_status_codes,
stream=False,
**kwargs,
)
)
@contextmanager
def http_stream_backoff(
method: HTTP_METHOD_T,
url: str,
*,
max_retries: int = 5,
base_wait_time: float = 1,
max_wait_time: float = 8,
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
httpx.TimeoutException,
httpx.NetworkError,
),
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
**kwargs,
) -> Generator[httpx.Response, None, None]:
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...)
and/or on specific status codes (ex: service unavailable). If the call failed more
than `max_retries`, the exception is thrown or `raise_for_status` is called on the
response object.
Re-implement mechanisms from the `backoff` library to avoid adding an external
dependencies to `hugging_face_hub`. See https://github.com/litl/backoff.
Args:
method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`):
HTTP method to perform.
url (`str`):
The URL of the resource to fetch.
max_retries (`int`, *optional*, defaults to `5`):
Maximum number of retries, defaults to 5 (no retries).
base_wait_time (`float`, *optional*, defaults to `1`):
Duration (in seconds) to wait before retrying the first time.
Wait time between retries then grows exponentially, capped by
`max_wait_time`.
max_wait_time (`float`, *optional*, defaults to `8`):
Maximum duration (in seconds) to wait before retrying.
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
By default, retry on `httpx.Timeout` and `httpx.NetworkError`.
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
Define on which status codes the request must be retried. By default, only
HTTP 503 Service Unavailable is retried.
**kwargs (`dict`, *optional*):
kwargs to pass to `httpx.request`.
Example:
```
>>> from huggingface_hub.utils import http_stream_backoff
# Same usage as "httpx.stream".
>>> with http_stream_backoff("GET", "https://www.google.com") as response:
... for chunk in response.iter_bytes():
... print(chunk)
# If you expect a Gateway Timeout from time to time
>>> with http_stream_backoff("PUT", upload_url, data=data, retry_on_status_codes=504) as response:
... response.raise_for_status()
```
<Tip warning={true}>
When using `httpx` it is possible to stream data by passing an iterator to the
`data` argument. On http backoff this is a problem as the iterator is not reset
after a failed call. This issue is mitigated for file objects or any IO streams
by saving the initial position of the cursor (with `data.tell()`) and resetting the
cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff
will fail. If this is a hard constraint for you, please let us know by opening an
issue on [Github](https://github.com/huggingface/huggingface_hub).
</Tip>
"""
yield from _http_backoff_base(
method=method,
url=url,
max_retries=max_retries,
base_wait_time=base_wait_time,
max_wait_time=max_wait_time,
retry_on_exceptions=retry_on_exceptions,
retry_on_status_codes=retry_on_status_codes,
stream=True,
**kwargs,
)
def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str:
"""Replace the default endpoint in a URL by a custom one.
This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint.
"""
endpoint = endpoint.rstrip("/") if endpoint else constants.ENDPOINT
# check if a proxy has been set => if yes, update the returned URL to use the proxy
if endpoint not in (constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT):
url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint)
url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint)
return url
def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] = None) -> None:
"""
Internal version of `response.raise_for_status()` that will refine a potential HTTPError.
Raised exception will be an instance of [`~errors.HfHubHTTPError`].
This helper is meant to be the unique method to raise_for_status when making a call to the Hugging Face Hub.
Args:
response (`Response`):
Response from the server.
endpoint_name (`str`, *optional*):
Name of the endpoint that has been called. If provided, the error message will be more complete.
> [!WARNING]
> Raises when the request has failed:
>
> - [`~utils.RepositoryNotFoundError`]
> If the repository to download from cannot be found. This may be because it
> doesn't exist, because `repo_type` is not set correctly, or because the repo
> is `private` and you do not have access.
> - [`~utils.GatedRepoError`]
> If the repository exists but is gated and the user is not on the authorized
> list.
> - [`~utils.RevisionNotFoundError`]
> If the repository exists but the revision couldn't be found.
> - [`~utils.EntryNotFoundError`]
> If the repository exists but the entry (e.g. the requested file) couldn't be
> find.
> - [`~utils.BadRequestError`]
> If request failed with a HTTP 400 BadRequest error.
> - [`~utils.HfHubHTTPError`]
> If request failed for a reason not listed above.
"""
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
if response.status_code // 100 == 3:
return # Do not raise on redirects to stay consistent with `requests`
error_code = response.headers.get("X-Error-Code")
error_message = response.headers.get("X-Error-Message")
if error_code == "RevisionNotFound":
message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}."
raise _format(RevisionNotFoundError, message, response) from e
elif error_code == "EntryNotFound":
message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}."
raise _format(RemoteEntryNotFoundError, message, response) from e
elif error_code == "GatedRepo":
message = (
f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}."
)
raise _format(GatedRepoError, message, response) from e
elif error_message == "Access to this resource is disabled.":
message = (
f"{response.status_code} Client Error."
+ "\n\n"
+ f"Cannot access repository for url {response.url}."
+ "\n"
+ "Access to this resource is disabled."
)
raise _format(DisabledRepoError, message, response) from e
elif error_code == "RepoNotFound" or (
response.status_code == 401
and error_message != "Invalid credentials in Authorization header"
and response.request is not None
and response.request.url is not None
and REPO_API_REGEX.search(str(response.request.url)) is not None
):
# 401 is misleading as it is returned for:
# - private and gated repos if user is not authenticated
# - missing repos
# => for now, we process them as `RepoNotFound` anyway.
# See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
message = (
f"{response.status_code} Client Error."
+ "\n\n"
+ f"Repository Not Found for url: {response.url}."
+ "\nPlease make sure you specified the correct `repo_id` and"
" `repo_type`.\nIf you are trying to access a private or gated repo,"
" make sure you are authenticated. For more details, see"
" https://huggingface.co/docs/huggingface_hub/authentication"
)
raise _format(RepositoryNotFoundError, message, response) from e
elif response.status_code == 400:
message = (
f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:"
)
raise _format(BadRequestError, message, response) from e
elif response.status_code == 403:
message = (
f"\n\n{response.status_code} Forbidden: {error_message}."
+ f"\nCannot access content at: {response.url}."
+ "\nMake sure your token has the correct permissions."
)
raise _format(HfHubHTTPError, message, response) from e
elif response.status_code == 416:
range_header = response.request.headers.get("Range")
message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}."
raise _format(HfHubHTTPError, message, response) from e
# Convert `HTTPError` into a `HfHubHTTPError` to display request information
# as well (request id and/or server error message)
raise _format(HfHubHTTPError, str(e), response) from e
def _format(error_type: type[HfHubHTTPError], custom_message: str, response: httpx.Response) -> HfHubHTTPError:
server_errors = []
# Retrieve server error from header
from_headers = response.headers.get("X-Error-Message")
if from_headers is not None:
server_errors.append(from_headers)
# Retrieve server error from body
try:
# Case errors are returned in a JSON format
try:
data = response.json()
except httpx.ResponseNotRead:
try:
response.read() # In case of streaming response, we need to read the response first
data = response.json()
except RuntimeError:
# In case of async streaming response, we can't read the stream here.
# In practice if user is using the default async client from `get_async_client`, the stream will have
# already been read in the async event hook `async_hf_response_event_hook`.
#
# Here, we are skipping reading the response to avoid RuntimeError but it happens only if async + stream + used httpx.AsyncClient directly.
data = {}
error = data.get("error")
if error is not None:
if isinstance(error, list):
# Case {'error': ['my error 1', 'my error 2']}
server_errors.extend(error)
else:
# Case {'error': 'my error'}
server_errors.append(error)
errors = data.get("errors")
if errors is not None:
# Case {'errors': [{'message': 'my error 1'}, {'message': 'my error 2'}]}
for error in errors:
if "message" in error:
server_errors.append(error["message"])
except json.JSONDecodeError:
# If content is not JSON and not HTML, append the text
content_type = response.headers.get("Content-Type", "")
if response.text and "html" not in content_type.lower():
server_errors.append(response.text)
# Strip all server messages
server_errors = [str(line).strip() for line in server_errors if str(line).strip()]
# Deduplicate server messages (keep order)
# taken from https://stackoverflow.com/a/17016257
server_errors = list(dict.fromkeys(server_errors))
# Format server error
server_message = "\n".join(server_errors)
# Add server error to custom message
final_error_message = custom_message
if server_message and server_message.lower() not in custom_message.lower():
if "\n\n" in custom_message:
final_error_message += "\n" + server_message
else:
final_error_message += "\n\n" + server_message
# Add Request ID
request_id = str(response.headers.get(X_REQUEST_ID, ""))
if request_id:
request_id_message = f" (Request ID: {request_id})"
else:
# Fallback to X-Amzn-Trace-Id
request_id = str(response.headers.get(X_AMZN_TRACE_ID, ""))
if request_id:
request_id_message = f" (Amzn Trace ID: {request_id})"
if request_id and request_id.lower() not in final_error_message.lower():
if "\n" in final_error_message:
newline_index = final_error_message.index("\n")
final_error_message = (
final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:]
)
else:
final_error_message += request_id_message
# Return
return error_type(final_error_message.strip(), response=response, server_message=server_message or None)
def _curlify(request: httpx.Request) -> str:
"""Convert a `httpx.Request` into a curl command (str).
Used for debug purposes only.
Implementation vendored from https://github.com/ofw/curlify/blob/master/curlify.py.
MIT License Copyright (c) 2016 Egor.
"""
parts: list[tuple[Any, Any]] = [
("curl", None),
("-X", request.method),
]
for k, v in sorted(request.headers.items()):
if k.lower() == "authorization":
v = "<TOKEN>" # Hide authorization header, no matter its value (can be Bearer, Key, etc.)
parts += [("-H", f"{k}: {v}")]
body: Optional[str] = None
if request.content is not None:
body = request.content.decode("utf-8", errors="ignore")
if len(body) > 1000:
body = f"{body[:1000]} ... [truncated]"
elif request.stream is not None:
body = "<streaming body>"
if body is not None:
parts += [("-d", body.replace("\n", ""))]
parts += [(None, request.url)]
flat_parts = []
for k, v in parts:
if k:
flat_parts.append(quote(str(k)))
if v:
flat_parts.append(quote(str(v)))
return " ".join(flat_parts)
# Regex to parse HTTP Range header
RANGE_REGEX = re.compile(r"^\s*bytes\s*=\s*(\d*)\s*-\s*(\d*)\s*$", re.IGNORECASE)
def _adjust_range_header(original_range: Optional[str], resume_size: int) -> Optional[str]:
"""
Adjust HTTP Range header to account for resume position.
"""
if not original_range:
return f"bytes={resume_size}-"
if "," in original_range:
raise ValueError(f"Multiple ranges detected - {original_range!r}, not supported yet.")
match = RANGE_REGEX.match(original_range)
if not match:
raise RuntimeError(f"Invalid range format - {original_range!r}.")
start, end = match.groups()
if not start:
if not end:
raise RuntimeError(f"Invalid range format - {original_range!r}.")
new_suffix = int(end) - resume_size
new_range = f"bytes=-{new_suffix}"
if new_suffix <= 0:
raise RuntimeError(f"Empty new range - {new_range!r}.")
return new_range
start = int(start)
new_start = start + resume_size
if end:
end = int(end)
new_range = f"bytes={new_start}-{end}"
if new_start > end:
raise RuntimeError(f"Empty new range - {new_range!r}.")
return new_range
return f"bytes={new_start}-"

View File

@@ -0,0 +1,110 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Git LFS related utilities"""
import io
import os
from contextlib import AbstractContextManager
from typing import BinaryIO
class SliceFileObj(AbstractContextManager):
"""
Utility context manager to read a *slice* of a seekable file-like object as a seekable, file-like object.
This is NOT thread safe
Inspired by stackoverflow.com/a/29838711/593036
Credits to @julien-c
Args:
fileobj (`BinaryIO`):
A file-like object to slice. MUST implement `tell()` and `seek()` (and `read()` of course).
`fileobj` will be reset to its original position when exiting the context manager.
seek_from (`int`):
The start of the slice (offset from position 0 in bytes).
read_limit (`int`):
The maximum number of bytes to read from the slice.
Attributes:
previous_position (`int`):
The previous position
Examples:
Reading 200 bytes with an offset of 128 bytes from a file (ie bytes 128 to 327):
```python
>>> with open("path/to/file", "rb") as file:
... with SliceFileObj(file, seek_from=128, read_limit=200) as fslice:
... fslice.read(...)
```
Reading a file in chunks of 512 bytes
```python
>>> import os
>>> chunk_size = 512
>>> file_size = os.getsize("path/to/file")
>>> with open("path/to/file", "rb") as file:
... for chunk_idx in range(ceil(file_size / chunk_size)):
... with SliceFileObj(file, seek_from=chunk_idx * chunk_size, read_limit=chunk_size) as fslice:
... chunk = fslice.read(...)
```
"""
def __init__(self, fileobj: BinaryIO, seek_from: int, read_limit: int):
self.fileobj = fileobj
self.seek_from = seek_from
self.read_limit = read_limit
def __enter__(self):
self._previous_position = self.fileobj.tell()
end_of_stream = self.fileobj.seek(0, os.SEEK_END)
self._len = min(self.read_limit, end_of_stream - self.seek_from)
# ^^ The actual number of bytes that can be read from the slice
self.fileobj.seek(self.seek_from, io.SEEK_SET)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.fileobj.seek(self._previous_position, io.SEEK_SET)
def read(self, n: int = -1):
pos = self.tell()
if pos >= self._len:
return b""
remaining_amount = self._len - pos
data = self.fileobj.read(remaining_amount if n < 0 else min(n, remaining_amount))
return data
def tell(self) -> int:
return self.fileobj.tell() - self.seek_from
def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
start = self.seek_from
end = start + self._len
if whence in (os.SEEK_SET, os.SEEK_END):
offset = start + offset if whence == os.SEEK_SET else end + offset
offset = max(start, min(offset, end))
whence = os.SEEK_SET
elif whence == os.SEEK_CUR:
cur_pos = self.fileobj.tell()
offset = max(start - cur_pos, min(offset, end - cur_pos))
else:
raise ValueError(f"whence value {whence} is not supported")
return self.fileobj.seek(offset, whence) - self.seek_from
def __iter__(self):
yield self.read(n=4 * 1024 * 1024)

View File

@@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle pagination on Huggingface Hub."""
from typing import Iterable, Optional
import httpx
from . import get_session, hf_raise_for_status, http_backoff, logging
logger = logging.get_logger(__name__)
def paginate(path: str, params: dict, headers: dict) -> Iterable:
"""Fetch a list of models/datasets/spaces and paginate through results.
This is using the same "Link" header format as GitHub.
See:
- https://requests.readthedocs.io/en/latest/api/#requests.Response.links
- https://docs.github.com/en/rest/guides/traversing-with-pagination#link-header
"""
session = get_session()
r = session.get(path, params=params, headers=headers)
hf_raise_for_status(r)
yield from r.json()
# Follow pages
# Next link already contains query params
next_page = _get_next_page(r)
while next_page is not None:
logger.debug(f"Pagination detected. Requesting next page: {next_page}")
r = http_backoff("GET", next_page, max_retries=20, retry_on_status_codes=429, headers=headers)
hf_raise_for_status(r)
yield from r.json()
next_page = _get_next_page(r)
def _get_next_page(response: httpx.Response) -> Optional[str]:
return response.links.get("next", {}).get("url")

View File

@@ -0,0 +1,98 @@
# coding=utf-8
# Copyright 2025-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parsing helpers shared across modules."""
import re
import time
from typing import Dict
RE_NUMBER_WITH_UNIT = re.compile(r"(\d+)([a-z]+)", re.IGNORECASE)
BYTE_UNITS: Dict[str, int] = {
"k": 1_000,
"m": 1_000_000,
"g": 1_000_000_000,
"t": 1_000_000_000_000,
"p": 1_000_000_000_000_000,
}
TIME_UNITS: Dict[str, int] = {
"s": 1,
"m": 60,
"h": 60 * 60,
"d": 24 * 60 * 60,
"w": 7 * 24 * 60 * 60,
"mo": 30 * 24 * 60 * 60,
"y": 365 * 24 * 60 * 60,
}
def parse_size(value: str) -> int:
"""Parse a size expressed as a string with digits and unit (like `"10MB"`) to an integer (in bytes)."""
return _parse_with_unit(value, BYTE_UNITS)
def parse_duration(value: str) -> int:
"""Parse a duration expressed as a string with digits and unit (like `"10s"`) to an integer (in seconds)."""
return _parse_with_unit(value, TIME_UNITS)
def _parse_with_unit(value: str, units: Dict[str, int]) -> int:
"""Parse a numeric value with optional unit."""
stripped = value.strip()
if not stripped:
raise ValueError("Value cannot be empty.")
try:
return int(value)
except ValueError:
pass
match = RE_NUMBER_WITH_UNIT.fullmatch(stripped)
if not match:
raise ValueError(f"Invalid value '{value}'. Must match pattern '\\d+[a-z]+' or be a plain number.")
number = int(match.group(1))
unit = match.group(2).lower()
if unit not in units:
raise ValueError(f"Unknown unit '{unit}'. Must be one of {list(units.keys())}.")
return number * units[unit]
def format_timesince(ts: float) -> str:
"""Format timestamp in seconds into a human-readable string, relative to now.
Vaguely inspired by Django's `timesince` formatter.
"""
_TIMESINCE_CHUNKS = (
# Label, divider, max value
("second", 1, 60),
("minute", 60, 60),
("hour", 60 * 60, 24),
("day", 60 * 60 * 24, 6),
("week", 60 * 60 * 24 * 7, 6),
("month", 60 * 60 * 24 * 30, 11),
("year", 60 * 60 * 24 * 365, None),
)
delta = time.time() - ts
if delta < 20:
return "a few seconds ago"
for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007
value = round(delta / divider)
if max_value is not None and value <= max_value:
break
return f"{value} {label}{'s' if value > 1 else ''} ago"

View File

@@ -0,0 +1,141 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle paths in Huggingface Hub."""
from fnmatch import fnmatch
from pathlib import Path
from typing import Callable, Generator, Iterable, Optional, TypeVar, Union
T = TypeVar("T")
# Always ignore `.git` and `.cache/huggingface` folders in commits
DEFAULT_IGNORE_PATTERNS = [
".git",
".git/*",
"*/.git",
"**/.git/**",
".cache/huggingface",
".cache/huggingface/*",
"*/.cache/huggingface",
"**/.cache/huggingface/**",
]
# Forbidden to commit these folders
FORBIDDEN_FOLDERS = [".git", ".cache"]
def filter_repo_objects(
items: Iterable[T],
*,
allow_patterns: Optional[Union[list[str], str]] = None,
ignore_patterns: Optional[Union[list[str], str]] = None,
key: Optional[Callable[[T], str]] = None,
) -> Generator[T, None, None]:
"""Filter repo objects based on an allowlist and a denylist.
Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects.
In the later case, `key` must be provided and specifies a function of one argument
that is used to extract a path from each element in iterable.
Patterns are Unix shell-style wildcards which are NOT regular expressions. See
https://docs.python.org/3/library/fnmatch.html for more details.
Args:
items (`Iterable`):
List of items to filter.
allow_patterns (`str` or `list[str]`, *optional*):
Patterns constituting the allowlist. If provided, item paths must match at
least one pattern from the allowlist.
ignore_patterns (`str` or `list[str]`, *optional*):
Patterns constituting the denylist. If provided, item paths must not match
any patterns from the denylist.
key (`Callable[[T], str]`, *optional*):
Single-argument function to extract a path from each item. If not provided,
the `items` must already be `str` or `Path`.
Returns:
Filtered list of objects, as a generator.
Raises:
:class:`ValueError`:
If `key` is not provided and items are not `str` or `Path`.
Example usage with paths:
```python
>>> # Filter only PDFs that are not hidden.
>>> list(filter_repo_objects(
... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"],
... allow_patterns=["*.pdf"],
... ignore_patterns=[".*"],
... ))
["aaa.pdf"]
```
Example usage with objects:
```python
>>> list(filter_repo_objects(
... [
... CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf")
... CommitOperationAdd(path_or_fileobj="/tmp/bbb.jpg", path_in_repo="bbb.jpg")
... CommitOperationAdd(path_or_fileobj="/tmp/.ccc.pdf", path_in_repo=".ccc.pdf")
... CommitOperationAdd(path_or_fileobj="/tmp/.ddd.png", path_in_repo=".ddd.png")
... ],
... allow_patterns=["*.pdf"],
... ignore_patterns=[".*"],
... key=lambda x: x.repo_in_path
... ))
[CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf")]
```
"""
if isinstance(allow_patterns, str):
allow_patterns = [allow_patterns]
if isinstance(ignore_patterns, str):
ignore_patterns = [ignore_patterns]
if allow_patterns is not None:
allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
if ignore_patterns is not None:
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
if key is None:
def _identity(item: T) -> str:
if isinstance(item, str):
return item
if isinstance(item, Path):
return str(item)
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
key = _identity # Items must be `str` or `Path`, otherwise raise ValueError
for item in items:
path = key(item)
# Skip if there's an allowlist and path doesn't match any
if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
continue
# Skip if there's a denylist and path matches any
if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
continue
yield item
def _add_wildcard_to_directories(pattern: str) -> str:
if pattern[-1] == "/":
return pattern + "*"
return pattern

View File

@@ -0,0 +1,431 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Check presence of installed packages at runtime."""
import importlib.metadata
import os
import platform
import sys
import warnings
from pathlib import Path
from typing import Any, Literal
from .. import __version__, constants
_PY_VERSION: str = sys.version.split()[0].rstrip("+")
_package_versions = {}
_CANDIDATES = {
"aiohttp": {"aiohttp"},
"fastai": {"fastai"},
"fastapi": {"fastapi"},
"fastcore": {"fastcore"},
"gradio": {"gradio"},
"graphviz": {"graphviz"},
"hf_xet": {"hf_xet"},
"jinja": {"Jinja2"},
"httpx": {"httpx"},
"keras": {"keras"},
"numpy": {"numpy"},
"pillow": {"Pillow"},
"pydantic": {"pydantic"},
"pydot": {"pydot"},
"safetensors": {"safetensors"},
"tensorboard": {"tensorboardX"},
"tensorflow": (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
),
"torch": {"torch"},
}
# Check once at runtime
for candidate_name, package_names in _CANDIDATES.items():
_package_versions[candidate_name] = "N/A"
for name in package_names:
try:
_package_versions[candidate_name] = importlib.metadata.version(name)
break
except importlib.metadata.PackageNotFoundError:
pass
def _get_version(package_name: str) -> str:
return _package_versions.get(package_name, "N/A")
def is_package_available(package_name: str) -> bool:
return _get_version(package_name) != "N/A"
# Python
def get_python_version() -> str:
return _PY_VERSION
# Huggingface Hub
def get_hf_hub_version() -> str:
return __version__
# aiohttp
def is_aiohttp_available() -> bool:
return is_package_available("aiohttp")
def get_aiohttp_version() -> str:
return _get_version("aiohttp")
# FastAI
def is_fastai_available() -> bool:
return is_package_available("fastai")
def get_fastai_version() -> str:
return _get_version("fastai")
# FastAPI
def is_fastapi_available() -> bool:
return is_package_available("fastapi")
def get_fastapi_version() -> str:
return _get_version("fastapi")
# Fastcore
def is_fastcore_available() -> bool:
return is_package_available("fastcore")
def get_fastcore_version() -> str:
return _get_version("fastcore")
# FastAI
def is_gradio_available() -> bool:
return is_package_available("gradio")
def get_gradio_version() -> str:
return _get_version("gradio")
# Graphviz
def is_graphviz_available() -> bool:
return is_package_available("graphviz")
def get_graphviz_version() -> str:
return _get_version("graphviz")
# httpx
def is_httpx_available() -> bool:
return is_package_available("httpx")
def get_httpx_version() -> str:
return _get_version("httpx")
# xet
def is_xet_available() -> bool:
# since hf_xet is automatically used if available, allow explicit disabling via environment variable
if constants.HF_HUB_DISABLE_XET:
return False
return is_package_available("hf_xet")
def get_xet_version() -> str:
return _get_version("hf_xet")
# keras
def is_keras_available() -> bool:
return is_package_available("keras")
def get_keras_version() -> str:
return _get_version("keras")
# Numpy
def is_numpy_available() -> bool:
return is_package_available("numpy")
def get_numpy_version() -> str:
return _get_version("numpy")
# Jinja
def is_jinja_available() -> bool:
return is_package_available("jinja")
def get_jinja_version() -> str:
return _get_version("jinja")
# Pillow
def is_pillow_available() -> bool:
return is_package_available("pillow")
def get_pillow_version() -> str:
return _get_version("pillow")
# Pydantic
def is_pydantic_available() -> bool:
if not is_package_available("pydantic"):
return False
# For Pydantic, we add an extra check to test whether it is correctly installed or not. If both pydantic 2.x and
# typing_extensions<=4.5.0 are installed, then pydantic will fail at import time. This should not happen when
# it is installed with `pip install huggingface_hub[inference]` but it can happen when it is installed manually
# by the user in an environment that we don't control.
#
# Usually we won't need to do this kind of check on optional dependencies. However, pydantic is a special case
# as it is automatically imported when doing `from huggingface_hub import ...` even if the user doesn't use it.
#
# See https://github.com/huggingface/huggingface_hub/pull/1829 for more details.
try:
from pydantic import validator # noqa: F401
except ImportError:
# Example: "ImportError: cannot import name 'TypeAliasType' from 'typing_extensions'"
warnings.warn(
"Pydantic is installed but cannot be imported. Please check your installation. `huggingface_hub` will "
"default to not using Pydantic. Error message: '{e}'"
)
return False
return True
def get_pydantic_version() -> str:
return _get_version("pydantic")
# Pydot
def is_pydot_available() -> bool:
return is_package_available("pydot")
def get_pydot_version() -> str:
return _get_version("pydot")
# Tensorboard
def is_tensorboard_available() -> bool:
return is_package_available("tensorboard")
def get_tensorboard_version() -> str:
return _get_version("tensorboard")
# Tensorflow
def is_tf_available() -> bool:
return is_package_available("tensorflow")
def get_tf_version() -> str:
return _get_version("tensorflow")
# Torch
def is_torch_available() -> bool:
return is_package_available("torch")
def get_torch_version() -> str:
return _get_version("torch")
# Safetensors
def is_safetensors_available() -> bool:
return is_package_available("safetensors")
# Shell-related helpers
try:
# Set to `True` if script is running in a Google Colab notebook.
# If running in Google Colab, git credential store is set globally which makes the
# warning disappear. See https://github.com/huggingface/huggingface_hub/issues/1043
#
# Taken from https://stackoverflow.com/a/63519730.
_is_google_colab = "google.colab" in str(get_ipython()) # type: ignore # noqa: F821
except NameError:
_is_google_colab = False
def is_notebook() -> bool:
"""Return `True` if code is executed in a notebook (Jupyter, Colab, QTconsole).
Taken from https://stackoverflow.com/a/39662359.
Adapted to make it work with Google colab as well.
"""
try:
shell_class = get_ipython().__class__ # type: ignore # noqa: F821
for parent_class in shell_class.__mro__: # e.g. "is subclass of"
if parent_class.__name__ == "ZMQInteractiveShell":
return True # Jupyter notebook, Google colab or qtconsole
return False
except NameError:
return False # Probably standard Python interpreter
def is_google_colab() -> bool:
"""Return `True` if code is executed in a Google colab.
Taken from https://stackoverflow.com/a/63519730.
"""
return _is_google_colab
def is_colab_enterprise() -> bool:
"""Return `True` if code is executed in a Google Colab Enterprise environment."""
return os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE"
# Check how huggingface_hub has been installed
def installation_method() -> Literal["brew", "hf_installer", "unknown"]:
"""Return the installation method of the current environment.
- "hf_installer" if installed via the official installer script
- "brew" if installed via Homebrew
- "unknown" otherwise
"""
if _is_brew_installation():
return "brew"
elif _is_hf_installer_installation():
return "hf_installer"
else:
return "unknown"
def _is_brew_installation() -> bool:
"""Check if running from a Homebrew installation.
Note: AI-generated by Claude.
"""
exe_path = Path(sys.executable).resolve()
exe_str = str(exe_path)
# Check common Homebrew paths
# /opt/homebrew (Apple Silicon), /usr/local (Intel)
return "/Cellar/" in exe_str or "/opt/homebrew/" in exe_str or exe_str.startswith("/usr/local/Cellar/")
def _is_hf_installer_installation() -> bool:
"""Return `True` if the current environment was set up via the official hf installer script.
i.e. using one of
curl -LsSf https://hf.co/cli/install.sh | bash
powershell -ExecutionPolicy ByPass -c "irm https://hf.co/cli/install.ps1 | iex"
"""
venv = sys.prefix # points to venv root if active
marker = Path(venv) / ".hf_installer_marker"
return marker.exists()
def dump_environment_info() -> dict[str, Any]:
"""Dump information about the machine to help debugging issues.
Similar helper exist in:
- `datasets` (https://github.com/huggingface/datasets/blob/main/src/datasets/commands/env.py)
- `diffusers` (https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/env.py)
- `transformers` (https://github.com/huggingface/transformers/blob/main/src/transformers/commands/env.py)
"""
from huggingface_hub import get_token, whoami
from huggingface_hub.utils import list_credential_helpers
token = get_token()
# Generic machine info
info: dict[str, Any] = {
"huggingface_hub version": get_hf_hub_version(),
"Platform": platform.platform(),
"Python version": get_python_version(),
}
# Interpreter info
try:
shell_class = get_ipython().__class__ # type: ignore # noqa: F821
info["Running in iPython ?"] = "Yes"
info["iPython shell"] = shell_class.__name__
except NameError:
info["Running in iPython ?"] = "No"
info["Running in notebook ?"] = "Yes" if is_notebook() else "No"
info["Running in Google Colab ?"] = "Yes" if is_google_colab() else "No"
info["Running in Google Colab Enterprise ?"] = "Yes" if is_colab_enterprise() else "No"
# Login info
info["Token path ?"] = constants.HF_TOKEN_PATH
info["Has saved token ?"] = token is not None
if token is not None:
try:
info["Who am I ?"] = whoami()["name"]
except Exception:
pass
try:
info["Configured git credential helpers"] = ", ".join(list_credential_helpers())
except Exception:
pass
# How huggingface_hub has been installed?
info["Installation method"] = installation_method()
# Installed dependencies
info["httpx"] = get_httpx_version()
info["hf_xet"] = get_xet_version()
info["gradio"] = get_gradio_version()
info["tensorboard"] = get_tensorboard_version()
# Environment variables
info["ENDPOINT"] = constants.ENDPOINT
info["HF_HUB_CACHE"] = constants.HF_HUB_CACHE
info["HF_ASSETS_CACHE"] = constants.HF_ASSETS_CACHE
info["HF_TOKEN_PATH"] = constants.HF_TOKEN_PATH
info["HF_STORED_TOKENS_PATH"] = constants.HF_STORED_TOKENS_PATH
info["HF_HUB_OFFLINE"] = constants.HF_HUB_OFFLINE
info["HF_HUB_DISABLE_TELEMETRY"] = constants.HF_HUB_DISABLE_TELEMETRY
info["HF_HUB_DISABLE_PROGRESS_BARS"] = constants.HF_HUB_DISABLE_PROGRESS_BARS
info["HF_HUB_DISABLE_SYMLINKS_WARNING"] = constants.HF_HUB_DISABLE_SYMLINKS_WARNING
info["HF_HUB_DISABLE_EXPERIMENTAL_WARNING"] = constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING
info["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = constants.HF_HUB_DISABLE_IMPLICIT_TOKEN
info["HF_HUB_DISABLE_XET"] = constants.HF_HUB_DISABLE_XET
info["HF_HUB_ETAG_TIMEOUT"] = constants.HF_HUB_ETAG_TIMEOUT
info["HF_HUB_DOWNLOAD_TIMEOUT"] = constants.HF_HUB_DOWNLOAD_TIMEOUT
info["HF_XET_HIGH_PERFORMANCE"] = constants.HF_XET_HIGH_PERFORMANCE
print("\nCopy-and-paste the text below in your GitHub issue.\n")
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n")
return info

View File

@@ -0,0 +1,111 @@
import functools
import operator
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Literal, Optional
FILENAME_T = str
TENSOR_NAME_T = str
DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"]
@dataclass
class TensorInfo:
"""Information about a tensor.
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
dtype (`str`):
The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL").
shape (`list[int]`):
The shape of the tensor.
data_offsets (`tuple[int, int]`):
The offsets of the data in the file as a tuple `[BEGIN, END]`.
parameter_count (`int`):
The number of parameters in the tensor.
"""
dtype: DTYPE_T
shape: list[int]
data_offsets: tuple[int, int]
parameter_count: int = field(init=False)
def __post_init__(self) -> None:
# Taken from https://stackoverflow.com/a/13840436
try:
self.parameter_count = functools.reduce(operator.mul, self.shape)
except TypeError:
self.parameter_count = 1 # scalar value has no shape
@dataclass
class SafetensorsFileMetadata:
"""Metadata for a Safetensors file hosted on the Hub.
This class is returned by [`parse_safetensors_file_metadata`].
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
metadata (`dict`):
The metadata contained in the file.
tensors (`dict[str, TensorInfo]`):
A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a
[`TensorInfo`] object.
parameter_count (`dict[str, int]`):
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
of that data type.
"""
metadata: dict[str, str]
tensors: dict[TENSOR_NAME_T, TensorInfo]
parameter_count: dict[DTYPE_T, int] = field(init=False)
def __post_init__(self) -> None:
parameter_count: dict[DTYPE_T, int] = defaultdict(int)
for tensor in self.tensors.values():
parameter_count[tensor.dtype] += tensor.parameter_count
self.parameter_count = dict(parameter_count)
@dataclass
class SafetensorsRepoMetadata:
"""Metadata for a Safetensors repo.
A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared
model) or a 'model.safetensors.index.json' index file (sharded model) at its root.
This class is returned by [`get_safetensors_metadata`].
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
metadata (`dict`, *optional*):
The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded
models.
sharded (`bool`):
Whether the repo contains a sharded model or not.
weight_map (`dict[str, str]`):
A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors.
files_metadata (`dict[str, SafetensorsFileMetadata]`):
A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as
a [`SafetensorsFileMetadata`] object.
parameter_count (`dict[str, int]`):
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
of that data type.
"""
metadata: Optional[dict]
sharded: bool
weight_map: dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename
files_metadata: dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata
parameter_count: dict[DTYPE_T, int] = field(init=False)
def __post_init__(self) -> None:
parameter_count: dict[DTYPE_T, int] = defaultdict(int)
for file_metadata in self.files_metadata.values():
for dtype, nb_parameters_ in file_metadata.parameter_count.items():
parameter_count[dtype] += nb_parameters_
self.parameter_count = dict(parameter_count)

View File

@@ -0,0 +1,144 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""Contains utilities to easily handle subprocesses in `huggingface_hub`."""
import os
import subprocess
import sys
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from typing import IO, Generator, Optional, Union
from .logging import get_logger
logger = get_logger(__name__)
@contextmanager
def capture_output() -> Generator[StringIO, None, None]:
"""Capture output that is printed to terminal.
Taken from https://stackoverflow.com/a/34738440
Example:
```py
>>> with capture_output() as output:
... print("hello world")
>>> assert output.getvalue() == "hello world\n"
```
"""
output = StringIO()
previous_output = sys.stdout
sys.stdout = output
try:
yield output
finally:
sys.stdout = previous_output
def run_subprocess(
command: Union[str, list[str]],
folder: Optional[Union[str, Path]] = None,
check=True,
**kwargs,
) -> subprocess.CompletedProcess:
"""
Method to run subprocesses. Calling this will capture the `stderr` and `stdout`,
please call `subprocess.run` manually in case you would like for them not to
be captured.
Args:
command (`str` or `list[str]`):
The command to execute as a string or list of strings.
folder (`str`, *optional*):
The folder in which to run the command. Defaults to current working
directory (from `os.getcwd()`).
check (`bool`, *optional*, defaults to `True`):
Setting `check` to `True` will raise a `subprocess.CalledProcessError`
when the subprocess has a non-zero exit code.
kwargs (`dict[str]`):
Keyword arguments to be passed to the `subprocess.run` underlying command.
Returns:
`subprocess.CompletedProcess`: The completed process.
"""
if isinstance(command, str):
command = command.split()
if isinstance(folder, Path):
folder = str(folder)
return subprocess.run(
command,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=check,
encoding="utf-8",
errors="replace", # if not utf-8, replace char by <20>
cwd=folder or os.getcwd(),
**kwargs,
)
@contextmanager
def run_interactive_subprocess(
command: Union[str, list[str]],
folder: Optional[Union[str, Path]] = None,
**kwargs,
) -> Generator[tuple[IO[str], IO[str]], None, None]:
"""Run a subprocess in an interactive mode in a context manager.
Args:
command (`str` or `list[str]`):
The command to execute as a string or list of strings.
folder (`str`, *optional*):
The folder in which to run the command. Defaults to current working
directory (from `os.getcwd()`).
kwargs (`dict[str]`):
Keyword arguments to be passed to the `subprocess.run` underlying command.
Returns:
`tuple[IO[str], IO[str]]`: A tuple with `stdin` and `stdout` to interact
with the process (input and output are utf-8 encoded).
Example:
```python
with _interactive_subprocess("git credential-store get") as (stdin, stdout):
# Write to stdin
stdin.write("url=hf.co\nusername=obama\n".encode("utf-8"))
stdin.flush()
# Read from stdout
output = stdout.read().decode("utf-8")
```
"""
if isinstance(command, str):
command = command.split()
with subprocess.Popen(
command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
encoding="utf-8",
errors="replace", # if not utf-8, replace char by <20>
cwd=folder or os.getcwd(),
**kwargs,
) as process:
assert process.stdin is not None, "subprocess is opened as subprocess.PIPE"
assert process.stdout is not None, "subprocess is opened as subprocess.PIPE"
yield process.stdin, process.stdout

View File

@@ -0,0 +1,126 @@
from queue import Queue
from threading import Lock, Thread
from typing import Optional, Union
from urllib.parse import quote
from .. import constants, logging
from . import build_hf_headers, get_session, hf_raise_for_status
logger = logging.get_logger(__name__)
# Telemetry is sent by a separate thread to avoid blocking the main thread.
# A daemon thread is started once and consume tasks from the _TELEMETRY_QUEUE.
# If the thread stops for some reason -shouldn't happen-, we restart a new one.
_TELEMETRY_THREAD: Optional[Thread] = None
_TELEMETRY_THREAD_LOCK = Lock() # Lock to avoid starting multiple threads in parallel
_TELEMETRY_QUEUE: Queue = Queue()
def send_telemetry(
topic: str,
*,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[dict, str, None] = None,
) -> None:
"""
Sends telemetry that helps track usage of different HF libraries.
This usage data helps us debug issues and prioritize new features. However, we understand that not everyone wants
to share additional information, and we respect your privacy. You can disable telemetry collection by setting the
`HF_HUB_DISABLE_TELEMETRY=1` as environment variable. Telemetry is also disabled in offline mode (i.e. when setting
`HF_HUB_OFFLINE=1`).
Telemetry collection is run in a separate thread to minimize impact for the user.
Args:
topic (`str`):
Name of the topic that is monitored. The topic is directly used to build the URL. If you want to monitor
subtopics, just use "/" separation. Examples: "gradio", "transformers/examples",...
library_name (`str`, *optional*):
The name of the library that is making the HTTP request. Will be added to the user-agent header.
library_version (`str`, *optional*):
The version of the library that is making the HTTP request. Will be added to the user-agent header.
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string. It will be completed with information about the installed packages.
Example:
```py
>>> from huggingface_hub.utils import send_telemetry
# Send telemetry without library information
>>> send_telemetry("ping")
# Send telemetry to subtopic with library information
>>> send_telemetry("gradio/local_link", library_name="gradio", library_version="3.22.1")
# Send telemetry with additional data
>>> send_telemetry(
... topic="examples",
... library_name="transformers",
... library_version="4.26.0",
... user_agent={"pipeline": "text_classification", "framework": "flax"},
... )
```
"""
if constants.HF_HUB_OFFLINE or constants.HF_HUB_DISABLE_TELEMETRY:
return
_start_telemetry_thread() # starts thread only if doesn't exist yet
_TELEMETRY_QUEUE.put(
{"topic": topic, "library_name": library_name, "library_version": library_version, "user_agent": user_agent}
)
def _start_telemetry_thread():
"""Start a daemon thread to consume tasks from the telemetry queue.
If the thread is interrupted, start a new one.
"""
with _TELEMETRY_THREAD_LOCK: # avoid to start multiple threads if called concurrently
global _TELEMETRY_THREAD
if _TELEMETRY_THREAD is None or not _TELEMETRY_THREAD.is_alive():
_TELEMETRY_THREAD = Thread(target=_telemetry_worker, daemon=True)
_TELEMETRY_THREAD.start()
def _telemetry_worker():
"""Wait for a task and consume it."""
while True:
kwargs = _TELEMETRY_QUEUE.get()
_send_telemetry_in_thread(**kwargs)
_TELEMETRY_QUEUE.task_done()
def _send_telemetry_in_thread(
topic: str,
*,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[dict, str, None] = None,
) -> None:
"""Contains the actual data sending data to the Hub.
This function is called directly in gradio's analytics because
it is not possible to send telemetry from a daemon thread.
See here: https://github.com/gradio-app/gradio/pull/8180
Please do not rename or remove this function.
"""
path = "/".join(quote(part) for part in topic.split("/") if len(part) > 0)
try:
r = get_session().head(
f"{constants.ENDPOINT}/api/telemetry/{path}",
headers=build_hf_headers(
token=False, # no need to send a token for telemetry
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
),
)
hf_raise_for_status(r)
except Exception as e:
# We don't want to error in case of connection errors of any kind.
logger.debug(f"Error while sending telemetry: {e}")

View File

@@ -0,0 +1,69 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to print stuff to the terminal (styling, helpers)."""
import os
from typing import Union
class ANSI:
"""
Helper for en.wikipedia.org/wiki/ANSI_escape_code
"""
_bold = "\u001b[1m"
_gray = "\u001b[90m"
_red = "\u001b[31m"
_reset = "\u001b[0m"
_yellow = "\u001b[33m"
@classmethod
def bold(cls, s: str) -> str:
return cls._format(s, cls._bold)
@classmethod
def gray(cls, s: str) -> str:
return cls._format(s, cls._gray)
@classmethod
def red(cls, s: str) -> str:
return cls._format(s, cls._bold + cls._red)
@classmethod
def yellow(cls, s: str) -> str:
return cls._format(s, cls._yellow)
@classmethod
def _format(cls, s: str, code: str) -> str:
if os.environ.get("NO_COLOR"):
# See https://no-color.org/
return s
return f"{code}{s}{cls._reset}"
def tabulate(rows: list[list[Union[str, int]]], headers: list[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)

View File

@@ -0,0 +1,95 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handle typing imports based on system compatibility."""
import sys
from typing import Any, Callable, Literal, Optional, Type, TypeVar, Union, get_args, get_origin
UNION_TYPES: list[Any] = [Union]
if sys.version_info >= (3, 10):
from types import UnionType
UNION_TYPES += [UnionType]
HTTP_METHOD_T = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
# type hint meaning "function signature not changed by decorator"
CallableT = TypeVar("CallableT", bound=Callable)
_JSON_SERIALIZABLE_TYPES = (int, float, str, bool, type(None))
def is_jsonable(obj: Any, _visited: Optional[set[int]] = None) -> bool:
"""Check if an object is JSON serializable.
This is a weak check, as it does not check for the actual JSON serialization, but only for the types of the object.
It works correctly for basic use cases but do not guarantee an exhaustive check.
Object is considered to be recursively json serializable if:
- it is an instance of int, float, str, bool, or NoneType
- it is a list or tuple and all its items are json serializable
- it is a dict and all its keys are strings and all its values are json serializable
Uses a visited set to avoid infinite recursion on circular references. If object has already been visited, it is
considered not json serializable.
"""
# Initialize visited set to track object ids and detect circular references
if _visited is None:
_visited = set()
# Detect circular reference
obj_id = id(obj)
if obj_id in _visited:
return False
# Add current object to visited before recursive checks
_visited.add(obj_id)
try:
if isinstance(obj, _JSON_SERIALIZABLE_TYPES):
return True
if isinstance(obj, (list, tuple)):
return all(is_jsonable(item, _visited) for item in obj)
if isinstance(obj, dict):
return all(
isinstance(key, _JSON_SERIALIZABLE_TYPES) and is_jsonable(value, _visited)
for key, value in obj.items()
)
if hasattr(obj, "__json__"):
return True
return False
except RecursionError:
return False
finally:
# Remove the object id from visited to avoid sideeffects for other branches
_visited.discard(obj_id)
def is_simple_optional_type(type_: Type) -> bool:
"""Check if a type is optional, i.e. Optional[Type] or Union[Type, None] or Type | None, where Type is a non-composite type."""
if get_origin(type_) in UNION_TYPES:
union_args = get_args(type_)
if len(union_args) == 2 and type(None) in union_args:
return True
return False
def unwrap_simple_optional_type(optional_type: Type) -> Type:
"""Unwraps a simple optional type, i.e. returns Type from Optional[Type]."""
for arg in get_args(optional_type):
if arg is not type(None):
return arg
raise ValueError(f"'{optional_type}' is not an optional type")

View File

@@ -0,0 +1,207 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to validate argument values in `huggingface_hub`."""
import inspect
import re
import warnings
from functools import wraps
from itertools import chain
from typing import Any
from huggingface_hub.errors import HFValidationError
from ._typing import CallableT
REPO_ID_REGEX = re.compile(
r"""
^
(\b[\w\-.]+\b/)? # optional namespace (username or organization)
\b # starts with a word boundary
[\w\-.]{1,96} # repo_name: alphanumeric + . _ -
\b # ends with a word boundary
$
""",
flags=re.VERBOSE,
)
def validate_hf_hub_args(fn: CallableT) -> CallableT:
"""Validate values received as argument for any public method of `huggingface_hub`.
The goal of this decorator is to harmonize validation of arguments reused
everywhere. By default, all defined validators are tested.
Validators:
- [`~utils.validate_repo_id`]: `repo_id` must be `"repo_name"`
or `"namespace/repo_name"`. Namespace is a username or an organization.
- [`~utils.smoothly_deprecate_legacy_arguments`]: Ignore `proxies` when downloading files (should be set globally).
Example:
```py
>>> from huggingface_hub.utils import validate_hf_hub_args
>>> @validate_hf_hub_args
... def my_cool_method(repo_id: str):
... print(repo_id)
>>> my_cool_method(repo_id="valid_repo_id")
valid_repo_id
>>> my_cool_method("other..repo..id")
huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'.
>>> my_cool_method(repo_id="other..repo..id")
huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'.
```
Raises:
[`~utils.HFValidationError`]:
If an input is not valid.
"""
# TODO: add an argument to opt-out validation for specific argument?
signature = inspect.signature(fn)
@wraps(fn)
def _inner_fn(*args, **kwargs):
for arg_name, arg_value in chain(
zip(signature.parameters, args), # Args values
kwargs.items(), # Kwargs values
):
if arg_name in ["repo_id", "from_id", "to_id"]:
validate_repo_id(arg_value)
kwargs = smoothly_deprecate_legacy_arguments(fn_name=fn.__name__, kwargs=kwargs)
return fn(*args, **kwargs)
return _inner_fn # type: ignore
def validate_repo_id(repo_id: str) -> None:
"""Validate `repo_id` is valid.
This is not meant to replace the proper validation made on the Hub but rather to
avoid local inconsistencies whenever possible (example: passing `repo_type` in the
`repo_id` is forbidden).
Rules:
- Between 1 and 96 characters.
- Either "repo_name" or "namespace/repo_name"
- [a-zA-Z0-9] or "-", "_", "."
- "--" and ".." are forbidden
Valid: `"foo"`, `"foo/bar"`, `"123"`, `"Foo-BAR_foo.bar123"`
Not valid: `"datasets/foo/bar"`, `".repo_id"`, `"foo--bar"`, `"foo.git"`
Example:
```py
>>> from huggingface_hub.utils import validate_repo_id
>>> validate_repo_id(repo_id="valid_repo_id")
>>> validate_repo_id(repo_id="other..repo..id")
huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'.
```
Discussed in https://github.com/huggingface/huggingface_hub/issues/1008.
In moon-landing (internal repository):
- https://github.com/huggingface/moon-landing/blob/main/server/lib/Names.ts#L27
- https://github.com/huggingface/moon-landing/blob/main/server/views/components/NewRepoForm/NewRepoForm.svelte#L138
"""
if not isinstance(repo_id, str):
# Typically, a Path is not a repo_id
raise HFValidationError(f"Repo id must be a string, not {type(repo_id)}: '{repo_id}'.")
if repo_id.count("/") > 1:
raise HFValidationError(
"Repo id must be in the form 'repo_name' or 'namespace/repo_name':"
f" '{repo_id}'. Use `repo_type` argument if needed."
)
if not REPO_ID_REGEX.match(repo_id):
raise HFValidationError(
"Repo id must use alphanumeric chars, '-', '_' or '.'."
" The name cannot start or end with '-' or '.' and the maximum length is 96:"
f" '{repo_id}'."
)
if "--" in repo_id or ".." in repo_id:
raise HFValidationError(f"Cannot have -- or .. in repo_id: '{repo_id}'.")
if repo_id.endswith(".git"):
raise HFValidationError(f"Repo_id cannot end by '.git': '{repo_id}'.")
def smoothly_deprecate_legacy_arguments(fn_name: str, kwargs: dict[str, Any]) -> dict[str, Any]:
"""Smoothly deprecate legacy arguments in the `huggingface_hub` codebase.
This function ignores some deprecated arguments from the kwargs and warns the user they are ignored.
The goal is to avoid breaking existing code while guiding the user to the new way of doing things.
List of deprecated arguments:
- `proxies`:
To set up proxies, user must either use the HTTP_PROXY environment variable or configure the `httpx.Client`
manually using the [`set_client_factory`] function.
In huggingface_hub 0.x, `proxies` was a dictionary directly passed to `requests.request`.
In huggingface_hub 1.x, we migrated to `httpx` which does not support `proxies` the same way.
In particular, it is not possible to configure proxies on a per-request basis. The solution is to configure
it globally using the [`set_client_factory`] function or using the HTTP_PROXY environment variable.
For more details, see:
- https://www.python-httpx.org/advanced/proxies/
- https://www.python-httpx.org/compatibility/#proxy-keys.
- `resume_download`: deprecated without replacement. `huggingface_hub` always resumes downloads whenever possible.
- `force_filename`: deprecated without replacement. Filename is always the same as on the Hub.
- `local_dir_use_symlinks`: deprecated without replacement. Downloading to a local directory does not use symlinks anymore.
"""
new_kwargs = kwargs.copy() # do not mutate input !
# proxies
proxies = new_kwargs.pop("proxies", None) # remove from kwargs
if proxies is not None:
warnings.warn(
f"The `proxies` argument is ignored in `{fn_name}`. To set up proxies, use the HTTP_PROXY / HTTPS_PROXY"
" environment variables or configure the `httpx.Client` manually using `huggingface_hub.set_client_factory`."
" See https://www.python-httpx.org/advanced/proxies/ for more details."
)
# resume_download
resume_download = new_kwargs.pop("resume_download", None) # remove from kwargs
if resume_download is not None:
warnings.warn(
f"The `resume_download` argument is deprecated and ignored in `{fn_name}`. Downloads always resume"
" whenever possible."
)
# force_filename
force_filename = new_kwargs.pop("force_filename", None) # remove from kwargs
if force_filename is not None:
warnings.warn(
f"The `force_filename` argument is deprecated and ignored in `{fn_name}`. Filename is always the same "
"as on the Hub."
)
# local_dir_use_symlinks
local_dir_use_symlinks = new_kwargs.pop("local_dir_use_symlinks", None) # remove from kwargs
if local_dir_use_symlinks is not None:
warnings.warn(
f"The `local_dir_use_symlinks` argument is deprecated and ignored in `{fn_name}`. Downloading to a local"
" directory does not use symlinks anymore."
)
return new_kwargs

View File

@@ -0,0 +1,167 @@
import re
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
from .. import constants
from ..file_download import repo_folder_name
from .sha import git_hash, sha_fileobj
if TYPE_CHECKING:
from ..hf_api import RepoFile, RepoFolder
# using fullmatch for clarity and strictness
_REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
# Typed structure describing a checksum mismatch
class Mismatch(TypedDict):
path: str
expected: str
actual: str
algorithm: str
HashAlgo = Literal["sha256", "git-sha1"]
@dataclass(frozen=True)
class FolderVerification:
revision: str
checked_count: int
mismatches: list[Mismatch]
missing_paths: list[str]
extra_paths: list[str]
verified_path: Path
def collect_local_files(root: Path) -> dict[str, Path]:
"""
Return a mapping of repo-relative path -> absolute path for all files under `root`.
"""
return {p.relative_to(root).as_posix(): p for p in root.rglob("*") if p.is_file()}
def _resolve_commit_hash_from_cache(storage_folder: Path, revision: Optional[str]) -> str:
"""
Resolve a commit hash from a cache repo folder and an optional revision.
"""
if revision and _REGEX_COMMIT_HASH.fullmatch(revision):
return revision
refs_dir = storage_folder / "refs"
snapshots_dir = storage_folder / "snapshots"
if revision:
ref_path = refs_dir / revision
if ref_path.is_file():
return ref_path.read_text(encoding="utf-8").strip()
raise ValueError(f"Revision '{revision}' could not be resolved in cache (expected file '{ref_path}').")
# No revision provided: try common defaults
main_ref = refs_dir / "main"
if main_ref.is_file():
return main_ref.read_text(encoding="utf-8").strip()
if not snapshots_dir.is_dir():
raise ValueError(f"Cache repo is missing snapshots directory: {snapshots_dir}. Provide --revision explicitly.")
candidates = [p.name for p in snapshots_dir.iterdir() if p.is_dir() and _REGEX_COMMIT_HASH.fullmatch(p.name)]
if len(candidates) == 1:
return candidates[0]
raise ValueError(
"Ambiguous cached revision: multiple snapshots found and no refs to disambiguate. Please pass --revision."
)
def compute_file_hash(path: Path, algorithm: HashAlgo) -> str:
"""
Compute the checksum of a local file using the requested algorithm.
"""
with path.open("rb") as stream:
if algorithm == "sha256":
return sha_fileobj(stream).hex()
if algorithm == "git-sha1":
return git_hash(stream.read())
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
def verify_maps(
*,
remote_by_path: dict[str, Union["RepoFile", "RepoFolder"]],
local_by_path: dict[str, Path],
revision: str,
verified_path: Path,
) -> FolderVerification:
"""Compare remote entries and local files and return a verification result."""
remote_paths = set(remote_by_path)
local_paths = set(local_by_path)
missing = sorted(remote_paths - local_paths)
extra = sorted(local_paths - remote_paths)
both = sorted(remote_paths & local_paths)
mismatches: list[Mismatch] = []
for rel_path in both:
remote_entry = remote_by_path[rel_path]
local_path = local_by_path[rel_path]
lfs = getattr(remote_entry, "lfs", None)
lfs_sha = getattr(lfs, "sha256", None) if lfs is not None else None
if lfs_sha is None and isinstance(lfs, dict):
lfs_sha = lfs.get("sha256")
if lfs_sha:
algorithm: HashAlgo = "sha256"
expected = str(lfs_sha).lower()
else:
blob_id = remote_entry.blob_id # type: ignore
algorithm = "git-sha1"
expected = str(blob_id).lower()
actual = compute_file_hash(local_path, algorithm)
if actual != expected:
mismatches.append(Mismatch(path=rel_path, expected=expected, actual=actual, algorithm=algorithm))
return FolderVerification(
revision=revision,
checked_count=len(both),
mismatches=mismatches,
missing_paths=missing,
extra_paths=extra,
verified_path=verified_path,
)
def resolve_local_root(
*,
repo_id: str,
repo_type: str,
revision: Optional[str],
cache_dir: Optional[Path],
local_dir: Optional[Path],
) -> tuple[Path, str]:
"""
Resolve the root directory to scan locally and the remote revision to verify.
"""
if local_dir is not None:
root = Path(local_dir).expanduser().resolve()
if not root.is_dir():
raise ValueError(f"Local directory does not exist or is not a directory: {root}")
return root, (revision or constants.DEFAULT_REVISION)
cache_root = Path(cache_dir or constants.HF_HUB_CACHE).expanduser().resolve()
storage_folder = cache_root / repo_folder_name(repo_id=repo_id, repo_type=repo_type)
if not storage_folder.exists():
raise ValueError(
f"Repo is not present in cache: {storage_folder}. Use 'hf download' first or pass --local-dir."
)
commit = _resolve_commit_hash_from_cache(storage_folder, revision)
snapshot_dir = storage_folder / "snapshots" / commit
if not snapshot_dir.is_dir():
raise ValueError(f"Snapshot directory does not exist for revision '{commit}': {snapshot_dir}.")
return snapshot_dir, commit

View File

@@ -0,0 +1,235 @@
import time
from dataclasses import dataclass
from enum import Enum
from typing import Optional
import httpx
from .. import constants
from . import get_session, hf_raise_for_status, validate_hf_hub_args
XET_CONNECTION_INFO_SAFETY_PERIOD = 60 # seconds
XET_CONNECTION_INFO_CACHE_SIZE = 1_000
XET_CONNECTION_INFO_CACHE: dict[str, "XetConnectionInfo"] = {}
class XetTokenType(str, Enum):
READ = "read"
WRITE = "write"
@dataclass(frozen=True)
class XetFileData:
file_hash: str
refresh_route: str
@dataclass(frozen=True)
class XetConnectionInfo:
access_token: str
expiration_unix_epoch: int
endpoint: str
def parse_xet_file_data_from_response(
response: httpx.Response, endpoint: Optional[str] = None
) -> Optional[XetFileData]:
"""
Parse XET file metadata from an HTTP response.
This function extracts XET file metadata from the HTTP headers or HTTP links
of a given response object. If the required metadata is not found, it returns `None`.
Args:
response (`httpx.Response`):
The HTTP response object containing headers dict and links dict to extract the XET metadata from.
Returns:
`Optional[XetFileData]`:
An instance of `XetFileData` containing the file hash and refresh route if the metadata
is found. Returns `None` if the required metadata is missing.
"""
if response is None:
return None
try:
file_hash = response.headers[constants.HUGGINGFACE_HEADER_X_XET_HASH]
if constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY in response.links:
refresh_route = response.links[constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY]["url"]
else:
refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE]
except KeyError:
return None
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
if refresh_route.startswith(constants.HUGGINGFACE_CO_URL_HOME):
refresh_route = refresh_route.replace(constants.HUGGINGFACE_CO_URL_HOME.rstrip("/"), endpoint.rstrip("/"))
return XetFileData(
file_hash=file_hash,
refresh_route=refresh_route,
)
def parse_xet_connection_info_from_headers(headers: dict[str, str]) -> Optional[XetConnectionInfo]:
"""
Parse XET connection info from the HTTP headers or return None if not found.
Args:
headers (`dict`):
HTTP headers to extract the XET metadata from.
Returns:
`XetConnectionInfo` or `None`:
The information needed to connect to the XET storage service.
Returns `None` if the headers do not contain the XET connection info.
"""
try:
endpoint = headers[constants.HUGGINGFACE_HEADER_X_XET_ENDPOINT]
access_token = headers[constants.HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN]
expiration_unix_epoch = int(headers[constants.HUGGINGFACE_HEADER_X_XET_EXPIRATION])
except (KeyError, ValueError, TypeError):
return None
return XetConnectionInfo(
endpoint=endpoint,
access_token=access_token,
expiration_unix_epoch=expiration_unix_epoch,
)
@validate_hf_hub_args
def refresh_xet_connection_info(
*,
file_data: XetFileData,
headers: dict[str, str],
) -> XetConnectionInfo:
"""
Utilizes the information in the parsed metadata to request the Hub xet connection information.
This includes the access token, expiration, and XET service URL.
Args:
file_data: (`XetFileData`):
The file data needed to refresh the xet connection information.
headers (`dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
Returns:
`XetConnectionInfo`:
The connection information needed to make the request to the xet storage service.
Raises:
[`~utils.HfHubHTTPError`]
If the Hub API returned an error.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If the Hub API response is improperly formatted.
"""
if file_data.refresh_route is None:
raise ValueError("The provided xet metadata does not contain a refresh endpoint.")
return _fetch_xet_connection_info_with_url(file_data.refresh_route, headers)
@validate_hf_hub_args
def fetch_xet_connection_info_from_repo_info(
*,
token_type: XetTokenType,
repo_id: str,
repo_type: str,
revision: Optional[str] = None,
headers: dict[str, str],
endpoint: Optional[str] = None,
params: Optional[dict[str, str]] = None,
) -> XetConnectionInfo:
"""
Uses the repo info to request a xet access token from Hub.
Args:
token_type (`XetTokenType`):
Type of the token to request: `"read"` or `"write"`.
repo_id (`str`):
A namespace (user or an organization) and a repo name separated by a `/`.
repo_type (`str`):
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
revision (`str`, `optional`):
The revision of the repo to get the token for.
headers (`dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
endpoint (`str`, `optional`):
The endpoint to use for the request. Defaults to the Hub endpoint.
params (`dict[str, str]`, `optional`):
Additional parameters to pass with the request.
Returns:
`XetConnectionInfo`:
The connection information needed to make the request to the xet storage service.
Raises:
[`~utils.HfHubHTTPError`]
If the Hub API returned an error.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If the Hub API response is improperly formatted.
"""
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
url = f"{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type.value}-token/{revision}"
return _fetch_xet_connection_info_with_url(url, headers, params)
@validate_hf_hub_args
def _fetch_xet_connection_info_with_url(
url: str,
headers: dict[str, str],
params: Optional[dict[str, str]] = None,
) -> XetConnectionInfo:
"""
Requests the xet connection info from the supplied URL. This includes the
access token, expiration time, and endpoint to use for the xet storage service.
Result is cached to avoid redundant requests.
Args:
url: (`str`):
The access token endpoint URL.
headers (`dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
params (`dict[str, str]`, `optional`):
Additional parameters to pass with the request.
Returns:
`XetConnectionInfo`:
The connection information needed to make the request to the xet storage service.
Raises:
[`~utils.HfHubHTTPError`]
If the Hub API returned an error.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If the Hub API response is improperly formatted.
"""
# Check cache first
cache_key = _cache_key(url, headers, params)
cached_info = XET_CONNECTION_INFO_CACHE.get(cache_key)
if cached_info is not None:
if not _is_expired(cached_info):
return cached_info
# Fetch from server
resp = get_session().get(headers=headers, url=url, params=params)
hf_raise_for_status(resp)
metadata = parse_xet_connection_info_from_headers(resp.headers) # type: ignore
if metadata is None:
raise ValueError("Xet headers have not been correctly set by the server.")
# Delete expired cache entries
for k, v in list(XET_CONNECTION_INFO_CACHE.items()):
if _is_expired(v):
XET_CONNECTION_INFO_CACHE.pop(k, None)
# Enforce cache size limit
if len(XET_CONNECTION_INFO_CACHE) >= XET_CONNECTION_INFO_CACHE_SIZE:
XET_CONNECTION_INFO_CACHE.pop(next(iter(XET_CONNECTION_INFO_CACHE)))
# Update cache
XET_CONNECTION_INFO_CACHE[cache_key] = metadata
return metadata
def _cache_key(url: str, headers: dict[str, str], params: Optional[dict[str, str]]) -> str:
"""Return a unique cache key for the given request parameters."""
lower_headers = {k.lower(): v for k, v in headers.items()} # casing is not guaranteed here
auth_header = lower_headers.get("authorization", "")
params_str = "&".join(f"{k}={v}" for k, v in sorted((params or {}).items(), key=lambda x: x[0]))
return f"{url}|{auth_header}|{params_str}"
def _is_expired(connection_info: XetConnectionInfo) -> bool:
"""Check if the given XET connection info is expired."""
return connection_info.expiration_unix_epoch <= int(time.time()) + XET_CONNECTION_INFO_SAFETY_PERIOD

View File

@@ -0,0 +1,162 @@
from collections import OrderedDict
from typing import List
from hf_xet import PyItemProgressUpdate, PyTotalProgressUpdate
from . import is_google_colab, is_notebook
from .tqdm import tqdm
class XetProgressReporter:
"""
Reports on progress for Xet uploads.
Shows summary progress bars when running in notebooks or GUIs, and detailed per-file progress in console environments.
"""
def __init__(self, n_lines: int = 10, description_width: int = 30):
self.n_lines = n_lines
self.description_width = description_width
self.per_file_progress = is_google_colab() or not is_notebook()
self.tqdm_settings = {
"unit": "B",
"unit_scale": True,
"leave": True,
"unit_divisor": 1000,
"nrows": n_lines + 3 if self.per_file_progress else 3,
"miniters": 1,
"bar_format": "{l_bar}{bar}| {n_fmt:>5}B / {total_fmt:>5}B{postfix:>12}",
}
# Overall progress bars
self.data_processing_bar = tqdm(
total=0, desc=self.format_desc("Processing Files (0 / 0)", False), position=0, **self.tqdm_settings
)
self.upload_bar = tqdm(
total=0, desc=self.format_desc("New Data Upload", False), position=1, **self.tqdm_settings
)
self.known_items: set[str] = set()
self.completed_items: set[str] = set()
# Item bars (scrolling view)
self.item_state: OrderedDict[str, PyItemProgressUpdate] = OrderedDict()
self.current_bars: List = [None] * self.n_lines
def format_desc(self, name: str, indent: bool) -> str:
"""
if name is longer than width characters, prints ... at the start and then the last width-3 characters of the name, otherwise
the whole name right justified into description_width characters. Also adds some padding.
"""
if not self.per_file_progress:
# Here we just use the defaults.
return name
padding = " " if indent else ""
width = self.description_width - len(padding)
if len(name) > width:
name = f"...{name[-(width - 3) :]}"
return f"{padding}{name.ljust(width)}"
def update_progress(self, total_update: PyTotalProgressUpdate, item_updates: list[PyItemProgressUpdate]):
# Update all the per-item values.
for item in item_updates:
item_name = item.item_name
self.known_items.add(item_name)
# Only care about items where the processing has already started.
if item.bytes_completed == 0:
continue
# Overwrite the existing value in there.
self.item_state[item_name] = item
bar_idx = 0
new_completed = []
# Now, go through and update all the bars
for name, item in self.item_state.items():
# Is this ready to be removed on the next update?
if item.bytes_completed == item.total_bytes:
self.completed_items.add(name)
new_completed.append(name)
# If we're only showing summary information, then don't update the individual bars
if not self.per_file_progress:
continue
# If we've run out of bars to use, then collapse the last ones together.
if bar_idx >= len(self.current_bars):
bar = self.current_bars[-1]
in_final_bar_mode = True
final_bar_aggregation_count = bar_idx + 1 - len(self.current_bars)
else:
bar = self.current_bars[bar_idx]
in_final_bar_mode = False
if bar is None:
self.current_bars[bar_idx] = tqdm(
desc=self.format_desc(name, True),
position=2 + bar_idx, # Set to the position past the initial bars.
total=item.total_bytes,
initial=item.bytes_completed,
**self.tqdm_settings,
)
elif in_final_bar_mode:
bar.n += item.bytes_completed
bar.total += item.total_bytes
bar.set_description(self.format_desc(f"[+ {final_bar_aggregation_count} files]", True), refresh=False)
else:
bar.set_description(self.format_desc(name, True), refresh=False)
bar.n = item.bytes_completed
bar.total = item.total_bytes
bar_idx += 1
# Remove all the completed ones from the ordered dictionary
for name in new_completed:
# Only remove ones from consideration to make room for more items coming in.
if len(self.item_state) <= self.n_lines:
break
del self.item_state[name]
if self.per_file_progress:
# Now manually refresh each of the bars
for bar in self.current_bars:
if bar:
bar.refresh()
# Update overall bars
def postfix(speed):
s = tqdm.format_sizeof(speed) if speed is not None else "???"
return f"{s}B/s ".rjust(10, " ")
self.data_processing_bar.total = total_update.total_bytes
self.data_processing_bar.set_description(
self.format_desc(f"Processing Files ({len(self.completed_items)} / {len(self.known_items)})", False),
refresh=False,
)
self.data_processing_bar.set_postfix_str(postfix(total_update.total_bytes_completion_rate), refresh=False)
self.data_processing_bar.update(total_update.total_bytes_completion_increment)
self.upload_bar.total = total_update.total_transfer_bytes
self.upload_bar.set_postfix_str(postfix(total_update.total_transfer_bytes_completion_rate), refresh=False)
self.upload_bar.update(total_update.total_transfer_bytes_completion_increment)
def close(self, _success):
self.data_processing_bar.close()
self.upload_bar.close()
if self.per_file_progress:
for bar in self.current_bars:
if bar:
bar.close()

View File

@@ -0,0 +1,66 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helpful utility functions and classes in relation to exploring API endpoints
with the aim for a user-friendly interface.
"""
import math
import re
from typing import TYPE_CHECKING
from ..repocard_data import ModelCardData
if TYPE_CHECKING:
from ..hf_api import ModelInfo
def _is_emission_within_threshold(model_info: "ModelInfo", minimum_threshold: float, maximum_threshold: float) -> bool:
"""Checks if a model's emission is within a given threshold.
Args:
model_info (`ModelInfo`):
A model info object containing the model's emission information.
minimum_threshold (`float`):
A minimum carbon threshold to filter by, such as 1.
maximum_threshold (`float`):
A maximum carbon threshold to filter by, such as 10.
Returns:
`bool`: Whether the model's emission is within the given threshold.
"""
if minimum_threshold is None and maximum_threshold is None:
raise ValueError("Both `minimum_threshold` and `maximum_threshold` cannot both be `None`")
if minimum_threshold is None:
minimum_threshold = -1
if maximum_threshold is None:
maximum_threshold = math.inf
card_data = getattr(model_info, "card_data", None)
if card_data is None or not isinstance(card_data, (dict, ModelCardData)):
return False
# Get CO2 emission metadata
emission = card_data.get("co2_eq_emissions", None)
if isinstance(emission, dict):
emission = emission["emissions"]
if not emission:
return False
# Filter out if value is missing or out of range
matched = re.search(r"\d+\.\d+|\d+", str(emission))
if matched is None:
return False
emission_value = float(matched.group(0))
return minimum_threshold <= emission_value <= maximum_threshold

View File

@@ -0,0 +1,32 @@
# Taken from https://github.com/mlflow/mlflow/pull/10119
#
# DO NOT use this function for security purposes (e.g., password hashing).
#
# In Python >= 3.9, insecure hashing algorithms such as MD5 fail in FIPS-compliant
# environments unless `usedforsecurity=False` is explicitly passed.
#
# References:
# - https://github.com/mlflow/mlflow/issues/9905
# - https://github.com/mlflow/mlflow/pull/10119
# - https://docs.python.org/3/library/hashlib.html
# - https://github.com/huggingface/transformers/pull/27038
#
# Usage:
# ```python
# # Use
# from huggingface_hub.utils.insecure_hashlib import sha256
# # instead of
# from hashlib import sha256
#
# # Use
# from huggingface_hub.utils import insecure_hashlib
# # instead of
# import hashlib
# ```
import functools
import hashlib
md5 = functools.partial(hashlib.md5, usedforsecurity=False)
sha1 = functools.partial(hashlib.sha1, usedforsecurity=False)
sha256 = functools.partial(hashlib.sha256, usedforsecurity=False)

View File

@@ -0,0 +1,185 @@
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Logging utilities."""
import logging
import os
from logging import (
CRITICAL, # NOQA
DEBUG, # NOQA
ERROR, # NOQA
FATAL, # NOQA
INFO, # NOQA
NOTSET, # NOQA
WARN, # NOQA
WARNING, # NOQA
)
from typing import Optional
from .. import constants
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
_default_log_level = logging.WARNING
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> logging.Logger:
return logging.getLogger(_get_library_name())
def _get_default_logging_level():
"""
If `HF_HUB_VERBOSITY` env var is set to one of the valid choices return that as the new default level. If it is not
- fall back to `_default_log_level`
"""
env_level_str = os.getenv("HF_HUB_VERBOSITY", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option HF_HUB_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}"
)
return _default_log_level
def _configure_library_root_logger() -> None:
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(logging.StreamHandler())
library_root_logger.setLevel(_get_default_logging_level())
def _reset_library_root_logger() -> None:
library_root_logger = _get_library_root_logger()
library_root_logger.setLevel(logging.NOTSET)
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Returns a logger with the specified name. This function is not supposed
to be directly accessed by library users.
Args:
name (`str`, *optional*):
The name of the logger to get, usually the filename
Example:
```python
>>> from huggingface_hub import get_logger
>>> logger = get_logger(__file__)
>>> logger.set_verbosity_info()
```
"""
if name is None:
name = _get_library_name()
return logging.getLogger(name)
def get_verbosity() -> int:
"""Return the current level for the HuggingFace Hub's root logger.
Returns:
Logging level, e.g., `huggingface_hub.logging.DEBUG` and
`huggingface_hub.logging.INFO`.
> [!TIP]
> HuggingFace Hub has following logging levels:
>
> - `huggingface_hub.logging.CRITICAL`, `huggingface_hub.logging.FATAL`
> - `huggingface_hub.logging.ERROR`
> - `huggingface_hub.logging.WARNING`, `huggingface_hub.logging.WARN`
> - `huggingface_hub.logging.INFO`
> - `huggingface_hub.logging.DEBUG`
"""
return _get_library_root_logger().getEffectiveLevel()
def set_verbosity(verbosity: int) -> None:
"""
Sets the level for the HuggingFace Hub's root logger.
Args:
verbosity (`int`):
Logging level, e.g., `huggingface_hub.logging.DEBUG` and
`huggingface_hub.logging.INFO`.
"""
_get_library_root_logger().setLevel(verbosity)
def set_verbosity_info():
"""
Sets the verbosity to `logging.INFO`.
"""
return set_verbosity(INFO)
def set_verbosity_warning():
"""
Sets the verbosity to `logging.WARNING`.
"""
return set_verbosity(WARNING)
def set_verbosity_debug():
"""
Sets the verbosity to `logging.DEBUG`.
"""
return set_verbosity(DEBUG)
def set_verbosity_error():
"""
Sets the verbosity to `logging.ERROR`.
"""
return set_verbosity(ERROR)
def disable_propagation() -> None:
"""
Disable propagation of the library log outputs. Note that log propagation is
disabled by default.
"""
_get_library_root_logger().propagate = False
def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the
HuggingFace Hub's default handler to prevent double logging if the root
logger has been configured.
"""
_get_library_root_logger().propagate = True
_configure_library_root_logger()
if constants.HF_DEBUG:
# If `HF_DEBUG` environment variable is set, set the verbosity of `huggingface_hub` logger to `DEBUG`.
set_verbosity_debug()

View File

@@ -0,0 +1,64 @@
"""Utilities to efficiently compute the SHA 256 hash of a bunch of bytes."""
from typing import BinaryIO, Optional
from .insecure_hashlib import sha1, sha256
def sha_fileobj(fileobj: BinaryIO, chunk_size: Optional[int] = None) -> bytes:
"""
Computes the sha256 hash of the given file object, by chunks of size `chunk_size`.
Args:
fileobj (file-like object):
The File object to compute sha256 for, typically obtained with `open(path, "rb")`
chunk_size (`int`, *optional*):
The number of bytes to read from `fileobj` at once, defaults to 1MB.
Returns:
`bytes`: `fileobj`'s sha256 hash as bytes
"""
chunk_size = chunk_size if chunk_size is not None else 1024 * 1024
sha = sha256()
while True:
chunk = fileobj.read(chunk_size)
sha.update(chunk)
if not chunk:
break
return sha.digest()
def git_hash(data: bytes) -> str:
"""
Computes the git-sha1 hash of the given bytes, using the same algorithm as git.
This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object
for more details.
Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the
pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of
the LFS file content when we want to compare LFS files.
Args:
data (`bytes`):
The data to compute the git-hash for.
Returns:
`str`: the git-hash of `data` as an hexadecimal string.
Example:
```python
>>> from huggingface_hub.utils.sha import git_hash
>>> git_hash(b"Hello, World!")
'b45ef6fec89518d314f546fd6c3025367b721684'
```
"""
# Taken from https://gist.github.com/msabramo/763200
# Note: no need to optimize by reading the file in chunks as we're not supposed to hash huge files (5MB maximum).
sha = sha1()
sha.update(b"blob ")
sha.update(str(len(data)).encode())
sha.update(b"\0")
sha.update(data)
return sha.hexdigest()

View File

@@ -0,0 +1,308 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""Utility helpers to handle progress bars in `huggingface_hub`.
Example:
1. Use `huggingface_hub.utils.tqdm` as you would use `tqdm.tqdm` or `tqdm.auto.tqdm`.
2. To disable progress bars, either use `disable_progress_bars()` helper or set the
environment variable `HF_HUB_DISABLE_PROGRESS_BARS` to 1.
3. To re-enable progress bars, use `enable_progress_bars()`.
4. To check whether progress bars are disabled, use `are_progress_bars_disabled()`.
NOTE: Environment variable `HF_HUB_DISABLE_PROGRESS_BARS` has the priority.
Example:
```py
>>> from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm
# Disable progress bars globally
>>> disable_progress_bars()
# Use as normal `tqdm`
>>> for _ in tqdm(range(5)):
... pass
# Still not showing progress bars, as `disable=False` is overwritten to `True`.
>>> for _ in tqdm(range(5), disable=False):
... pass
>>> are_progress_bars_disabled()
True
# Re-enable progress bars globally
>>> enable_progress_bars()
# Progress bar will be shown !
>>> for _ in tqdm(range(5)):
... pass
100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s]
```
Group-based control:
```python
# Disable progress bars for a specific group
>>> disable_progress_bars("peft.foo")
# Check state of different groups
>>> assert not are_progress_bars_disabled("peft"))
>>> assert not are_progress_bars_disabled("peft.something")
>>> assert are_progress_bars_disabled("peft.foo"))
>>> assert are_progress_bars_disabled("peft.foo.bar"))
# Enable progress bars for a subgroup
>>> enable_progress_bars("peft.foo.bar")
# Check if enabling a subgroup affects the parent group
>>> assert are_progress_bars_disabled("peft.foo"))
>>> assert not are_progress_bars_disabled("peft.foo.bar"))
# No progress bar for `name="peft.foo"`
>>> for _ in tqdm(range(5), name="peft.foo"):
... pass
# Progress bar will be shown for `name="peft.foo.bar"`
>>> for _ in tqdm(range(5), name="peft.foo.bar"):
... pass
100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s]
```
"""
import io
import logging
import os
import warnings
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import ContextManager, Iterator, Optional, Union
from tqdm.auto import tqdm as old_tqdm
from ..constants import HF_HUB_DISABLE_PROGRESS_BARS
# The `HF_HUB_DISABLE_PROGRESS_BARS` environment variable can be True, False, or not set (None),
# allowing for control over progress bar visibility. When set, this variable takes precedence
# over programmatic settings, dictating whether progress bars should be shown or hidden globally.
# Essentially, the environment variable's setting overrides any code-based configurations.
#
# If `HF_HUB_DISABLE_PROGRESS_BARS` is not defined (None), it implies that users can manage
# progress bar visibility through code. By default, progress bars are turned on.
progress_bar_states: dict[str, bool] = {}
def disable_progress_bars(name: Optional[str] = None) -> None:
"""
Disable progress bars either globally or for a specified group.
This function updates the state of progress bars based on a group name.
If no group name is provided, all progress bars are disabled. The operation
respects the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable's setting.
Args:
name (`str`, *optional*):
The name of the group for which to disable the progress bars. If None,
progress bars are disabled globally.
Raises:
Warning: If the environment variable precludes changes.
"""
if HF_HUB_DISABLE_PROGRESS_BARS is False:
warnings.warn(
"Cannot disable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=0` is set and has priority."
)
return
if name is None:
progress_bar_states.clear()
progress_bar_states["_global"] = False
else:
keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")]
for key in keys_to_remove:
del progress_bar_states[key]
progress_bar_states[name] = False
def enable_progress_bars(name: Optional[str] = None) -> None:
"""
Enable progress bars either globally or for a specified group.
This function sets the progress bars to enabled for the specified group or globally
if no group is specified. The operation is subject to the `HF_HUB_DISABLE_PROGRESS_BARS`
environment setting.
Args:
name (`str`, *optional*):
The name of the group for which to enable the progress bars. If None,
progress bars are enabled globally.
Raises:
Warning: If the environment variable precludes changes.
"""
if HF_HUB_DISABLE_PROGRESS_BARS is True:
warnings.warn(
"Cannot enable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=1` is set and has priority."
)
return
if name is None:
progress_bar_states.clear()
progress_bar_states["_global"] = True
else:
keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")]
for key in keys_to_remove:
del progress_bar_states[key]
progress_bar_states[name] = True
def are_progress_bars_disabled(name: Optional[str] = None) -> bool:
"""
Check if progress bars are disabled globally or for a specific group.
This function returns whether progress bars are disabled for a given group or globally.
It checks the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable first, then the programmatic
settings.
Args:
name (`str`, *optional*):
The group name to check; if None, checks the global setting.
Returns:
`bool`: True if progress bars are disabled, False otherwise.
"""
if HF_HUB_DISABLE_PROGRESS_BARS is True:
return True
if name is None:
return not progress_bar_states.get("_global", True)
while name:
if name in progress_bar_states:
return not progress_bar_states[name]
name = ".".join(name.split(".")[:-1])
return not progress_bar_states.get("_global", True)
def is_tqdm_disabled(log_level: int) -> Optional[bool]:
"""
Determine if tqdm progress bars should be disabled based on logging level and environment settings.
see https://github.com/huggingface/huggingface_hub/pull/2000 and https://github.com/huggingface/huggingface_hub/pull/2698.
"""
if log_level == logging.NOTSET:
return True
if os.getenv("TQDM_POSITION") == "-1":
return False
return None
class tqdm(old_tqdm):
"""
Class to override `disable` argument in case progress bars are globally disabled.
Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324.
"""
def __init__(self, *args, **kwargs):
name = kwargs.pop("name", None) # do not pass `name` to `tqdm`
if are_progress_bars_disabled(name):
kwargs["disable"] = True
super().__init__(*args, **kwargs)
def __delattr__(self, attr: str) -> None:
"""Fix for https://github.com/huggingface/huggingface_hub/issues/1603"""
try:
super().__delattr__(attr)
except AttributeError:
if attr != "_lock":
raise
@contextmanager
def tqdm_stream_file(path: Union[Path, str]) -> Iterator[io.BufferedReader]:
"""
Open a file as binary and wrap the `read` method to display a progress bar when it's streamed.
First implemented in `transformers` in 2019 but removed when switched to git-lfs. Used in `huggingface_hub` to show
progress bar when uploading an LFS file to the Hub. See github.com/huggingface/transformers/pull/2078#discussion_r354739608
for implementation details.
Note: currently implementation handles only files stored on disk as it is the most common use case. Could be
extended to stream any `BinaryIO` object but we might have to debug some corner cases.
Example:
```py
>>> with tqdm_stream_file("config.json") as f:
>>> httpx.put(url, data=f)
config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
```
"""
if isinstance(path, str):
path = Path(path)
with path.open("rb") as f:
total_size = path.stat().st_size
pbar = tqdm(
unit="B",
unit_scale=True,
total=total_size,
initial=0,
desc=path.name,
)
f_read = f.read
def _inner_read(size: Optional[int] = -1) -> bytes:
data = f_read(size)
pbar.update(len(data))
return data
f.read = _inner_read # type: ignore
yield f
pbar.close()
def _get_progress_bar_context(
*,
desc: str,
log_level: int,
total: Optional[int] = None,
initial: int = 0,
unit: str = "B",
unit_scale: bool = True,
name: Optional[str] = None,
tqdm_class: Optional[type[old_tqdm]] = None,
_tqdm_bar: Optional[tqdm] = None,
) -> ContextManager[tqdm]:
if _tqdm_bar is not None:
return nullcontext(_tqdm_bar)
# ^ `contextlib.nullcontext` mimics a context manager that does nothing
# Makes it easier to use the same code path for both cases but in the later
# case, the progress bar is not closed when exiting the context manager.
return (tqdm_class or tqdm)( # type: ignore[return-value]
unit=unit,
unit_scale=unit_scale,
total=total,
initial=initial,
desc=desc,
disable=is_tqdm_disabled(log_level=log_level),
name=name,
)