chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,129 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
# ruff: noqa: F401
|
||||
from huggingface_hub.errors import (
|
||||
BadRequestError,
|
||||
CacheNotFound,
|
||||
CorruptedCacheException,
|
||||
DisabledRepoError,
|
||||
EntryNotFoundError,
|
||||
FileMetadataError,
|
||||
GatedRepoError,
|
||||
HfHubHTTPError,
|
||||
HFValidationError,
|
||||
LocalEntryNotFoundError,
|
||||
LocalTokenNotFoundError,
|
||||
NotASafetensorsRepoError,
|
||||
OfflineModeIsEnabled,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
SafetensorsParsingError,
|
||||
)
|
||||
|
||||
from . import tqdm as _tqdm # _tqdm is the module
|
||||
from ._auth import get_stored_tokens, get_token
|
||||
from ._cache_assets import cached_assets_path
|
||||
from ._cache_manager import (
|
||||
CachedFileInfo,
|
||||
CachedRepoInfo,
|
||||
CachedRevisionInfo,
|
||||
DeleteCacheStrategy,
|
||||
HFCacheInfo,
|
||||
_format_size,
|
||||
scan_cache_dir,
|
||||
)
|
||||
from ._chunk_utils import chunk_iterable
|
||||
from ._datetime import parse_datetime
|
||||
from ._experimental import experimental
|
||||
from ._fixes import SoftTemporaryDirectory, WeakFileLock, yaml_dump
|
||||
from ._git_credential import list_credential_helpers, set_git_credential, unset_git_credential
|
||||
from ._headers import build_hf_headers, get_token_to_send
|
||||
from ._http import (
|
||||
ASYNC_CLIENT_FACTORY_T,
|
||||
CLIENT_FACTORY_T,
|
||||
close_session,
|
||||
fix_hf_endpoint_in_url,
|
||||
get_async_session,
|
||||
get_session,
|
||||
hf_raise_for_status,
|
||||
http_backoff,
|
||||
http_stream_backoff,
|
||||
set_async_client_factory,
|
||||
set_client_factory,
|
||||
)
|
||||
from ._pagination import paginate
|
||||
from ._paths import DEFAULT_IGNORE_PATTERNS, FORBIDDEN_FOLDERS, filter_repo_objects
|
||||
from ._runtime import (
|
||||
dump_environment_info,
|
||||
get_aiohttp_version,
|
||||
get_fastai_version,
|
||||
get_fastapi_version,
|
||||
get_fastcore_version,
|
||||
get_gradio_version,
|
||||
get_graphviz_version,
|
||||
get_hf_hub_version,
|
||||
get_jinja_version,
|
||||
get_numpy_version,
|
||||
get_pillow_version,
|
||||
get_pydantic_version,
|
||||
get_pydot_version,
|
||||
get_python_version,
|
||||
get_tensorboard_version,
|
||||
get_tf_version,
|
||||
get_torch_version,
|
||||
installation_method,
|
||||
is_aiohttp_available,
|
||||
is_colab_enterprise,
|
||||
is_fastai_available,
|
||||
is_fastapi_available,
|
||||
is_fastcore_available,
|
||||
is_google_colab,
|
||||
is_gradio_available,
|
||||
is_graphviz_available,
|
||||
is_jinja_available,
|
||||
is_notebook,
|
||||
is_numpy_available,
|
||||
is_package_available,
|
||||
is_pillow_available,
|
||||
is_pydantic_available,
|
||||
is_pydot_available,
|
||||
is_safetensors_available,
|
||||
is_tensorboard_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from ._safetensors import SafetensorsFileMetadata, SafetensorsRepoMetadata, TensorInfo
|
||||
from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess
|
||||
from ._telemetry import send_telemetry
|
||||
from ._terminal import ANSI, tabulate
|
||||
from ._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type
|
||||
from ._validators import validate_hf_hub_args, validate_repo_id
|
||||
from ._xet import (
|
||||
XetConnectionInfo,
|
||||
XetFileData,
|
||||
XetTokenType,
|
||||
fetch_xet_connection_info_from_repo_info,
|
||||
parse_xet_file_data_from_response,
|
||||
refresh_xet_connection_info,
|
||||
)
|
||||
from .tqdm import (
|
||||
are_progress_bars_disabled,
|
||||
disable_progress_bars,
|
||||
enable_progress_bars,
|
||||
is_tqdm_disabled,
|
||||
tqdm,
|
||||
tqdm_stream_file,
|
||||
)
|
||||
@@ -0,0 +1,214 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains a helper to get the token from machine (env variable, secret or config file)."""
|
||||
|
||||
import configparser
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
from .. import constants
|
||||
from ._runtime import is_colab_enterprise, is_google_colab
|
||||
|
||||
|
||||
_IS_GOOGLE_COLAB_CHECKED = False
|
||||
_GOOGLE_COLAB_SECRET_LOCK = Lock()
|
||||
_GOOGLE_COLAB_SECRET: Optional[str] = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_token() -> Optional[str]:
|
||||
"""
|
||||
Get token if user is logged in.
|
||||
|
||||
Note: in most cases, you should use [`huggingface_hub.utils.build_hf_headers`] instead. This method is only useful
|
||||
if you want to retrieve the token for other purposes than sending an HTTP request.
|
||||
|
||||
Token is retrieved in priority from the `HF_TOKEN` environment variable. Otherwise, we read the token file located
|
||||
in the Hugging Face home folder. Returns None if user is not logged in. To log in, use [`login`] or
|
||||
`hf auth login`.
|
||||
|
||||
Returns:
|
||||
`str` or `None`: The token, `None` if it doesn't exist.
|
||||
"""
|
||||
return _get_token_from_google_colab() or _get_token_from_environment() or _get_token_from_file()
|
||||
|
||||
|
||||
def _get_token_from_google_colab() -> Optional[str]:
|
||||
"""Get token from Google Colab secrets vault using `google.colab.userdata.get(...)`.
|
||||
|
||||
Token is read from the vault only once per session and then stored in a global variable to avoid re-requesting
|
||||
access to the vault.
|
||||
"""
|
||||
# If it's not a Google Colab or it's Colab Enterprise, fallback to environment variable or token file authentication
|
||||
if not is_google_colab() or is_colab_enterprise():
|
||||
return None
|
||||
|
||||
# `google.colab.userdata` is not thread-safe
|
||||
# This can lead to a deadlock if multiple threads try to access it at the same time
|
||||
# (typically when using `snapshot_download`)
|
||||
# => use a lock
|
||||
# See https://github.com/huggingface/huggingface_hub/issues/1952 for more details.
|
||||
with _GOOGLE_COLAB_SECRET_LOCK:
|
||||
global _GOOGLE_COLAB_SECRET
|
||||
global _IS_GOOGLE_COLAB_CHECKED
|
||||
|
||||
if _IS_GOOGLE_COLAB_CHECKED: # request access only once
|
||||
return _GOOGLE_COLAB_SECRET
|
||||
|
||||
try:
|
||||
from google.colab import userdata # type: ignore
|
||||
from google.colab.errors import Error as ColabError # type: ignore
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
try:
|
||||
token = userdata.get("HF_TOKEN")
|
||||
_GOOGLE_COLAB_SECRET = _clean_token(token)
|
||||
except userdata.NotebookAccessError:
|
||||
# Means the user has a secret call `HF_TOKEN` and got a popup "please grand access to HF_TOKEN" and refused it
|
||||
# => warn user but ignore error => do not re-request access to user
|
||||
warnings.warn(
|
||||
"\nAccess to the secret `HF_TOKEN` has not been granted on this notebook."
|
||||
"\nYou will not be requested again."
|
||||
"\nPlease restart the session if you want to be prompted again."
|
||||
)
|
||||
_GOOGLE_COLAB_SECRET = None
|
||||
except userdata.SecretNotFoundError:
|
||||
# Means the user did not define a `HF_TOKEN` secret => warn
|
||||
warnings.warn(
|
||||
"\nThe secret `HF_TOKEN` does not exist in your Colab secrets."
|
||||
"\nTo authenticate with the Hugging Face Hub, create a token in your settings tab "
|
||||
"(https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session."
|
||||
"\nYou will be able to reuse this secret in all of your notebooks."
|
||||
"\nPlease note that authentication is recommended but still optional to access public models or datasets."
|
||||
)
|
||||
_GOOGLE_COLAB_SECRET = None
|
||||
except ColabError as e:
|
||||
# Something happen but we don't know what => recommend to open a GitHub issue
|
||||
warnings.warn(
|
||||
f"\nError while fetching `HF_TOKEN` secret value from your vault: '{str(e)}'."
|
||||
"\nYou are not authenticated with the Hugging Face Hub in this notebook."
|
||||
"\nIf the error persists, please let us know by opening an issue on GitHub "
|
||||
"(https://github.com/huggingface/huggingface_hub/issues/new)."
|
||||
)
|
||||
_GOOGLE_COLAB_SECRET = None
|
||||
|
||||
_IS_GOOGLE_COLAB_CHECKED = True
|
||||
return _GOOGLE_COLAB_SECRET
|
||||
|
||||
|
||||
def _get_token_from_environment() -> Optional[str]:
|
||||
# `HF_TOKEN` has priority (keep `HUGGING_FACE_HUB_TOKEN` for backward compatibility)
|
||||
return _clean_token(os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN"))
|
||||
|
||||
|
||||
def _get_token_from_file() -> Optional[str]:
|
||||
try:
|
||||
return _clean_token(Path(constants.HF_TOKEN_PATH).read_text())
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def get_stored_tokens() -> dict[str, str]:
|
||||
"""
|
||||
Returns the parsed INI file containing the access tokens.
|
||||
The file is located at `HF_STORED_TOKENS_PATH`, defaulting to `~/.cache/huggingface/stored_tokens`.
|
||||
If the file does not exist, an empty dictionary is returned.
|
||||
|
||||
Returns: `dict[str, str]`
|
||||
Key is the token name and value is the token.
|
||||
"""
|
||||
tokens_path = Path(constants.HF_STORED_TOKENS_PATH)
|
||||
if not tokens_path.exists():
|
||||
stored_tokens = {}
|
||||
config = configparser.ConfigParser()
|
||||
try:
|
||||
config.read(tokens_path)
|
||||
stored_tokens = {token_name: config.get(token_name, "hf_token") for token_name in config.sections()}
|
||||
except configparser.Error as e:
|
||||
logger.error(f"Error parsing stored tokens file: {e}")
|
||||
stored_tokens = {}
|
||||
return stored_tokens
|
||||
|
||||
|
||||
def _save_stored_tokens(stored_tokens: dict[str, str]) -> None:
|
||||
"""
|
||||
Saves the given configuration to the stored tokens file.
|
||||
|
||||
Args:
|
||||
stored_tokens (`dict[str, str]`):
|
||||
The stored tokens to save. Key is the token name and value is the token.
|
||||
"""
|
||||
stored_tokens_path = Path(constants.HF_STORED_TOKENS_PATH)
|
||||
|
||||
# Write the stored tokens into an INI file
|
||||
config = configparser.ConfigParser()
|
||||
for token_name in sorted(stored_tokens.keys()):
|
||||
config.add_section(token_name)
|
||||
config.set(token_name, "hf_token", stored_tokens[token_name])
|
||||
|
||||
stored_tokens_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with stored_tokens_path.open("w") as config_file:
|
||||
config.write(config_file)
|
||||
|
||||
|
||||
def _get_token_by_name(token_name: str) -> Optional[str]:
|
||||
"""
|
||||
Get the token by name.
|
||||
|
||||
Args:
|
||||
token_name (`str`):
|
||||
The name of the token to get.
|
||||
|
||||
Returns:
|
||||
`str` or `None`: The token, `None` if it doesn't exist.
|
||||
|
||||
"""
|
||||
stored_tokens = get_stored_tokens()
|
||||
if token_name not in stored_tokens:
|
||||
return None
|
||||
return _clean_token(stored_tokens[token_name])
|
||||
|
||||
|
||||
def _save_token(token: str, token_name: str) -> None:
|
||||
"""
|
||||
Save the given token.
|
||||
|
||||
If the stored tokens file does not exist, it will be created.
|
||||
Args:
|
||||
token (`str`):
|
||||
The token to save.
|
||||
token_name (`str`):
|
||||
The name of the token.
|
||||
"""
|
||||
tokens_path = Path(constants.HF_STORED_TOKENS_PATH)
|
||||
stored_tokens = get_stored_tokens()
|
||||
stored_tokens[token_name] = token
|
||||
_save_stored_tokens(stored_tokens)
|
||||
logger.info(f"The token `{token_name}` has been saved to {tokens_path}")
|
||||
|
||||
|
||||
def _clean_token(token: Optional[str]) -> Optional[str]:
|
||||
"""Clean token by removing trailing and leading spaces and newlines.
|
||||
|
||||
If token is an empty string, return None.
|
||||
"""
|
||||
if token is None:
|
||||
return None
|
||||
return token.replace("\r", "").replace("\n", "").strip() or None
|
||||
@@ -0,0 +1,135 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from ..constants import HF_ASSETS_CACHE
|
||||
|
||||
|
||||
def cached_assets_path(
|
||||
library_name: str,
|
||||
namespace: str = "default",
|
||||
subfolder: str = "default",
|
||||
*,
|
||||
assets_dir: Union[str, Path, None] = None,
|
||||
):
|
||||
"""Return a folder path to cache arbitrary files.
|
||||
|
||||
`huggingface_hub` provides a canonical folder path to store assets. This is the
|
||||
recommended way to integrate cache in a downstream library as it will benefit from
|
||||
the builtins tools to scan and delete the cache properly.
|
||||
|
||||
The distinction is made between files cached from the Hub and assets. Files from the
|
||||
Hub are cached in a git-aware manner and entirely managed by `huggingface_hub`. See
|
||||
[related documentation](https://huggingface.co/docs/huggingface_hub/how-to-cache).
|
||||
All other files that a downstream library caches are considered to be "assets"
|
||||
(files downloaded from external sources, extracted from a .tar archive, preprocessed
|
||||
for training,...).
|
||||
|
||||
Once the folder path is generated, it is guaranteed to exist and to be a directory.
|
||||
The path is based on 3 levels of depth: the library name, a namespace and a
|
||||
subfolder. Those 3 levels grants flexibility while allowing `huggingface_hub` to
|
||||
expect folders when scanning/deleting parts of the assets cache. Within a library,
|
||||
it is expected that all namespaces share the same subset of subfolder names but this
|
||||
is not a mandatory rule. The downstream library has then full control on which file
|
||||
structure to adopt within its cache. Namespace and subfolder are optional (would
|
||||
default to a `"default/"` subfolder) but library name is mandatory as we want every
|
||||
downstream library to manage its own cache.
|
||||
|
||||
Expected tree:
|
||||
```text
|
||||
assets/
|
||||
└── datasets/
|
||||
│ ├── SQuAD/
|
||||
│ │ ├── downloaded/
|
||||
│ │ ├── extracted/
|
||||
│ │ └── processed/
|
||||
│ ├── Helsinki-NLP--tatoeba_mt/
|
||||
│ ├── downloaded/
|
||||
│ ├── extracted/
|
||||
│ └── processed/
|
||||
└── transformers/
|
||||
├── default/
|
||||
│ ├── something/
|
||||
├── bert-base-cased/
|
||||
│ ├── default/
|
||||
│ └── training/
|
||||
hub/
|
||||
└── models--julien-c--EsperBERTo-small/
|
||||
├── blobs/
|
||||
│ ├── (...)
|
||||
│ ├── (...)
|
||||
├── refs/
|
||||
│ └── (...)
|
||||
└── [ 128] snapshots/
|
||||
├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/
|
||||
│ ├── (...)
|
||||
└── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/
|
||||
└── (...)
|
||||
```
|
||||
|
||||
|
||||
Args:
|
||||
library_name (`str`):
|
||||
Name of the library that will manage the cache folder. Example: `"dataset"`.
|
||||
namespace (`str`, *optional*, defaults to "default"):
|
||||
Namespace to which the data belongs. Example: `"SQuAD"`.
|
||||
subfolder (`str`, *optional*, defaults to "default"):
|
||||
Subfolder in which the data will be stored. Example: `extracted`.
|
||||
assets_dir (`str`, `Path`, *optional*):
|
||||
Path to the folder where assets are cached. This must not be the same folder
|
||||
where Hub files are cached. Defaults to `HF_HOME / "assets"` if not provided.
|
||||
Can also be set with `HF_ASSETS_CACHE` environment variable.
|
||||
|
||||
Returns:
|
||||
Path to the cache folder (`Path`).
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> from huggingface_hub import cached_assets_path
|
||||
|
||||
>>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="download")
|
||||
PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/download')
|
||||
|
||||
>>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="extracted")
|
||||
PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/extracted')
|
||||
|
||||
>>> cached_assets_path(library_name="datasets", namespace="Helsinki-NLP/tatoeba_mt")
|
||||
PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/Helsinki-NLP--tatoeba_mt/default')
|
||||
|
||||
>>> cached_assets_path(library_name="datasets", assets_dir="/tmp/tmp123456")
|
||||
PosixPath('/tmp/tmp123456/datasets/default/default')
|
||||
```
|
||||
"""
|
||||
# Resolve assets_dir
|
||||
if assets_dir is None:
|
||||
assets_dir = HF_ASSETS_CACHE
|
||||
assets_dir = Path(assets_dir).expanduser().resolve()
|
||||
|
||||
# Avoid names that could create path issues
|
||||
for part in (" ", "/", "\\"):
|
||||
library_name = library_name.replace(part, "--")
|
||||
namespace = namespace.replace(part, "--")
|
||||
subfolder = subfolder.replace(part, "--")
|
||||
|
||||
# Path to subfolder is created
|
||||
path = assets_dir / library_name / namespace / subfolder
|
||||
try:
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
except (FileExistsError, NotADirectoryError):
|
||||
raise ValueError(f"Corrupted assets folder: cannot create directory because of an existing file ({path}).")
|
||||
|
||||
# Return
|
||||
return path
|
||||
@@ -0,0 +1,841 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to manage the HF cache directory."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from huggingface_hub.errors import CacheNotFound, CorruptedCacheException
|
||||
|
||||
from ..constants import HF_HUB_CACHE
|
||||
from . import logging
|
||||
from ._parsing import format_timesince
|
||||
from ._terminal import tabulate
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
REPO_TYPE_T = Literal["model", "dataset", "space"]
|
||||
|
||||
# List of OS-created helper files that need to be ignored
|
||||
FILES_TO_IGNORE = [".DS_Store"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CachedFileInfo:
|
||||
"""Frozen data structure holding information about a single cached file.
|
||||
|
||||
Args:
|
||||
file_name (`str`):
|
||||
Name of the file. Example: `config.json`.
|
||||
file_path (`Path`):
|
||||
Path of the file in the `snapshots` directory. The file path is a symlink
|
||||
referring to a blob in the `blobs` folder.
|
||||
blob_path (`Path`):
|
||||
Path of the blob file. This is equivalent to `file_path.resolve()`.
|
||||
size_on_disk (`int`):
|
||||
Size of the blob file in bytes.
|
||||
blob_last_accessed (`float`):
|
||||
Timestamp of the last time the blob file has been accessed (from any
|
||||
revision).
|
||||
blob_last_modified (`float`):
|
||||
Timestamp of the last time the blob file has been modified/created.
|
||||
|
||||
> [!WARNING]
|
||||
> `blob_last_accessed` and `blob_last_modified` reliability can depend on the OS you
|
||||
> are using. See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result)
|
||||
> for more details.
|
||||
"""
|
||||
|
||||
file_name: str
|
||||
file_path: Path
|
||||
blob_path: Path
|
||||
size_on_disk: int
|
||||
|
||||
blob_last_accessed: float
|
||||
blob_last_modified: float
|
||||
|
||||
@property
|
||||
def blob_last_accessed_str(self) -> str:
|
||||
"""
|
||||
(property) Timestamp of the last time the blob file has been accessed (from any
|
||||
revision), returned as a human-readable string.
|
||||
|
||||
Example: "2 weeks ago".
|
||||
"""
|
||||
return format_timesince(self.blob_last_accessed)
|
||||
|
||||
@property
|
||||
def blob_last_modified_str(self) -> str:
|
||||
"""
|
||||
(property) Timestamp of the last time the blob file has been modified, returned
|
||||
as a human-readable string.
|
||||
|
||||
Example: "2 weeks ago".
|
||||
"""
|
||||
return format_timesince(self.blob_last_modified)
|
||||
|
||||
@property
|
||||
def size_on_disk_str(self) -> str:
|
||||
"""
|
||||
(property) Size of the blob file as a human-readable string.
|
||||
|
||||
Example: "42.2K".
|
||||
"""
|
||||
return _format_size(self.size_on_disk)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CachedRevisionInfo:
|
||||
"""Frozen data structure holding information about a revision.
|
||||
|
||||
A revision correspond to a folder in the `snapshots` folder and is populated with
|
||||
the exact tree structure as the repo on the Hub but contains only symlinks. A
|
||||
revision can be either referenced by 1 or more `refs` or be "detached" (no refs).
|
||||
|
||||
Args:
|
||||
commit_hash (`str`):
|
||||
Hash of the revision (unique).
|
||||
Example: `"9338f7b671827df886678df2bdd7cc7b4f36dffd"`.
|
||||
snapshot_path (`Path`):
|
||||
Path to the revision directory in the `snapshots` folder. It contains the
|
||||
exact tree structure as the repo on the Hub.
|
||||
files: (`frozenset[CachedFileInfo]`):
|
||||
Set of [`~CachedFileInfo`] describing all files contained in the snapshot.
|
||||
refs (`frozenset[str]`):
|
||||
Set of `refs` pointing to this revision. If the revision has no `refs`, it
|
||||
is considered detached.
|
||||
Example: `{"main", "2.4.0"}` or `{"refs/pr/1"}`.
|
||||
size_on_disk (`int`):
|
||||
Sum of the blob file sizes that are symlink-ed by the revision.
|
||||
last_modified (`float`):
|
||||
Timestamp of the last time the revision has been created/modified.
|
||||
|
||||
> [!WARNING]
|
||||
> `last_accessed` cannot be determined correctly on a single revision as blob files
|
||||
> are shared across revisions.
|
||||
|
||||
> [!WARNING]
|
||||
> `size_on_disk` is not necessarily the sum of all file sizes because of possible
|
||||
> duplicated files. Besides, only blobs are taken into account, not the (negligible)
|
||||
> size of folders and symlinks.
|
||||
"""
|
||||
|
||||
commit_hash: str
|
||||
snapshot_path: Path
|
||||
size_on_disk: int
|
||||
files: frozenset[CachedFileInfo]
|
||||
refs: frozenset[str]
|
||||
|
||||
last_modified: float
|
||||
|
||||
@property
|
||||
def last_modified_str(self) -> str:
|
||||
"""
|
||||
(property) Timestamp of the last time the revision has been modified, returned
|
||||
as a human-readable string.
|
||||
|
||||
Example: "2 weeks ago".
|
||||
"""
|
||||
return format_timesince(self.last_modified)
|
||||
|
||||
@property
|
||||
def size_on_disk_str(self) -> str:
|
||||
"""
|
||||
(property) Sum of the blob file sizes as a human-readable string.
|
||||
|
||||
Example: "42.2K".
|
||||
"""
|
||||
return _format_size(self.size_on_disk)
|
||||
|
||||
@property
|
||||
def nb_files(self) -> int:
|
||||
"""
|
||||
(property) Total number of files in the revision.
|
||||
"""
|
||||
return len(self.files)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CachedRepoInfo:
|
||||
"""Frozen data structure holding information about a cached repository.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
Repo id of the repo on the Hub. Example: `"google/fleurs"`.
|
||||
repo_type (`Literal["dataset", "model", "space"]`):
|
||||
Type of the cached repo.
|
||||
repo_path (`Path`):
|
||||
Local path to the cached repo.
|
||||
size_on_disk (`int`):
|
||||
Sum of the blob file sizes in the cached repo.
|
||||
nb_files (`int`):
|
||||
Total number of blob files in the cached repo.
|
||||
revisions (`frozenset[CachedRevisionInfo]`):
|
||||
Set of [`~CachedRevisionInfo`] describing all revisions cached in the repo.
|
||||
last_accessed (`float`):
|
||||
Timestamp of the last time a blob file of the repo has been accessed.
|
||||
last_modified (`float`):
|
||||
Timestamp of the last time a blob file of the repo has been modified/created.
|
||||
|
||||
> [!WARNING]
|
||||
> `size_on_disk` is not necessarily the sum of all revisions sizes because of
|
||||
> duplicated files. Besides, only blobs are taken into account, not the (negligible)
|
||||
> size of folders and symlinks.
|
||||
|
||||
> [!WARNING]
|
||||
> `last_accessed` and `last_modified` reliability can depend on the OS you are using.
|
||||
> See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result)
|
||||
> for more details.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
repo_type: REPO_TYPE_T
|
||||
repo_path: Path
|
||||
size_on_disk: int
|
||||
nb_files: int
|
||||
revisions: frozenset[CachedRevisionInfo]
|
||||
|
||||
last_accessed: float
|
||||
last_modified: float
|
||||
|
||||
@property
|
||||
def last_accessed_str(self) -> str:
|
||||
"""
|
||||
(property) Last time a blob file of the repo has been accessed, returned as a
|
||||
human-readable string.
|
||||
|
||||
Example: "2 weeks ago".
|
||||
"""
|
||||
return format_timesince(self.last_accessed)
|
||||
|
||||
@property
|
||||
def last_modified_str(self) -> str:
|
||||
"""
|
||||
(property) Last time a blob file of the repo has been modified, returned as a
|
||||
human-readable string.
|
||||
|
||||
Example: "2 weeks ago".
|
||||
"""
|
||||
return format_timesince(self.last_modified)
|
||||
|
||||
@property
|
||||
def size_on_disk_str(self) -> str:
|
||||
"""
|
||||
(property) Sum of the blob file sizes as a human-readable string.
|
||||
|
||||
Example: "42.2K".
|
||||
"""
|
||||
return _format_size(self.size_on_disk)
|
||||
|
||||
@property
|
||||
def cache_id(self) -> str:
|
||||
"""Canonical `type/id` identifier used across cache tooling."""
|
||||
return f"{self.repo_type}/{self.repo_id}"
|
||||
|
||||
@property
|
||||
def refs(self) -> dict[str, CachedRevisionInfo]:
|
||||
"""
|
||||
(property) Mapping between `refs` and revision data structures.
|
||||
"""
|
||||
return {ref: revision for revision in self.revisions for ref in revision.refs}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeleteCacheStrategy:
|
||||
"""Frozen data structure holding the strategy to delete cached revisions.
|
||||
|
||||
This object is not meant to be instantiated programmatically but to be returned by
|
||||
[`~utils.HFCacheInfo.delete_revisions`]. See documentation for usage example.
|
||||
|
||||
Args:
|
||||
expected_freed_size (`float`):
|
||||
Expected freed size once strategy is executed.
|
||||
blobs (`frozenset[Path]`):
|
||||
Set of blob file paths to be deleted.
|
||||
refs (`frozenset[Path]`):
|
||||
Set of reference file paths to be deleted.
|
||||
repos (`frozenset[Path]`):
|
||||
Set of entire repo paths to be deleted.
|
||||
snapshots (`frozenset[Path]`):
|
||||
Set of snapshots to be deleted (directory of symlinks).
|
||||
"""
|
||||
|
||||
expected_freed_size: int
|
||||
blobs: frozenset[Path]
|
||||
refs: frozenset[Path]
|
||||
repos: frozenset[Path]
|
||||
snapshots: frozenset[Path]
|
||||
|
||||
@property
|
||||
def expected_freed_size_str(self) -> str:
|
||||
"""
|
||||
(property) Expected size that will be freed as a human-readable string.
|
||||
|
||||
Example: "42.2K".
|
||||
"""
|
||||
return _format_size(self.expected_freed_size)
|
||||
|
||||
def execute(self) -> None:
|
||||
"""Execute the defined strategy.
|
||||
|
||||
> [!WARNING]
|
||||
> If this method is interrupted, the cache might get corrupted. Deletion order is
|
||||
> implemented so that references and symlinks are deleted before the actual blob
|
||||
> files.
|
||||
|
||||
> [!WARNING]
|
||||
> This method is irreversible. If executed, cached files are erased and must be
|
||||
> downloaded again.
|
||||
"""
|
||||
# Deletion order matters. Blobs are deleted in last so that the user can't end
|
||||
# up in a state where a `ref`` refers to a missing snapshot or a snapshot
|
||||
# symlink refers to a deleted blob.
|
||||
|
||||
# Delete entire repos
|
||||
for path in self.repos:
|
||||
_try_delete_path(path, path_type="repo")
|
||||
|
||||
# Delete snapshot directories
|
||||
for path in self.snapshots:
|
||||
_try_delete_path(path, path_type="snapshot")
|
||||
|
||||
# Delete refs files
|
||||
for path in self.refs:
|
||||
_try_delete_path(path, path_type="ref")
|
||||
|
||||
# Delete blob files
|
||||
for path in self.blobs:
|
||||
_try_delete_path(path, path_type="blob")
|
||||
|
||||
logger.info(f"Cache deletion done. Saved {self.expected_freed_size_str}.")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HFCacheInfo:
|
||||
"""Frozen data structure holding information about the entire cache-system.
|
||||
|
||||
This data structure is returned by [`scan_cache_dir`] and is immutable.
|
||||
|
||||
Args:
|
||||
size_on_disk (`int`):
|
||||
Sum of all valid repo sizes in the cache-system.
|
||||
repos (`frozenset[CachedRepoInfo]`):
|
||||
Set of [`~CachedRepoInfo`] describing all valid cached repos found on the
|
||||
cache-system while scanning.
|
||||
warnings (`list[CorruptedCacheException]`):
|
||||
List of [`~CorruptedCacheException`] that occurred while scanning the cache.
|
||||
Those exceptions are captured so that the scan can continue. Corrupted repos
|
||||
are skipped from the scan.
|
||||
|
||||
> [!WARNING]
|
||||
> Here `size_on_disk` is equal to the sum of all repo sizes (only blobs). However if
|
||||
> some cached repos are corrupted, their sizes are not taken into account.
|
||||
"""
|
||||
|
||||
size_on_disk: int
|
||||
repos: frozenset[CachedRepoInfo]
|
||||
warnings: list[CorruptedCacheException]
|
||||
|
||||
@property
|
||||
def size_on_disk_str(self) -> str:
|
||||
"""
|
||||
(property) Sum of all valid repo sizes in the cache-system as a human-readable
|
||||
string.
|
||||
|
||||
Example: "42.2K".
|
||||
"""
|
||||
return _format_size(self.size_on_disk)
|
||||
|
||||
def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy:
|
||||
"""Prepare the strategy to delete one or more revisions cached locally.
|
||||
|
||||
Input revisions can be any revision hash. If a revision hash is not found in the
|
||||
local cache, a warning is thrown but no error is raised. Revisions can be from
|
||||
different cached repos since hashes are unique across repos,
|
||||
|
||||
Examples:
|
||||
```py
|
||||
>>> from huggingface_hub import scan_cache_dir
|
||||
>>> cache_info = scan_cache_dir()
|
||||
>>> delete_strategy = cache_info.delete_revisions(
|
||||
... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa"
|
||||
... )
|
||||
>>> print(f"Will free {delete_strategy.expected_freed_size_str}.")
|
||||
Will free 7.9K.
|
||||
>>> delete_strategy.execute()
|
||||
Cache deletion done. Saved 7.9K.
|
||||
```
|
||||
|
||||
```py
|
||||
>>> from huggingface_hub import scan_cache_dir
|
||||
>>> scan_cache_dir().delete_revisions(
|
||||
... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa",
|
||||
... "e2983b237dccf3ab4937c97fa717319a9ca1a96d",
|
||||
... "6c0e6080953db56375760c0471a8c5f2929baf11",
|
||||
... ).execute()
|
||||
Cache deletion done. Saved 8.6G.
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> `delete_revisions` returns a [`~utils.DeleteCacheStrategy`] object that needs to
|
||||
> be executed. The [`~utils.DeleteCacheStrategy`] is not meant to be modified but
|
||||
> allows having a dry run before actually executing the deletion.
|
||||
"""
|
||||
hashes_to_delete: set[str] = set(revisions)
|
||||
|
||||
repos_with_revisions: dict[CachedRepoInfo, set[CachedRevisionInfo]] = defaultdict(set)
|
||||
|
||||
for repo in self.repos:
|
||||
for revision in repo.revisions:
|
||||
if revision.commit_hash in hashes_to_delete:
|
||||
repos_with_revisions[repo].add(revision)
|
||||
hashes_to_delete.remove(revision.commit_hash)
|
||||
|
||||
if len(hashes_to_delete) > 0:
|
||||
logger.warning(f"Revision(s) not found - cannot delete them: {', '.join(hashes_to_delete)}")
|
||||
|
||||
delete_strategy_blobs: set[Path] = set()
|
||||
delete_strategy_refs: set[Path] = set()
|
||||
delete_strategy_repos: set[Path] = set()
|
||||
delete_strategy_snapshots: set[Path] = set()
|
||||
delete_strategy_expected_freed_size = 0
|
||||
|
||||
for affected_repo, revisions_to_delete in repos_with_revisions.items():
|
||||
other_revisions = affected_repo.revisions - revisions_to_delete
|
||||
|
||||
# If no other revisions, it means all revisions are deleted
|
||||
# -> delete the entire cached repo
|
||||
if len(other_revisions) == 0:
|
||||
delete_strategy_repos.add(affected_repo.repo_path)
|
||||
delete_strategy_expected_freed_size += affected_repo.size_on_disk
|
||||
continue
|
||||
|
||||
# Some revisions of the repo will be deleted but not all. We need to filter
|
||||
# which blob files will not be linked anymore.
|
||||
for revision_to_delete in revisions_to_delete:
|
||||
# Snapshot dir
|
||||
delete_strategy_snapshots.add(revision_to_delete.snapshot_path)
|
||||
|
||||
# Refs dir
|
||||
for ref in revision_to_delete.refs:
|
||||
delete_strategy_refs.add(affected_repo.repo_path / "refs" / ref)
|
||||
|
||||
# Blobs dir
|
||||
for file in revision_to_delete.files:
|
||||
if file.blob_path not in delete_strategy_blobs:
|
||||
is_file_alone = True
|
||||
for revision in other_revisions:
|
||||
for rev_file in revision.files:
|
||||
if file.blob_path == rev_file.blob_path:
|
||||
is_file_alone = False
|
||||
break
|
||||
if not is_file_alone:
|
||||
break
|
||||
|
||||
# Blob file not referenced by remaining revisions -> delete
|
||||
if is_file_alone:
|
||||
delete_strategy_blobs.add(file.blob_path)
|
||||
delete_strategy_expected_freed_size += file.size_on_disk
|
||||
|
||||
# Return the strategy instead of executing it.
|
||||
return DeleteCacheStrategy(
|
||||
blobs=frozenset(delete_strategy_blobs),
|
||||
refs=frozenset(delete_strategy_refs),
|
||||
repos=frozenset(delete_strategy_repos),
|
||||
snapshots=frozenset(delete_strategy_snapshots),
|
||||
expected_freed_size=delete_strategy_expected_freed_size,
|
||||
)
|
||||
|
||||
def export_as_table(self, *, verbosity: int = 0) -> str:
|
||||
"""Generate a table from the [`HFCacheInfo`] object.
|
||||
|
||||
Pass `verbosity=0` to get a table with a single row per repo, with columns
|
||||
"repo_id", "repo_type", "size_on_disk", "nb_files", "last_accessed", "last_modified", "refs", "local_path".
|
||||
|
||||
Pass `verbosity=1` to get a table with a row per repo and revision (thus multiple rows can appear for a single repo), with columns
|
||||
"repo_id", "repo_type", "revision", "size_on_disk", "nb_files", "last_modified", "refs", "local_path".
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> from huggingface_hub.utils import scan_cache_dir
|
||||
|
||||
>>> hf_cache_info = scan_cache_dir()
|
||||
HFCacheInfo(...)
|
||||
|
||||
>>> print(hf_cache_info.export_as_table())
|
||||
REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH
|
||||
--------------------------------------------------- --------- ------------ -------- ------------- ------------- ---- --------------------------------------------------------------------------------------------------
|
||||
roberta-base model 2.7M 5 1 day ago 1 week ago main ~/.cache/huggingface/hub/models--roberta-base
|
||||
suno/bark model 8.8K 1 1 week ago 1 week ago main ~/.cache/huggingface/hub/models--suno--bark
|
||||
t5-base model 893.8M 4 4 days ago 7 months ago main ~/.cache/huggingface/hub/models--t5-base
|
||||
t5-large model 3.0G 4 5 weeks ago 5 months ago main ~/.cache/huggingface/hub/models--t5-large
|
||||
|
||||
>>> print(hf_cache_info.export_as_table(verbosity=1))
|
||||
REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH
|
||||
--------------------------------------------------- --------- ---------------------------------------- ------------ -------- ------------- ---- -----------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
roberta-base model e2da8e2f811d1448a5b465c236feacd80ffbac7b 2.7M 5 1 week ago main ~/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b
|
||||
suno/bark model 70a8a7d34168586dc5d028fa9666aceade177992 8.8K 1 1 week ago main ~/.cache/huggingface/hub/models--suno--bark/snapshots/70a8a7d34168586dc5d028fa9666aceade177992
|
||||
t5-base model a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 893.8M 4 7 months ago main ~/.cache/huggingface/hub/models--t5-base/snapshots/a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1
|
||||
t5-large model 150ebc2c4b72291e770f58e6057481c8d2ed331a 3.0G 4 5 months ago main ~/.cache/huggingface/hub/models--t5-large/snapshots/150ebc2c4b72291e770f58e6057481c8d2ed331a
|
||||
```
|
||||
|
||||
Args:
|
||||
verbosity (`int`, *optional*):
|
||||
The verbosity level. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
`str`: The table as a string.
|
||||
"""
|
||||
if verbosity == 0:
|
||||
return tabulate(
|
||||
rows=[
|
||||
[
|
||||
repo.repo_id,
|
||||
repo.repo_type,
|
||||
"{:>12}".format(repo.size_on_disk_str),
|
||||
repo.nb_files,
|
||||
repo.last_accessed_str,
|
||||
repo.last_modified_str,
|
||||
", ".join(sorted(repo.refs)),
|
||||
str(repo.repo_path),
|
||||
]
|
||||
for repo in sorted(self.repos, key=lambda repo: repo.repo_path)
|
||||
],
|
||||
headers=[
|
||||
"REPO ID",
|
||||
"REPO TYPE",
|
||||
"SIZE ON DISK",
|
||||
"NB FILES",
|
||||
"LAST_ACCESSED",
|
||||
"LAST_MODIFIED",
|
||||
"REFS",
|
||||
"LOCAL PATH",
|
||||
],
|
||||
)
|
||||
else:
|
||||
return tabulate(
|
||||
rows=[
|
||||
[
|
||||
repo.repo_id,
|
||||
repo.repo_type,
|
||||
revision.commit_hash,
|
||||
"{:>12}".format(revision.size_on_disk_str),
|
||||
revision.nb_files,
|
||||
revision.last_modified_str,
|
||||
", ".join(sorted(revision.refs)),
|
||||
str(revision.snapshot_path),
|
||||
]
|
||||
for repo in sorted(self.repos, key=lambda repo: repo.repo_path)
|
||||
for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash)
|
||||
],
|
||||
headers=[
|
||||
"REPO ID",
|
||||
"REPO TYPE",
|
||||
"REVISION",
|
||||
"SIZE ON DISK",
|
||||
"NB FILES",
|
||||
"LAST_MODIFIED",
|
||||
"REFS",
|
||||
"LOCAL PATH",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def scan_cache_dir(cache_dir: Optional[Union[str, Path]] = None) -> HFCacheInfo:
|
||||
"""Scan the entire HF cache-system and return a [`~HFCacheInfo`] structure.
|
||||
|
||||
Use `scan_cache_dir` in order to programmatically scan your cache-system. The cache
|
||||
will be scanned repo by repo. If a repo is corrupted, a [`~CorruptedCacheException`]
|
||||
will be thrown internally but captured and returned in the [`~HFCacheInfo`]
|
||||
structure. Only valid repos get a proper report.
|
||||
|
||||
```py
|
||||
>>> from huggingface_hub import scan_cache_dir
|
||||
|
||||
>>> hf_cache_info = scan_cache_dir()
|
||||
HFCacheInfo(
|
||||
size_on_disk=3398085269,
|
||||
repos=frozenset({
|
||||
CachedRepoInfo(
|
||||
repo_id='t5-small',
|
||||
repo_type='model',
|
||||
repo_path=PosixPath(...),
|
||||
size_on_disk=970726914,
|
||||
nb_files=11,
|
||||
revisions=frozenset({
|
||||
CachedRevisionInfo(
|
||||
commit_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5',
|
||||
size_on_disk=970726339,
|
||||
snapshot_path=PosixPath(...),
|
||||
files=frozenset({
|
||||
CachedFileInfo(
|
||||
file_name='config.json',
|
||||
size_on_disk=1197
|
||||
file_path=PosixPath(...),
|
||||
blob_path=PosixPath(...),
|
||||
),
|
||||
CachedFileInfo(...),
|
||||
...
|
||||
}),
|
||||
),
|
||||
CachedRevisionInfo(...),
|
||||
...
|
||||
}),
|
||||
),
|
||||
CachedRepoInfo(...),
|
||||
...
|
||||
}),
|
||||
warnings=[
|
||||
CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."),
|
||||
CorruptedCacheException(...),
|
||||
...
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
You can also print a detailed report directly from the `hf` command line using:
|
||||
```text
|
||||
> hf cache ls
|
||||
ID SIZE LAST_ACCESSED LAST_MODIFIED REFS
|
||||
--------------------------- -------- ------------- ------------- -----------
|
||||
dataset/nyu-mll/glue 157.4M 2 days ago 2 days ago main script
|
||||
model/LiquidAI/LFM2-VL-1.6B 3.2G 4 days ago 4 days ago main
|
||||
model/microsoft/UserLM-8b 32.1G 4 days ago 4 days ago main
|
||||
|
||||
Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G.
|
||||
Got 1 warning(s) while scanning. Use -vvv to print details.
|
||||
```
|
||||
|
||||
Args:
|
||||
cache_dir (`str` or `Path`, `optional`):
|
||||
Cache directory to cache. Defaults to the default HF cache directory.
|
||||
|
||||
> [!WARNING]
|
||||
> Raises:
|
||||
>
|
||||
> `CacheNotFound`
|
||||
> If the cache directory does not exist.
|
||||
>
|
||||
> [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
> If the cache directory is a file, instead of a directory.
|
||||
|
||||
Returns: a [`~HFCacheInfo`] object.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = HF_HUB_CACHE
|
||||
|
||||
cache_dir = Path(cache_dir).expanduser().resolve()
|
||||
if not cache_dir.exists():
|
||||
raise CacheNotFound(
|
||||
f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.",
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
if cache_dir.is_file():
|
||||
raise ValueError(
|
||||
f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable."
|
||||
)
|
||||
|
||||
repos: set[CachedRepoInfo] = set()
|
||||
warnings: list[CorruptedCacheException] = []
|
||||
for repo_path in cache_dir.iterdir():
|
||||
if repo_path.name == ".locks": # skip './.locks/' folder
|
||||
continue
|
||||
try:
|
||||
repos.add(_scan_cached_repo(repo_path))
|
||||
except CorruptedCacheException as e:
|
||||
warnings.append(e)
|
||||
|
||||
return HFCacheInfo(
|
||||
repos=frozenset(repos),
|
||||
size_on_disk=sum(repo.size_on_disk for repo in repos),
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo:
|
||||
"""Scan a single cache repo and return information about it.
|
||||
|
||||
Any unexpected behavior will raise a [`~CorruptedCacheException`].
|
||||
"""
|
||||
if not repo_path.is_dir():
|
||||
raise CorruptedCacheException(f"Repo path is not a directory: {repo_path}")
|
||||
|
||||
if "--" not in repo_path.name:
|
||||
raise CorruptedCacheException(f"Repo path is not a valid HuggingFace cache directory: {repo_path}")
|
||||
|
||||
repo_type, repo_id = repo_path.name.split("--", maxsplit=1)
|
||||
repo_type = repo_type[:-1] # "models" -> "model"
|
||||
repo_id = repo_id.replace("--", "/") # google/fleurs -> "google/fleurs"
|
||||
|
||||
if repo_type not in {"dataset", "model", "space"}:
|
||||
raise CorruptedCacheException(
|
||||
f"Repo type must be `dataset`, `model` or `space`, found `{repo_type}` ({repo_path})."
|
||||
)
|
||||
|
||||
blob_stats: dict[Path, os.stat_result] = {} # Key is blob_path, value is blob stats
|
||||
|
||||
snapshots_path = repo_path / "snapshots"
|
||||
refs_path = repo_path / "refs"
|
||||
|
||||
if not snapshots_path.exists() or not snapshots_path.is_dir():
|
||||
raise CorruptedCacheException(f"Snapshots dir doesn't exist in cached repo: {snapshots_path}")
|
||||
|
||||
# Scan over `refs` directory
|
||||
|
||||
# key is revision hash, value is set of refs
|
||||
refs_by_hash: dict[str, set[str]] = defaultdict(set)
|
||||
if refs_path.exists():
|
||||
# Example of `refs` directory
|
||||
# ── refs
|
||||
# ├── main
|
||||
# └── refs
|
||||
# └── pr
|
||||
# └── 1
|
||||
if refs_path.is_file():
|
||||
raise CorruptedCacheException(f"Refs directory cannot be a file: {refs_path}")
|
||||
|
||||
for ref_path in refs_path.glob("**/*"):
|
||||
# glob("**/*") iterates over all files and directories -> skip directories
|
||||
if ref_path.is_dir() or ref_path.name in FILES_TO_IGNORE:
|
||||
continue
|
||||
|
||||
ref_name = str(ref_path.relative_to(refs_path))
|
||||
with ref_path.open() as f:
|
||||
commit_hash = f.read()
|
||||
|
||||
refs_by_hash[commit_hash].add(ref_name)
|
||||
|
||||
# Scan snapshots directory
|
||||
cached_revisions: set[CachedRevisionInfo] = set()
|
||||
for revision_path in snapshots_path.iterdir():
|
||||
# Ignore OS-created helper files
|
||||
if revision_path.name in FILES_TO_IGNORE:
|
||||
continue
|
||||
if revision_path.is_file():
|
||||
raise CorruptedCacheException(f"Snapshots folder corrupted. Found a file: {revision_path}")
|
||||
|
||||
cached_files = set()
|
||||
for file_path in revision_path.glob("**/*"):
|
||||
# glob("**/*") iterates over all files and directories -> skip directories
|
||||
if file_path.is_dir():
|
||||
continue
|
||||
|
||||
blob_path = Path(file_path).resolve()
|
||||
if not blob_path.exists():
|
||||
raise CorruptedCacheException(f"Blob missing (broken symlink): {blob_path}")
|
||||
|
||||
if blob_path not in blob_stats:
|
||||
blob_stats[blob_path] = blob_path.stat()
|
||||
|
||||
cached_files.add(
|
||||
CachedFileInfo(
|
||||
file_name=file_path.name,
|
||||
file_path=file_path,
|
||||
size_on_disk=blob_stats[blob_path].st_size,
|
||||
blob_path=blob_path,
|
||||
blob_last_accessed=blob_stats[blob_path].st_atime,
|
||||
blob_last_modified=blob_stats[blob_path].st_mtime,
|
||||
)
|
||||
)
|
||||
|
||||
# Last modified is either the last modified blob file or the revision folder
|
||||
# itself if it is empty
|
||||
if len(cached_files) > 0:
|
||||
revision_last_modified = max(blob_stats[file.blob_path].st_mtime for file in cached_files)
|
||||
else:
|
||||
revision_last_modified = revision_path.stat().st_mtime
|
||||
|
||||
cached_revisions.add(
|
||||
CachedRevisionInfo(
|
||||
commit_hash=revision_path.name,
|
||||
files=frozenset(cached_files),
|
||||
refs=frozenset(refs_by_hash.pop(revision_path.name, set())),
|
||||
size_on_disk=sum(
|
||||
blob_stats[blob_path].st_size for blob_path in set(file.blob_path for file in cached_files)
|
||||
),
|
||||
snapshot_path=revision_path,
|
||||
last_modified=revision_last_modified,
|
||||
)
|
||||
)
|
||||
|
||||
# Check that all refs referred to an existing revision
|
||||
if len(refs_by_hash) > 0:
|
||||
raise CorruptedCacheException(
|
||||
f"Reference(s) refer to missing commit hashes: {dict(refs_by_hash)} ({repo_path})."
|
||||
)
|
||||
|
||||
# Last modified is either the last modified blob file or the repo folder itself if
|
||||
# no blob files has been found. Same for last accessed.
|
||||
if len(blob_stats) > 0:
|
||||
repo_last_accessed = max(stat.st_atime for stat in blob_stats.values())
|
||||
repo_last_modified = max(stat.st_mtime for stat in blob_stats.values())
|
||||
else:
|
||||
repo_stats = repo_path.stat()
|
||||
repo_last_accessed = repo_stats.st_atime
|
||||
repo_last_modified = repo_stats.st_mtime
|
||||
|
||||
# Build and return frozen structure
|
||||
return CachedRepoInfo(
|
||||
nb_files=len(blob_stats),
|
||||
repo_id=repo_id,
|
||||
repo_path=repo_path,
|
||||
repo_type=repo_type, # type: ignore
|
||||
revisions=frozenset(cached_revisions),
|
||||
size_on_disk=sum(stat.st_size for stat in blob_stats.values()),
|
||||
last_accessed=repo_last_accessed,
|
||||
last_modified=repo_last_modified,
|
||||
)
|
||||
|
||||
|
||||
def _format_size(num: int) -> str:
|
||||
"""Format size in bytes into a human-readable string.
|
||||
|
||||
Taken from https://stackoverflow.com/a/1094933
|
||||
"""
|
||||
num_f = float(num)
|
||||
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
|
||||
if abs(num_f) < 1000.0:
|
||||
return f"{num_f:3.1f}{unit}"
|
||||
num_f /= 1000.0
|
||||
return f"{num_f:.1f}Y"
|
||||
|
||||
|
||||
def _try_delete_path(path: Path, path_type: str) -> None:
|
||||
"""Try to delete a local file or folder.
|
||||
|
||||
If the path does not exist, error is logged as a warning and then ignored.
|
||||
|
||||
Args:
|
||||
path (`Path`)
|
||||
Path to delete. Can be a file or a folder.
|
||||
path_type (`str`)
|
||||
What path are we deleting ? Only for logging purposes. Example: "snapshot".
|
||||
"""
|
||||
logger.info(f"Delete {path_type}: {path}")
|
||||
try:
|
||||
if path.is_file():
|
||||
os.remove(path)
|
||||
else:
|
||||
shutil.rmtree(path)
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Couldn't delete {path_type}: file not found ({path})", exc_info=True)
|
||||
except PermissionError:
|
||||
logger.warning(f"Couldn't delete {path_type}: permission denied ({path})", exc_info=True)
|
||||
@@ -0,0 +1,64 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains a utility to iterate by chunks over an iterator."""
|
||||
|
||||
import itertools
|
||||
from typing import Iterable, TypeVar
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def chunk_iterable(iterable: Iterable[T], chunk_size: int) -> Iterable[Iterable[T]]:
|
||||
"""Iterates over an iterator chunk by chunk.
|
||||
|
||||
Taken from https://stackoverflow.com/a/8998040.
|
||||
See also https://github.com/huggingface/huggingface_hub/pull/920#discussion_r938793088.
|
||||
|
||||
Args:
|
||||
iterable (`Iterable`):
|
||||
The iterable on which we want to iterate.
|
||||
chunk_size (`int`):
|
||||
Size of the chunks. Must be a strictly positive integer (e.g. >0).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from huggingface_hub.utils import chunk_iterable
|
||||
|
||||
>>> for items in chunk_iterable(range(17), chunk_size=8):
|
||||
... print(items)
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
# [8, 9, 10, 11, 12, 13, 14, 15]
|
||||
# [16] # smaller last chunk
|
||||
```
|
||||
|
||||
Raises:
|
||||
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
If `chunk_size` <= 0.
|
||||
|
||||
> [!WARNING]
|
||||
> The last chunk can be smaller than `chunk_size`.
|
||||
"""
|
||||
if not isinstance(chunk_size, int) or chunk_size <= 0:
|
||||
raise ValueError("`chunk_size` must be a strictly positive integer (>0).")
|
||||
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
next_item = next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
yield itertools.chain((next_item,), itertools.islice(iterator, chunk_size - 1))
|
||||
@@ -0,0 +1,67 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to handle datetimes in Huggingface Hub."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def parse_datetime(date_string: str) -> datetime:
|
||||
"""
|
||||
Parses a date_string returned from the server to a datetime object.
|
||||
|
||||
This parser is a weak-parser is the sense that it handles only a single format of
|
||||
date_string. It is expected that the server format will never change. The
|
||||
implementation depends only on the standard lib to avoid an external dependency
|
||||
(python-dateutil). See full discussion about this decision on PR:
|
||||
https://github.com/huggingface/huggingface_hub/pull/999.
|
||||
|
||||
Example:
|
||||
```py
|
||||
> parse_datetime('2022-08-19T07:19:38.123Z')
|
||||
datetime.datetime(2022, 8, 19, 7, 19, 38, 123000, tzinfo=timezone.utc)
|
||||
```
|
||||
|
||||
Args:
|
||||
date_string (`str`):
|
||||
A string representing a datetime returned by the Hub server.
|
||||
String is expected to follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern.
|
||||
|
||||
Returns:
|
||||
A python datetime object.
|
||||
|
||||
Raises:
|
||||
:class:`ValueError`:
|
||||
If `date_string` cannot be parsed.
|
||||
"""
|
||||
try:
|
||||
# Normalize the string to always have 6 digits of fractional seconds
|
||||
if date_string.endswith("Z"):
|
||||
# Case 1: No decimal point (e.g., "2024-11-16T00:27:02Z")
|
||||
if "." not in date_string:
|
||||
# No fractional seconds - insert .000000
|
||||
date_string = date_string[:-1] + ".000000Z"
|
||||
# Case 2: Has decimal point (e.g., "2022-08-19T07:19:38.123456789Z")
|
||||
else:
|
||||
# Get the fractional and base parts
|
||||
base, fraction = date_string[:-1].split(".")
|
||||
# fraction[:6] takes first 6 digits and :0<6 pads with zeros if less than 6 digits
|
||||
date_string = f"{base}.{fraction[:6]:0<6}Z"
|
||||
|
||||
return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Cannot parse '{date_string}' as a datetime. Date string is expected to"
|
||||
" follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern."
|
||||
) from e
|
||||
@@ -0,0 +1,136 @@
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from inspect import Parameter, signature
|
||||
from typing import Iterable, Optional
|
||||
|
||||
|
||||
def _deprecate_positional_args(*, version: str):
|
||||
"""Decorator for methods that issues warnings for positional arguments.
|
||||
Using the keyword-only argument syntax in pep 3102, arguments after the
|
||||
* will issue a warning when passed as a positional argument.
|
||||
|
||||
Args:
|
||||
version (`str`):
|
||||
The version when positional arguments will result in error.
|
||||
"""
|
||||
|
||||
def _inner_deprecate_positional_args(f):
|
||||
sig = signature(f)
|
||||
kwonly_args = []
|
||||
all_args = []
|
||||
for name, param in sig.parameters.items():
|
||||
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
all_args.append(name)
|
||||
elif param.kind == Parameter.KEYWORD_ONLY:
|
||||
kwonly_args.append(name)
|
||||
|
||||
@wraps(f)
|
||||
def inner_f(*args, **kwargs):
|
||||
extra_args = len(args) - len(all_args)
|
||||
if extra_args <= 0:
|
||||
return f(*args, **kwargs)
|
||||
# extra_args > 0
|
||||
args_msg = [
|
||||
f"{name}='{arg}'" if isinstance(arg, str) else f"{name}={arg}"
|
||||
for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:])
|
||||
]
|
||||
args_msg = ", ".join(args_msg)
|
||||
warnings.warn(
|
||||
f"Deprecated positional argument(s) used in '{f.__name__}': pass"
|
||||
f" {args_msg} as keyword args. From version {version} passing these"
|
||||
" as positional arguments will result in an error,",
|
||||
FutureWarning,
|
||||
)
|
||||
kwargs.update(zip(sig.parameters, args))
|
||||
return f(**kwargs)
|
||||
|
||||
return inner_f
|
||||
|
||||
return _inner_deprecate_positional_args
|
||||
|
||||
|
||||
def _deprecate_arguments(
|
||||
*,
|
||||
version: str,
|
||||
deprecated_args: Iterable[str],
|
||||
custom_message: Optional[str] = None,
|
||||
):
|
||||
"""Decorator to issue warnings when using deprecated arguments.
|
||||
|
||||
TODO: could be useful to be able to set a custom error message.
|
||||
|
||||
Args:
|
||||
version (`str`):
|
||||
The version when deprecated arguments will result in error.
|
||||
deprecated_args (`list[str]`):
|
||||
List of the arguments to be deprecated.
|
||||
custom_message (`str`, *optional*):
|
||||
Warning message that is raised. If not passed, a default warning message
|
||||
will be created.
|
||||
"""
|
||||
|
||||
def _inner_deprecate_positional_args(f):
|
||||
sig = signature(f)
|
||||
|
||||
@wraps(f)
|
||||
def inner_f(*args, **kwargs):
|
||||
# Check for used deprecated arguments
|
||||
used_deprecated_args = []
|
||||
for _, parameter in zip(args, sig.parameters.values()):
|
||||
if parameter.name in deprecated_args:
|
||||
used_deprecated_args.append(parameter.name)
|
||||
for kwarg_name, kwarg_value in kwargs.items():
|
||||
if (
|
||||
# If argument is deprecated but still used
|
||||
kwarg_name in deprecated_args
|
||||
# And then the value is not the default value
|
||||
and kwarg_value != sig.parameters[kwarg_name].default
|
||||
):
|
||||
used_deprecated_args.append(kwarg_name)
|
||||
|
||||
# Warn and proceed
|
||||
if len(used_deprecated_args) > 0:
|
||||
message = (
|
||||
f"Deprecated argument(s) used in '{f.__name__}':"
|
||||
f" {', '.join(used_deprecated_args)}. Will not be supported from"
|
||||
f" version '{version}'."
|
||||
)
|
||||
if custom_message is not None:
|
||||
message += "\n\n" + custom_message
|
||||
warnings.warn(message, FutureWarning)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return inner_f
|
||||
|
||||
return _inner_deprecate_positional_args
|
||||
|
||||
|
||||
def _deprecate_method(*, version: str, message: Optional[str] = None):
|
||||
"""Decorator to issue warnings when using a deprecated method.
|
||||
|
||||
Args:
|
||||
version (`str`):
|
||||
The version when deprecated arguments will result in error.
|
||||
message (`str`, *optional*):
|
||||
Warning message that is raised. If not passed, a default warning message
|
||||
will be created.
|
||||
"""
|
||||
|
||||
def _inner_deprecate_method(f):
|
||||
name = f.__name__
|
||||
if name == "__init__":
|
||||
name = f.__qualname__.split(".")[0] # class name instead of method name
|
||||
|
||||
@wraps(f)
|
||||
def inner_f(*args, **kwargs):
|
||||
warning_message = (
|
||||
f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'."
|
||||
)
|
||||
if message is not None:
|
||||
warning_message += " " + message
|
||||
warnings.warn(warning_message, FutureWarning)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return inner_f
|
||||
|
||||
return _inner_deprecate_method
|
||||
@@ -0,0 +1,55 @@
|
||||
# AI-generated module (ChatGPT)
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def load_dotenv(dotenv_str: str, environ: Optional[dict[str, str]] = None) -> dict[str, str]:
|
||||
"""
|
||||
Parse a DOTENV-format string and return a dictionary of key-value pairs.
|
||||
Handles quoted values, comments, export keyword, and blank lines.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
line_pattern = re.compile(
|
||||
r"""
|
||||
^\s*
|
||||
(?:export[^\S\n]+)? # optional export
|
||||
([A-Za-z_][A-Za-z0-9_]*) # key
|
||||
[^\S\n]*(=)?[^\S\n]*
|
||||
( # value group
|
||||
(?:
|
||||
'(?:\\'|[^'])*' # single-quoted value
|
||||
| \"(?:\\\"|[^\"])*\" # double-quoted value
|
||||
| [^#\n\r]+? # unquoted value
|
||||
)
|
||||
)?
|
||||
[^\S\n]*(?:\#.*)?$ # optional inline comment
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
for line in dotenv_str.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue # Skip comments and empty lines
|
||||
|
||||
match = line_pattern.match(line)
|
||||
if match:
|
||||
key = match.group(1)
|
||||
val = None
|
||||
if match.group(2): # if there is '='
|
||||
raw_val = match.group(3) or ""
|
||||
val = raw_val.strip()
|
||||
# Remove surrounding quotes if quoted
|
||||
if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")):
|
||||
val = val[1:-1]
|
||||
val = val.replace(r"\n", "\n").replace(r"\t", "\t").replace(r"\"", '"').replace(r"\\", "\\")
|
||||
if raw_val.startswith('"'):
|
||||
val = val.replace(r"\$", "$") # only in double quotes
|
||||
elif environ is not None:
|
||||
# Get it from the current environment
|
||||
val = environ.get(key)
|
||||
|
||||
if val is not None:
|
||||
env[key] = val
|
||||
|
||||
return env
|
||||
@@ -0,0 +1,68 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to flag a feature as "experimental" in Huggingface Hub."""
|
||||
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from .. import constants
|
||||
|
||||
|
||||
def experimental(fn: Callable) -> Callable:
|
||||
"""Decorator to flag a feature as experimental.
|
||||
|
||||
An experimental feature triggers a warning when used as it might be subject to breaking changes without prior notice
|
||||
in the future.
|
||||
|
||||
Warnings can be disabled by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable.
|
||||
|
||||
Args:
|
||||
fn (`Callable`):
|
||||
The function to flag as experimental.
|
||||
|
||||
Returns:
|
||||
`Callable`: The decorated function.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from huggingface_hub.utils import experimental
|
||||
|
||||
>>> @experimental
|
||||
... def my_function():
|
||||
... print("Hello world!")
|
||||
|
||||
>>> my_function()
|
||||
UserWarning: 'my_function' is experimental and might be subject to breaking changes in the future without prior
|
||||
notice. You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable.
|
||||
Hello world!
|
||||
```
|
||||
"""
|
||||
# For classes, put the "experimental" around the "__new__" method => __new__ will be removed in warning message
|
||||
name = fn.__qualname__[: -len(".__new__")] if fn.__qualname__.endswith(".__new__") else fn.__qualname__
|
||||
|
||||
@wraps(fn)
|
||||
def _inner_fn(*args, **kwargs):
|
||||
if not constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING:
|
||||
warnings.warn(
|
||||
f"'{name}' is experimental and might be subject to breaking changes in the future without prior notice."
|
||||
" You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment"
|
||||
" variable.",
|
||||
UserWarning,
|
||||
)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _inner_fn
|
||||
@@ -0,0 +1,123 @@
|
||||
import contextlib
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
import tempfile
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, Optional, Union
|
||||
|
||||
import yaml
|
||||
from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout
|
||||
|
||||
from .. import constants
|
||||
from . import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Wrap `yaml.dump` to set `allow_unicode=True` by default.
|
||||
#
|
||||
# Example:
|
||||
# ```py
|
||||
# >>> yaml.dump({"emoji": "👀", "some unicode": "日本か"})
|
||||
# 'emoji: "\\U0001F440"\nsome unicode: "\\u65E5\\u672C\\u304B"\n'
|
||||
#
|
||||
# >>> yaml_dump({"emoji": "👀", "some unicode": "日本か"})
|
||||
# 'emoji: "👀"\nsome unicode: "日本か"\n'
|
||||
# ```
|
||||
yaml_dump: Callable[..., str] = partial(yaml.dump, stream=None, allow_unicode=True) # type: ignore
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def SoftTemporaryDirectory(
|
||||
suffix: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
dir: Optional[Union[Path, str]] = None,
|
||||
**kwargs,
|
||||
) -> Generator[Path, None, None]:
|
||||
"""
|
||||
Context manager to create a temporary directory and safely delete it.
|
||||
|
||||
If tmp directory cannot be deleted normally, we set the WRITE permission and retry.
|
||||
If cleanup still fails, we give up but don't raise an exception. This is equivalent
|
||||
to `tempfile.TemporaryDirectory(..., ignore_cleanup_errors=True)` introduced in
|
||||
Python 3.10.
|
||||
|
||||
See https://www.scivision.dev/python-tempfile-permission-error-windows/.
|
||||
"""
|
||||
tmpdir = tempfile.TemporaryDirectory(prefix=prefix, suffix=suffix, dir=dir, **kwargs)
|
||||
yield Path(tmpdir.name).resolve()
|
||||
|
||||
try:
|
||||
# First once with normal cleanup
|
||||
shutil.rmtree(tmpdir.name)
|
||||
except Exception:
|
||||
# If failed, try to set write permission and retry
|
||||
try:
|
||||
shutil.rmtree(tmpdir.name, onerror=_set_write_permission_and_retry)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# And finally, cleanup the tmpdir.
|
||||
# If it fails again, give up but do not throw error
|
||||
try:
|
||||
tmpdir.cleanup()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _set_write_permission_and_retry(func, path, excinfo):
|
||||
os.chmod(path, stat.S_IWRITE)
|
||||
func(path)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def WeakFileLock(
|
||||
lock_file: Union[str, Path], *, timeout: Optional[float] = None
|
||||
) -> Generator[BaseFileLock, None, None]:
|
||||
"""A filelock with some custom logic.
|
||||
|
||||
This filelock is weaker than the default filelock in that:
|
||||
1. It won't raise an exception if release fails.
|
||||
2. It will default to a SoftFileLock if the filesystem does not support flock.
|
||||
|
||||
An INFO log message is emitted every 10 seconds if the lock is not acquired immediately.
|
||||
If a timeout is provided, a `filelock.Timeout` exception is raised if the lock is not acquired within the timeout.
|
||||
"""
|
||||
log_interval = constants.FILELOCK_LOG_EVERY_SECONDS
|
||||
lock = FileLock(lock_file, timeout=log_interval)
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
elapsed_time = time.time() - start_time
|
||||
if timeout is not None and elapsed_time >= timeout:
|
||||
raise Timeout(str(lock_file))
|
||||
|
||||
try:
|
||||
lock.acquire(timeout=min(log_interval, timeout - elapsed_time) if timeout else log_interval)
|
||||
except Timeout:
|
||||
logger.info(
|
||||
f"Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)"
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
if "use SoftFileLock instead" in str(e):
|
||||
logger.warning(
|
||||
"FileSystem does not appear to support flock. Falling back to SoftFileLock for %s", lock_file
|
||||
)
|
||||
lock = SoftFileLock(lock_file, timeout=log_interval)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
try:
|
||||
yield lock
|
||||
finally:
|
||||
try:
|
||||
lock.release()
|
||||
except OSError:
|
||||
try:
|
||||
Path(lock_file).unlink()
|
||||
except OSError:
|
||||
pass
|
||||
@@ -0,0 +1,121 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to manage Git credentials."""
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
from ..constants import ENDPOINT
|
||||
from ._subprocess import run_interactive_subprocess, run_subprocess
|
||||
|
||||
|
||||
GIT_CREDENTIAL_REGEX = re.compile(
|
||||
r"""
|
||||
^\s* # start of line
|
||||
credential\.helper # credential.helper value
|
||||
\s*=\s* # separator
|
||||
([\w\-\/]+) # the helper name or absolute path (group 1)
|
||||
(\s|$) # whitespace or end of line
|
||||
""",
|
||||
flags=re.MULTILINE | re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
def list_credential_helpers(folder: Optional[str] = None) -> list[str]:
|
||||
"""Return the list of git credential helpers configured.
|
||||
|
||||
See https://git-scm.com/docs/gitcredentials.
|
||||
|
||||
Credentials are saved in all configured helpers (store, cache, macOS keychain,...).
|
||||
Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential.
|
||||
|
||||
Args:
|
||||
folder (`str`, *optional*):
|
||||
The folder in which to check the configured helpers.
|
||||
"""
|
||||
try:
|
||||
output = run_subprocess("git config --list", folder=folder).stdout
|
||||
parsed = _parse_credential_output(output)
|
||||
return parsed
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise EnvironmentError(exc.stderr)
|
||||
|
||||
|
||||
def set_git_credential(token: str, username: str = "hf_user", folder: Optional[str] = None) -> None:
|
||||
"""Save a username/token pair in git credential for HF Hub registry.
|
||||
|
||||
Credentials are saved in all configured helpers (store, cache, macOS keychain,...).
|
||||
Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential.
|
||||
|
||||
Args:
|
||||
username (`str`, defaults to `"hf_user"`):
|
||||
A git username. Defaults to `"hf_user"`, the default user used in the Hub.
|
||||
token (`str`, defaults to `"hf_user"`):
|
||||
A git password. In practice, the User Access Token for the Hub.
|
||||
See https://huggingface.co/settings/tokens.
|
||||
folder (`str`, *optional*):
|
||||
The folder in which to check the configured helpers.
|
||||
"""
|
||||
with run_interactive_subprocess("git credential approve", folder=folder) as (
|
||||
stdin,
|
||||
_,
|
||||
):
|
||||
stdin.write(f"url={ENDPOINT}\nusername={username.lower()}\npassword={token}\n\n")
|
||||
stdin.flush()
|
||||
|
||||
|
||||
def unset_git_credential(username: str = "hf_user", folder: Optional[str] = None) -> None:
|
||||
"""Erase credentials from git credential for HF Hub registry.
|
||||
|
||||
Credentials are erased from the configured helpers (store, cache, macOS
|
||||
keychain,...), if any. If `username` is not provided, any credential configured for
|
||||
HF Hub endpoint is erased.
|
||||
Calls "`git credential erase`" internally. See https://git-scm.com/docs/git-credential.
|
||||
|
||||
Args:
|
||||
username (`str`, defaults to `"hf_user"`):
|
||||
A git username. Defaults to `"hf_user"`, the default user used in the Hub.
|
||||
folder (`str`, *optional*):
|
||||
The folder in which to check the configured helpers.
|
||||
"""
|
||||
with run_interactive_subprocess("git credential reject", folder=folder) as (
|
||||
stdin,
|
||||
_,
|
||||
):
|
||||
standard_input = f"url={ENDPOINT}\n"
|
||||
if username is not None:
|
||||
standard_input += f"username={username.lower()}\n"
|
||||
standard_input += "\n"
|
||||
|
||||
stdin.write(standard_input)
|
||||
stdin.flush()
|
||||
|
||||
|
||||
def _parse_credential_output(output: str) -> list[str]:
|
||||
"""Parse the output of `git credential fill` to extract the password.
|
||||
|
||||
Args:
|
||||
output (`str`):
|
||||
The output of `git credential fill`.
|
||||
"""
|
||||
# NOTE: If user has set a helper for a custom URL, it will not be caught here.
|
||||
# Example: `credential.https://huggingface.co.helper=store`
|
||||
# See: https://github.com/huggingface/huggingface_hub/pull/1138#discussion_r1013324508
|
||||
return sorted( # Sort for nice printing
|
||||
set( # Might have some duplicates
|
||||
match[0] for match in GIT_CREDENTIAL_REGEX.findall(output)
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,206 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to handle headers to send in calls to Huggingface Hub."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub.errors import LocalTokenNotFoundError
|
||||
|
||||
from .. import constants
|
||||
from ._auth import get_token
|
||||
from ._runtime import (
|
||||
get_hf_hub_version,
|
||||
get_python_version,
|
||||
get_torch_version,
|
||||
is_torch_available,
|
||||
)
|
||||
from ._validators import validate_hf_hub_args
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def build_hf_headers(
|
||||
*,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
library_name: Optional[str] = None,
|
||||
library_version: Optional[str] = None,
|
||||
user_agent: Union[dict, str, None] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Build headers dictionary to send in a HF Hub call.
|
||||
|
||||
By default, authorization token is always provided either from argument (explicit
|
||||
use) or retrieved from the cache (implicit use). To explicitly avoid sending the
|
||||
token to the Hub, set `token=False` or set the `HF_HUB_DISABLE_IMPLICIT_TOKEN`
|
||||
environment variable.
|
||||
|
||||
In case of an API call that requires write access, an error is thrown if token is
|
||||
`None` or token is an organization token (starting with `"api_org***"`).
|
||||
|
||||
In addition to the auth header, a user-agent is added to provide information about
|
||||
the installed packages (versions of python, huggingface_hub, torch).
|
||||
|
||||
Args:
|
||||
token (`str`, `bool`, *optional*):
|
||||
The token to be sent in authorization header for the Hub call:
|
||||
- if a string, it is used as the Hugging Face token
|
||||
- if `True`, the token is read from the machine (cache or env variable)
|
||||
- if `False`, authorization header is not set
|
||||
- if `None`, the token is read from the machine only except if
|
||||
`HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set.
|
||||
library_name (`str`, *optional*):
|
||||
The name of the library that is making the HTTP request. Will be added to
|
||||
the user-agent header.
|
||||
library_version (`str`, *optional*):
|
||||
The version of the library that is making the HTTP request. Will be added
|
||||
to the user-agent header.
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user agent info in the form of a dictionary or a single string. It will
|
||||
be completed with information about the installed packages.
|
||||
headers (`dict`, *optional*):
|
||||
Additional headers to include in the request. Those headers take precedence
|
||||
over the ones generated by this function.
|
||||
|
||||
Returns:
|
||||
A `dict` of headers to pass in your API call.
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> build_hf_headers(token="hf_***") # explicit token
|
||||
{"authorization": "Bearer hf_***", "user-agent": ""}
|
||||
|
||||
>>> build_hf_headers(token=True) # explicitly use cached token
|
||||
{"authorization": "Bearer hf_***",...}
|
||||
|
||||
>>> build_hf_headers(token=False) # explicitly don't use cached token
|
||||
{"user-agent": ...}
|
||||
|
||||
>>> build_hf_headers() # implicit use of the cached token
|
||||
{"authorization": "Bearer hf_***",...}
|
||||
|
||||
# HF_HUB_DISABLE_IMPLICIT_TOKEN=True # to set as env variable
|
||||
>>> build_hf_headers() # token is not sent
|
||||
{"user-agent": ...}
|
||||
|
||||
>>> build_hf_headers(library_name="transformers", library_version="1.2.3")
|
||||
{"authorization": ..., "user-agent": "transformers/1.2.3; hf_hub/0.10.2; python/3.10.4; tensorflow/1.55"}
|
||||
```
|
||||
|
||||
Raises:
|
||||
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
If organization token is passed and "write" access is required.
|
||||
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
If "write" access is required but token is not passed and not saved locally.
|
||||
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
If `token=True` but token is not saved locally.
|
||||
"""
|
||||
# Get auth token to send
|
||||
token_to_send = get_token_to_send(token)
|
||||
|
||||
# Combine headers
|
||||
hf_headers = {
|
||||
"user-agent": _http_user_agent(
|
||||
library_name=library_name,
|
||||
library_version=library_version,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
}
|
||||
if token_to_send is not None:
|
||||
hf_headers["authorization"] = f"Bearer {token_to_send}"
|
||||
if headers is not None:
|
||||
hf_headers.update(headers)
|
||||
return hf_headers
|
||||
|
||||
|
||||
def get_token_to_send(token: Optional[Union[bool, str]]) -> Optional[str]:
|
||||
"""Select the token to send from either `token` or the cache."""
|
||||
# Case token is explicitly provided
|
||||
if isinstance(token, str):
|
||||
return token
|
||||
|
||||
# Case token is explicitly forbidden
|
||||
if token is False:
|
||||
return None
|
||||
|
||||
# Token is not provided: we get it from local cache
|
||||
cached_token = get_token()
|
||||
|
||||
# Case token is explicitly required
|
||||
if token is True:
|
||||
if cached_token is None:
|
||||
raise LocalTokenNotFoundError(
|
||||
"Token is required (`token=True`), but no token found. You"
|
||||
" need to provide a token or be logged in to Hugging Face with"
|
||||
" `hf auth login` or `huggingface_hub.login`. See"
|
||||
" https://huggingface.co/settings/tokens."
|
||||
)
|
||||
return cached_token
|
||||
|
||||
# Case implicit use of the token is forbidden by env variable
|
||||
if constants.HF_HUB_DISABLE_IMPLICIT_TOKEN:
|
||||
return None
|
||||
|
||||
# Otherwise: we use the cached token as the user has not explicitly forbidden it
|
||||
return cached_token
|
||||
|
||||
|
||||
def _http_user_agent(
|
||||
*,
|
||||
library_name: Optional[str] = None,
|
||||
library_version: Optional[str] = None,
|
||||
user_agent: Union[dict, str, None] = None,
|
||||
) -> str:
|
||||
"""Format a user-agent string containing information about the installed packages.
|
||||
|
||||
Args:
|
||||
library_name (`str`, *optional*):
|
||||
The name of the library that is making the HTTP request.
|
||||
library_version (`str`, *optional*):
|
||||
The version of the library that is making the HTTP request.
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user agent info in the form of a dictionary or a single string.
|
||||
|
||||
Returns:
|
||||
The formatted user-agent string.
|
||||
"""
|
||||
if library_name is not None:
|
||||
ua = f"{library_name}/{library_version}"
|
||||
else:
|
||||
ua = "unknown/None"
|
||||
ua += f"; hf_hub/{get_hf_hub_version()}"
|
||||
ua += f"; python/{get_python_version()}"
|
||||
|
||||
if not constants.HF_HUB_DISABLE_TELEMETRY:
|
||||
if is_torch_available():
|
||||
ua += f"; torch/{get_torch_version()}"
|
||||
|
||||
if isinstance(user_agent, dict):
|
||||
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += "; " + user_agent
|
||||
|
||||
# Retrieve user-agent origin headers from environment variable
|
||||
origin = constants.HF_HUB_USER_AGENT_ORIGIN
|
||||
if origin is not None:
|
||||
ua += "; origin/" + origin
|
||||
|
||||
return _deduplicate_user_agent(ua)
|
||||
|
||||
|
||||
def _deduplicate_user_agent(user_agent: str) -> str:
|
||||
"""Deduplicate redundant information in the generated user-agent."""
|
||||
# Split around ";" > Strip whitespaces > Store as dict keys (ensure unicity) > format back as string
|
||||
# Order is implicitly preserved by dictionary structure (see https://stackoverflow.com/a/53657523).
|
||||
return "; ".join({key.strip(): None for key in user_agent.split(";")}.keys())
|
||||
@@ -0,0 +1,796 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to handle HTTP requests in huggingface_hub."""
|
||||
|
||||
import atexit
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from http import HTTPStatus
|
||||
from shlex import quote
|
||||
from typing import Any, Callable, Generator, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||
|
||||
from .. import constants
|
||||
from ..errors import (
|
||||
BadRequestError,
|
||||
DisabledRepoError,
|
||||
GatedRepoError,
|
||||
HfHubHTTPError,
|
||||
RemoteEntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
from . import logging
|
||||
from ._lfs import SliceFileObj
|
||||
from ._typing import HTTP_METHOD_T
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Both headers are used by the Hub to debug failed requests.
|
||||
# `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB.
|
||||
# If `X_AMZN_TRACE_ID` is set, the Hub will use it as well.
|
||||
X_AMZN_TRACE_ID = "X-Amzn-Trace-Id"
|
||||
X_REQUEST_ID = "x-request-id"
|
||||
|
||||
REPO_API_REGEX = re.compile(
|
||||
r"""
|
||||
# staging or production endpoint
|
||||
^https://[^/]+
|
||||
(
|
||||
# on /api/repo_type/repo_id
|
||||
/api/(models|datasets|spaces)/(.+)
|
||||
|
|
||||
# or /repo_id/resolve/revision/...
|
||||
/(.+)/resolve/(.+)
|
||||
)
|
||||
""",
|
||||
flags=re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
def hf_request_event_hook(request: httpx.Request) -> None:
|
||||
"""
|
||||
Event hook that will be used to make HTTP requests to the Hugging Face Hub.
|
||||
|
||||
What it does:
|
||||
- Block requests if offline mode is enabled
|
||||
- Add a request ID to the request headers
|
||||
- Log the request if debug mode is enabled
|
||||
"""
|
||||
if constants.HF_HUB_OFFLINE:
|
||||
raise OfflineModeIsEnabled(
|
||||
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
|
||||
)
|
||||
|
||||
# Add random request ID => easier for server-side debugging
|
||||
if X_AMZN_TRACE_ID not in request.headers:
|
||||
request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
|
||||
request_id = request.headers.get(X_AMZN_TRACE_ID)
|
||||
|
||||
# Debug log
|
||||
logger.debug(
|
||||
"Request %s: %s %s (authenticated: %s)",
|
||||
request_id,
|
||||
request.method,
|
||||
request.url,
|
||||
request.headers.get("authorization") is not None,
|
||||
)
|
||||
if constants.HF_DEBUG:
|
||||
logger.debug("Send: %s", _curlify(request))
|
||||
|
||||
return request_id
|
||||
|
||||
|
||||
async def async_hf_request_event_hook(request: httpx.Request) -> None:
|
||||
"""
|
||||
Async version of `hf_request_event_hook`.
|
||||
"""
|
||||
return hf_request_event_hook(request)
|
||||
|
||||
|
||||
async def async_hf_response_event_hook(response: httpx.Response) -> None:
|
||||
if response.status_code >= 400:
|
||||
# If response will raise, read content from stream to have it available when raising the exception
|
||||
# If content-length is not set or is too large, skip reading the content to avoid OOM
|
||||
if "Content-length" in response.headers:
|
||||
try:
|
||||
length = int(response.headers["Content-length"])
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
if length < 1_000_000:
|
||||
await response.aread()
|
||||
|
||||
|
||||
def default_client_factory() -> httpx.Client:
|
||||
"""
|
||||
Factory function to create a `httpx.Client` with the default transport.
|
||||
"""
|
||||
return httpx.Client(
|
||||
event_hooks={"request": [hf_request_event_hook]},
|
||||
follow_redirects=True,
|
||||
timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0),
|
||||
)
|
||||
|
||||
|
||||
def default_async_client_factory() -> httpx.AsyncClient:
|
||||
"""
|
||||
Factory function to create a `httpx.AsyncClient` with the default transport.
|
||||
"""
|
||||
return httpx.AsyncClient(
|
||||
event_hooks={"request": [async_hf_request_event_hook], "response": [async_hf_response_event_hook]},
|
||||
follow_redirects=True,
|
||||
timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0),
|
||||
)
|
||||
|
||||
|
||||
CLIENT_FACTORY_T = Callable[[], httpx.Client]
|
||||
ASYNC_CLIENT_FACTORY_T = Callable[[], httpx.AsyncClient]
|
||||
|
||||
_CLIENT_LOCK = threading.Lock()
|
||||
_GLOBAL_CLIENT_FACTORY: CLIENT_FACTORY_T = default_client_factory
|
||||
_GLOBAL_ASYNC_CLIENT_FACTORY: ASYNC_CLIENT_FACTORY_T = default_async_client_factory
|
||||
_GLOBAL_CLIENT: Optional[httpx.Client] = None
|
||||
|
||||
|
||||
def set_client_factory(client_factory: CLIENT_FACTORY_T) -> None:
|
||||
"""
|
||||
Set the HTTP client factory to be used by `huggingface_hub`.
|
||||
|
||||
The client factory is a method that returns a `httpx.Client` object. On the first call to [`get_client`] the client factory
|
||||
will be used to create a new `httpx.Client` object that will be shared between all calls made by `huggingface_hub`.
|
||||
|
||||
This can be useful if you are running your scripts in a specific environment requiring custom configuration (e.g. custom proxy or certifications).
|
||||
|
||||
Use [`get_client`] to get a correctly configured `httpx.Client`.
|
||||
"""
|
||||
global _GLOBAL_CLIENT_FACTORY
|
||||
with _CLIENT_LOCK:
|
||||
close_session()
|
||||
_GLOBAL_CLIENT_FACTORY = client_factory
|
||||
|
||||
|
||||
def set_async_client_factory(async_client_factory: ASYNC_CLIENT_FACTORY_T) -> None:
|
||||
"""
|
||||
Set the HTTP async client factory to be used by `huggingface_hub`.
|
||||
|
||||
The async client factory is a method that returns a `httpx.AsyncClient` object.
|
||||
This can be useful if you are running your scripts in a specific environment requiring custom configuration (e.g. custom proxy or certifications).
|
||||
Use [`get_async_client`] to get a correctly configured `httpx.AsyncClient`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Contrary to the `httpx.Client` that is shared between all calls made by `huggingface_hub`, the `httpx.AsyncClient` is not shared.
|
||||
It is recommended to use an async context manager to ensure the client is properly closed when the context is exited.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
global _GLOBAL_ASYNC_CLIENT_FACTORY
|
||||
_GLOBAL_ASYNC_CLIENT_FACTORY = async_client_factory
|
||||
|
||||
|
||||
def get_session() -> httpx.Client:
|
||||
"""
|
||||
Get a `httpx.Client` object, using the transport factory from the user.
|
||||
|
||||
This client is shared between all calls made by `huggingface_hub`. Therefore you should not close it manually.
|
||||
|
||||
Use [`set_client_factory`] to customize the `httpx.Client`.
|
||||
"""
|
||||
global _GLOBAL_CLIENT
|
||||
if _GLOBAL_CLIENT is None:
|
||||
with _CLIENT_LOCK:
|
||||
_GLOBAL_CLIENT = _GLOBAL_CLIENT_FACTORY()
|
||||
return _GLOBAL_CLIENT
|
||||
|
||||
|
||||
def get_async_session() -> httpx.AsyncClient:
|
||||
"""
|
||||
Return a `httpx.AsyncClient` object, using the transport factory from the user.
|
||||
|
||||
Use [`set_async_client_factory`] to customize the `httpx.AsyncClient`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Contrary to the `httpx.Client` that is shared between all calls made by `huggingface_hub`, the `httpx.AsyncClient` is not shared.
|
||||
It is recommended to use an async context manager to ensure the client is properly closed when the context is exited.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
return _GLOBAL_ASYNC_CLIENT_FACTORY()
|
||||
|
||||
|
||||
def close_session() -> None:
|
||||
"""
|
||||
Close the global `httpx.Client` used by `huggingface_hub`.
|
||||
|
||||
If a Client is closed, it will be recreated on the next call to [`get_session`].
|
||||
|
||||
Can be useful if e.g. an SSL certificate has been updated.
|
||||
"""
|
||||
global _GLOBAL_CLIENT
|
||||
client = _GLOBAL_CLIENT
|
||||
|
||||
# First, set global client to None
|
||||
_GLOBAL_CLIENT = None
|
||||
|
||||
# Then, close the clients
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing client: {e}")
|
||||
|
||||
|
||||
atexit.register(close_session)
|
||||
if hasattr(os, "register_at_fork"):
|
||||
os.register_at_fork(after_in_child=close_session)
|
||||
|
||||
|
||||
def _http_backoff_base(
|
||||
method: HTTP_METHOD_T,
|
||||
url: str,
|
||||
*,
|
||||
max_retries: int = 5,
|
||||
base_wait_time: float = 1,
|
||||
max_wait_time: float = 8,
|
||||
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
),
|
||||
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Generator[httpx.Response, None, None]:
|
||||
"""Internal implementation of HTTP backoff logic shared between `http_backoff` and `http_stream_backoff`."""
|
||||
if isinstance(retry_on_exceptions, type): # Tuple from single exception type
|
||||
retry_on_exceptions = (retry_on_exceptions,)
|
||||
|
||||
if isinstance(retry_on_status_codes, int): # Tuple from single status code
|
||||
retry_on_status_codes = (retry_on_status_codes,)
|
||||
|
||||
nb_tries = 0
|
||||
sleep_time = base_wait_time
|
||||
|
||||
# If `data` is used and is a file object (or any IO), it will be consumed on the
|
||||
# first HTTP request. We need to save the initial position so that the full content
|
||||
# of the file is re-sent on http backoff. See warning tip in docstring.
|
||||
io_obj_initial_pos = None
|
||||
if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)):
|
||||
io_obj_initial_pos = kwargs["data"].tell()
|
||||
|
||||
client = get_session()
|
||||
while True:
|
||||
nb_tries += 1
|
||||
try:
|
||||
# If `data` is used and is a file object (or any IO), set back cursor to
|
||||
# initial position.
|
||||
if io_obj_initial_pos is not None:
|
||||
kwargs["data"].seek(io_obj_initial_pos)
|
||||
|
||||
# Perform request and handle response
|
||||
def _should_retry(response: httpx.Response) -> bool:
|
||||
"""Handle response and return True if should retry, False if should return/yield."""
|
||||
if response.status_code not in retry_on_status_codes:
|
||||
return False # Success, don't retry
|
||||
|
||||
# Wrong status code returned (HTTP 503 for instance)
|
||||
logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}")
|
||||
if nb_tries > max_retries:
|
||||
hf_raise_for_status(response) # Will raise uncaught exception
|
||||
# Return/yield response to avoid infinite loop in the corner case where the
|
||||
# user ask for retry on a status code that doesn't raise_for_status.
|
||||
return False # Don't retry, return/yield response
|
||||
|
||||
return True # Should retry
|
||||
|
||||
if stream:
|
||||
with client.stream(method=method, url=url, **kwargs) as response:
|
||||
if not _should_retry(response):
|
||||
yield response
|
||||
return
|
||||
else:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
if not _should_retry(response):
|
||||
yield response
|
||||
return
|
||||
|
||||
except retry_on_exceptions as err:
|
||||
logger.warning(f"'{err}' thrown while requesting {method} {url}")
|
||||
|
||||
if isinstance(err, httpx.ConnectError):
|
||||
close_session() # In case of SSLError it's best to close the shared httpx.Client objects
|
||||
|
||||
if nb_tries > max_retries:
|
||||
raise err
|
||||
|
||||
# Sleep for X seconds
|
||||
logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# Update sleep time for next retry
|
||||
sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff
|
||||
|
||||
|
||||
def http_backoff(
|
||||
method: HTTP_METHOD_T,
|
||||
url: str,
|
||||
*,
|
||||
max_retries: int = 5,
|
||||
base_wait_time: float = 1,
|
||||
max_wait_time: float = 8,
|
||||
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
),
|
||||
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
**kwargs,
|
||||
) -> httpx.Response:
|
||||
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
|
||||
|
||||
Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...)
|
||||
and/or on specific status codes (ex: service unavailable). If the call failed more
|
||||
than `max_retries`, the exception is thrown or `raise_for_status` is called on the
|
||||
response object.
|
||||
|
||||
Re-implement mechanisms from the `backoff` library to avoid adding an external
|
||||
dependencies to `hugging_face_hub`. See https://github.com/litl/backoff.
|
||||
|
||||
Args:
|
||||
method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`):
|
||||
HTTP method to perform.
|
||||
url (`str`):
|
||||
The URL of the resource to fetch.
|
||||
max_retries (`int`, *optional*, defaults to `5`):
|
||||
Maximum number of retries, defaults to 5 (no retries).
|
||||
base_wait_time (`float`, *optional*, defaults to `1`):
|
||||
Duration (in seconds) to wait before retrying the first time.
|
||||
Wait time between retries then grows exponentially, capped by
|
||||
`max_wait_time`.
|
||||
max_wait_time (`float`, *optional*, defaults to `8`):
|
||||
Maximum duration (in seconds) to wait before retrying.
|
||||
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
|
||||
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
|
||||
By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`.
|
||||
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
|
||||
Define on which status codes the request must be retried. By default, only
|
||||
HTTP 503 Service Unavailable is retried.
|
||||
**kwargs (`dict`, *optional*):
|
||||
kwargs to pass to `httpx.request`.
|
||||
|
||||
Example:
|
||||
```
|
||||
>>> from huggingface_hub.utils import http_backoff
|
||||
|
||||
# Same usage as "httpx.request".
|
||||
>>> response = http_backoff("GET", "https://www.google.com")
|
||||
>>> response.raise_for_status()
|
||||
|
||||
# If you expect a Gateway Timeout from time to time
|
||||
>>> http_backoff("PUT", upload_url, data=data, retry_on_status_codes=504)
|
||||
>>> response.raise_for_status()
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> When using `requests` it is possible to stream data by passing an iterator to the
|
||||
> `data` argument. On http backoff this is a problem as the iterator is not reset
|
||||
> after a failed call. This issue is mitigated for file objects or any IO streams
|
||||
> by saving the initial position of the cursor (with `data.tell()`) and resetting the
|
||||
> cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff
|
||||
> will fail. If this is a hard constraint for you, please let us know by opening an
|
||||
> issue on [Github](https://github.com/huggingface/huggingface_hub).
|
||||
"""
|
||||
return next(
|
||||
_http_backoff_base(
|
||||
method=method,
|
||||
url=url,
|
||||
max_retries=max_retries,
|
||||
base_wait_time=base_wait_time,
|
||||
max_wait_time=max_wait_time,
|
||||
retry_on_exceptions=retry_on_exceptions,
|
||||
retry_on_status_codes=retry_on_status_codes,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def http_stream_backoff(
|
||||
method: HTTP_METHOD_T,
|
||||
url: str,
|
||||
*,
|
||||
max_retries: int = 5,
|
||||
base_wait_time: float = 1,
|
||||
max_wait_time: float = 8,
|
||||
retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
),
|
||||
retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
**kwargs,
|
||||
) -> Generator[httpx.Response, None, None]:
|
||||
"""Wrapper around httpx to retry calls on an endpoint, with exponential backoff.
|
||||
|
||||
Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...)
|
||||
and/or on specific status codes (ex: service unavailable). If the call failed more
|
||||
than `max_retries`, the exception is thrown or `raise_for_status` is called on the
|
||||
response object.
|
||||
|
||||
Re-implement mechanisms from the `backoff` library to avoid adding an external
|
||||
dependencies to `hugging_face_hub`. See https://github.com/litl/backoff.
|
||||
|
||||
Args:
|
||||
method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`):
|
||||
HTTP method to perform.
|
||||
url (`str`):
|
||||
The URL of the resource to fetch.
|
||||
max_retries (`int`, *optional*, defaults to `5`):
|
||||
Maximum number of retries, defaults to 5 (no retries).
|
||||
base_wait_time (`float`, *optional*, defaults to `1`):
|
||||
Duration (in seconds) to wait before retrying the first time.
|
||||
Wait time between retries then grows exponentially, capped by
|
||||
`max_wait_time`.
|
||||
max_wait_time (`float`, *optional*, defaults to `8`):
|
||||
Maximum duration (in seconds) to wait before retrying.
|
||||
retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*):
|
||||
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
|
||||
By default, retry on `httpx.Timeout` and `httpx.NetworkError`.
|
||||
retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`):
|
||||
Define on which status codes the request must be retried. By default, only
|
||||
HTTP 503 Service Unavailable is retried.
|
||||
**kwargs (`dict`, *optional*):
|
||||
kwargs to pass to `httpx.request`.
|
||||
|
||||
Example:
|
||||
```
|
||||
>>> from huggingface_hub.utils import http_stream_backoff
|
||||
|
||||
# Same usage as "httpx.stream".
|
||||
>>> with http_stream_backoff("GET", "https://www.google.com") as response:
|
||||
... for chunk in response.iter_bytes():
|
||||
... print(chunk)
|
||||
|
||||
# If you expect a Gateway Timeout from time to time
|
||||
>>> with http_stream_backoff("PUT", upload_url, data=data, retry_on_status_codes=504) as response:
|
||||
... response.raise_for_status()
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
When using `httpx` it is possible to stream data by passing an iterator to the
|
||||
`data` argument. On http backoff this is a problem as the iterator is not reset
|
||||
after a failed call. This issue is mitigated for file objects or any IO streams
|
||||
by saving the initial position of the cursor (with `data.tell()`) and resetting the
|
||||
cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff
|
||||
will fail. If this is a hard constraint for you, please let us know by opening an
|
||||
issue on [Github](https://github.com/huggingface/huggingface_hub).
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
yield from _http_backoff_base(
|
||||
method=method,
|
||||
url=url,
|
||||
max_retries=max_retries,
|
||||
base_wait_time=base_wait_time,
|
||||
max_wait_time=max_wait_time,
|
||||
retry_on_exceptions=retry_on_exceptions,
|
||||
retry_on_status_codes=retry_on_status_codes,
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str:
|
||||
"""Replace the default endpoint in a URL by a custom one.
|
||||
|
||||
This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint.
|
||||
"""
|
||||
endpoint = endpoint.rstrip("/") if endpoint else constants.ENDPOINT
|
||||
# check if a proxy has been set => if yes, update the returned URL to use the proxy
|
||||
if endpoint not in (constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT):
|
||||
url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint)
|
||||
url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint)
|
||||
return url
|
||||
|
||||
|
||||
def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Internal version of `response.raise_for_status()` that will refine a potential HTTPError.
|
||||
Raised exception will be an instance of [`~errors.HfHubHTTPError`].
|
||||
|
||||
This helper is meant to be the unique method to raise_for_status when making a call to the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
response (`Response`):
|
||||
Response from the server.
|
||||
endpoint_name (`str`, *optional*):
|
||||
Name of the endpoint that has been called. If provided, the error message will be more complete.
|
||||
|
||||
> [!WARNING]
|
||||
> Raises when the request has failed:
|
||||
>
|
||||
> - [`~utils.RepositoryNotFoundError`]
|
||||
> If the repository to download from cannot be found. This may be because it
|
||||
> doesn't exist, because `repo_type` is not set correctly, or because the repo
|
||||
> is `private` and you do not have access.
|
||||
> - [`~utils.GatedRepoError`]
|
||||
> If the repository exists but is gated and the user is not on the authorized
|
||||
> list.
|
||||
> - [`~utils.RevisionNotFoundError`]
|
||||
> If the repository exists but the revision couldn't be found.
|
||||
> - [`~utils.EntryNotFoundError`]
|
||||
> If the repository exists but the entry (e.g. the requested file) couldn't be
|
||||
> find.
|
||||
> - [`~utils.BadRequestError`]
|
||||
> If request failed with a HTTP 400 BadRequest error.
|
||||
> - [`~utils.HfHubHTTPError`]
|
||||
> If request failed for a reason not listed above.
|
||||
"""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if response.status_code // 100 == 3:
|
||||
return # Do not raise on redirects to stay consistent with `requests`
|
||||
|
||||
error_code = response.headers.get("X-Error-Code")
|
||||
error_message = response.headers.get("X-Error-Message")
|
||||
|
||||
if error_code == "RevisionNotFound":
|
||||
message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}."
|
||||
raise _format(RevisionNotFoundError, message, response) from e
|
||||
|
||||
elif error_code == "EntryNotFound":
|
||||
message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}."
|
||||
raise _format(RemoteEntryNotFoundError, message, response) from e
|
||||
|
||||
elif error_code == "GatedRepo":
|
||||
message = (
|
||||
f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}."
|
||||
)
|
||||
raise _format(GatedRepoError, message, response) from e
|
||||
|
||||
elif error_message == "Access to this resource is disabled.":
|
||||
message = (
|
||||
f"{response.status_code} Client Error."
|
||||
+ "\n\n"
|
||||
+ f"Cannot access repository for url {response.url}."
|
||||
+ "\n"
|
||||
+ "Access to this resource is disabled."
|
||||
)
|
||||
raise _format(DisabledRepoError, message, response) from e
|
||||
|
||||
elif error_code == "RepoNotFound" or (
|
||||
response.status_code == 401
|
||||
and error_message != "Invalid credentials in Authorization header"
|
||||
and response.request is not None
|
||||
and response.request.url is not None
|
||||
and REPO_API_REGEX.search(str(response.request.url)) is not None
|
||||
):
|
||||
# 401 is misleading as it is returned for:
|
||||
# - private and gated repos if user is not authenticated
|
||||
# - missing repos
|
||||
# => for now, we process them as `RepoNotFound` anyway.
|
||||
# See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
|
||||
message = (
|
||||
f"{response.status_code} Client Error."
|
||||
+ "\n\n"
|
||||
+ f"Repository Not Found for url: {response.url}."
|
||||
+ "\nPlease make sure you specified the correct `repo_id` and"
|
||||
" `repo_type`.\nIf you are trying to access a private or gated repo,"
|
||||
" make sure you are authenticated. For more details, see"
|
||||
" https://huggingface.co/docs/huggingface_hub/authentication"
|
||||
)
|
||||
raise _format(RepositoryNotFoundError, message, response) from e
|
||||
|
||||
elif response.status_code == 400:
|
||||
message = (
|
||||
f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:"
|
||||
)
|
||||
raise _format(BadRequestError, message, response) from e
|
||||
|
||||
elif response.status_code == 403:
|
||||
message = (
|
||||
f"\n\n{response.status_code} Forbidden: {error_message}."
|
||||
+ f"\nCannot access content at: {response.url}."
|
||||
+ "\nMake sure your token has the correct permissions."
|
||||
)
|
||||
raise _format(HfHubHTTPError, message, response) from e
|
||||
|
||||
elif response.status_code == 416:
|
||||
range_header = response.request.headers.get("Range")
|
||||
message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}."
|
||||
raise _format(HfHubHTTPError, message, response) from e
|
||||
|
||||
# Convert `HTTPError` into a `HfHubHTTPError` to display request information
|
||||
# as well (request id and/or server error message)
|
||||
raise _format(HfHubHTTPError, str(e), response) from e
|
||||
|
||||
|
||||
def _format(error_type: type[HfHubHTTPError], custom_message: str, response: httpx.Response) -> HfHubHTTPError:
|
||||
server_errors = []
|
||||
|
||||
# Retrieve server error from header
|
||||
from_headers = response.headers.get("X-Error-Message")
|
||||
if from_headers is not None:
|
||||
server_errors.append(from_headers)
|
||||
|
||||
# Retrieve server error from body
|
||||
try:
|
||||
# Case errors are returned in a JSON format
|
||||
try:
|
||||
data = response.json()
|
||||
except httpx.ResponseNotRead:
|
||||
try:
|
||||
response.read() # In case of streaming response, we need to read the response first
|
||||
data = response.json()
|
||||
except RuntimeError:
|
||||
# In case of async streaming response, we can't read the stream here.
|
||||
# In practice if user is using the default async client from `get_async_client`, the stream will have
|
||||
# already been read in the async event hook `async_hf_response_event_hook`.
|
||||
#
|
||||
# Here, we are skipping reading the response to avoid RuntimeError but it happens only if async + stream + used httpx.AsyncClient directly.
|
||||
data = {}
|
||||
|
||||
error = data.get("error")
|
||||
if error is not None:
|
||||
if isinstance(error, list):
|
||||
# Case {'error': ['my error 1', 'my error 2']}
|
||||
server_errors.extend(error)
|
||||
else:
|
||||
# Case {'error': 'my error'}
|
||||
server_errors.append(error)
|
||||
|
||||
errors = data.get("errors")
|
||||
if errors is not None:
|
||||
# Case {'errors': [{'message': 'my error 1'}, {'message': 'my error 2'}]}
|
||||
for error in errors:
|
||||
if "message" in error:
|
||||
server_errors.append(error["message"])
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# If content is not JSON and not HTML, append the text
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if response.text and "html" not in content_type.lower():
|
||||
server_errors.append(response.text)
|
||||
|
||||
# Strip all server messages
|
||||
server_errors = [str(line).strip() for line in server_errors if str(line).strip()]
|
||||
|
||||
# Deduplicate server messages (keep order)
|
||||
# taken from https://stackoverflow.com/a/17016257
|
||||
server_errors = list(dict.fromkeys(server_errors))
|
||||
|
||||
# Format server error
|
||||
server_message = "\n".join(server_errors)
|
||||
|
||||
# Add server error to custom message
|
||||
final_error_message = custom_message
|
||||
if server_message and server_message.lower() not in custom_message.lower():
|
||||
if "\n\n" in custom_message:
|
||||
final_error_message += "\n" + server_message
|
||||
else:
|
||||
final_error_message += "\n\n" + server_message
|
||||
# Add Request ID
|
||||
request_id = str(response.headers.get(X_REQUEST_ID, ""))
|
||||
if request_id:
|
||||
request_id_message = f" (Request ID: {request_id})"
|
||||
else:
|
||||
# Fallback to X-Amzn-Trace-Id
|
||||
request_id = str(response.headers.get(X_AMZN_TRACE_ID, ""))
|
||||
if request_id:
|
||||
request_id_message = f" (Amzn Trace ID: {request_id})"
|
||||
if request_id and request_id.lower() not in final_error_message.lower():
|
||||
if "\n" in final_error_message:
|
||||
newline_index = final_error_message.index("\n")
|
||||
final_error_message = (
|
||||
final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:]
|
||||
)
|
||||
else:
|
||||
final_error_message += request_id_message
|
||||
|
||||
# Return
|
||||
return error_type(final_error_message.strip(), response=response, server_message=server_message or None)
|
||||
|
||||
|
||||
def _curlify(request: httpx.Request) -> str:
|
||||
"""Convert a `httpx.Request` into a curl command (str).
|
||||
|
||||
Used for debug purposes only.
|
||||
|
||||
Implementation vendored from https://github.com/ofw/curlify/blob/master/curlify.py.
|
||||
MIT License Copyright (c) 2016 Egor.
|
||||
"""
|
||||
parts: list[tuple[Any, Any]] = [
|
||||
("curl", None),
|
||||
("-X", request.method),
|
||||
]
|
||||
|
||||
for k, v in sorted(request.headers.items()):
|
||||
if k.lower() == "authorization":
|
||||
v = "<TOKEN>" # Hide authorization header, no matter its value (can be Bearer, Key, etc.)
|
||||
parts += [("-H", f"{k}: {v}")]
|
||||
|
||||
body: Optional[str] = None
|
||||
if request.content is not None:
|
||||
body = request.content.decode("utf-8", errors="ignore")
|
||||
if len(body) > 1000:
|
||||
body = f"{body[:1000]} ... [truncated]"
|
||||
elif request.stream is not None:
|
||||
body = "<streaming body>"
|
||||
if body is not None:
|
||||
parts += [("-d", body.replace("\n", ""))]
|
||||
|
||||
parts += [(None, request.url)]
|
||||
|
||||
flat_parts = []
|
||||
for k, v in parts:
|
||||
if k:
|
||||
flat_parts.append(quote(str(k)))
|
||||
if v:
|
||||
flat_parts.append(quote(str(v)))
|
||||
|
||||
return " ".join(flat_parts)
|
||||
|
||||
|
||||
# Regex to parse HTTP Range header
|
||||
RANGE_REGEX = re.compile(r"^\s*bytes\s*=\s*(\d*)\s*-\s*(\d*)\s*$", re.IGNORECASE)
|
||||
|
||||
|
||||
def _adjust_range_header(original_range: Optional[str], resume_size: int) -> Optional[str]:
|
||||
"""
|
||||
Adjust HTTP Range header to account for resume position.
|
||||
"""
|
||||
if not original_range:
|
||||
return f"bytes={resume_size}-"
|
||||
|
||||
if "," in original_range:
|
||||
raise ValueError(f"Multiple ranges detected - {original_range!r}, not supported yet.")
|
||||
|
||||
match = RANGE_REGEX.match(original_range)
|
||||
if not match:
|
||||
raise RuntimeError(f"Invalid range format - {original_range!r}.")
|
||||
start, end = match.groups()
|
||||
|
||||
if not start:
|
||||
if not end:
|
||||
raise RuntimeError(f"Invalid range format - {original_range!r}.")
|
||||
|
||||
new_suffix = int(end) - resume_size
|
||||
new_range = f"bytes=-{new_suffix}"
|
||||
if new_suffix <= 0:
|
||||
raise RuntimeError(f"Empty new range - {new_range!r}.")
|
||||
return new_range
|
||||
|
||||
start = int(start)
|
||||
new_start = start + resume_size
|
||||
if end:
|
||||
end = int(end)
|
||||
new_range = f"bytes={new_start}-{end}"
|
||||
if new_start > end:
|
||||
raise RuntimeError(f"Empty new range - {new_range!r}.")
|
||||
return new_range
|
||||
|
||||
return f"bytes={new_start}-"
|
||||
@@ -0,0 +1,110 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Git LFS related utilities"""
|
||||
|
||||
import io
|
||||
import os
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import BinaryIO
|
||||
|
||||
|
||||
class SliceFileObj(AbstractContextManager):
|
||||
"""
|
||||
Utility context manager to read a *slice* of a seekable file-like object as a seekable, file-like object.
|
||||
|
||||
This is NOT thread safe
|
||||
|
||||
Inspired by stackoverflow.com/a/29838711/593036
|
||||
|
||||
Credits to @julien-c
|
||||
|
||||
Args:
|
||||
fileobj (`BinaryIO`):
|
||||
A file-like object to slice. MUST implement `tell()` and `seek()` (and `read()` of course).
|
||||
`fileobj` will be reset to its original position when exiting the context manager.
|
||||
seek_from (`int`):
|
||||
The start of the slice (offset from position 0 in bytes).
|
||||
read_limit (`int`):
|
||||
The maximum number of bytes to read from the slice.
|
||||
|
||||
Attributes:
|
||||
previous_position (`int`):
|
||||
The previous position
|
||||
|
||||
Examples:
|
||||
|
||||
Reading 200 bytes with an offset of 128 bytes from a file (ie bytes 128 to 327):
|
||||
```python
|
||||
>>> with open("path/to/file", "rb") as file:
|
||||
... with SliceFileObj(file, seek_from=128, read_limit=200) as fslice:
|
||||
... fslice.read(...)
|
||||
```
|
||||
|
||||
Reading a file in chunks of 512 bytes
|
||||
```python
|
||||
>>> import os
|
||||
>>> chunk_size = 512
|
||||
>>> file_size = os.getsize("path/to/file")
|
||||
>>> with open("path/to/file", "rb") as file:
|
||||
... for chunk_idx in range(ceil(file_size / chunk_size)):
|
||||
... with SliceFileObj(file, seek_from=chunk_idx * chunk_size, read_limit=chunk_size) as fslice:
|
||||
... chunk = fslice.read(...)
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, fileobj: BinaryIO, seek_from: int, read_limit: int):
|
||||
self.fileobj = fileobj
|
||||
self.seek_from = seek_from
|
||||
self.read_limit = read_limit
|
||||
|
||||
def __enter__(self):
|
||||
self._previous_position = self.fileobj.tell()
|
||||
end_of_stream = self.fileobj.seek(0, os.SEEK_END)
|
||||
self._len = min(self.read_limit, end_of_stream - self.seek_from)
|
||||
# ^^ The actual number of bytes that can be read from the slice
|
||||
self.fileobj.seek(self.seek_from, io.SEEK_SET)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.fileobj.seek(self._previous_position, io.SEEK_SET)
|
||||
|
||||
def read(self, n: int = -1):
|
||||
pos = self.tell()
|
||||
if pos >= self._len:
|
||||
return b""
|
||||
remaining_amount = self._len - pos
|
||||
data = self.fileobj.read(remaining_amount if n < 0 else min(n, remaining_amount))
|
||||
return data
|
||||
|
||||
def tell(self) -> int:
|
||||
return self.fileobj.tell() - self.seek_from
|
||||
|
||||
def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
|
||||
start = self.seek_from
|
||||
end = start + self._len
|
||||
if whence in (os.SEEK_SET, os.SEEK_END):
|
||||
offset = start + offset if whence == os.SEEK_SET else end + offset
|
||||
offset = max(start, min(offset, end))
|
||||
whence = os.SEEK_SET
|
||||
elif whence == os.SEEK_CUR:
|
||||
cur_pos = self.fileobj.tell()
|
||||
offset = max(start - cur_pos, min(offset, end - cur_pos))
|
||||
else:
|
||||
raise ValueError(f"whence value {whence} is not supported")
|
||||
return self.fileobj.seek(offset, whence) - self.seek_from
|
||||
|
||||
def __iter__(self):
|
||||
yield self.read(n=4 * 1024 * 1024)
|
||||
@@ -0,0 +1,52 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to handle pagination on Huggingface Hub."""
|
||||
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from . import get_session, hf_raise_for_status, http_backoff, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def paginate(path: str, params: dict, headers: dict) -> Iterable:
|
||||
"""Fetch a list of models/datasets/spaces and paginate through results.
|
||||
|
||||
This is using the same "Link" header format as GitHub.
|
||||
See:
|
||||
- https://requests.readthedocs.io/en/latest/api/#requests.Response.links
|
||||
- https://docs.github.com/en/rest/guides/traversing-with-pagination#link-header
|
||||
"""
|
||||
session = get_session()
|
||||
r = session.get(path, params=params, headers=headers)
|
||||
hf_raise_for_status(r)
|
||||
yield from r.json()
|
||||
|
||||
# Follow pages
|
||||
# Next link already contains query params
|
||||
next_page = _get_next_page(r)
|
||||
while next_page is not None:
|
||||
logger.debug(f"Pagination detected. Requesting next page: {next_page}")
|
||||
r = http_backoff("GET", next_page, max_retries=20, retry_on_status_codes=429, headers=headers)
|
||||
hf_raise_for_status(r)
|
||||
yield from r.json()
|
||||
next_page = _get_next_page(r)
|
||||
|
||||
|
||||
def _get_next_page(response: httpx.Response) -> Optional[str]:
|
||||
return response.links.get("next", {}).get("url")
|
||||
@@ -0,0 +1,98 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Parsing helpers shared across modules."""
|
||||
|
||||
import re
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
|
||||
RE_NUMBER_WITH_UNIT = re.compile(r"(\d+)([a-z]+)", re.IGNORECASE)
|
||||
|
||||
BYTE_UNITS: Dict[str, int] = {
|
||||
"k": 1_000,
|
||||
"m": 1_000_000,
|
||||
"g": 1_000_000_000,
|
||||
"t": 1_000_000_000_000,
|
||||
"p": 1_000_000_000_000_000,
|
||||
}
|
||||
|
||||
TIME_UNITS: Dict[str, int] = {
|
||||
"s": 1,
|
||||
"m": 60,
|
||||
"h": 60 * 60,
|
||||
"d": 24 * 60 * 60,
|
||||
"w": 7 * 24 * 60 * 60,
|
||||
"mo": 30 * 24 * 60 * 60,
|
||||
"y": 365 * 24 * 60 * 60,
|
||||
}
|
||||
|
||||
|
||||
def parse_size(value: str) -> int:
|
||||
"""Parse a size expressed as a string with digits and unit (like `"10MB"`) to an integer (in bytes)."""
|
||||
return _parse_with_unit(value, BYTE_UNITS)
|
||||
|
||||
|
||||
def parse_duration(value: str) -> int:
|
||||
"""Parse a duration expressed as a string with digits and unit (like `"10s"`) to an integer (in seconds)."""
|
||||
return _parse_with_unit(value, TIME_UNITS)
|
||||
|
||||
|
||||
def _parse_with_unit(value: str, units: Dict[str, int]) -> int:
|
||||
"""Parse a numeric value with optional unit."""
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Value cannot be empty.")
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
match = RE_NUMBER_WITH_UNIT.fullmatch(stripped)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid value '{value}'. Must match pattern '\\d+[a-z]+' or be a plain number.")
|
||||
|
||||
number = int(match.group(1))
|
||||
unit = match.group(2).lower()
|
||||
|
||||
if unit not in units:
|
||||
raise ValueError(f"Unknown unit '{unit}'. Must be one of {list(units.keys())}.")
|
||||
|
||||
return number * units[unit]
|
||||
|
||||
|
||||
def format_timesince(ts: float) -> str:
|
||||
"""Format timestamp in seconds into a human-readable string, relative to now.
|
||||
|
||||
Vaguely inspired by Django's `timesince` formatter.
|
||||
"""
|
||||
_TIMESINCE_CHUNKS = (
|
||||
# Label, divider, max value
|
||||
("second", 1, 60),
|
||||
("minute", 60, 60),
|
||||
("hour", 60 * 60, 24),
|
||||
("day", 60 * 60 * 24, 6),
|
||||
("week", 60 * 60 * 24 * 7, 6),
|
||||
("month", 60 * 60 * 24 * 30, 11),
|
||||
("year", 60 * 60 * 24 * 365, None),
|
||||
)
|
||||
delta = time.time() - ts
|
||||
if delta < 20:
|
||||
return "a few seconds ago"
|
||||
for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007
|
||||
value = round(delta / divider)
|
||||
if max_value is not None and value <= max_value:
|
||||
break
|
||||
return f"{value} {label}{'s' if value > 1 else ''} ago"
|
||||
@@ -0,0 +1,141 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to handle paths in Huggingface Hub."""
|
||||
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, Iterable, Optional, TypeVar, Union
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Always ignore `.git` and `.cache/huggingface` folders in commits
|
||||
DEFAULT_IGNORE_PATTERNS = [
|
||||
".git",
|
||||
".git/*",
|
||||
"*/.git",
|
||||
"**/.git/**",
|
||||
".cache/huggingface",
|
||||
".cache/huggingface/*",
|
||||
"*/.cache/huggingface",
|
||||
"**/.cache/huggingface/**",
|
||||
]
|
||||
# Forbidden to commit these folders
|
||||
FORBIDDEN_FOLDERS = [".git", ".cache"]
|
||||
|
||||
|
||||
def filter_repo_objects(
|
||||
items: Iterable[T],
|
||||
*,
|
||||
allow_patterns: Optional[Union[list[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None,
|
||||
key: Optional[Callable[[T], str]] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""Filter repo objects based on an allowlist and a denylist.
|
||||
|
||||
Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects.
|
||||
In the later case, `key` must be provided and specifies a function of one argument
|
||||
that is used to extract a path from each element in iterable.
|
||||
|
||||
Patterns are Unix shell-style wildcards which are NOT regular expressions. See
|
||||
https://docs.python.org/3/library/fnmatch.html for more details.
|
||||
|
||||
Args:
|
||||
items (`Iterable`):
|
||||
List of items to filter.
|
||||
allow_patterns (`str` or `list[str]`, *optional*):
|
||||
Patterns constituting the allowlist. If provided, item paths must match at
|
||||
least one pattern from the allowlist.
|
||||
ignore_patterns (`str` or `list[str]`, *optional*):
|
||||
Patterns constituting the denylist. If provided, item paths must not match
|
||||
any patterns from the denylist.
|
||||
key (`Callable[[T], str]`, *optional*):
|
||||
Single-argument function to extract a path from each item. If not provided,
|
||||
the `items` must already be `str` or `Path`.
|
||||
|
||||
Returns:
|
||||
Filtered list of objects, as a generator.
|
||||
|
||||
Raises:
|
||||
:class:`ValueError`:
|
||||
If `key` is not provided and items are not `str` or `Path`.
|
||||
|
||||
Example usage with paths:
|
||||
```python
|
||||
>>> # Filter only PDFs that are not hidden.
|
||||
>>> list(filter_repo_objects(
|
||||
... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"],
|
||||
... allow_patterns=["*.pdf"],
|
||||
... ignore_patterns=[".*"],
|
||||
... ))
|
||||
["aaa.pdf"]
|
||||
```
|
||||
|
||||
Example usage with objects:
|
||||
```python
|
||||
>>> list(filter_repo_objects(
|
||||
... [
|
||||
... CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf")
|
||||
... CommitOperationAdd(path_or_fileobj="/tmp/bbb.jpg", path_in_repo="bbb.jpg")
|
||||
... CommitOperationAdd(path_or_fileobj="/tmp/.ccc.pdf", path_in_repo=".ccc.pdf")
|
||||
... CommitOperationAdd(path_or_fileobj="/tmp/.ddd.png", path_in_repo=".ddd.png")
|
||||
... ],
|
||||
... allow_patterns=["*.pdf"],
|
||||
... ignore_patterns=[".*"],
|
||||
... key=lambda x: x.repo_in_path
|
||||
... ))
|
||||
[CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf")]
|
||||
```
|
||||
"""
|
||||
if isinstance(allow_patterns, str):
|
||||
allow_patterns = [allow_patterns]
|
||||
|
||||
if isinstance(ignore_patterns, str):
|
||||
ignore_patterns = [ignore_patterns]
|
||||
|
||||
if allow_patterns is not None:
|
||||
allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
|
||||
if ignore_patterns is not None:
|
||||
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
|
||||
|
||||
if key is None:
|
||||
|
||||
def _identity(item: T) -> str:
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, Path):
|
||||
return str(item)
|
||||
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
|
||||
|
||||
key = _identity # Items must be `str` or `Path`, otherwise raise ValueError
|
||||
|
||||
for item in items:
|
||||
path = key(item)
|
||||
|
||||
# Skip if there's an allowlist and path doesn't match any
|
||||
if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
|
||||
continue
|
||||
|
||||
# Skip if there's a denylist and path matches any
|
||||
if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
|
||||
def _add_wildcard_to_directories(pattern: str) -> str:
|
||||
if pattern[-1] == "/":
|
||||
return pattern + "*"
|
||||
return pattern
|
||||
@@ -0,0 +1,431 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Check presence of installed packages at runtime."""
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from .. import __version__, constants
|
||||
|
||||
|
||||
_PY_VERSION: str = sys.version.split()[0].rstrip("+")
|
||||
|
||||
_package_versions = {}
|
||||
|
||||
_CANDIDATES = {
|
||||
"aiohttp": {"aiohttp"},
|
||||
"fastai": {"fastai"},
|
||||
"fastapi": {"fastapi"},
|
||||
"fastcore": {"fastcore"},
|
||||
"gradio": {"gradio"},
|
||||
"graphviz": {"graphviz"},
|
||||
"hf_xet": {"hf_xet"},
|
||||
"jinja": {"Jinja2"},
|
||||
"httpx": {"httpx"},
|
||||
"keras": {"keras"},
|
||||
"numpy": {"numpy"},
|
||||
"pillow": {"Pillow"},
|
||||
"pydantic": {"pydantic"},
|
||||
"pydot": {"pydot"},
|
||||
"safetensors": {"safetensors"},
|
||||
"tensorboard": {"tensorboardX"},
|
||||
"tensorflow": (
|
||||
"tensorflow",
|
||||
"tensorflow-cpu",
|
||||
"tensorflow-gpu",
|
||||
"tf-nightly",
|
||||
"tf-nightly-cpu",
|
||||
"tf-nightly-gpu",
|
||||
"intel-tensorflow",
|
||||
"intel-tensorflow-avx512",
|
||||
"tensorflow-rocm",
|
||||
"tensorflow-macos",
|
||||
),
|
||||
"torch": {"torch"},
|
||||
}
|
||||
|
||||
# Check once at runtime
|
||||
for candidate_name, package_names in _CANDIDATES.items():
|
||||
_package_versions[candidate_name] = "N/A"
|
||||
for name in package_names:
|
||||
try:
|
||||
_package_versions[candidate_name] = importlib.metadata.version(name)
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def _get_version(package_name: str) -> str:
|
||||
return _package_versions.get(package_name, "N/A")
|
||||
|
||||
|
||||
def is_package_available(package_name: str) -> bool:
|
||||
return _get_version(package_name) != "N/A"
|
||||
|
||||
|
||||
# Python
|
||||
def get_python_version() -> str:
|
||||
return _PY_VERSION
|
||||
|
||||
|
||||
# Huggingface Hub
|
||||
def get_hf_hub_version() -> str:
|
||||
return __version__
|
||||
|
||||
|
||||
# aiohttp
|
||||
def is_aiohttp_available() -> bool:
|
||||
return is_package_available("aiohttp")
|
||||
|
||||
|
||||
def get_aiohttp_version() -> str:
|
||||
return _get_version("aiohttp")
|
||||
|
||||
|
||||
# FastAI
|
||||
def is_fastai_available() -> bool:
|
||||
return is_package_available("fastai")
|
||||
|
||||
|
||||
def get_fastai_version() -> str:
|
||||
return _get_version("fastai")
|
||||
|
||||
|
||||
# FastAPI
|
||||
def is_fastapi_available() -> bool:
|
||||
return is_package_available("fastapi")
|
||||
|
||||
|
||||
def get_fastapi_version() -> str:
|
||||
return _get_version("fastapi")
|
||||
|
||||
|
||||
# Fastcore
|
||||
def is_fastcore_available() -> bool:
|
||||
return is_package_available("fastcore")
|
||||
|
||||
|
||||
def get_fastcore_version() -> str:
|
||||
return _get_version("fastcore")
|
||||
|
||||
|
||||
# FastAI
|
||||
def is_gradio_available() -> bool:
|
||||
return is_package_available("gradio")
|
||||
|
||||
|
||||
def get_gradio_version() -> str:
|
||||
return _get_version("gradio")
|
||||
|
||||
|
||||
# Graphviz
|
||||
def is_graphviz_available() -> bool:
|
||||
return is_package_available("graphviz")
|
||||
|
||||
|
||||
def get_graphviz_version() -> str:
|
||||
return _get_version("graphviz")
|
||||
|
||||
|
||||
# httpx
|
||||
def is_httpx_available() -> bool:
|
||||
return is_package_available("httpx")
|
||||
|
||||
|
||||
def get_httpx_version() -> str:
|
||||
return _get_version("httpx")
|
||||
|
||||
|
||||
# xet
|
||||
def is_xet_available() -> bool:
|
||||
# since hf_xet is automatically used if available, allow explicit disabling via environment variable
|
||||
if constants.HF_HUB_DISABLE_XET:
|
||||
return False
|
||||
|
||||
return is_package_available("hf_xet")
|
||||
|
||||
|
||||
def get_xet_version() -> str:
|
||||
return _get_version("hf_xet")
|
||||
|
||||
|
||||
# keras
|
||||
def is_keras_available() -> bool:
|
||||
return is_package_available("keras")
|
||||
|
||||
|
||||
def get_keras_version() -> str:
|
||||
return _get_version("keras")
|
||||
|
||||
|
||||
# Numpy
|
||||
def is_numpy_available() -> bool:
|
||||
return is_package_available("numpy")
|
||||
|
||||
|
||||
def get_numpy_version() -> str:
|
||||
return _get_version("numpy")
|
||||
|
||||
|
||||
# Jinja
|
||||
def is_jinja_available() -> bool:
|
||||
return is_package_available("jinja")
|
||||
|
||||
|
||||
def get_jinja_version() -> str:
|
||||
return _get_version("jinja")
|
||||
|
||||
|
||||
# Pillow
|
||||
def is_pillow_available() -> bool:
|
||||
return is_package_available("pillow")
|
||||
|
||||
|
||||
def get_pillow_version() -> str:
|
||||
return _get_version("pillow")
|
||||
|
||||
|
||||
# Pydantic
|
||||
def is_pydantic_available() -> bool:
|
||||
if not is_package_available("pydantic"):
|
||||
return False
|
||||
# For Pydantic, we add an extra check to test whether it is correctly installed or not. If both pydantic 2.x and
|
||||
# typing_extensions<=4.5.0 are installed, then pydantic will fail at import time. This should not happen when
|
||||
# it is installed with `pip install huggingface_hub[inference]` but it can happen when it is installed manually
|
||||
# by the user in an environment that we don't control.
|
||||
#
|
||||
# Usually we won't need to do this kind of check on optional dependencies. However, pydantic is a special case
|
||||
# as it is automatically imported when doing `from huggingface_hub import ...` even if the user doesn't use it.
|
||||
#
|
||||
# See https://github.com/huggingface/huggingface_hub/pull/1829 for more details.
|
||||
try:
|
||||
from pydantic import validator # noqa: F401
|
||||
except ImportError:
|
||||
# Example: "ImportError: cannot import name 'TypeAliasType' from 'typing_extensions'"
|
||||
warnings.warn(
|
||||
"Pydantic is installed but cannot be imported. Please check your installation. `huggingface_hub` will "
|
||||
"default to not using Pydantic. Error message: '{e}'"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_pydantic_version() -> str:
|
||||
return _get_version("pydantic")
|
||||
|
||||
|
||||
# Pydot
|
||||
def is_pydot_available() -> bool:
|
||||
return is_package_available("pydot")
|
||||
|
||||
|
||||
def get_pydot_version() -> str:
|
||||
return _get_version("pydot")
|
||||
|
||||
|
||||
# Tensorboard
|
||||
def is_tensorboard_available() -> bool:
|
||||
return is_package_available("tensorboard")
|
||||
|
||||
|
||||
def get_tensorboard_version() -> str:
|
||||
return _get_version("tensorboard")
|
||||
|
||||
|
||||
# Tensorflow
|
||||
def is_tf_available() -> bool:
|
||||
return is_package_available("tensorflow")
|
||||
|
||||
|
||||
def get_tf_version() -> str:
|
||||
return _get_version("tensorflow")
|
||||
|
||||
|
||||
# Torch
|
||||
def is_torch_available() -> bool:
|
||||
return is_package_available("torch")
|
||||
|
||||
|
||||
def get_torch_version() -> str:
|
||||
return _get_version("torch")
|
||||
|
||||
|
||||
# Safetensors
|
||||
def is_safetensors_available() -> bool:
|
||||
return is_package_available("safetensors")
|
||||
|
||||
|
||||
# Shell-related helpers
|
||||
try:
|
||||
# Set to `True` if script is running in a Google Colab notebook.
|
||||
# If running in Google Colab, git credential store is set globally which makes the
|
||||
# warning disappear. See https://github.com/huggingface/huggingface_hub/issues/1043
|
||||
#
|
||||
# Taken from https://stackoverflow.com/a/63519730.
|
||||
_is_google_colab = "google.colab" in str(get_ipython()) # type: ignore # noqa: F821
|
||||
except NameError:
|
||||
_is_google_colab = False
|
||||
|
||||
|
||||
def is_notebook() -> bool:
|
||||
"""Return `True` if code is executed in a notebook (Jupyter, Colab, QTconsole).
|
||||
|
||||
Taken from https://stackoverflow.com/a/39662359.
|
||||
Adapted to make it work with Google colab as well.
|
||||
"""
|
||||
try:
|
||||
shell_class = get_ipython().__class__ # type: ignore # noqa: F821
|
||||
for parent_class in shell_class.__mro__: # e.g. "is subclass of"
|
||||
if parent_class.__name__ == "ZMQInteractiveShell":
|
||||
return True # Jupyter notebook, Google colab or qtconsole
|
||||
return False
|
||||
except NameError:
|
||||
return False # Probably standard Python interpreter
|
||||
|
||||
|
||||
def is_google_colab() -> bool:
|
||||
"""Return `True` if code is executed in a Google colab.
|
||||
|
||||
Taken from https://stackoverflow.com/a/63519730.
|
||||
"""
|
||||
return _is_google_colab
|
||||
|
||||
|
||||
def is_colab_enterprise() -> bool:
|
||||
"""Return `True` if code is executed in a Google Colab Enterprise environment."""
|
||||
return os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE"
|
||||
|
||||
|
||||
# Check how huggingface_hub has been installed
|
||||
|
||||
|
||||
def installation_method() -> Literal["brew", "hf_installer", "unknown"]:
|
||||
"""Return the installation method of the current environment.
|
||||
|
||||
- "hf_installer" if installed via the official installer script
|
||||
- "brew" if installed via Homebrew
|
||||
- "unknown" otherwise
|
||||
"""
|
||||
if _is_brew_installation():
|
||||
return "brew"
|
||||
elif _is_hf_installer_installation():
|
||||
return "hf_installer"
|
||||
else:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _is_brew_installation() -> bool:
|
||||
"""Check if running from a Homebrew installation.
|
||||
|
||||
Note: AI-generated by Claude.
|
||||
"""
|
||||
exe_path = Path(sys.executable).resolve()
|
||||
exe_str = str(exe_path)
|
||||
|
||||
# Check common Homebrew paths
|
||||
# /opt/homebrew (Apple Silicon), /usr/local (Intel)
|
||||
return "/Cellar/" in exe_str or "/opt/homebrew/" in exe_str or exe_str.startswith("/usr/local/Cellar/")
|
||||
|
||||
|
||||
def _is_hf_installer_installation() -> bool:
|
||||
"""Return `True` if the current environment was set up via the official hf installer script.
|
||||
|
||||
i.e. using one of
|
||||
curl -LsSf https://hf.co/cli/install.sh | bash
|
||||
powershell -ExecutionPolicy ByPass -c "irm https://hf.co/cli/install.ps1 | iex"
|
||||
"""
|
||||
venv = sys.prefix # points to venv root if active
|
||||
marker = Path(venv) / ".hf_installer_marker"
|
||||
return marker.exists()
|
||||
|
||||
|
||||
def dump_environment_info() -> dict[str, Any]:
|
||||
"""Dump information about the machine to help debugging issues.
|
||||
|
||||
Similar helper exist in:
|
||||
- `datasets` (https://github.com/huggingface/datasets/blob/main/src/datasets/commands/env.py)
|
||||
- `diffusers` (https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/env.py)
|
||||
- `transformers` (https://github.com/huggingface/transformers/blob/main/src/transformers/commands/env.py)
|
||||
"""
|
||||
from huggingface_hub import get_token, whoami
|
||||
from huggingface_hub.utils import list_credential_helpers
|
||||
|
||||
token = get_token()
|
||||
|
||||
# Generic machine info
|
||||
info: dict[str, Any] = {
|
||||
"huggingface_hub version": get_hf_hub_version(),
|
||||
"Platform": platform.platform(),
|
||||
"Python version": get_python_version(),
|
||||
}
|
||||
|
||||
# Interpreter info
|
||||
try:
|
||||
shell_class = get_ipython().__class__ # type: ignore # noqa: F821
|
||||
info["Running in iPython ?"] = "Yes"
|
||||
info["iPython shell"] = shell_class.__name__
|
||||
except NameError:
|
||||
info["Running in iPython ?"] = "No"
|
||||
info["Running in notebook ?"] = "Yes" if is_notebook() else "No"
|
||||
info["Running in Google Colab ?"] = "Yes" if is_google_colab() else "No"
|
||||
info["Running in Google Colab Enterprise ?"] = "Yes" if is_colab_enterprise() else "No"
|
||||
# Login info
|
||||
info["Token path ?"] = constants.HF_TOKEN_PATH
|
||||
info["Has saved token ?"] = token is not None
|
||||
if token is not None:
|
||||
try:
|
||||
info["Who am I ?"] = whoami()["name"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
info["Configured git credential helpers"] = ", ".join(list_credential_helpers())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# How huggingface_hub has been installed?
|
||||
info["Installation method"] = installation_method()
|
||||
|
||||
# Installed dependencies
|
||||
info["httpx"] = get_httpx_version()
|
||||
info["hf_xet"] = get_xet_version()
|
||||
info["gradio"] = get_gradio_version()
|
||||
info["tensorboard"] = get_tensorboard_version()
|
||||
|
||||
# Environment variables
|
||||
info["ENDPOINT"] = constants.ENDPOINT
|
||||
info["HF_HUB_CACHE"] = constants.HF_HUB_CACHE
|
||||
info["HF_ASSETS_CACHE"] = constants.HF_ASSETS_CACHE
|
||||
info["HF_TOKEN_PATH"] = constants.HF_TOKEN_PATH
|
||||
info["HF_STORED_TOKENS_PATH"] = constants.HF_STORED_TOKENS_PATH
|
||||
info["HF_HUB_OFFLINE"] = constants.HF_HUB_OFFLINE
|
||||
info["HF_HUB_DISABLE_TELEMETRY"] = constants.HF_HUB_DISABLE_TELEMETRY
|
||||
info["HF_HUB_DISABLE_PROGRESS_BARS"] = constants.HF_HUB_DISABLE_PROGRESS_BARS
|
||||
info["HF_HUB_DISABLE_SYMLINKS_WARNING"] = constants.HF_HUB_DISABLE_SYMLINKS_WARNING
|
||||
info["HF_HUB_DISABLE_EXPERIMENTAL_WARNING"] = constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING
|
||||
info["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = constants.HF_HUB_DISABLE_IMPLICIT_TOKEN
|
||||
info["HF_HUB_DISABLE_XET"] = constants.HF_HUB_DISABLE_XET
|
||||
info["HF_HUB_ETAG_TIMEOUT"] = constants.HF_HUB_ETAG_TIMEOUT
|
||||
info["HF_HUB_DOWNLOAD_TIMEOUT"] = constants.HF_HUB_DOWNLOAD_TIMEOUT
|
||||
info["HF_XET_HIGH_PERFORMANCE"] = constants.HF_XET_HIGH_PERFORMANCE
|
||||
|
||||
print("\nCopy-and-paste the text below in your GitHub issue.\n")
|
||||
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n")
|
||||
return info
|
||||
@@ -0,0 +1,111 @@
|
||||
import functools
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
FILENAME_T = str
|
||||
TENSOR_NAME_T = str
|
||||
DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
"""Information about a tensor.
|
||||
|
||||
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
|
||||
|
||||
Attributes:
|
||||
dtype (`str`):
|
||||
The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL").
|
||||
shape (`list[int]`):
|
||||
The shape of the tensor.
|
||||
data_offsets (`tuple[int, int]`):
|
||||
The offsets of the data in the file as a tuple `[BEGIN, END]`.
|
||||
parameter_count (`int`):
|
||||
The number of parameters in the tensor.
|
||||
"""
|
||||
|
||||
dtype: DTYPE_T
|
||||
shape: list[int]
|
||||
data_offsets: tuple[int, int]
|
||||
parameter_count: int = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Taken from https://stackoverflow.com/a/13840436
|
||||
try:
|
||||
self.parameter_count = functools.reduce(operator.mul, self.shape)
|
||||
except TypeError:
|
||||
self.parameter_count = 1 # scalar value has no shape
|
||||
|
||||
|
||||
@dataclass
|
||||
class SafetensorsFileMetadata:
|
||||
"""Metadata for a Safetensors file hosted on the Hub.
|
||||
|
||||
This class is returned by [`parse_safetensors_file_metadata`].
|
||||
|
||||
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
|
||||
|
||||
Attributes:
|
||||
metadata (`dict`):
|
||||
The metadata contained in the file.
|
||||
tensors (`dict[str, TensorInfo]`):
|
||||
A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a
|
||||
[`TensorInfo`] object.
|
||||
parameter_count (`dict[str, int]`):
|
||||
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
|
||||
of that data type.
|
||||
"""
|
||||
|
||||
metadata: dict[str, str]
|
||||
tensors: dict[TENSOR_NAME_T, TensorInfo]
|
||||
parameter_count: dict[DTYPE_T, int] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
parameter_count: dict[DTYPE_T, int] = defaultdict(int)
|
||||
for tensor in self.tensors.values():
|
||||
parameter_count[tensor.dtype] += tensor.parameter_count
|
||||
self.parameter_count = dict(parameter_count)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SafetensorsRepoMetadata:
|
||||
"""Metadata for a Safetensors repo.
|
||||
|
||||
A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared
|
||||
model) or a 'model.safetensors.index.json' index file (sharded model) at its root.
|
||||
|
||||
This class is returned by [`get_safetensors_metadata`].
|
||||
|
||||
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
|
||||
|
||||
Attributes:
|
||||
metadata (`dict`, *optional*):
|
||||
The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded
|
||||
models.
|
||||
sharded (`bool`):
|
||||
Whether the repo contains a sharded model or not.
|
||||
weight_map (`dict[str, str]`):
|
||||
A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors.
|
||||
files_metadata (`dict[str, SafetensorsFileMetadata]`):
|
||||
A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as
|
||||
a [`SafetensorsFileMetadata`] object.
|
||||
parameter_count (`dict[str, int]`):
|
||||
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
|
||||
of that data type.
|
||||
"""
|
||||
|
||||
metadata: Optional[dict]
|
||||
sharded: bool
|
||||
weight_map: dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename
|
||||
files_metadata: dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata
|
||||
parameter_count: dict[DTYPE_T, int] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
parameter_count: dict[DTYPE_T, int] = defaultdict(int)
|
||||
for file_metadata in self.files_metadata.values():
|
||||
for dtype, nb_parameters_ in file_metadata.parameter_count.items():
|
||||
parameter_count[dtype] += nb_parameters_
|
||||
self.parameter_count = dict(parameter_count)
|
||||
@@ -0,0 +1,144 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
"""Contains utilities to easily handle subprocesses in `huggingface_hub`."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import IO, Generator, Optional, Union
|
||||
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def capture_output() -> Generator[StringIO, None, None]:
|
||||
"""Capture output that is printed to terminal.
|
||||
|
||||
Taken from https://stackoverflow.com/a/34738440
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> with capture_output() as output:
|
||||
... print("hello world")
|
||||
>>> assert output.getvalue() == "hello world\n"
|
||||
```
|
||||
"""
|
||||
output = StringIO()
|
||||
previous_output = sys.stdout
|
||||
sys.stdout = output
|
||||
try:
|
||||
yield output
|
||||
finally:
|
||||
sys.stdout = previous_output
|
||||
|
||||
|
||||
def run_subprocess(
|
||||
command: Union[str, list[str]],
|
||||
folder: Optional[Union[str, Path]] = None,
|
||||
check=True,
|
||||
**kwargs,
|
||||
) -> subprocess.CompletedProcess:
|
||||
"""
|
||||
Method to run subprocesses. Calling this will capture the `stderr` and `stdout`,
|
||||
please call `subprocess.run` manually in case you would like for them not to
|
||||
be captured.
|
||||
|
||||
Args:
|
||||
command (`str` or `list[str]`):
|
||||
The command to execute as a string or list of strings.
|
||||
folder (`str`, *optional*):
|
||||
The folder in which to run the command. Defaults to current working
|
||||
directory (from `os.getcwd()`).
|
||||
check (`bool`, *optional*, defaults to `True`):
|
||||
Setting `check` to `True` will raise a `subprocess.CalledProcessError`
|
||||
when the subprocess has a non-zero exit code.
|
||||
kwargs (`dict[str]`):
|
||||
Keyword arguments to be passed to the `subprocess.run` underlying command.
|
||||
|
||||
Returns:
|
||||
`subprocess.CompletedProcess`: The completed process.
|
||||
"""
|
||||
if isinstance(command, str):
|
||||
command = command.split()
|
||||
|
||||
if isinstance(folder, Path):
|
||||
folder = str(folder)
|
||||
|
||||
return subprocess.run(
|
||||
command,
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
check=check,
|
||||
encoding="utf-8",
|
||||
errors="replace", # if not utf-8, replace char by <20>
|
||||
cwd=folder or os.getcwd(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_interactive_subprocess(
|
||||
command: Union[str, list[str]],
|
||||
folder: Optional[Union[str, Path]] = None,
|
||||
**kwargs,
|
||||
) -> Generator[tuple[IO[str], IO[str]], None, None]:
|
||||
"""Run a subprocess in an interactive mode in a context manager.
|
||||
|
||||
Args:
|
||||
command (`str` or `list[str]`):
|
||||
The command to execute as a string or list of strings.
|
||||
folder (`str`, *optional*):
|
||||
The folder in which to run the command. Defaults to current working
|
||||
directory (from `os.getcwd()`).
|
||||
kwargs (`dict[str]`):
|
||||
Keyword arguments to be passed to the `subprocess.run` underlying command.
|
||||
|
||||
Returns:
|
||||
`tuple[IO[str], IO[str]]`: A tuple with `stdin` and `stdout` to interact
|
||||
with the process (input and output are utf-8 encoded).
|
||||
|
||||
Example:
|
||||
```python
|
||||
with _interactive_subprocess("git credential-store get") as (stdin, stdout):
|
||||
# Write to stdin
|
||||
stdin.write("url=hf.co\nusername=obama\n".encode("utf-8"))
|
||||
stdin.flush()
|
||||
|
||||
# Read from stdout
|
||||
output = stdout.read().decode("utf-8")
|
||||
```
|
||||
"""
|
||||
if isinstance(command, str):
|
||||
command = command.split()
|
||||
|
||||
with subprocess.Popen(
|
||||
command,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
encoding="utf-8",
|
||||
errors="replace", # if not utf-8, replace char by <20>
|
||||
cwd=folder or os.getcwd(),
|
||||
**kwargs,
|
||||
) as process:
|
||||
assert process.stdin is not None, "subprocess is opened as subprocess.PIPE"
|
||||
assert process.stdout is not None, "subprocess is opened as subprocess.PIPE"
|
||||
yield process.stdin, process.stdout
|
||||
@@ -0,0 +1,126 @@
|
||||
from queue import Queue
|
||||
from threading import Lock, Thread
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from .. import constants, logging
|
||||
from . import build_hf_headers, get_session, hf_raise_for_status
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Telemetry is sent by a separate thread to avoid blocking the main thread.
|
||||
# A daemon thread is started once and consume tasks from the _TELEMETRY_QUEUE.
|
||||
# If the thread stops for some reason -shouldn't happen-, we restart a new one.
|
||||
_TELEMETRY_THREAD: Optional[Thread] = None
|
||||
_TELEMETRY_THREAD_LOCK = Lock() # Lock to avoid starting multiple threads in parallel
|
||||
_TELEMETRY_QUEUE: Queue = Queue()
|
||||
|
||||
|
||||
def send_telemetry(
|
||||
topic: str,
|
||||
*,
|
||||
library_name: Optional[str] = None,
|
||||
library_version: Optional[str] = None,
|
||||
user_agent: Union[dict, str, None] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sends telemetry that helps track usage of different HF libraries.
|
||||
|
||||
This usage data helps us debug issues and prioritize new features. However, we understand that not everyone wants
|
||||
to share additional information, and we respect your privacy. You can disable telemetry collection by setting the
|
||||
`HF_HUB_DISABLE_TELEMETRY=1` as environment variable. Telemetry is also disabled in offline mode (i.e. when setting
|
||||
`HF_HUB_OFFLINE=1`).
|
||||
|
||||
Telemetry collection is run in a separate thread to minimize impact for the user.
|
||||
|
||||
Args:
|
||||
topic (`str`):
|
||||
Name of the topic that is monitored. The topic is directly used to build the URL. If you want to monitor
|
||||
subtopics, just use "/" separation. Examples: "gradio", "transformers/examples",...
|
||||
library_name (`str`, *optional*):
|
||||
The name of the library that is making the HTTP request. Will be added to the user-agent header.
|
||||
library_version (`str`, *optional*):
|
||||
The version of the library that is making the HTTP request. Will be added to the user-agent header.
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user agent info in the form of a dictionary or a single string. It will be completed with information about the installed packages.
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> from huggingface_hub.utils import send_telemetry
|
||||
|
||||
# Send telemetry without library information
|
||||
>>> send_telemetry("ping")
|
||||
|
||||
# Send telemetry to subtopic with library information
|
||||
>>> send_telemetry("gradio/local_link", library_name="gradio", library_version="3.22.1")
|
||||
|
||||
# Send telemetry with additional data
|
||||
>>> send_telemetry(
|
||||
... topic="examples",
|
||||
... library_name="transformers",
|
||||
... library_version="4.26.0",
|
||||
... user_agent={"pipeline": "text_classification", "framework": "flax"},
|
||||
... )
|
||||
```
|
||||
"""
|
||||
if constants.HF_HUB_OFFLINE or constants.HF_HUB_DISABLE_TELEMETRY:
|
||||
return
|
||||
|
||||
_start_telemetry_thread() # starts thread only if doesn't exist yet
|
||||
_TELEMETRY_QUEUE.put(
|
||||
{"topic": topic, "library_name": library_name, "library_version": library_version, "user_agent": user_agent}
|
||||
)
|
||||
|
||||
|
||||
def _start_telemetry_thread():
|
||||
"""Start a daemon thread to consume tasks from the telemetry queue.
|
||||
|
||||
If the thread is interrupted, start a new one.
|
||||
"""
|
||||
with _TELEMETRY_THREAD_LOCK: # avoid to start multiple threads if called concurrently
|
||||
global _TELEMETRY_THREAD
|
||||
if _TELEMETRY_THREAD is None or not _TELEMETRY_THREAD.is_alive():
|
||||
_TELEMETRY_THREAD = Thread(target=_telemetry_worker, daemon=True)
|
||||
_TELEMETRY_THREAD.start()
|
||||
|
||||
|
||||
def _telemetry_worker():
|
||||
"""Wait for a task and consume it."""
|
||||
while True:
|
||||
kwargs = _TELEMETRY_QUEUE.get()
|
||||
_send_telemetry_in_thread(**kwargs)
|
||||
_TELEMETRY_QUEUE.task_done()
|
||||
|
||||
|
||||
def _send_telemetry_in_thread(
|
||||
topic: str,
|
||||
*,
|
||||
library_name: Optional[str] = None,
|
||||
library_version: Optional[str] = None,
|
||||
user_agent: Union[dict, str, None] = None,
|
||||
) -> None:
|
||||
"""Contains the actual data sending data to the Hub.
|
||||
|
||||
This function is called directly in gradio's analytics because
|
||||
it is not possible to send telemetry from a daemon thread.
|
||||
|
||||
See here: https://github.com/gradio-app/gradio/pull/8180
|
||||
|
||||
Please do not rename or remove this function.
|
||||
"""
|
||||
path = "/".join(quote(part) for part in topic.split("/") if len(part) > 0)
|
||||
try:
|
||||
r = get_session().head(
|
||||
f"{constants.ENDPOINT}/api/telemetry/{path}",
|
||||
headers=build_hf_headers(
|
||||
token=False, # no need to send a token for telemetry
|
||||
library_name=library_name,
|
||||
library_version=library_version,
|
||||
user_agent=user_agent,
|
||||
),
|
||||
)
|
||||
hf_raise_for_status(r)
|
||||
except Exception as e:
|
||||
# We don't want to error in case of connection errors of any kind.
|
||||
logger.debug(f"Error while sending telemetry: {e}")
|
||||
@@ -0,0 +1,69 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to print stuff to the terminal (styling, helpers)."""
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
|
||||
class ANSI:
|
||||
"""
|
||||
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
||||
"""
|
||||
|
||||
_bold = "\u001b[1m"
|
||||
_gray = "\u001b[90m"
|
||||
_red = "\u001b[31m"
|
||||
_reset = "\u001b[0m"
|
||||
_yellow = "\u001b[33m"
|
||||
|
||||
@classmethod
|
||||
def bold(cls, s: str) -> str:
|
||||
return cls._format(s, cls._bold)
|
||||
|
||||
@classmethod
|
||||
def gray(cls, s: str) -> str:
|
||||
return cls._format(s, cls._gray)
|
||||
|
||||
@classmethod
|
||||
def red(cls, s: str) -> str:
|
||||
return cls._format(s, cls._bold + cls._red)
|
||||
|
||||
@classmethod
|
||||
def yellow(cls, s: str) -> str:
|
||||
return cls._format(s, cls._yellow)
|
||||
|
||||
@classmethod
|
||||
def _format(cls, s: str, code: str) -> str:
|
||||
if os.environ.get("NO_COLOR"):
|
||||
# See https://no-color.org/
|
||||
return s
|
||||
return f"{code}{s}{cls._reset}"
|
||||
|
||||
|
||||
def tabulate(rows: list[list[Union[str, int]]], headers: list[str]) -> str:
|
||||
"""
|
||||
Inspired by:
|
||||
|
||||
- stackoverflow.com/a/8356620/593036
|
||||
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
||||
"""
|
||||
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||
lines = []
|
||||
lines.append(row_format.format(*headers))
|
||||
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
||||
for row in rows:
|
||||
lines.append(row_format.format(*row))
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,95 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Handle typing imports based on system compatibility."""
|
||||
|
||||
import sys
|
||||
from typing import Any, Callable, Literal, Optional, Type, TypeVar, Union, get_args, get_origin
|
||||
|
||||
|
||||
UNION_TYPES: list[Any] = [Union]
|
||||
if sys.version_info >= (3, 10):
|
||||
from types import UnionType
|
||||
|
||||
UNION_TYPES += [UnionType]
|
||||
|
||||
|
||||
HTTP_METHOD_T = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
|
||||
|
||||
# type hint meaning "function signature not changed by decorator"
|
||||
CallableT = TypeVar("CallableT", bound=Callable)
|
||||
|
||||
_JSON_SERIALIZABLE_TYPES = (int, float, str, bool, type(None))
|
||||
|
||||
|
||||
def is_jsonable(obj: Any, _visited: Optional[set[int]] = None) -> bool:
|
||||
"""Check if an object is JSON serializable.
|
||||
|
||||
This is a weak check, as it does not check for the actual JSON serialization, but only for the types of the object.
|
||||
It works correctly for basic use cases but do not guarantee an exhaustive check.
|
||||
|
||||
Object is considered to be recursively json serializable if:
|
||||
- it is an instance of int, float, str, bool, or NoneType
|
||||
- it is a list or tuple and all its items are json serializable
|
||||
- it is a dict and all its keys are strings and all its values are json serializable
|
||||
|
||||
Uses a visited set to avoid infinite recursion on circular references. If object has already been visited, it is
|
||||
considered not json serializable.
|
||||
"""
|
||||
# Initialize visited set to track object ids and detect circular references
|
||||
if _visited is None:
|
||||
_visited = set()
|
||||
|
||||
# Detect circular reference
|
||||
obj_id = id(obj)
|
||||
if obj_id in _visited:
|
||||
return False
|
||||
|
||||
# Add current object to visited before recursive checks
|
||||
_visited.add(obj_id)
|
||||
try:
|
||||
if isinstance(obj, _JSON_SERIALIZABLE_TYPES):
|
||||
return True
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return all(is_jsonable(item, _visited) for item in obj)
|
||||
if isinstance(obj, dict):
|
||||
return all(
|
||||
isinstance(key, _JSON_SERIALIZABLE_TYPES) and is_jsonable(value, _visited)
|
||||
for key, value in obj.items()
|
||||
)
|
||||
if hasattr(obj, "__json__"):
|
||||
return True
|
||||
return False
|
||||
except RecursionError:
|
||||
return False
|
||||
finally:
|
||||
# Remove the object id from visited to avoid side‑effects for other branches
|
||||
_visited.discard(obj_id)
|
||||
|
||||
|
||||
def is_simple_optional_type(type_: Type) -> bool:
|
||||
"""Check if a type is optional, i.e. Optional[Type] or Union[Type, None] or Type | None, where Type is a non-composite type."""
|
||||
if get_origin(type_) in UNION_TYPES:
|
||||
union_args = get_args(type_)
|
||||
if len(union_args) == 2 and type(None) in union_args:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def unwrap_simple_optional_type(optional_type: Type) -> Type:
|
||||
"""Unwraps a simple optional type, i.e. returns Type from Optional[Type]."""
|
||||
for arg in get_args(optional_type):
|
||||
if arg is not type(None):
|
||||
return arg
|
||||
raise ValueError(f"'{optional_type}' is not an optional type")
|
||||
@@ -0,0 +1,207 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains utilities to validate argument values in `huggingface_hub`."""
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub.errors import HFValidationError
|
||||
|
||||
from ._typing import CallableT
|
||||
|
||||
|
||||
REPO_ID_REGEX = re.compile(
|
||||
r"""
|
||||
^
|
||||
(\b[\w\-.]+\b/)? # optional namespace (username or organization)
|
||||
\b # starts with a word boundary
|
||||
[\w\-.]{1,96} # repo_name: alphanumeric + . _ -
|
||||
\b # ends with a word boundary
|
||||
$
|
||||
""",
|
||||
flags=re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
def validate_hf_hub_args(fn: CallableT) -> CallableT:
|
||||
"""Validate values received as argument for any public method of `huggingface_hub`.
|
||||
|
||||
The goal of this decorator is to harmonize validation of arguments reused
|
||||
everywhere. By default, all defined validators are tested.
|
||||
|
||||
Validators:
|
||||
- [`~utils.validate_repo_id`]: `repo_id` must be `"repo_name"`
|
||||
or `"namespace/repo_name"`. Namespace is a username or an organization.
|
||||
- [`~utils.smoothly_deprecate_legacy_arguments`]: Ignore `proxies` when downloading files (should be set globally).
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
>>> @validate_hf_hub_args
|
||||
... def my_cool_method(repo_id: str):
|
||||
... print(repo_id)
|
||||
|
||||
>>> my_cool_method(repo_id="valid_repo_id")
|
||||
valid_repo_id
|
||||
|
||||
>>> my_cool_method("other..repo..id")
|
||||
huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'.
|
||||
|
||||
>>> my_cool_method(repo_id="other..repo..id")
|
||||
huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'.
|
||||
```
|
||||
|
||||
Raises:
|
||||
[`~utils.HFValidationError`]:
|
||||
If an input is not valid.
|
||||
"""
|
||||
# TODO: add an argument to opt-out validation for specific argument?
|
||||
signature = inspect.signature(fn)
|
||||
|
||||
@wraps(fn)
|
||||
def _inner_fn(*args, **kwargs):
|
||||
for arg_name, arg_value in chain(
|
||||
zip(signature.parameters, args), # Args values
|
||||
kwargs.items(), # Kwargs values
|
||||
):
|
||||
if arg_name in ["repo_id", "from_id", "to_id"]:
|
||||
validate_repo_id(arg_value)
|
||||
|
||||
kwargs = smoothly_deprecate_legacy_arguments(fn_name=fn.__name__, kwargs=kwargs)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _inner_fn # type: ignore
|
||||
|
||||
|
||||
def validate_repo_id(repo_id: str) -> None:
|
||||
"""Validate `repo_id` is valid.
|
||||
|
||||
This is not meant to replace the proper validation made on the Hub but rather to
|
||||
avoid local inconsistencies whenever possible (example: passing `repo_type` in the
|
||||
`repo_id` is forbidden).
|
||||
|
||||
Rules:
|
||||
- Between 1 and 96 characters.
|
||||
- Either "repo_name" or "namespace/repo_name"
|
||||
- [a-zA-Z0-9] or "-", "_", "."
|
||||
- "--" and ".." are forbidden
|
||||
|
||||
Valid: `"foo"`, `"foo/bar"`, `"123"`, `"Foo-BAR_foo.bar123"`
|
||||
|
||||
Not valid: `"datasets/foo/bar"`, `".repo_id"`, `"foo--bar"`, `"foo.git"`
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> from huggingface_hub.utils import validate_repo_id
|
||||
>>> validate_repo_id(repo_id="valid_repo_id")
|
||||
>>> validate_repo_id(repo_id="other..repo..id")
|
||||
huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'.
|
||||
```
|
||||
|
||||
Discussed in https://github.com/huggingface/huggingface_hub/issues/1008.
|
||||
In moon-landing (internal repository):
|
||||
- https://github.com/huggingface/moon-landing/blob/main/server/lib/Names.ts#L27
|
||||
- https://github.com/huggingface/moon-landing/blob/main/server/views/components/NewRepoForm/NewRepoForm.svelte#L138
|
||||
"""
|
||||
if not isinstance(repo_id, str):
|
||||
# Typically, a Path is not a repo_id
|
||||
raise HFValidationError(f"Repo id must be a string, not {type(repo_id)}: '{repo_id}'.")
|
||||
|
||||
if repo_id.count("/") > 1:
|
||||
raise HFValidationError(
|
||||
"Repo id must be in the form 'repo_name' or 'namespace/repo_name':"
|
||||
f" '{repo_id}'. Use `repo_type` argument if needed."
|
||||
)
|
||||
|
||||
if not REPO_ID_REGEX.match(repo_id):
|
||||
raise HFValidationError(
|
||||
"Repo id must use alphanumeric chars, '-', '_' or '.'."
|
||||
" The name cannot start or end with '-' or '.' and the maximum length is 96:"
|
||||
f" '{repo_id}'."
|
||||
)
|
||||
|
||||
if "--" in repo_id or ".." in repo_id:
|
||||
raise HFValidationError(f"Cannot have -- or .. in repo_id: '{repo_id}'.")
|
||||
|
||||
if repo_id.endswith(".git"):
|
||||
raise HFValidationError(f"Repo_id cannot end by '.git': '{repo_id}'.")
|
||||
|
||||
|
||||
def smoothly_deprecate_legacy_arguments(fn_name: str, kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Smoothly deprecate legacy arguments in the `huggingface_hub` codebase.
|
||||
|
||||
This function ignores some deprecated arguments from the kwargs and warns the user they are ignored.
|
||||
The goal is to avoid breaking existing code while guiding the user to the new way of doing things.
|
||||
|
||||
List of deprecated arguments:
|
||||
- `proxies`:
|
||||
To set up proxies, user must either use the HTTP_PROXY environment variable or configure the `httpx.Client`
|
||||
manually using the [`set_client_factory`] function.
|
||||
|
||||
In huggingface_hub 0.x, `proxies` was a dictionary directly passed to `requests.request`.
|
||||
In huggingface_hub 1.x, we migrated to `httpx` which does not support `proxies` the same way.
|
||||
In particular, it is not possible to configure proxies on a per-request basis. The solution is to configure
|
||||
it globally using the [`set_client_factory`] function or using the HTTP_PROXY environment variable.
|
||||
|
||||
For more details, see:
|
||||
- https://www.python-httpx.org/advanced/proxies/
|
||||
- https://www.python-httpx.org/compatibility/#proxy-keys.
|
||||
|
||||
- `resume_download`: deprecated without replacement. `huggingface_hub` always resumes downloads whenever possible.
|
||||
- `force_filename`: deprecated without replacement. Filename is always the same as on the Hub.
|
||||
- `local_dir_use_symlinks`: deprecated without replacement. Downloading to a local directory does not use symlinks anymore.
|
||||
"""
|
||||
new_kwargs = kwargs.copy() # do not mutate input !
|
||||
|
||||
# proxies
|
||||
proxies = new_kwargs.pop("proxies", None) # remove from kwargs
|
||||
if proxies is not None:
|
||||
warnings.warn(
|
||||
f"The `proxies` argument is ignored in `{fn_name}`. To set up proxies, use the HTTP_PROXY / HTTPS_PROXY"
|
||||
" environment variables or configure the `httpx.Client` manually using `huggingface_hub.set_client_factory`."
|
||||
" See https://www.python-httpx.org/advanced/proxies/ for more details."
|
||||
)
|
||||
|
||||
# resume_download
|
||||
resume_download = new_kwargs.pop("resume_download", None) # remove from kwargs
|
||||
if resume_download is not None:
|
||||
warnings.warn(
|
||||
f"The `resume_download` argument is deprecated and ignored in `{fn_name}`. Downloads always resume"
|
||||
" whenever possible."
|
||||
)
|
||||
|
||||
# force_filename
|
||||
force_filename = new_kwargs.pop("force_filename", None) # remove from kwargs
|
||||
if force_filename is not None:
|
||||
warnings.warn(
|
||||
f"The `force_filename` argument is deprecated and ignored in `{fn_name}`. Filename is always the same "
|
||||
"as on the Hub."
|
||||
)
|
||||
|
||||
# local_dir_use_symlinks
|
||||
local_dir_use_symlinks = new_kwargs.pop("local_dir_use_symlinks", None) # remove from kwargs
|
||||
if local_dir_use_symlinks is not None:
|
||||
warnings.warn(
|
||||
f"The `local_dir_use_symlinks` argument is deprecated and ignored in `{fn_name}`. Downloading to a local"
|
||||
" directory does not use symlinks anymore."
|
||||
)
|
||||
|
||||
return new_kwargs
|
||||
@@ -0,0 +1,167 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
|
||||
|
||||
from .. import constants
|
||||
from ..file_download import repo_folder_name
|
||||
from .sha import git_hash, sha_fileobj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hf_api import RepoFile, RepoFolder
|
||||
|
||||
# using fullmatch for clarity and strictness
|
||||
_REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
|
||||
|
||||
|
||||
# Typed structure describing a checksum mismatch
|
||||
class Mismatch(TypedDict):
|
||||
path: str
|
||||
expected: str
|
||||
actual: str
|
||||
algorithm: str
|
||||
|
||||
|
||||
HashAlgo = Literal["sha256", "git-sha1"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FolderVerification:
|
||||
revision: str
|
||||
checked_count: int
|
||||
mismatches: list[Mismatch]
|
||||
missing_paths: list[str]
|
||||
extra_paths: list[str]
|
||||
verified_path: Path
|
||||
|
||||
|
||||
def collect_local_files(root: Path) -> dict[str, Path]:
|
||||
"""
|
||||
Return a mapping of repo-relative path -> absolute path for all files under `root`.
|
||||
"""
|
||||
return {p.relative_to(root).as_posix(): p for p in root.rglob("*") if p.is_file()}
|
||||
|
||||
|
||||
def _resolve_commit_hash_from_cache(storage_folder: Path, revision: Optional[str]) -> str:
|
||||
"""
|
||||
Resolve a commit hash from a cache repo folder and an optional revision.
|
||||
"""
|
||||
if revision and _REGEX_COMMIT_HASH.fullmatch(revision):
|
||||
return revision
|
||||
|
||||
refs_dir = storage_folder / "refs"
|
||||
snapshots_dir = storage_folder / "snapshots"
|
||||
|
||||
if revision:
|
||||
ref_path = refs_dir / revision
|
||||
if ref_path.is_file():
|
||||
return ref_path.read_text(encoding="utf-8").strip()
|
||||
raise ValueError(f"Revision '{revision}' could not be resolved in cache (expected file '{ref_path}').")
|
||||
|
||||
# No revision provided: try common defaults
|
||||
main_ref = refs_dir / "main"
|
||||
if main_ref.is_file():
|
||||
return main_ref.read_text(encoding="utf-8").strip()
|
||||
|
||||
if not snapshots_dir.is_dir():
|
||||
raise ValueError(f"Cache repo is missing snapshots directory: {snapshots_dir}. Provide --revision explicitly.")
|
||||
|
||||
candidates = [p.name for p in snapshots_dir.iterdir() if p.is_dir() and _REGEX_COMMIT_HASH.fullmatch(p.name)]
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
|
||||
raise ValueError(
|
||||
"Ambiguous cached revision: multiple snapshots found and no refs to disambiguate. Please pass --revision."
|
||||
)
|
||||
|
||||
|
||||
def compute_file_hash(path: Path, algorithm: HashAlgo) -> str:
|
||||
"""
|
||||
Compute the checksum of a local file using the requested algorithm.
|
||||
"""
|
||||
|
||||
with path.open("rb") as stream:
|
||||
if algorithm == "sha256":
|
||||
return sha_fileobj(stream).hex()
|
||||
if algorithm == "git-sha1":
|
||||
return git_hash(stream.read())
|
||||
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
||||
|
||||
|
||||
def verify_maps(
|
||||
*,
|
||||
remote_by_path: dict[str, Union["RepoFile", "RepoFolder"]],
|
||||
local_by_path: dict[str, Path],
|
||||
revision: str,
|
||||
verified_path: Path,
|
||||
) -> FolderVerification:
|
||||
"""Compare remote entries and local files and return a verification result."""
|
||||
remote_paths = set(remote_by_path)
|
||||
local_paths = set(local_by_path)
|
||||
|
||||
missing = sorted(remote_paths - local_paths)
|
||||
extra = sorted(local_paths - remote_paths)
|
||||
both = sorted(remote_paths & local_paths)
|
||||
|
||||
mismatches: list[Mismatch] = []
|
||||
|
||||
for rel_path in both:
|
||||
remote_entry = remote_by_path[rel_path]
|
||||
local_path = local_by_path[rel_path]
|
||||
|
||||
lfs = getattr(remote_entry, "lfs", None)
|
||||
lfs_sha = getattr(lfs, "sha256", None) if lfs is not None else None
|
||||
if lfs_sha is None and isinstance(lfs, dict):
|
||||
lfs_sha = lfs.get("sha256")
|
||||
if lfs_sha:
|
||||
algorithm: HashAlgo = "sha256"
|
||||
expected = str(lfs_sha).lower()
|
||||
else:
|
||||
blob_id = remote_entry.blob_id # type: ignore
|
||||
algorithm = "git-sha1"
|
||||
expected = str(blob_id).lower()
|
||||
|
||||
actual = compute_file_hash(local_path, algorithm)
|
||||
|
||||
if actual != expected:
|
||||
mismatches.append(Mismatch(path=rel_path, expected=expected, actual=actual, algorithm=algorithm))
|
||||
|
||||
return FolderVerification(
|
||||
revision=revision,
|
||||
checked_count=len(both),
|
||||
mismatches=mismatches,
|
||||
missing_paths=missing,
|
||||
extra_paths=extra,
|
||||
verified_path=verified_path,
|
||||
)
|
||||
|
||||
|
||||
def resolve_local_root(
|
||||
*,
|
||||
repo_id: str,
|
||||
repo_type: str,
|
||||
revision: Optional[str],
|
||||
cache_dir: Optional[Path],
|
||||
local_dir: Optional[Path],
|
||||
) -> tuple[Path, str]:
|
||||
"""
|
||||
Resolve the root directory to scan locally and the remote revision to verify.
|
||||
"""
|
||||
if local_dir is not None:
|
||||
root = Path(local_dir).expanduser().resolve()
|
||||
if not root.is_dir():
|
||||
raise ValueError(f"Local directory does not exist or is not a directory: {root}")
|
||||
return root, (revision or constants.DEFAULT_REVISION)
|
||||
|
||||
cache_root = Path(cache_dir or constants.HF_HUB_CACHE).expanduser().resolve()
|
||||
storage_folder = cache_root / repo_folder_name(repo_id=repo_id, repo_type=repo_type)
|
||||
if not storage_folder.exists():
|
||||
raise ValueError(
|
||||
f"Repo is not present in cache: {storage_folder}. Use 'hf download' first or pass --local-dir."
|
||||
)
|
||||
commit = _resolve_commit_hash_from_cache(storage_folder, revision)
|
||||
snapshot_dir = storage_folder / "snapshots" / commit
|
||||
if not snapshot_dir.is_dir():
|
||||
raise ValueError(f"Snapshot directory does not exist for revision '{commit}': {snapshot_dir}.")
|
||||
return snapshot_dir, commit
|
||||
@@ -0,0 +1,235 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .. import constants
|
||||
from . import get_session, hf_raise_for_status, validate_hf_hub_args
|
||||
|
||||
|
||||
XET_CONNECTION_INFO_SAFETY_PERIOD = 60 # seconds
|
||||
XET_CONNECTION_INFO_CACHE_SIZE = 1_000
|
||||
XET_CONNECTION_INFO_CACHE: dict[str, "XetConnectionInfo"] = {}
|
||||
|
||||
|
||||
class XetTokenType(str, Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class XetFileData:
|
||||
file_hash: str
|
||||
refresh_route: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class XetConnectionInfo:
|
||||
access_token: str
|
||||
expiration_unix_epoch: int
|
||||
endpoint: str
|
||||
|
||||
|
||||
def parse_xet_file_data_from_response(
|
||||
response: httpx.Response, endpoint: Optional[str] = None
|
||||
) -> Optional[XetFileData]:
|
||||
"""
|
||||
Parse XET file metadata from an HTTP response.
|
||||
|
||||
This function extracts XET file metadata from the HTTP headers or HTTP links
|
||||
of a given response object. If the required metadata is not found, it returns `None`.
|
||||
|
||||
Args:
|
||||
response (`httpx.Response`):
|
||||
The HTTP response object containing headers dict and links dict to extract the XET metadata from.
|
||||
Returns:
|
||||
`Optional[XetFileData]`:
|
||||
An instance of `XetFileData` containing the file hash and refresh route if the metadata
|
||||
is found. Returns `None` if the required metadata is missing.
|
||||
"""
|
||||
if response is None:
|
||||
return None
|
||||
try:
|
||||
file_hash = response.headers[constants.HUGGINGFACE_HEADER_X_XET_HASH]
|
||||
|
||||
if constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY in response.links:
|
||||
refresh_route = response.links[constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY]["url"]
|
||||
else:
|
||||
refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE]
|
||||
except KeyError:
|
||||
return None
|
||||
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
|
||||
if refresh_route.startswith(constants.HUGGINGFACE_CO_URL_HOME):
|
||||
refresh_route = refresh_route.replace(constants.HUGGINGFACE_CO_URL_HOME.rstrip("/"), endpoint.rstrip("/"))
|
||||
return XetFileData(
|
||||
file_hash=file_hash,
|
||||
refresh_route=refresh_route,
|
||||
)
|
||||
|
||||
|
||||
def parse_xet_connection_info_from_headers(headers: dict[str, str]) -> Optional[XetConnectionInfo]:
|
||||
"""
|
||||
Parse XET connection info from the HTTP headers or return None if not found.
|
||||
Args:
|
||||
headers (`dict`):
|
||||
HTTP headers to extract the XET metadata from.
|
||||
Returns:
|
||||
`XetConnectionInfo` or `None`:
|
||||
The information needed to connect to the XET storage service.
|
||||
Returns `None` if the headers do not contain the XET connection info.
|
||||
"""
|
||||
try:
|
||||
endpoint = headers[constants.HUGGINGFACE_HEADER_X_XET_ENDPOINT]
|
||||
access_token = headers[constants.HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN]
|
||||
expiration_unix_epoch = int(headers[constants.HUGGINGFACE_HEADER_X_XET_EXPIRATION])
|
||||
except (KeyError, ValueError, TypeError):
|
||||
return None
|
||||
|
||||
return XetConnectionInfo(
|
||||
endpoint=endpoint,
|
||||
access_token=access_token,
|
||||
expiration_unix_epoch=expiration_unix_epoch,
|
||||
)
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def refresh_xet_connection_info(
|
||||
*,
|
||||
file_data: XetFileData,
|
||||
headers: dict[str, str],
|
||||
) -> XetConnectionInfo:
|
||||
"""
|
||||
Utilizes the information in the parsed metadata to request the Hub xet connection information.
|
||||
This includes the access token, expiration, and XET service URL.
|
||||
Args:
|
||||
file_data: (`XetFileData`):
|
||||
The file data needed to refresh the xet connection information.
|
||||
headers (`dict[str, str]`):
|
||||
Headers to use for the request, including authorization headers and user agent.
|
||||
Returns:
|
||||
`XetConnectionInfo`:
|
||||
The connection information needed to make the request to the xet storage service.
|
||||
Raises:
|
||||
[`~utils.HfHubHTTPError`]
|
||||
If the Hub API returned an error.
|
||||
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
If the Hub API response is improperly formatted.
|
||||
"""
|
||||
if file_data.refresh_route is None:
|
||||
raise ValueError("The provided xet metadata does not contain a refresh endpoint.")
|
||||
return _fetch_xet_connection_info_with_url(file_data.refresh_route, headers)
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def fetch_xet_connection_info_from_repo_info(
|
||||
*,
|
||||
token_type: XetTokenType,
|
||||
repo_id: str,
|
||||
repo_type: str,
|
||||
revision: Optional[str] = None,
|
||||
headers: dict[str, str],
|
||||
endpoint: Optional[str] = None,
|
||||
params: Optional[dict[str, str]] = None,
|
||||
) -> XetConnectionInfo:
|
||||
"""
|
||||
Uses the repo info to request a xet access token from Hub.
|
||||
Args:
|
||||
token_type (`XetTokenType`):
|
||||
Type of the token to request: `"read"` or `"write"`.
|
||||
repo_id (`str`):
|
||||
A namespace (user or an organization) and a repo name separated by a `/`.
|
||||
repo_type (`str`):
|
||||
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
||||
revision (`str`, `optional`):
|
||||
The revision of the repo to get the token for.
|
||||
headers (`dict[str, str]`):
|
||||
Headers to use for the request, including authorization headers and user agent.
|
||||
endpoint (`str`, `optional`):
|
||||
The endpoint to use for the request. Defaults to the Hub endpoint.
|
||||
params (`dict[str, str]`, `optional`):
|
||||
Additional parameters to pass with the request.
|
||||
Returns:
|
||||
`XetConnectionInfo`:
|
||||
The connection information needed to make the request to the xet storage service.
|
||||
Raises:
|
||||
[`~utils.HfHubHTTPError`]
|
||||
If the Hub API returned an error.
|
||||
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
If the Hub API response is improperly formatted.
|
||||
"""
|
||||
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
|
||||
url = f"{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type.value}-token/{revision}"
|
||||
return _fetch_xet_connection_info_with_url(url, headers, params)
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def _fetch_xet_connection_info_with_url(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
params: Optional[dict[str, str]] = None,
|
||||
) -> XetConnectionInfo:
|
||||
"""
|
||||
Requests the xet connection info from the supplied URL. This includes the
|
||||
access token, expiration time, and endpoint to use for the xet storage service.
|
||||
|
||||
Result is cached to avoid redundant requests.
|
||||
|
||||
Args:
|
||||
url: (`str`):
|
||||
The access token endpoint URL.
|
||||
headers (`dict[str, str]`):
|
||||
Headers to use for the request, including authorization headers and user agent.
|
||||
params (`dict[str, str]`, `optional`):
|
||||
Additional parameters to pass with the request.
|
||||
Returns:
|
||||
`XetConnectionInfo`:
|
||||
The connection information needed to make the request to the xet storage service.
|
||||
Raises:
|
||||
[`~utils.HfHubHTTPError`]
|
||||
If the Hub API returned an error.
|
||||
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
If the Hub API response is improperly formatted.
|
||||
"""
|
||||
# Check cache first
|
||||
cache_key = _cache_key(url, headers, params)
|
||||
cached_info = XET_CONNECTION_INFO_CACHE.get(cache_key)
|
||||
if cached_info is not None:
|
||||
if not _is_expired(cached_info):
|
||||
return cached_info
|
||||
|
||||
# Fetch from server
|
||||
resp = get_session().get(headers=headers, url=url, params=params)
|
||||
hf_raise_for_status(resp)
|
||||
|
||||
metadata = parse_xet_connection_info_from_headers(resp.headers) # type: ignore
|
||||
if metadata is None:
|
||||
raise ValueError("Xet headers have not been correctly set by the server.")
|
||||
|
||||
# Delete expired cache entries
|
||||
for k, v in list(XET_CONNECTION_INFO_CACHE.items()):
|
||||
if _is_expired(v):
|
||||
XET_CONNECTION_INFO_CACHE.pop(k, None)
|
||||
|
||||
# Enforce cache size limit
|
||||
if len(XET_CONNECTION_INFO_CACHE) >= XET_CONNECTION_INFO_CACHE_SIZE:
|
||||
XET_CONNECTION_INFO_CACHE.pop(next(iter(XET_CONNECTION_INFO_CACHE)))
|
||||
|
||||
# Update cache
|
||||
XET_CONNECTION_INFO_CACHE[cache_key] = metadata
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def _cache_key(url: str, headers: dict[str, str], params: Optional[dict[str, str]]) -> str:
|
||||
"""Return a unique cache key for the given request parameters."""
|
||||
lower_headers = {k.lower(): v for k, v in headers.items()} # casing is not guaranteed here
|
||||
auth_header = lower_headers.get("authorization", "")
|
||||
params_str = "&".join(f"{k}={v}" for k, v in sorted((params or {}).items(), key=lambda x: x[0]))
|
||||
return f"{url}|{auth_header}|{params_str}"
|
||||
|
||||
|
||||
def _is_expired(connection_info: XetConnectionInfo) -> bool:
|
||||
"""Check if the given XET connection info is expired."""
|
||||
return connection_info.expiration_unix_epoch <= int(time.time()) + XET_CONNECTION_INFO_SAFETY_PERIOD
|
||||
@@ -0,0 +1,162 @@
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
from hf_xet import PyItemProgressUpdate, PyTotalProgressUpdate
|
||||
|
||||
from . import is_google_colab, is_notebook
|
||||
from .tqdm import tqdm
|
||||
|
||||
|
||||
class XetProgressReporter:
|
||||
"""
|
||||
Reports on progress for Xet uploads.
|
||||
|
||||
Shows summary progress bars when running in notebooks or GUIs, and detailed per-file progress in console environments.
|
||||
"""
|
||||
|
||||
def __init__(self, n_lines: int = 10, description_width: int = 30):
|
||||
self.n_lines = n_lines
|
||||
self.description_width = description_width
|
||||
|
||||
self.per_file_progress = is_google_colab() or not is_notebook()
|
||||
|
||||
self.tqdm_settings = {
|
||||
"unit": "B",
|
||||
"unit_scale": True,
|
||||
"leave": True,
|
||||
"unit_divisor": 1000,
|
||||
"nrows": n_lines + 3 if self.per_file_progress else 3,
|
||||
"miniters": 1,
|
||||
"bar_format": "{l_bar}{bar}| {n_fmt:>5}B / {total_fmt:>5}B{postfix:>12}",
|
||||
}
|
||||
|
||||
# Overall progress bars
|
||||
self.data_processing_bar = tqdm(
|
||||
total=0, desc=self.format_desc("Processing Files (0 / 0)", False), position=0, **self.tqdm_settings
|
||||
)
|
||||
|
||||
self.upload_bar = tqdm(
|
||||
total=0, desc=self.format_desc("New Data Upload", False), position=1, **self.tqdm_settings
|
||||
)
|
||||
|
||||
self.known_items: set[str] = set()
|
||||
self.completed_items: set[str] = set()
|
||||
|
||||
# Item bars (scrolling view)
|
||||
self.item_state: OrderedDict[str, PyItemProgressUpdate] = OrderedDict()
|
||||
self.current_bars: List = [None] * self.n_lines
|
||||
|
||||
def format_desc(self, name: str, indent: bool) -> str:
|
||||
"""
|
||||
if name is longer than width characters, prints ... at the start and then the last width-3 characters of the name, otherwise
|
||||
the whole name right justified into description_width characters. Also adds some padding.
|
||||
"""
|
||||
|
||||
if not self.per_file_progress:
|
||||
# Here we just use the defaults.
|
||||
return name
|
||||
|
||||
padding = " " if indent else ""
|
||||
width = self.description_width - len(padding)
|
||||
|
||||
if len(name) > width:
|
||||
name = f"...{name[-(width - 3) :]}"
|
||||
|
||||
return f"{padding}{name.ljust(width)}"
|
||||
|
||||
def update_progress(self, total_update: PyTotalProgressUpdate, item_updates: list[PyItemProgressUpdate]):
|
||||
# Update all the per-item values.
|
||||
for item in item_updates:
|
||||
item_name = item.item_name
|
||||
|
||||
self.known_items.add(item_name)
|
||||
|
||||
# Only care about items where the processing has already started.
|
||||
if item.bytes_completed == 0:
|
||||
continue
|
||||
|
||||
# Overwrite the existing value in there.
|
||||
self.item_state[item_name] = item
|
||||
|
||||
bar_idx = 0
|
||||
new_completed = []
|
||||
|
||||
# Now, go through and update all the bars
|
||||
for name, item in self.item_state.items():
|
||||
# Is this ready to be removed on the next update?
|
||||
if item.bytes_completed == item.total_bytes:
|
||||
self.completed_items.add(name)
|
||||
new_completed.append(name)
|
||||
|
||||
# If we're only showing summary information, then don't update the individual bars
|
||||
if not self.per_file_progress:
|
||||
continue
|
||||
|
||||
# If we've run out of bars to use, then collapse the last ones together.
|
||||
if bar_idx >= len(self.current_bars):
|
||||
bar = self.current_bars[-1]
|
||||
in_final_bar_mode = True
|
||||
final_bar_aggregation_count = bar_idx + 1 - len(self.current_bars)
|
||||
else:
|
||||
bar = self.current_bars[bar_idx]
|
||||
in_final_bar_mode = False
|
||||
|
||||
if bar is None:
|
||||
self.current_bars[bar_idx] = tqdm(
|
||||
desc=self.format_desc(name, True),
|
||||
position=2 + bar_idx, # Set to the position past the initial bars.
|
||||
total=item.total_bytes,
|
||||
initial=item.bytes_completed,
|
||||
**self.tqdm_settings,
|
||||
)
|
||||
|
||||
elif in_final_bar_mode:
|
||||
bar.n += item.bytes_completed
|
||||
bar.total += item.total_bytes
|
||||
bar.set_description(self.format_desc(f"[+ {final_bar_aggregation_count} files]", True), refresh=False)
|
||||
else:
|
||||
bar.set_description(self.format_desc(name, True), refresh=False)
|
||||
bar.n = item.bytes_completed
|
||||
bar.total = item.total_bytes
|
||||
|
||||
bar_idx += 1
|
||||
|
||||
# Remove all the completed ones from the ordered dictionary
|
||||
for name in new_completed:
|
||||
# Only remove ones from consideration to make room for more items coming in.
|
||||
if len(self.item_state) <= self.n_lines:
|
||||
break
|
||||
|
||||
del self.item_state[name]
|
||||
|
||||
if self.per_file_progress:
|
||||
# Now manually refresh each of the bars
|
||||
for bar in self.current_bars:
|
||||
if bar:
|
||||
bar.refresh()
|
||||
|
||||
# Update overall bars
|
||||
def postfix(speed):
|
||||
s = tqdm.format_sizeof(speed) if speed is not None else "???"
|
||||
return f"{s}B/s ".rjust(10, " ")
|
||||
|
||||
self.data_processing_bar.total = total_update.total_bytes
|
||||
self.data_processing_bar.set_description(
|
||||
self.format_desc(f"Processing Files ({len(self.completed_items)} / {len(self.known_items)})", False),
|
||||
refresh=False,
|
||||
)
|
||||
self.data_processing_bar.set_postfix_str(postfix(total_update.total_bytes_completion_rate), refresh=False)
|
||||
self.data_processing_bar.update(total_update.total_bytes_completion_increment)
|
||||
|
||||
self.upload_bar.total = total_update.total_transfer_bytes
|
||||
self.upload_bar.set_postfix_str(postfix(total_update.total_transfer_bytes_completion_rate), refresh=False)
|
||||
self.upload_bar.update(total_update.total_transfer_bytes_completion_increment)
|
||||
|
||||
def close(self, _success):
|
||||
self.data_processing_bar.close()
|
||||
self.upload_bar.close()
|
||||
|
||||
if self.per_file_progress:
|
||||
for bar in self.current_bars:
|
||||
if bar:
|
||||
bar.close()
|
||||
@@ -0,0 +1,66 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Helpful utility functions and classes in relation to exploring API endpoints
|
||||
with the aim for a user-friendly interface.
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..repocard_data import ModelCardData
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hf_api import ModelInfo
|
||||
|
||||
|
||||
def _is_emission_within_threshold(model_info: "ModelInfo", minimum_threshold: float, maximum_threshold: float) -> bool:
|
||||
"""Checks if a model's emission is within a given threshold.
|
||||
|
||||
Args:
|
||||
model_info (`ModelInfo`):
|
||||
A model info object containing the model's emission information.
|
||||
minimum_threshold (`float`):
|
||||
A minimum carbon threshold to filter by, such as 1.
|
||||
maximum_threshold (`float`):
|
||||
A maximum carbon threshold to filter by, such as 10.
|
||||
|
||||
Returns:
|
||||
`bool`: Whether the model's emission is within the given threshold.
|
||||
"""
|
||||
if minimum_threshold is None and maximum_threshold is None:
|
||||
raise ValueError("Both `minimum_threshold` and `maximum_threshold` cannot both be `None`")
|
||||
if minimum_threshold is None:
|
||||
minimum_threshold = -1
|
||||
if maximum_threshold is None:
|
||||
maximum_threshold = math.inf
|
||||
|
||||
card_data = getattr(model_info, "card_data", None)
|
||||
if card_data is None or not isinstance(card_data, (dict, ModelCardData)):
|
||||
return False
|
||||
|
||||
# Get CO2 emission metadata
|
||||
emission = card_data.get("co2_eq_emissions", None)
|
||||
if isinstance(emission, dict):
|
||||
emission = emission["emissions"]
|
||||
if not emission:
|
||||
return False
|
||||
|
||||
# Filter out if value is missing or out of range
|
||||
matched = re.search(r"\d+\.\d+|\d+", str(emission))
|
||||
if matched is None:
|
||||
return False
|
||||
|
||||
emission_value = float(matched.group(0))
|
||||
return minimum_threshold <= emission_value <= maximum_threshold
|
||||
@@ -0,0 +1,32 @@
|
||||
# Taken from https://github.com/mlflow/mlflow/pull/10119
|
||||
#
|
||||
# DO NOT use this function for security purposes (e.g., password hashing).
|
||||
#
|
||||
# In Python >= 3.9, insecure hashing algorithms such as MD5 fail in FIPS-compliant
|
||||
# environments unless `usedforsecurity=False` is explicitly passed.
|
||||
#
|
||||
# References:
|
||||
# - https://github.com/mlflow/mlflow/issues/9905
|
||||
# - https://github.com/mlflow/mlflow/pull/10119
|
||||
# - https://docs.python.org/3/library/hashlib.html
|
||||
# - https://github.com/huggingface/transformers/pull/27038
|
||||
#
|
||||
# Usage:
|
||||
# ```python
|
||||
# # Use
|
||||
# from huggingface_hub.utils.insecure_hashlib import sha256
|
||||
# # instead of
|
||||
# from hashlib import sha256
|
||||
#
|
||||
# # Use
|
||||
# from huggingface_hub.utils import insecure_hashlib
|
||||
# # instead of
|
||||
# import hashlib
|
||||
# ```
|
||||
import functools
|
||||
import hashlib
|
||||
|
||||
|
||||
md5 = functools.partial(hashlib.md5, usedforsecurity=False)
|
||||
sha1 = functools.partial(hashlib.sha1, usedforsecurity=False)
|
||||
sha256 = functools.partial(hashlib.sha256, usedforsecurity=False)
|
||||
@@ -0,0 +1,185 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Optuna, Hugging Face
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Logging utilities."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging import (
|
||||
CRITICAL, # NOQA
|
||||
DEBUG, # NOQA
|
||||
ERROR, # NOQA
|
||||
FATAL, # NOQA
|
||||
INFO, # NOQA
|
||||
NOTSET, # NOQA
|
||||
WARN, # NOQA
|
||||
WARNING, # NOQA
|
||||
)
|
||||
from typing import Optional
|
||||
|
||||
from .. import constants
|
||||
|
||||
|
||||
log_levels = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}
|
||||
|
||||
_default_log_level = logging.WARNING
|
||||
|
||||
|
||||
def _get_library_name() -> str:
|
||||
return __name__.split(".")[0]
|
||||
|
||||
|
||||
def _get_library_root_logger() -> logging.Logger:
|
||||
return logging.getLogger(_get_library_name())
|
||||
|
||||
|
||||
def _get_default_logging_level():
|
||||
"""
|
||||
If `HF_HUB_VERBOSITY` env var is set to one of the valid choices return that as the new default level. If it is not
|
||||
- fall back to `_default_log_level`
|
||||
"""
|
||||
env_level_str = os.getenv("HF_HUB_VERBOSITY", None)
|
||||
if env_level_str:
|
||||
if env_level_str in log_levels:
|
||||
return log_levels[env_level_str]
|
||||
else:
|
||||
logging.getLogger().warning(
|
||||
f"Unknown option HF_HUB_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}"
|
||||
)
|
||||
return _default_log_level
|
||||
|
||||
|
||||
def _configure_library_root_logger() -> None:
|
||||
library_root_logger = _get_library_root_logger()
|
||||
library_root_logger.addHandler(logging.StreamHandler())
|
||||
library_root_logger.setLevel(_get_default_logging_level())
|
||||
|
||||
|
||||
def _reset_library_root_logger() -> None:
|
||||
library_root_logger = _get_library_root_logger()
|
||||
library_root_logger.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
"""
|
||||
Returns a logger with the specified name. This function is not supposed
|
||||
to be directly accessed by library users.
|
||||
|
||||
Args:
|
||||
name (`str`, *optional*):
|
||||
The name of the logger to get, usually the filename
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from huggingface_hub import get_logger
|
||||
|
||||
>>> logger = get_logger(__file__)
|
||||
>>> logger.set_verbosity_info()
|
||||
```
|
||||
"""
|
||||
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def get_verbosity() -> int:
|
||||
"""Return the current level for the HuggingFace Hub's root logger.
|
||||
|
||||
Returns:
|
||||
Logging level, e.g., `huggingface_hub.logging.DEBUG` and
|
||||
`huggingface_hub.logging.INFO`.
|
||||
|
||||
> [!TIP]
|
||||
> HuggingFace Hub has following logging levels:
|
||||
>
|
||||
> - `huggingface_hub.logging.CRITICAL`, `huggingface_hub.logging.FATAL`
|
||||
> - `huggingface_hub.logging.ERROR`
|
||||
> - `huggingface_hub.logging.WARNING`, `huggingface_hub.logging.WARN`
|
||||
> - `huggingface_hub.logging.INFO`
|
||||
> - `huggingface_hub.logging.DEBUG`
|
||||
"""
|
||||
return _get_library_root_logger().getEffectiveLevel()
|
||||
|
||||
|
||||
def set_verbosity(verbosity: int) -> None:
|
||||
"""
|
||||
Sets the level for the HuggingFace Hub's root logger.
|
||||
|
||||
Args:
|
||||
verbosity (`int`):
|
||||
Logging level, e.g., `huggingface_hub.logging.DEBUG` and
|
||||
`huggingface_hub.logging.INFO`.
|
||||
"""
|
||||
_get_library_root_logger().setLevel(verbosity)
|
||||
|
||||
|
||||
def set_verbosity_info():
|
||||
"""
|
||||
Sets the verbosity to `logging.INFO`.
|
||||
"""
|
||||
return set_verbosity(INFO)
|
||||
|
||||
|
||||
def set_verbosity_warning():
|
||||
"""
|
||||
Sets the verbosity to `logging.WARNING`.
|
||||
"""
|
||||
return set_verbosity(WARNING)
|
||||
|
||||
|
||||
def set_verbosity_debug():
|
||||
"""
|
||||
Sets the verbosity to `logging.DEBUG`.
|
||||
"""
|
||||
return set_verbosity(DEBUG)
|
||||
|
||||
|
||||
def set_verbosity_error():
|
||||
"""
|
||||
Sets the verbosity to `logging.ERROR`.
|
||||
"""
|
||||
return set_verbosity(ERROR)
|
||||
|
||||
|
||||
def disable_propagation() -> None:
|
||||
"""
|
||||
Disable propagation of the library log outputs. Note that log propagation is
|
||||
disabled by default.
|
||||
"""
|
||||
_get_library_root_logger().propagate = False
|
||||
|
||||
|
||||
def enable_propagation() -> None:
|
||||
"""
|
||||
Enable propagation of the library log outputs. Please disable the
|
||||
HuggingFace Hub's default handler to prevent double logging if the root
|
||||
logger has been configured.
|
||||
"""
|
||||
_get_library_root_logger().propagate = True
|
||||
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
if constants.HF_DEBUG:
|
||||
# If `HF_DEBUG` environment variable is set, set the verbosity of `huggingface_hub` logger to `DEBUG`.
|
||||
set_verbosity_debug()
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Utilities to efficiently compute the SHA 256 hash of a bunch of bytes."""
|
||||
|
||||
from typing import BinaryIO, Optional
|
||||
|
||||
from .insecure_hashlib import sha1, sha256
|
||||
|
||||
|
||||
def sha_fileobj(fileobj: BinaryIO, chunk_size: Optional[int] = None) -> bytes:
|
||||
"""
|
||||
Computes the sha256 hash of the given file object, by chunks of size `chunk_size`.
|
||||
|
||||
Args:
|
||||
fileobj (file-like object):
|
||||
The File object to compute sha256 for, typically obtained with `open(path, "rb")`
|
||||
chunk_size (`int`, *optional*):
|
||||
The number of bytes to read from `fileobj` at once, defaults to 1MB.
|
||||
|
||||
Returns:
|
||||
`bytes`: `fileobj`'s sha256 hash as bytes
|
||||
"""
|
||||
chunk_size = chunk_size if chunk_size is not None else 1024 * 1024
|
||||
|
||||
sha = sha256()
|
||||
while True:
|
||||
chunk = fileobj.read(chunk_size)
|
||||
sha.update(chunk)
|
||||
if not chunk:
|
||||
break
|
||||
return sha.digest()
|
||||
|
||||
|
||||
def git_hash(data: bytes) -> str:
|
||||
"""
|
||||
Computes the git-sha1 hash of the given bytes, using the same algorithm as git.
|
||||
|
||||
This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object
|
||||
for more details.
|
||||
|
||||
Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the
|
||||
pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of
|
||||
the LFS file content when we want to compare LFS files.
|
||||
|
||||
Args:
|
||||
data (`bytes`):
|
||||
The data to compute the git-hash for.
|
||||
|
||||
Returns:
|
||||
`str`: the git-hash of `data` as an hexadecimal string.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from huggingface_hub.utils.sha import git_hash
|
||||
>>> git_hash(b"Hello, World!")
|
||||
'b45ef6fec89518d314f546fd6c3025367b721684'
|
||||
```
|
||||
"""
|
||||
# Taken from https://gist.github.com/msabramo/763200
|
||||
# Note: no need to optimize by reading the file in chunks as we're not supposed to hash huge files (5MB maximum).
|
||||
sha = sha1()
|
||||
sha.update(b"blob ")
|
||||
sha.update(str(len(data)).encode())
|
||||
sha.update(b"\0")
|
||||
sha.update(data)
|
||||
return sha.hexdigest()
|
||||
@@ -0,0 +1,308 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
"""Utility helpers to handle progress bars in `huggingface_hub`.
|
||||
|
||||
Example:
|
||||
1. Use `huggingface_hub.utils.tqdm` as you would use `tqdm.tqdm` or `tqdm.auto.tqdm`.
|
||||
2. To disable progress bars, either use `disable_progress_bars()` helper or set the
|
||||
environment variable `HF_HUB_DISABLE_PROGRESS_BARS` to 1.
|
||||
3. To re-enable progress bars, use `enable_progress_bars()`.
|
||||
4. To check whether progress bars are disabled, use `are_progress_bars_disabled()`.
|
||||
|
||||
NOTE: Environment variable `HF_HUB_DISABLE_PROGRESS_BARS` has the priority.
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm
|
||||
|
||||
# Disable progress bars globally
|
||||
>>> disable_progress_bars()
|
||||
|
||||
# Use as normal `tqdm`
|
||||
>>> for _ in tqdm(range(5)):
|
||||
... pass
|
||||
|
||||
# Still not showing progress bars, as `disable=False` is overwritten to `True`.
|
||||
>>> for _ in tqdm(range(5), disable=False):
|
||||
... pass
|
||||
|
||||
>>> are_progress_bars_disabled()
|
||||
True
|
||||
|
||||
# Re-enable progress bars globally
|
||||
>>> enable_progress_bars()
|
||||
|
||||
# Progress bar will be shown !
|
||||
>>> for _ in tqdm(range(5)):
|
||||
... pass
|
||||
100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s]
|
||||
```
|
||||
|
||||
Group-based control:
|
||||
```python
|
||||
# Disable progress bars for a specific group
|
||||
>>> disable_progress_bars("peft.foo")
|
||||
|
||||
# Check state of different groups
|
||||
>>> assert not are_progress_bars_disabled("peft"))
|
||||
>>> assert not are_progress_bars_disabled("peft.something")
|
||||
>>> assert are_progress_bars_disabled("peft.foo"))
|
||||
>>> assert are_progress_bars_disabled("peft.foo.bar"))
|
||||
|
||||
# Enable progress bars for a subgroup
|
||||
>>> enable_progress_bars("peft.foo.bar")
|
||||
|
||||
# Check if enabling a subgroup affects the parent group
|
||||
>>> assert are_progress_bars_disabled("peft.foo"))
|
||||
>>> assert not are_progress_bars_disabled("peft.foo.bar"))
|
||||
|
||||
# No progress bar for `name="peft.foo"`
|
||||
>>> for _ in tqdm(range(5), name="peft.foo"):
|
||||
... pass
|
||||
|
||||
# Progress bar will be shown for `name="peft.foo.bar"`
|
||||
>>> for _ in tqdm(range(5), name="peft.foo.bar"):
|
||||
... pass
|
||||
100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s]
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import ContextManager, Iterator, Optional, Union
|
||||
|
||||
from tqdm.auto import tqdm as old_tqdm
|
||||
|
||||
from ..constants import HF_HUB_DISABLE_PROGRESS_BARS
|
||||
|
||||
|
||||
# The `HF_HUB_DISABLE_PROGRESS_BARS` environment variable can be True, False, or not set (None),
|
||||
# allowing for control over progress bar visibility. When set, this variable takes precedence
|
||||
# over programmatic settings, dictating whether progress bars should be shown or hidden globally.
|
||||
# Essentially, the environment variable's setting overrides any code-based configurations.
|
||||
#
|
||||
# If `HF_HUB_DISABLE_PROGRESS_BARS` is not defined (None), it implies that users can manage
|
||||
# progress bar visibility through code. By default, progress bars are turned on.
|
||||
|
||||
|
||||
progress_bar_states: dict[str, bool] = {}
|
||||
|
||||
|
||||
def disable_progress_bars(name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Disable progress bars either globally or for a specified group.
|
||||
|
||||
This function updates the state of progress bars based on a group name.
|
||||
If no group name is provided, all progress bars are disabled. The operation
|
||||
respects the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable's setting.
|
||||
|
||||
Args:
|
||||
name (`str`, *optional*):
|
||||
The name of the group for which to disable the progress bars. If None,
|
||||
progress bars are disabled globally.
|
||||
|
||||
Raises:
|
||||
Warning: If the environment variable precludes changes.
|
||||
"""
|
||||
if HF_HUB_DISABLE_PROGRESS_BARS is False:
|
||||
warnings.warn(
|
||||
"Cannot disable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=0` is set and has priority."
|
||||
)
|
||||
return
|
||||
|
||||
if name is None:
|
||||
progress_bar_states.clear()
|
||||
progress_bar_states["_global"] = False
|
||||
else:
|
||||
keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")]
|
||||
for key in keys_to_remove:
|
||||
del progress_bar_states[key]
|
||||
progress_bar_states[name] = False
|
||||
|
||||
|
||||
def enable_progress_bars(name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Enable progress bars either globally or for a specified group.
|
||||
|
||||
This function sets the progress bars to enabled for the specified group or globally
|
||||
if no group is specified. The operation is subject to the `HF_HUB_DISABLE_PROGRESS_BARS`
|
||||
environment setting.
|
||||
|
||||
Args:
|
||||
name (`str`, *optional*):
|
||||
The name of the group for which to enable the progress bars. If None,
|
||||
progress bars are enabled globally.
|
||||
|
||||
Raises:
|
||||
Warning: If the environment variable precludes changes.
|
||||
"""
|
||||
if HF_HUB_DISABLE_PROGRESS_BARS is True:
|
||||
warnings.warn(
|
||||
"Cannot enable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=1` is set and has priority."
|
||||
)
|
||||
return
|
||||
|
||||
if name is None:
|
||||
progress_bar_states.clear()
|
||||
progress_bar_states["_global"] = True
|
||||
else:
|
||||
keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")]
|
||||
for key in keys_to_remove:
|
||||
del progress_bar_states[key]
|
||||
progress_bar_states[name] = True
|
||||
|
||||
|
||||
def are_progress_bars_disabled(name: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Check if progress bars are disabled globally or for a specific group.
|
||||
|
||||
This function returns whether progress bars are disabled for a given group or globally.
|
||||
It checks the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable first, then the programmatic
|
||||
settings.
|
||||
|
||||
Args:
|
||||
name (`str`, *optional*):
|
||||
The group name to check; if None, checks the global setting.
|
||||
|
||||
Returns:
|
||||
`bool`: True if progress bars are disabled, False otherwise.
|
||||
"""
|
||||
if HF_HUB_DISABLE_PROGRESS_BARS is True:
|
||||
return True
|
||||
|
||||
if name is None:
|
||||
return not progress_bar_states.get("_global", True)
|
||||
|
||||
while name:
|
||||
if name in progress_bar_states:
|
||||
return not progress_bar_states[name]
|
||||
name = ".".join(name.split(".")[:-1])
|
||||
|
||||
return not progress_bar_states.get("_global", True)
|
||||
|
||||
|
||||
def is_tqdm_disabled(log_level: int) -> Optional[bool]:
|
||||
"""
|
||||
Determine if tqdm progress bars should be disabled based on logging level and environment settings.
|
||||
|
||||
see https://github.com/huggingface/huggingface_hub/pull/2000 and https://github.com/huggingface/huggingface_hub/pull/2698.
|
||||
"""
|
||||
if log_level == logging.NOTSET:
|
||||
return True
|
||||
if os.getenv("TQDM_POSITION") == "-1":
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
class tqdm(old_tqdm):
|
||||
"""
|
||||
Class to override `disable` argument in case progress bars are globally disabled.
|
||||
|
||||
Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
name = kwargs.pop("name", None) # do not pass `name` to `tqdm`
|
||||
if are_progress_bars_disabled(name):
|
||||
kwargs["disable"] = True
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __delattr__(self, attr: str) -> None:
|
||||
"""Fix for https://github.com/huggingface/huggingface_hub/issues/1603"""
|
||||
try:
|
||||
super().__delattr__(attr)
|
||||
except AttributeError:
|
||||
if attr != "_lock":
|
||||
raise
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tqdm_stream_file(path: Union[Path, str]) -> Iterator[io.BufferedReader]:
|
||||
"""
|
||||
Open a file as binary and wrap the `read` method to display a progress bar when it's streamed.
|
||||
|
||||
First implemented in `transformers` in 2019 but removed when switched to git-lfs. Used in `huggingface_hub` to show
|
||||
progress bar when uploading an LFS file to the Hub. See github.com/huggingface/transformers/pull/2078#discussion_r354739608
|
||||
for implementation details.
|
||||
|
||||
Note: currently implementation handles only files stored on disk as it is the most common use case. Could be
|
||||
extended to stream any `BinaryIO` object but we might have to debug some corner cases.
|
||||
|
||||
Example:
|
||||
```py
|
||||
>>> with tqdm_stream_file("config.json") as f:
|
||||
>>> httpx.put(url, data=f)
|
||||
config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
|
||||
```
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
with path.open("rb") as f:
|
||||
total_size = path.stat().st_size
|
||||
pbar = tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
total=total_size,
|
||||
initial=0,
|
||||
desc=path.name,
|
||||
)
|
||||
|
||||
f_read = f.read
|
||||
|
||||
def _inner_read(size: Optional[int] = -1) -> bytes:
|
||||
data = f_read(size)
|
||||
pbar.update(len(data))
|
||||
return data
|
||||
|
||||
f.read = _inner_read # type: ignore
|
||||
|
||||
yield f
|
||||
|
||||
pbar.close()
|
||||
|
||||
|
||||
def _get_progress_bar_context(
|
||||
*,
|
||||
desc: str,
|
||||
log_level: int,
|
||||
total: Optional[int] = None,
|
||||
initial: int = 0,
|
||||
unit: str = "B",
|
||||
unit_scale: bool = True,
|
||||
name: Optional[str] = None,
|
||||
tqdm_class: Optional[type[old_tqdm]] = None,
|
||||
_tqdm_bar: Optional[tqdm] = None,
|
||||
) -> ContextManager[tqdm]:
|
||||
if _tqdm_bar is not None:
|
||||
return nullcontext(_tqdm_bar)
|
||||
# ^ `contextlib.nullcontext` mimics a context manager that does nothing
|
||||
# Makes it easier to use the same code path for both cases but in the later
|
||||
# case, the progress bar is not closed when exiting the context manager.
|
||||
|
||||
return (tqdm_class or tqdm)( # type: ignore[return-value]
|
||||
unit=unit,
|
||||
unit_scale=unit_scale,
|
||||
total=total,
|
||||
initial=initial,
|
||||
desc=desc,
|
||||
disable=is_tqdm_disabled(log_level=log_level),
|
||||
name=name,
|
||||
)
|
||||
Reference in New Issue
Block a user