chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""The embedding module in agentscope."""
from ._embedding_base import EmbeddingModelBase
from ._embedding_usage import EmbeddingUsage
from ._embedding_response import EmbeddingResponse
from ._dashscope_embedding import DashScopeTextEmbedding
from ._dashscope_multimodal_embedding import DashScopeMultiModalEmbedding
from ._openai_embedding import OpenAITextEmbedding
from ._gemini_embedding import GeminiTextEmbedding
from ._ollama_embedding import OllamaTextEmbedding
from ._cache_base import EmbeddingCacheBase
from ._file_cache import FileEmbeddingCache
__all__ = [
"EmbeddingModelBase",
"EmbeddingUsage",
"EmbeddingResponse",
"DashScopeTextEmbedding",
"DashScopeMultiModalEmbedding",
"OpenAITextEmbedding",
"GeminiTextEmbedding",
"OllamaTextEmbedding",
"EmbeddingCacheBase",
"FileEmbeddingCache",
]

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
"""The embedding cache base class."""
from abc import abstractmethod
from typing import List, Any
from ..types import (
JSONSerializableObject,
Embedding,
)
class EmbeddingCacheBase:
"""Base class for embedding caches, which is responsible for storing and
retrieving embeddings."""
@abstractmethod
async def store(
self,
embeddings: List[Embedding],
identifier: JSONSerializableObject,
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Store the embeddings with the given identifier.
Args:
embeddings (`List[Embedding]`):
The embeddings to store.
identifier (`JSONSerializableObject`):
The identifier to distinguish the embeddings.
overwrite (`bool`, defaults to `False`):
Whether to overwrite existing embeddings with the same
identifier. If `True`, existing embeddings will be replaced.
"""
@abstractmethod
async def retrieve(
self,
identifier: JSONSerializableObject,
) -> List[Embedding] | None:
"""Retrieve the embeddings with the given identifier. If not
found, return `None`.
Args:
identifier (`JSONSerializableObject`):
The identifier to retrieve the embeddings.
"""
@abstractmethod
async def remove(
self,
identifier: JSONSerializableObject,
) -> None:
"""Remove the embeddings with the given identifier.
Args:
identifier (`JSONSerializableObject`):
The identifier to remove the embeddings.
"""
@abstractmethod
async def clear(self) -> None:
"""Clear all cached embeddings."""

View File

@@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
"""The dashscope embedding module in agentscope."""
from datetime import datetime
from typing import Any, List, Literal
from ._cache_base import EmbeddingCacheBase
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._embedding_base import EmbeddingModelBase
from .._logging import logger
from ..message import TextBlock
class DashScopeTextEmbedding(EmbeddingModelBase):
"""DashScope text embedding API class.
.. note:: From the `official documentation
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_:
- The max batch size that DashScope text embedding API
supports is 10 for `text-embedding-v4` and `text-embedding-v3` models, and
25 for `text-embedding-v2` and `text-embedding-v1` models.
- The max token limit for a single input is 8192 tokens for `v4` and `v3`
models, and 2048 tokens for `v2` and `v1` models.
"""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 1024,
embedding_cache: EmbeddingCacheBase | None = None,
) -> None:
"""Initialize the DashScope text embedding model class.
Args:
api_key (`str`):
The dashscope API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector, refer to the
`official documentation
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_
for more details.
embedding_cache (`EmbeddingCacheBase`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
super().__init__(model_name, dimensions)
self.api_key = api_key
self.embedding_cache = embedding_cache
self.batch_size_limit = 10
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
"""Call the DashScope embedding API by the given keyword arguments."""
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
import dashscope
start_time = datetime.now()
response = dashscope.embeddings.TextEmbedding.call(
api_key=self.api_key,
**kwargs,
)
time = (datetime.now() - start_time).total_seconds()
if response.status_code != 200:
raise RuntimeError(
f"Failed to get embedding from DashScope API: {response}",
)
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[
_["embedding"] for _ in response.output["embeddings"]
],
)
return EmbeddingResponse(
embeddings=[_["embedding"] for _ in response.output["embeddings"]],
usage=EmbeddingUsage(
tokens=response.usage["total_tokens"],
time=time,
),
)
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the DashScope embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
if len(gather_text) > self.batch_size_limit:
logger.info(
"The input texts (%d) will be embedded with %d API calls due "
f"to the batch size limit of {self.batch_size_limit} for "
f"DashScope embedding API.",
len(gather_text),
(len(gather_text) + self.batch_size_limit - 1)
// self.batch_size_limit,
)
# Handle the batch size limit for DashScope embedding API
collected_embeddings = []
collected_time = 0.0
collected_tokens = 0
collected_source: Literal["cache", "api"] = "cache"
for _ in range(0, len(gather_text), self.batch_size_limit):
batch_texts = gather_text[_ : _ + self.batch_size_limit]
batch_kwargs = {
"input": batch_texts,
"model": self.model_name,
"dimension": self.dimensions,
**kwargs,
}
res = await self._call_api(batch_kwargs)
collected_embeddings.extend(res.embeddings)
collected_time += res.usage.time
if res.usage.tokens:
collected_tokens += res.usage.tokens
if res.source == "api":
collected_source = "api"
return EmbeddingResponse(
embeddings=collected_embeddings,
usage=EmbeddingUsage(
tokens=collected_tokens,
time=collected_time,
),
source=collected_source,
)

View File

@@ -0,0 +1,244 @@
# -*- coding: utf-8 -*-
"""The dashscope multimodal embedding model in agentscope."""
from datetime import datetime
from typing import Any, Literal
from ._cache_base import EmbeddingCacheBase
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._embedding_base import EmbeddingModelBase
from ..message import (
VideoBlock,
ImageBlock,
TextBlock,
)
class DashScopeMultiModalEmbedding(EmbeddingModelBase):
"""The DashScope multimodal embedding API, supporting text, image and
video embedding."""
supported_modalities: list[str] = ["text", "image", "video"]
"""This class supports text, image and video input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int | None = None,
embedding_cache: EmbeddingCacheBase | None = None,
) -> None:
"""Initialize the DashScope multimodal embedding model class.
Args:
api_key (`str`):
The dashscope API key.
model_name (`str`):
The name of the embedding model, e.g. "multimodal-embedding-
v1", "tongyi-embedding-vision-plus".
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector, refer to the
`official documentation
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712517>`_
for more details.
embedding_cache (`EmbeddingCacheBase`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
path_doc = (
"https://bailian.console.aliyun.com/?tab=api#/api/?type=model&"
"url=2712517"
)
self.batch_size_limit = 1
if model_name.startswith("tongyi-embedding-vision-plus"):
self.batch_size_limit = 8
if dimensions is None:
dimensions = 1152
elif dimensions != 1152:
raise ValueError(
f"The dimension of model {model_name} must be 1152, "
"refer to the official documentation for more details: "
f"{path_doc}",
)
if model_name.startswith("tongyi-embedding-vision-flash"):
self.batch_size_limit = 8
if dimensions is None:
dimensions = 768
elif dimensions != 768:
raise ValueError(
f"The dimension of model {model_name} must be 768, "
"refer to the official documentation for more details: "
f"{path_doc}",
)
if model_name.startswith("multimodal-embedding-v"):
if dimensions is None:
dimensions = 1024
elif dimensions != 1024:
raise ValueError(
f"The dimension of model {model_name} must be 1024, "
"refer to the official documentation for more details: "
f"{path_doc}",
)
refined_dimensions: int = 1024
if dimensions is not None:
refined_dimensions = dimensions
super().__init__(model_name, refined_dimensions)
self.api_key = api_key
self.embedding_cache = embedding_cache
async def __call__(
self,
inputs: list[TextBlock | ImageBlock | VideoBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the DashScope multimodal embedding API, which accepts text,
image, and video data.
Args:
inputs (`list[TextBlock | ImageBlock | VideoBlock]`):
The input data to be embedded. It can be a list of text,
image, and video blocks.
Returns:
`EmbeddingResponse`:
The embedding response object, which contains the embeddings
and usage information.
"""
# check data type
formatted_data = []
for _ in inputs:
if (
not isinstance(_, dict)
or "type" not in _
or _["type"]
not in [
"text",
"image",
"video",
]
):
raise ValueError(
f"Invalid data : {_}. It should be a list of "
"TextBlock, ImageBlock, or VideoBlock.",
)
if (
_["type"] == "video"
and _.get("source", {}).get("type") != "url"
):
raise ValueError(
f"The multimodal embedding API only supports URL input "
f"for video data, but got {_}.",
)
if _["type"] == "text":
assert "text" in _, (
f"Invalid text block: {_}. It should contain a "
f"'text' field.",
)
formatted_data.append({"text": _["text"]})
elif _["type"] == "video":
formatted_data.append({"video": _["source"]["url"]})
elif (
_["type"] == "image"
and "source" in _
and _["source"].get("type") in ["base64", "url"]
):
typ = _["source"]["type"]
if typ == "base64":
formatted_data.append(
{
"image": f'data:{_["source"]["media_type"]};'
f'base64,{_["source"]["data"]}',
},
)
elif typ == "url":
formatted_data.append(
{"image": _["source"]["url"]},
)
else:
raise ValueError(
f"Invalid block {_}. It should be a valid TextBlock, "
f"ImageBlock, or VideoBlock.",
)
# Handle the batch size limit of the DashScope multimodal embedding API
collected_embeddings = []
collected_time = 0.0
collected_tokens = 0
collected_source: Literal["cache", "api"] = "cache"
for _ in range(0, len(formatted_data), self.batch_size_limit):
batch_data = formatted_data[_ : _ + self.batch_size_limit]
batch_kwargs = {
"input": batch_data,
"model": self.model_name,
**kwargs,
}
res = await self._call_api(batch_kwargs)
collected_embeddings.extend(res.embeddings)
collected_time += res.usage.time
if res.usage.tokens:
collected_tokens += res.usage.tokens
if res.source == "api":
collected_source = "api"
return EmbeddingResponse(
embeddings=collected_embeddings,
usage=EmbeddingUsage(
tokens=collected_tokens,
time=collected_time,
),
source=collected_source,
)
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
"""
Call the DashScope multimodal embedding API by the given arguments.
"""
# Search in cache first
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
import dashscope
kwargs["api_key"] = self.api_key
start_time = datetime.now()
res = dashscope.MultiModalEmbedding.call(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if res.status_code != 200:
raise RuntimeError(
f"Failed to get embedding from DashScope API: {res}",
)
return EmbeddingResponse(
embeddings=[_["embedding"] for _ in res.output["embeddings"]],
usage=EmbeddingUsage(
tokens=res.usage.get(
"image_tokens",
0,
)
+ res.usage.get(
"input_tokens",
0,
),
time=time,
),
source="api",
)

View File

@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
"""The embedding model base class."""
from typing import Any
from ._embedding_response import EmbeddingResponse
class EmbeddingModelBase:
"""Base class for embedding models."""
model_name: str
"""The embedding model name"""
supported_modalities: list[str]
"""The supported data modalities, e.g. "text", "image", "video"."""
dimensions: int
"""The dimensions of the embedding vector."""
def __init__(
self,
model_name: str,
dimensions: int,
) -> None:
"""Initialize the embedding model base class.
Args:
model_name (`str`):
The name of the embedding model.
dimensions (`int`):
The dimension of the embedding vector.
"""
self.model_name = model_name
self.dimensions = dimensions
async def __call__(
self,
*args: Any,
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the embedding API with the given arguments."""
raise NotImplementedError(
f"The {self.__class__.__name__} class does not implement "
f"the __call__ method.",
)

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""The embedding response class."""
from dataclasses import dataclass, field
from typing import Literal, List
from ._embedding_usage import EmbeddingUsage
from .._utils._common import _get_timestamp
from .._utils._mixin import DictMixin
from ..types import Embedding
@dataclass
class EmbeddingResponse(DictMixin):
"""The embedding response class."""
embeddings: List[Embedding]
"""The embedding data"""
id: str = field(default_factory=lambda: _get_timestamp(True))
"""The identity of the embedding response"""
created_at: str = field(default_factory=_get_timestamp)
"""The timestamp of the embedding response creation"""
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
"""The type of the response, must be `embedding`."""
usage: EmbeddingUsage | None = field(default_factory=lambda: None)
"""The usage of the embedding model API invocation, if available."""
source: Literal["cache", "api"] = field(default_factory=lambda: "api")
"""If the response comes from the cache or the API."""

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""The embedding usage class in agentscope."""
from dataclasses import dataclass, field
from typing import Literal
from .._utils._mixin import DictMixin
@dataclass
class EmbeddingUsage(DictMixin):
"""The usage of an embedding model API invocation."""
time: float
"""The time used in seconds."""
tokens: int | None = field(default_factory=lambda: None)
"""The number of tokens used, if available."""
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
"""The type of the usage, must be `embedding`."""

View File

@@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
"""A file embedding cache implementation for storing and retrieving
embeddings in binary files."""
import hashlib
import json
import os
from typing import Any, List
import numpy as np
from ._cache_base import EmbeddingCacheBase
from .._logging import logger
from ..types import (
Embedding,
JSONSerializableObject,
)
class FileEmbeddingCache(EmbeddingCacheBase):
"""The embedding cache class that stores each embeddings vector in
binary files."""
def __init__(
self,
cache_dir: str = "./.cache/embeddings",
max_file_number: int | None = None,
max_cache_size: int | None = None,
) -> None:
"""Initialize the file embedding cache class.
Args:
cache_dir (`str`, defaults to `"./.cache/embeddings"`):
The directory to store the embedding files.
max_file_number (`int | None`, defaults to `None`):
The maximum number of files to keep in the cache directory. If
exceeded, the oldest files will be removed.
max_cache_size (`int | None`, defaults to `None`):
The maximum size of the cache directory in MB. If exceeded,
the oldest files will be removed until the size is within the
limit.
"""
self._cache_dir = os.path.abspath(cache_dir)
self.max_file_number = max_file_number
self.max_cache_size = max_cache_size
@property
def cache_dir(self) -> str:
"""The cache directory where the embedding files are stored."""
if not os.path.exists(self._cache_dir):
os.makedirs(self._cache_dir, exist_ok=True)
return self._cache_dir
async def store(
self,
embeddings: List[Embedding],
identifier: JSONSerializableObject,
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Store the embeddings with the given identifier.
Args:
embeddings (`List[Embedding]`):
The embeddings to store.
identifier (`JSONSerializableObject`):
The identifier to distinguish the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
overwrite (`bool`, defaults to `False`):
Whether to overwrite existing embeddings with the same
identifier. If `True`, existing embeddings will be replaced.
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
if not os.path.isfile(path_file):
raise RuntimeError(
f"Path {path_file} exists but is not a file.",
)
if overwrite:
np.save(path_file, embeddings)
await self._maintain_cache_dir()
else:
np.save(path_file, embeddings)
await self._maintain_cache_dir()
async def retrieve(
self,
identifier: JSONSerializableObject,
) -> List[Embedding] | None:
"""Retrieve the embeddings with the given identifier. If not found,
return `None`.
Args:
identifier (`JSONSerializableObject`):
The identifier to retrieve the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
return np.load(os.path.join(self.cache_dir, filename)).tolist()
return None
async def remove(self, identifier: JSONSerializableObject) -> None:
"""Remove the embeddings with the given identifier.
Args:
identifier (`JSONSerializableObject`):
The identifiers to remove the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
os.remove(path_file)
else:
raise FileNotFoundError(f"File {path_file} does not exist.")
async def clear(self) -> None:
"""Clear the cache directory by removing all files."""
for filename in os.listdir(self.cache_dir):
if filename.endswith(".npy"):
os.remove(os.path.join(self.cache_dir, filename))
def _get_cache_size(self) -> float:
"""Get the current size of the cache directory in MB."""
total_size = 0
for filename in os.listdir(self.cache_dir):
if filename.endswith(".npy"):
path_file = os.path.join(self.cache_dir, filename)
if os.path.isfile(path_file):
total_size += os.path.getsize(path_file)
return total_size / (1024.0 * 1024.0)
@staticmethod
def _get_filename(identifier: JSONSerializableObject) -> str:
"""Generate a filename based on the identifier."""
json_str = json.dumps(identifier, ensure_ascii=False)
return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + ".npy"
async def _maintain_cache_dir(self) -> None:
"""Maintain the cache directory by removing old files if the number of
files exceeds the maximum limit or if the cache size exceeds the
maximum size."""
files = [
(_.name, _.stat().st_mtime)
for _ in os.scandir(self.cache_dir)
if _.is_file() and _.name.endswith(".npy")
]
files.sort(key=lambda x: x[1])
if self.max_file_number and len(files) > self.max_file_number:
for file_name, _ in files[: 0 - self.max_file_number]:
os.remove(os.path.join(self.cache_dir, file_name))
logger.info(
"Remove cached embedding file %s for limited number "
"of files (%d).",
file_name,
self.max_file_number,
)
files = files[0 - self.max_file_number :]
if (
self.max_cache_size is not None
and self._get_cache_size() > self.max_cache_size
):
removed_files = []
for filename, _ in files:
os.remove(os.path.join(self.cache_dir, filename))
removed_files.append(filename)
if self._get_cache_size() <= self.max_cache_size:
break
if removed_files:
logger.info(
"Remove %d cached embedding file(s) for limited "
"cache size (%d MB).",
len(removed_files),
self.max_cache_size,
)

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""The gemini text embedding model class."""
from datetime import datetime
from typing import Any, List
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ._embedding_base import EmbeddingModelBase
from ..message import TextBlock
class GeminiTextEmbedding(EmbeddingModelBase):
"""The Gemini text embedding model."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 3072,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Gemini text embedding model class.
Args:
api_key (`str`):
The Gemini API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 3072):
The dimension of the embedding vector, refer to the
`official documentation
<https://ai.google.dev/gemini-api/docs/embeddings?hl=zh-cn#control-embedding-size>`_
for more details.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
from google import genai
super().__init__(model_name, dimensions)
self.client = genai.Client(api_key=api_key, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""The Gemini embedding API call.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
# TODO: handle the batch size limit
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"model": self.model_name,
"contents": gather_text,
"config": kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = self.client.models.embed_content(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[_.values for _ in response.embeddings],
)
return EmbeddingResponse(
embeddings=[_.values for _ in response.embeddings],
usage=EmbeddingUsage(
time=time,
),
)

View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
"""The ollama text embedding model class."""
from datetime import datetime
from typing import List, Any
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ..embedding import EmbeddingModelBase
from ..message import TextBlock
class OllamaTextEmbedding(EmbeddingModelBase):
"""The Ollama embedding model."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
model_name: str,
dimensions: int,
host: str | None = None,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Ollama text embedding model class.
Args:
model_name (`str`):
The name of the embedding model.
dimensions (`int`):
The dimension of the embedding vector, the parameter should be
provided according to the model used.
host (`str | None`, defaults to `None`):
The host URL for the Ollama API.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
import ollama
super().__init__(model_name, dimensions)
self.client = ollama.AsyncClient(host=host, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the Ollama embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"input": gather_text,
"model": self.model_name,
"dimensions": self.dimensions,
**kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = await self.client.embed(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=response.embeddings,
)
return EmbeddingResponse(
embeddings=response.embeddings,
usage=EmbeddingUsage(
time=time,
),
)

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""The OpenAI text embedding model class."""
from datetime import datetime
from typing import Any, List
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ._embedding_base import EmbeddingModelBase
from ..message import TextBlock
class OpenAITextEmbedding(EmbeddingModelBase):
"""OpenAI text embedding model class."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 1024,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the OpenAI text embedding model class.
Args:
api_key (`str`):
The OpenAI API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
# TODO: handle batch size limit and token limit
"""
import openai
super().__init__(model_name, dimensions)
self.client = openai.AsyncClient(api_key=api_key, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the OpenAI embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"input": gather_text,
"model": self.model_name,
"dimensions": self.dimensions,
"encoding_format": "float",
**kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = await self.client.embeddings.create(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[_.embedding for _ in response.data],
)
return EmbeddingResponse(
embeddings=[_.embedding for _ in response.data],
usage=EmbeddingUsage(
tokens=response.usage.total_tokens,
time=time,
),
)