chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding module in agentscope."""
|
||||
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._dashscope_embedding import DashScopeTextEmbedding
|
||||
from ._dashscope_multimodal_embedding import DashScopeMultiModalEmbedding
|
||||
from ._openai_embedding import OpenAITextEmbedding
|
||||
from ._gemini_embedding import GeminiTextEmbedding
|
||||
from ._ollama_embedding import OllamaTextEmbedding
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._file_cache import FileEmbeddingCache
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingModelBase",
|
||||
"EmbeddingUsage",
|
||||
"EmbeddingResponse",
|
||||
"DashScopeTextEmbedding",
|
||||
"DashScopeMultiModalEmbedding",
|
||||
"OpenAITextEmbedding",
|
||||
"GeminiTextEmbedding",
|
||||
"OllamaTextEmbedding",
|
||||
"EmbeddingCacheBase",
|
||||
"FileEmbeddingCache",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding cache base class."""
|
||||
from abc import abstractmethod
|
||||
from typing import List, Any
|
||||
|
||||
from ..types import (
|
||||
JSONSerializableObject,
|
||||
Embedding,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingCacheBase:
|
||||
"""Base class for embedding caches, which is responsible for storing and
|
||||
retrieving embeddings."""
|
||||
|
||||
@abstractmethod
|
||||
async def store(
|
||||
self,
|
||||
embeddings: List[Embedding],
|
||||
identifier: JSONSerializableObject,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Store the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
embeddings (`List[Embedding]`):
|
||||
The embeddings to store.
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to distinguish the embeddings.
|
||||
overwrite (`bool`, defaults to `False`):
|
||||
Whether to overwrite existing embeddings with the same
|
||||
identifier. If `True`, existing embeddings will be replaced.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(
|
||||
self,
|
||||
identifier: JSONSerializableObject,
|
||||
) -> List[Embedding] | None:
|
||||
"""Retrieve the embeddings with the given identifier. If not
|
||||
found, return `None`.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to retrieve the embeddings.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def remove(
|
||||
self,
|
||||
identifier: JSONSerializableObject,
|
||||
) -> None:
|
||||
"""Remove the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to remove the embeddings.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cached embeddings."""
|
||||
@@ -0,0 +1,169 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The dashscope embedding module in agentscope."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from .._logging import logger
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class DashScopeTextEmbedding(EmbeddingModelBase):
|
||||
"""DashScope text embedding API class.
|
||||
|
||||
.. note:: From the `official documentation
|
||||
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_:
|
||||
|
||||
- The max batch size that DashScope text embedding API
|
||||
supports is 10 for `text-embedding-v4` and `text-embedding-v3` models, and
|
||||
25 for `text-embedding-v2` and `text-embedding-v1` models.
|
||||
- The max token limit for a single input is 8192 tokens for `v4` and `v3`
|
||||
models, and 2048 tokens for `v2` and `v1` models.
|
||||
|
||||
"""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 1024,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The dashscope API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector, refer to the
|
||||
`official documentation
|
||||
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_
|
||||
for more details.
|
||||
embedding_cache (`EmbeddingCacheBase`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.api_key = api_key
|
||||
self.embedding_cache = embedding_cache
|
||||
self.batch_size_limit = 10
|
||||
|
||||
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
|
||||
"""Call the DashScope embedding API by the given keyword arguments."""
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
import dashscope
|
||||
|
||||
start_time = datetime.now()
|
||||
response = dashscope.embeddings.TextEmbedding.call(
|
||||
api_key=self.api_key,
|
||||
**kwargs,
|
||||
)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get embedding from DashScope API: {response}",
|
||||
)
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[
|
||||
_["embedding"] for _ in response.output["embeddings"]
|
||||
],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_["embedding"] for _ in response.output["embeddings"]],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=response.usage["total_tokens"],
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the DashScope embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
if len(gather_text) > self.batch_size_limit:
|
||||
logger.info(
|
||||
"The input texts (%d) will be embedded with %d API calls due "
|
||||
f"to the batch size limit of {self.batch_size_limit} for "
|
||||
f"DashScope embedding API.",
|
||||
len(gather_text),
|
||||
(len(gather_text) + self.batch_size_limit - 1)
|
||||
// self.batch_size_limit,
|
||||
)
|
||||
|
||||
# Handle the batch size limit for DashScope embedding API
|
||||
collected_embeddings = []
|
||||
collected_time = 0.0
|
||||
collected_tokens = 0
|
||||
collected_source: Literal["cache", "api"] = "cache"
|
||||
for _ in range(0, len(gather_text), self.batch_size_limit):
|
||||
batch_texts = gather_text[_ : _ + self.batch_size_limit]
|
||||
batch_kwargs = {
|
||||
"input": batch_texts,
|
||||
"model": self.model_name,
|
||||
"dimension": self.dimensions,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
res = await self._call_api(batch_kwargs)
|
||||
|
||||
collected_embeddings.extend(res.embeddings)
|
||||
collected_time += res.usage.time
|
||||
if res.usage.tokens:
|
||||
collected_tokens += res.usage.tokens
|
||||
if res.source == "api":
|
||||
collected_source = "api"
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=collected_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=collected_tokens,
|
||||
time=collected_time,
|
||||
),
|
||||
source=collected_source,
|
||||
)
|
||||
@@ -0,0 +1,244 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The dashscope multimodal embedding model in agentscope."""
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import (
|
||||
VideoBlock,
|
||||
ImageBlock,
|
||||
TextBlock,
|
||||
)
|
||||
|
||||
|
||||
class DashScopeMultiModalEmbedding(EmbeddingModelBase):
|
||||
"""The DashScope multimodal embedding API, supporting text, image and
|
||||
video embedding."""
|
||||
|
||||
supported_modalities: list[str] = ["text", "image", "video"]
|
||||
"""This class supports text, image and video input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int | None = None,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope multimodal embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The dashscope API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model, e.g. "multimodal-embedding-
|
||||
v1", "tongyi-embedding-vision-plus".
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector, refer to the
|
||||
`official documentation
|
||||
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712517>`_
|
||||
for more details.
|
||||
embedding_cache (`EmbeddingCacheBase`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
path_doc = (
|
||||
"https://bailian.console.aliyun.com/?tab=api#/api/?type=model&"
|
||||
"url=2712517"
|
||||
)
|
||||
self.batch_size_limit = 1
|
||||
|
||||
if model_name.startswith("tongyi-embedding-vision-plus"):
|
||||
self.batch_size_limit = 8
|
||||
if dimensions is None:
|
||||
dimensions = 1152
|
||||
elif dimensions != 1152:
|
||||
raise ValueError(
|
||||
f"The dimension of model {model_name} must be 1152, "
|
||||
"refer to the official documentation for more details: "
|
||||
f"{path_doc}",
|
||||
)
|
||||
if model_name.startswith("tongyi-embedding-vision-flash"):
|
||||
self.batch_size_limit = 8
|
||||
if dimensions is None:
|
||||
dimensions = 768
|
||||
elif dimensions != 768:
|
||||
raise ValueError(
|
||||
f"The dimension of model {model_name} must be 768, "
|
||||
"refer to the official documentation for more details: "
|
||||
f"{path_doc}",
|
||||
)
|
||||
if model_name.startswith("multimodal-embedding-v"):
|
||||
if dimensions is None:
|
||||
dimensions = 1024
|
||||
elif dimensions != 1024:
|
||||
raise ValueError(
|
||||
f"The dimension of model {model_name} must be 1024, "
|
||||
"refer to the official documentation for more details: "
|
||||
f"{path_doc}",
|
||||
)
|
||||
refined_dimensions: int = 1024
|
||||
if dimensions is not None:
|
||||
refined_dimensions = dimensions
|
||||
super().__init__(model_name, refined_dimensions)
|
||||
|
||||
self.api_key = api_key
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
inputs: list[TextBlock | ImageBlock | VideoBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the DashScope multimodal embedding API, which accepts text,
|
||||
image, and video data.
|
||||
|
||||
Args:
|
||||
inputs (`list[TextBlock | ImageBlock | VideoBlock]`):
|
||||
The input data to be embedded. It can be a list of text,
|
||||
image, and video blocks.
|
||||
|
||||
Returns:
|
||||
`EmbeddingResponse`:
|
||||
The embedding response object, which contains the embeddings
|
||||
and usage information.
|
||||
"""
|
||||
# check data type
|
||||
formatted_data = []
|
||||
for _ in inputs:
|
||||
if (
|
||||
not isinstance(_, dict)
|
||||
or "type" not in _
|
||||
or _["type"]
|
||||
not in [
|
||||
"text",
|
||||
"image",
|
||||
"video",
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid data : {_}. It should be a list of "
|
||||
"TextBlock, ImageBlock, or VideoBlock.",
|
||||
)
|
||||
if (
|
||||
_["type"] == "video"
|
||||
and _.get("source", {}).get("type") != "url"
|
||||
):
|
||||
raise ValueError(
|
||||
f"The multimodal embedding API only supports URL input "
|
||||
f"for video data, but got {_}.",
|
||||
)
|
||||
|
||||
if _["type"] == "text":
|
||||
assert "text" in _, (
|
||||
f"Invalid text block: {_}. It should contain a "
|
||||
f"'text' field.",
|
||||
)
|
||||
formatted_data.append({"text": _["text"]})
|
||||
|
||||
elif _["type"] == "video":
|
||||
formatted_data.append({"video": _["source"]["url"]})
|
||||
|
||||
elif (
|
||||
_["type"] == "image"
|
||||
and "source" in _
|
||||
and _["source"].get("type") in ["base64", "url"]
|
||||
):
|
||||
typ = _["source"]["type"]
|
||||
if typ == "base64":
|
||||
formatted_data.append(
|
||||
{
|
||||
"image": f'data:{_["source"]["media_type"]};'
|
||||
f'base64,{_["source"]["data"]}',
|
||||
},
|
||||
)
|
||||
elif typ == "url":
|
||||
formatted_data.append(
|
||||
{"image": _["source"]["url"]},
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid block {_}. It should be a valid TextBlock, "
|
||||
f"ImageBlock, or VideoBlock.",
|
||||
)
|
||||
|
||||
# Handle the batch size limit of the DashScope multimodal embedding API
|
||||
collected_embeddings = []
|
||||
collected_time = 0.0
|
||||
collected_tokens = 0
|
||||
collected_source: Literal["cache", "api"] = "cache"
|
||||
for _ in range(0, len(formatted_data), self.batch_size_limit):
|
||||
batch_data = formatted_data[_ : _ + self.batch_size_limit]
|
||||
batch_kwargs = {
|
||||
"input": batch_data,
|
||||
"model": self.model_name,
|
||||
**kwargs,
|
||||
}
|
||||
res = await self._call_api(batch_kwargs)
|
||||
|
||||
collected_embeddings.extend(res.embeddings)
|
||||
collected_time += res.usage.time
|
||||
if res.usage.tokens:
|
||||
collected_tokens += res.usage.tokens
|
||||
if res.source == "api":
|
||||
collected_source = "api"
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=collected_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=collected_tokens,
|
||||
time=collected_time,
|
||||
),
|
||||
source=collected_source,
|
||||
)
|
||||
|
||||
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
|
||||
"""
|
||||
Call the DashScope multimodal embedding API by the given arguments.
|
||||
"""
|
||||
# Search in cache first
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
import dashscope
|
||||
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
||||
start_time = datetime.now()
|
||||
res = dashscope.MultiModalEmbedding.call(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if res.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get embedding from DashScope API: {res}",
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_["embedding"] for _ in res.output["embeddings"]],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=res.usage.get(
|
||||
"image_tokens",
|
||||
0,
|
||||
)
|
||||
+ res.usage.get(
|
||||
"input_tokens",
|
||||
0,
|
||||
),
|
||||
time=time,
|
||||
),
|
||||
source="api",
|
||||
)
|
||||
@@ -0,0 +1,45 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding model base class."""
|
||||
from typing import Any
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
|
||||
|
||||
class EmbeddingModelBase:
|
||||
"""Base class for embedding models."""
|
||||
|
||||
model_name: str
|
||||
"""The embedding model name"""
|
||||
|
||||
supported_modalities: list[str]
|
||||
"""The supported data modalities, e.g. "text", "image", "video"."""
|
||||
|
||||
dimensions: int
|
||||
"""The dimensions of the embedding vector."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
dimensions: int,
|
||||
) -> None:
|
||||
"""Initialize the embedding model base class.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`):
|
||||
The dimension of the embedding vector.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.dimensions = dimensions
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the embedding API with the given arguments."""
|
||||
raise NotImplementedError(
|
||||
f"The {self.__class__.__name__} class does not implement "
|
||||
f"the __call__ method.",
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding response class."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, List
|
||||
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from .._utils._common import _get_timestamp
|
||||
from .._utils._mixin import DictMixin
|
||||
from ..types import Embedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResponse(DictMixin):
|
||||
"""The embedding response class."""
|
||||
|
||||
embeddings: List[Embedding]
|
||||
"""The embedding data"""
|
||||
|
||||
id: str = field(default_factory=lambda: _get_timestamp(True))
|
||||
"""The identity of the embedding response"""
|
||||
|
||||
created_at: str = field(default_factory=_get_timestamp)
|
||||
"""The timestamp of the embedding response creation"""
|
||||
|
||||
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
|
||||
"""The type of the response, must be `embedding`."""
|
||||
|
||||
usage: EmbeddingUsage | None = field(default_factory=lambda: None)
|
||||
"""The usage of the embedding model API invocation, if available."""
|
||||
|
||||
source: Literal["cache", "api"] = field(default_factory=lambda: "api")
|
||||
"""If the response comes from the cache or the API."""
|
||||
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding usage class in agentscope."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from .._utils._mixin import DictMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingUsage(DictMixin):
|
||||
"""The usage of an embedding model API invocation."""
|
||||
|
||||
time: float
|
||||
"""The time used in seconds."""
|
||||
|
||||
tokens: int | None = field(default_factory=lambda: None)
|
||||
"""The number of tokens used, if available."""
|
||||
|
||||
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
|
||||
"""The type of the usage, must be `embedding`."""
|
||||
@@ -0,0 +1,187 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""A file embedding cache implementation for storing and retrieving
|
||||
embeddings in binary files."""
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from .._logging import logger
|
||||
from ..types import (
|
||||
Embedding,
|
||||
JSONSerializableObject,
|
||||
)
|
||||
|
||||
|
||||
class FileEmbeddingCache(EmbeddingCacheBase):
|
||||
"""The embedding cache class that stores each embeddings vector in
|
||||
binary files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: str = "./.cache/embeddings",
|
||||
max_file_number: int | None = None,
|
||||
max_cache_size: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the file embedding cache class.
|
||||
|
||||
Args:
|
||||
cache_dir (`str`, defaults to `"./.cache/embeddings"`):
|
||||
The directory to store the embedding files.
|
||||
max_file_number (`int | None`, defaults to `None`):
|
||||
The maximum number of files to keep in the cache directory. If
|
||||
exceeded, the oldest files will be removed.
|
||||
max_cache_size (`int | None`, defaults to `None`):
|
||||
The maximum size of the cache directory in MB. If exceeded,
|
||||
the oldest files will be removed until the size is within the
|
||||
limit.
|
||||
"""
|
||||
self._cache_dir = os.path.abspath(cache_dir)
|
||||
self.max_file_number = max_file_number
|
||||
self.max_cache_size = max_cache_size
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> str:
|
||||
"""The cache directory where the embedding files are stored."""
|
||||
if not os.path.exists(self._cache_dir):
|
||||
os.makedirs(self._cache_dir, exist_ok=True)
|
||||
return self._cache_dir
|
||||
|
||||
async def store(
|
||||
self,
|
||||
embeddings: List[Embedding],
|
||||
identifier: JSONSerializableObject,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Store the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
embeddings (`List[Embedding]`):
|
||||
The embeddings to store.
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to distinguish the embeddings, which will be
|
||||
used to generate a hashable filename, so it should be
|
||||
JSON serializable (e.g. a string, number, list, dict).
|
||||
overwrite (`bool`, defaults to `False`):
|
||||
Whether to overwrite existing embeddings with the same
|
||||
identifier. If `True`, existing embeddings will be replaced.
|
||||
"""
|
||||
filename = self._get_filename(identifier)
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
if os.path.exists(path_file):
|
||||
if not os.path.isfile(path_file):
|
||||
raise RuntimeError(
|
||||
f"Path {path_file} exists but is not a file.",
|
||||
)
|
||||
|
||||
if overwrite:
|
||||
np.save(path_file, embeddings)
|
||||
await self._maintain_cache_dir()
|
||||
else:
|
||||
np.save(path_file, embeddings)
|
||||
await self._maintain_cache_dir()
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
identifier: JSONSerializableObject,
|
||||
) -> List[Embedding] | None:
|
||||
"""Retrieve the embeddings with the given identifier. If not found,
|
||||
return `None`.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to retrieve the embeddings, which will be
|
||||
used to generate a hashable filename, so it should be
|
||||
JSON serializable (e.g. a string, number, list, dict).
|
||||
"""
|
||||
filename = self._get_filename(identifier)
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
if os.path.exists(path_file):
|
||||
return np.load(os.path.join(self.cache_dir, filename)).tolist()
|
||||
return None
|
||||
|
||||
async def remove(self, identifier: JSONSerializableObject) -> None:
|
||||
"""Remove the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifiers to remove the embeddings, which will be
|
||||
used to generate a hashable filename, so it should be
|
||||
JSON serializable (e.g. a string, number, list, dict).
|
||||
"""
|
||||
filename = self._get_filename(identifier)
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
if os.path.exists(path_file):
|
||||
os.remove(path_file)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {path_file} does not exist.")
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache directory by removing all files."""
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
if filename.endswith(".npy"):
|
||||
os.remove(os.path.join(self.cache_dir, filename))
|
||||
|
||||
def _get_cache_size(self) -> float:
|
||||
"""Get the current size of the cache directory in MB."""
|
||||
total_size = 0
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
if filename.endswith(".npy"):
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
if os.path.isfile(path_file):
|
||||
total_size += os.path.getsize(path_file)
|
||||
return total_size / (1024.0 * 1024.0)
|
||||
|
||||
@staticmethod
|
||||
def _get_filename(identifier: JSONSerializableObject) -> str:
|
||||
"""Generate a filename based on the identifier."""
|
||||
json_str = json.dumps(identifier, ensure_ascii=False)
|
||||
return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + ".npy"
|
||||
|
||||
async def _maintain_cache_dir(self) -> None:
|
||||
"""Maintain the cache directory by removing old files if the number of
|
||||
files exceeds the maximum limit or if the cache size exceeds the
|
||||
maximum size."""
|
||||
files = [
|
||||
(_.name, _.stat().st_mtime)
|
||||
for _ in os.scandir(self.cache_dir)
|
||||
if _.is_file() and _.name.endswith(".npy")
|
||||
]
|
||||
files.sort(key=lambda x: x[1])
|
||||
|
||||
if self.max_file_number and len(files) > self.max_file_number:
|
||||
for file_name, _ in files[: 0 - self.max_file_number]:
|
||||
os.remove(os.path.join(self.cache_dir, file_name))
|
||||
logger.info(
|
||||
"Remove cached embedding file %s for limited number "
|
||||
"of files (%d).",
|
||||
file_name,
|
||||
self.max_file_number,
|
||||
)
|
||||
files = files[0 - self.max_file_number :]
|
||||
|
||||
if (
|
||||
self.max_cache_size is not None
|
||||
and self._get_cache_size() > self.max_cache_size
|
||||
):
|
||||
removed_files = []
|
||||
for filename, _ in files:
|
||||
os.remove(os.path.join(self.cache_dir, filename))
|
||||
removed_files.append(filename)
|
||||
if self._get_cache_size() <= self.max_cache_size:
|
||||
break
|
||||
|
||||
if removed_files:
|
||||
logger.info(
|
||||
"Remove %d cached embedding file(s) for limited "
|
||||
"cache size (%d MB).",
|
||||
len(removed_files),
|
||||
self.max_cache_size,
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The gemini text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class GeminiTextEmbedding(EmbeddingModelBase):
|
||||
"""The Gemini text embedding model."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 3072,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Gemini text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The Gemini API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 3072):
|
||||
The dimension of the embedding vector, refer to the
|
||||
`official documentation
|
||||
<https://ai.google.dev/gemini-api/docs/embeddings?hl=zh-cn#control-embedding-size>`_
|
||||
for more details.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
from google import genai
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = genai.Client(api_key=api_key, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""The Gemini embedding API call.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
|
||||
# TODO: handle the batch size limit
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": self.model_name,
|
||||
"contents": gather_text,
|
||||
"config": kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = self.client.models.embed_content(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[_.values for _ in response.embeddings],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_.values for _ in response.embeddings],
|
||||
usage=EmbeddingUsage(
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The ollama text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Any
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ..embedding import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class OllamaTextEmbedding(EmbeddingModelBase):
|
||||
"""The Ollama embedding model."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
dimensions: int,
|
||||
host: str | None = None,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Ollama text embedding model class.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`):
|
||||
The dimension of the embedding vector, the parameter should be
|
||||
provided according to the model used.
|
||||
host (`str | None`, defaults to `None`):
|
||||
The host URL for the Ollama API.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
import ollama
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = ollama.AsyncClient(host=host, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the Ollama embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"input": gather_text,
|
||||
"model": self.model_name,
|
||||
"dimensions": self.dimensions,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = await self.client.embed(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=response.embeddings,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=response.embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The OpenAI text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class OpenAITextEmbedding(EmbeddingModelBase):
|
||||
"""OpenAI text embedding model class."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 1024,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The OpenAI API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
|
||||
# TODO: handle batch size limit and token limit
|
||||
"""
|
||||
import openai
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = openai.AsyncClient(api_key=api_key, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the OpenAI embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"input": gather_text,
|
||||
"model": self.model_name,
|
||||
"dimensions": self.dimensions,
|
||||
"encoding_format": "float",
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = await self.client.embeddings.create(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[_.embedding for _ in response.data],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_.embedding for _ in response.data],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=response.usage.total_tokens,
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user