chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
import importlib
|
||||
from typing import Type, TypeVar, cast
|
||||
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
def get_class(fqn: str, type: Type[C]) -> Type[C]:
|
||||
"""Given a fully qualifed class name, import the module and return the class"""
|
||||
module_name, class_name = fqn.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
return cast(Type[C], cls)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,63 @@
|
||||
import inspect
|
||||
import asyncio
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def async_to_sync(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, R]:
|
||||
"""A function decorator that converts an async function to a sync function.
|
||||
|
||||
This should generally not be used in production code paths.
|
||||
"""
|
||||
|
||||
def sync_wrapper(*args, **kwargs): # type: ignore
|
||||
loop = None
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if loop.is_running():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
result = loop.run_until_complete(func(*args, **kwargs))
|
||||
|
||||
def convert_result(result: Any) -> Any:
|
||||
if isinstance(result, list):
|
||||
return [convert_result(r) for r in result]
|
||||
|
||||
if isinstance(result, object):
|
||||
return async_class_to_sync(result)
|
||||
|
||||
if callable(result):
|
||||
return async_to_sync(result)
|
||||
|
||||
return result
|
||||
|
||||
return convert_result(result)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def async_class_to_sync(cls: T) -> T:
|
||||
"""A decorator that converts a class with async methods to a class with sync methods.
|
||||
|
||||
This should generally not be used in production code paths.
|
||||
"""
|
||||
for attr, value in inspect.getmembers(cls):
|
||||
if (
|
||||
callable(value)
|
||||
and inspect.iscoroutinefunction(value)
|
||||
and not attr.startswith("__")
|
||||
):
|
||||
setattr(cls, attr, async_to_sync(value))
|
||||
|
||||
return cls
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import Optional, Tuple, List
|
||||
from chromadb.api import BaseAPI
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Embeddings,
|
||||
IDs,
|
||||
Metadatas,
|
||||
)
|
||||
|
||||
|
||||
def create_batches(
|
||||
api: BaseAPI,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
) -> List[Tuple[IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]]]:
|
||||
_batches: List[
|
||||
Tuple[IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]]
|
||||
] = []
|
||||
if len(ids) > api.get_max_batch_size():
|
||||
# create split batches
|
||||
for i in range(0, len(ids), api.get_max_batch_size()):
|
||||
_batches.append(
|
||||
(
|
||||
ids[i : i + api.get_max_batch_size()],
|
||||
embeddings[i : i + api.get_max_batch_size()]
|
||||
if embeddings is not None
|
||||
else None,
|
||||
metadatas[i : i + api.get_max_batch_size()] if metadatas else None,
|
||||
documents[i : i + api.get_max_batch_size()] if documents else None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
_batches.append((ids, embeddings, metadatas, documents))
|
||||
return _batches
|
||||
@@ -0,0 +1,31 @@
|
||||
import importlib
|
||||
import multiprocessing
|
||||
from typing import Optional, Sequence, List, Tuple
|
||||
import numpy as np
|
||||
from chromadb.api.types import URI, DataLoader, Image, URIs
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
class ImageLoader(DataLoader[List[Optional[Image]]]):
|
||||
def __init__(self, max_workers: int = multiprocessing.cpu_count()) -> None:
|
||||
try:
|
||||
self._PILImage = importlib.import_module("PIL.Image")
|
||||
self._max_workers = max_workers
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The PIL python package is not installed. Please install it with `pip install pillow`"
|
||||
)
|
||||
|
||||
def _load_image(self, uri: Optional[URI]) -> Optional[Image]:
|
||||
return np.array(self._PILImage.open(uri)) if uri is not None else None
|
||||
|
||||
def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]:
|
||||
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
|
||||
return list(executor.map(self._load_image, uris))
|
||||
|
||||
|
||||
class ChromaLangchainPassthroughDataLoader(DataLoader[List[Optional[Image]]]):
|
||||
# This is a simple pass through data loader that just returns the input data with "images"
|
||||
# flag which lets the langchain embedding function know that the data is image uris
|
||||
def __call__(self, uris: URIs) -> Tuple[str, URIs]: # type: ignore
|
||||
return ("images", uris)
|
||||
@@ -0,0 +1,38 @@
|
||||
import os
|
||||
import random
|
||||
import gc
|
||||
import time
|
||||
|
||||
|
||||
# Borrowed from https://github.com/rogerbinns/apsw/blob/master/apsw/tests.py#L224
|
||||
# Used to delete sqlite files on Windows, since Windows file locking
|
||||
# behaves differently to other operating systems
|
||||
# This should only be used for test or non-production code, such as in reset_state.
|
||||
def delete_file(name: str) -> None:
|
||||
try:
|
||||
os.remove(name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
chars = list("abcdefghijklmn")
|
||||
random.shuffle(chars)
|
||||
newname = name + "-n-" + "".join(chars)
|
||||
count = 0
|
||||
while os.path.exists(name):
|
||||
count += 1
|
||||
try:
|
||||
os.rename(name, newname)
|
||||
except Exception:
|
||||
if count > 30:
|
||||
n = list("abcdefghijklmnopqrstuvwxyz")
|
||||
random.shuffle(n)
|
||||
final_name = "".join(n)
|
||||
try:
|
||||
os.rename(
|
||||
name, "chroma-to-clean" + final_name + ".deletememanually"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
time.sleep(0.1)
|
||||
gc.collect()
|
||||
@@ -0,0 +1,22 @@
|
||||
import os
|
||||
|
||||
|
||||
def get_directory_size(directory: str) -> int:
|
||||
"""
|
||||
Calculate the total size of the directory by walking through each file.
|
||||
|
||||
Parameters:
|
||||
directory (str): The path of the directory for which to calculate the size.
|
||||
|
||||
Returns:
|
||||
total_size (int): The total size of the directory in bytes.
|
||||
"""
|
||||
total_size = 0
|
||||
for dirpath, _, filenames in os.walk(directory):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
# skip if it is symbolic link
|
||||
if not os.path.islink(fp):
|
||||
total_size += os.path.getsize(fp)
|
||||
|
||||
return total_size
|
||||
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
These functions match what the spec of hnswlib is.
|
||||
"""
|
||||
from typing import Union, cast
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
Vector = NDArray[Union[np.int32, np.float32, np.int16, np.float16]]
|
||||
|
||||
|
||||
def l2(x: Vector, y: Vector) -> float:
|
||||
return (np.linalg.norm(x - y) ** 2).item()
|
||||
|
||||
|
||||
def cosine(x: Vector, y: Vector) -> float:
|
||||
# This epsilon is used to prevent division by zero, and the value is the same
|
||||
# https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238
|
||||
|
||||
# We need to adapt the epsilon to the precision of the input
|
||||
NORM_EPS = 1e-30
|
||||
if x.dtype == np.float16 or y.dtype == np.float16:
|
||||
NORM_EPS = 1e-7
|
||||
return cast(
|
||||
float,
|
||||
(
|
||||
1.0 - np.dot(x, y) / ((np.linalg.norm(x) * np.linalg.norm(y)) + NORM_EPS)
|
||||
).item(),
|
||||
)
|
||||
|
||||
|
||||
def ip(x: Vector, y: Vector) -> float:
|
||||
return cast(float, (1.0 - np.dot(x, y)).item())
|
||||
@@ -0,0 +1,285 @@
|
||||
from typing import Dict, Any, Type, Set
|
||||
from chromadb.api.types import (
|
||||
EmbeddingFunction,
|
||||
DefaultEmbeddingFunction,
|
||||
SparseEmbeddingFunction,
|
||||
)
|
||||
|
||||
# Import all embedding functions
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
JinaQueryConfig,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import (
|
||||
ChromaLangchainEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.baseten_embedding_function import (
|
||||
BasetenEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cloudflare_workers_ai_embedding_function import (
|
||||
CloudflareWorkersAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.together_ai_embedding_function import (
|
||||
TogetherAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.mistral_embedding_function import (
|
||||
MistralEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.morph_embedding_function import (
|
||||
MorphEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_sparse_embedding_function import (
|
||||
HuggingFaceSparseEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.fastembed_sparse_embedding_function import (
|
||||
FastembedSparseEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.bm25_embedding_function import (
|
||||
Bm25EmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.chroma_cloud_qwen_embedding_function import (
|
||||
ChromaCloudQwenEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.chroma_cloud_splade_embedding_function import (
|
||||
ChromaCloudSpladeEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.chroma_bm25_embedding_function import (
|
||||
ChromaBm25EmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
# Get all the class names for backward compatibility
|
||||
_all_classes: Set[str] = {
|
||||
"CohereEmbeddingFunction",
|
||||
"OpenAIEmbeddingFunction",
|
||||
"HuggingFaceEmbeddingFunction",
|
||||
"HuggingFaceEmbeddingServer",
|
||||
"SentenceTransformerEmbeddingFunction",
|
||||
"GooglePalmEmbeddingFunction",
|
||||
"GoogleGenerativeAiEmbeddingFunction",
|
||||
"GoogleVertexEmbeddingFunction",
|
||||
"OllamaEmbeddingFunction",
|
||||
"InstructorEmbeddingFunction",
|
||||
"JinaEmbeddingFunction",
|
||||
"MistralEmbeddingFunction",
|
||||
"MorphEmbeddingFunction",
|
||||
"VoyageAIEmbeddingFunction",
|
||||
"ONNXMiniLM_L6_V2",
|
||||
"OpenCLIPEmbeddingFunction",
|
||||
"RoboflowEmbeddingFunction",
|
||||
"Text2VecEmbeddingFunction",
|
||||
"AmazonBedrockEmbeddingFunction",
|
||||
"ChromaLangchainEmbeddingFunction",
|
||||
"BasetenEmbeddingFunction",
|
||||
"CloudflareWorkersAIEmbeddingFunction",
|
||||
"TogetherAIEmbeddingFunction",
|
||||
"DefaultEmbeddingFunction",
|
||||
"HuggingFaceSparseEmbeddingFunction",
|
||||
"FastembedSparseEmbeddingFunction",
|
||||
"Bm25EmbeddingFunction",
|
||||
"ChromaCloudQwenEmbeddingFunction",
|
||||
"ChromaCloudSpladeEmbeddingFunction",
|
||||
"ChromaBm25EmbeddingFunction",
|
||||
}
|
||||
|
||||
|
||||
def get_builtins() -> Set[str]:
|
||||
return _all_classes
|
||||
|
||||
|
||||
# Dictionary of supported embedding functions
|
||||
known_embedding_functions: Dict[str, Type[EmbeddingFunction]] = { # type: ignore
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"huggingface_server": HuggingFaceEmbeddingServer,
|
||||
"sentence_transformer": SentenceTransformerEmbeddingFunction,
|
||||
"google_palm": GooglePalmEmbeddingFunction,
|
||||
"google_generative_ai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google_vertex": GoogleVertexEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"mistral": MistralEmbeddingFunction,
|
||||
"morph": MorphEmbeddingFunction,
|
||||
"voyageai": VoyageAIEmbeddingFunction,
|
||||
"onnx_mini_lm_l6_v2": ONNXMiniLM_L6_V2,
|
||||
"open_clip": OpenCLIPEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"amazon_bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"chroma_langchain": ChromaLangchainEmbeddingFunction,
|
||||
"baseten": BasetenEmbeddingFunction,
|
||||
"default": DefaultEmbeddingFunction,
|
||||
"cloudflare_workers_ai": CloudflareWorkersAIEmbeddingFunction,
|
||||
"together_ai": TogetherAIEmbeddingFunction,
|
||||
"chroma-cloud-qwen": ChromaCloudQwenEmbeddingFunction,
|
||||
}
|
||||
|
||||
sparse_known_embedding_functions: Dict[str, Type[SparseEmbeddingFunction]] = { # type: ignore
|
||||
"huggingface_sparse": HuggingFaceSparseEmbeddingFunction,
|
||||
"fastembed_sparse": FastembedSparseEmbeddingFunction,
|
||||
"bm25": Bm25EmbeddingFunction,
|
||||
"chroma-cloud-splade": ChromaCloudSpladeEmbeddingFunction,
|
||||
"chroma_bm25": ChromaBm25EmbeddingFunction,
|
||||
}
|
||||
|
||||
|
||||
def register_embedding_function(ef_class=None): # type: ignore
|
||||
"""Register a custom embedding function.
|
||||
|
||||
Can be used as a decorator:
|
||||
@register_embedding_function
|
||||
class MyEmbedding(EmbeddingFunction):
|
||||
@classmethod
|
||||
def name(cls): return "my_embedding"
|
||||
|
||||
Or directly:
|
||||
register_embedding_function(MyEmbedding)
|
||||
|
||||
Args:
|
||||
ef_class: The embedding function class to register.
|
||||
"""
|
||||
|
||||
def _register(cls): # type: ignore
|
||||
try:
|
||||
name = cls.name()
|
||||
known_embedding_functions[name] = cls
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to register embedding function: {e}")
|
||||
return cls # Return the class unchanged
|
||||
|
||||
# If called with a class, register it immediately
|
||||
if ef_class is not None:
|
||||
return _register(ef_class) # type: ignore
|
||||
|
||||
# If called without arguments, return a decorator
|
||||
return _register
|
||||
|
||||
|
||||
def register_sparse_embedding_function(ef_class=None): # type: ignore
|
||||
"""Register a custom sparse embedding function.
|
||||
|
||||
Can be used as a decorator:
|
||||
@register_sparse_embedding_function
|
||||
class MySparseEmbeddingFunction(SparseEmbeddingFunction):
|
||||
@classmethod
|
||||
def name(cls): return "my_sparse_embedding"
|
||||
"""
|
||||
|
||||
def _register(cls): # type: ignore
|
||||
try:
|
||||
name = cls.name()
|
||||
sparse_known_embedding_functions[name] = cls
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to register sparse embedding function: {e}")
|
||||
return cls # Return the class unchanged
|
||||
|
||||
if ef_class is not None:
|
||||
return _register(ef_class) # type: ignore
|
||||
|
||||
return _register
|
||||
|
||||
|
||||
# Function to convert config to embedding function
|
||||
def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction: # type: ignore
|
||||
"""Convert a config dictionary to an embedding function.
|
||||
|
||||
Args:
|
||||
config: The config dictionary.
|
||||
|
||||
Returns:
|
||||
The embedding function.
|
||||
"""
|
||||
if "name" not in config:
|
||||
raise ValueError("Config must contain a 'name' field.")
|
||||
|
||||
name = config["name"]
|
||||
if name not in known_embedding_functions:
|
||||
raise ValueError(f"Unsupported embedding function: {name}")
|
||||
|
||||
ef_config = config.get("config", {})
|
||||
|
||||
if known_embedding_functions[name] is None:
|
||||
raise ValueError(f"Unsupported embedding function: {name}")
|
||||
|
||||
return known_embedding_functions[name].build_from_config(ef_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingFunction",
|
||||
"DefaultEmbeddingFunction",
|
||||
"CohereEmbeddingFunction",
|
||||
"OpenAIEmbeddingFunction",
|
||||
"BasetenEmbeddingFunction",
|
||||
"CloudflareWorkersAIEmbeddingFunction",
|
||||
"HuggingFaceEmbeddingFunction",
|
||||
"HuggingFaceEmbeddingServer",
|
||||
"SentenceTransformerEmbeddingFunction",
|
||||
"GooglePalmEmbeddingFunction",
|
||||
"GoogleGenerativeAiEmbeddingFunction",
|
||||
"GoogleVertexEmbeddingFunction",
|
||||
"OllamaEmbeddingFunction",
|
||||
"InstructorEmbeddingFunction",
|
||||
"JinaEmbeddingFunction",
|
||||
"JinaQueryConfig",
|
||||
"MistralEmbeddingFunction",
|
||||
"MorphEmbeddingFunction",
|
||||
"VoyageAIEmbeddingFunction",
|
||||
"ONNXMiniLM_L6_V2",
|
||||
"OpenCLIPEmbeddingFunction",
|
||||
"RoboflowEmbeddingFunction",
|
||||
"Text2VecEmbeddingFunction",
|
||||
"AmazonBedrockEmbeddingFunction",
|
||||
"ChromaLangchainEmbeddingFunction",
|
||||
"TogetherAIEmbeddingFunction",
|
||||
"HuggingFaceSparseEmbeddingFunction",
|
||||
"FastembedSparseEmbeddingFunction",
|
||||
"Bm25EmbeddingFunction",
|
||||
"ChromaCloudQwenEmbeddingFunction",
|
||||
"ChromaCloudSpladeEmbeddingFunction",
|
||||
"ChromaBm25EmbeddingFunction",
|
||||
"register_embedding_function",
|
||||
"config_to_embedding_function",
|
||||
"known_embedding_functions",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,138 @@
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction
|
||||
from typing import Dict, Any, cast
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""
|
||||
This class is used to generate embeddings for a list of texts using Amazon Bedrock.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Any,
|
||||
model_name: str = "amazon.titan-embed-text-v1",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize AmazonBedrockEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
session (boto3.Session): The boto3 session to use. You need to have boto3
|
||||
installed, `pip install boto3`. Access & secret key are not supported.
|
||||
model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1"
|
||||
**kwargs: Additional arguments to pass to the boto3 client.
|
||||
|
||||
Example:
|
||||
>>> import boto3
|
||||
>>> session = boto3.Session(profile_name="profile", region_name="us-east-1")
|
||||
>>> bedrock = AmazonBedrockEmbeddingFunction(session=session)
|
||||
>>> texts = ["Hello, world!", "How are you?"]
|
||||
>>> embeddings = bedrock(texts)
|
||||
"""
|
||||
|
||||
self.model_name = model_name
|
||||
# check kwargs are primitives only
|
||||
for key, value in kwargs.items():
|
||||
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
|
||||
raise ValueError(f"Keyword argument {key} is not a primitive type")
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Store the session for serialization
|
||||
self._session_args = {}
|
||||
if hasattr(session, "region_name") and session.region_name:
|
||||
self._session_args["region_name"] = session.region_name
|
||||
if hasattr(session, "profile_name") and session.profile_name:
|
||||
self._session_args["profile_name"] = session.profile_name
|
||||
|
||||
self._client = session.client(
|
||||
service_name="bedrock-runtime",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
accept = "application/json"
|
||||
content_type = "application/json"
|
||||
embeddings = []
|
||||
|
||||
for text in input:
|
||||
input_body = {"inputText": text}
|
||||
body = json.dumps(input_body)
|
||||
response = self._client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.model_name,
|
||||
accept=accept,
|
||||
contentType=content_type,
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
embedding = response_body.get("embedding")
|
||||
embeddings.append(np.array(embedding, dtype=np.float32))
|
||||
|
||||
# Convert to the expected Embeddings type
|
||||
return cast(Embeddings, embeddings)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "amazon_bedrock"
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The boto3 python package is not installed. Please install it with `pip install boto3`"
|
||||
)
|
||||
|
||||
model_name = config.get("model_name")
|
||||
session_args = config.get("session_args")
|
||||
if model_name is None:
|
||||
assert False, "This code should not be reached"
|
||||
kwargs = config.get("kwargs", {})
|
||||
|
||||
if session_args is None:
|
||||
session = boto3.Session()
|
||||
else:
|
||||
session = boto3.Session(**session_args)
|
||||
|
||||
return AmazonBedrockEmbeddingFunction(
|
||||
session=session, model_name=model_name, **kwargs
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"session_args": self._session_args,
|
||||
"kwargs": self.kwargs,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "amazon_bedrock")
|
||||
@@ -0,0 +1,98 @@
|
||||
import os
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from typing import Dict, Any, Optional, List
|
||||
from chromadb.api.types import Space
|
||||
import warnings
|
||||
|
||||
|
||||
class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: str,
|
||||
api_key_env_var: str = "CHROMA_BASETEN_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the BasetenEmbeddingFunction.
|
||||
Args:
|
||||
api_key (str, optional): The API key for your Baseten account
|
||||
api_base (str, required): The Baseten URL of the deployment
|
||||
api_key_env_var (str, optional): The environment variable to use for the API key. Defaults to "CHROMA_BASETEN_API_KEY".
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The openai python package is not installed. Please install it with `pip install openai`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
# Prioritize api_key argument, then environment variable
|
||||
resolved_api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not resolved_api_key:
|
||||
raise ValueError(
|
||||
f"API key not provided and {api_key_env_var} environment variable is not set."
|
||||
)
|
||||
self.api_key = resolved_api_key
|
||||
if not api_base:
|
||||
raise ValueError("The api_base argument must be provided.")
|
||||
self.api_base = api_base
|
||||
self.model_name = "baseten-embedding-model"
|
||||
self.dimensions = None
|
||||
|
||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.api_base)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "baseten"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"api_base": self.api_base, "api_key_env_var": self.api_key_env_var}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "BasetenEmbeddingFunction":
|
||||
"""
|
||||
Build the BasetenEmbeddingFunction from a configuration dictionary.
|
||||
|
||||
Args:
|
||||
config (Dict[str, Any]): A dictionary containing the configuration parameters.
|
||||
Expected keys: 'api_key', 'api_base', 'api_key_env_var'.
|
||||
|
||||
Returns:
|
||||
BasetenEmbeddingFunction: An instance of BasetenEmbeddingFunction.
|
||||
"""
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
api_base = config.get("api_base")
|
||||
if api_key_env_var is None or api_base is None:
|
||||
raise ValueError(
|
||||
"Missing 'api_key_env_var' or 'api_base' in configuration for BasetenEmbeddingFunction."
|
||||
)
|
||||
|
||||
# Note: We rely on the __init__ method to handle potential missing api_key
|
||||
# by checking the environment variable if the config value is None.
|
||||
# However, api_base must be present either in config or have a default.
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Missing 'api_base' in configuration for BasetenEmbeddingFunction."
|
||||
)
|
||||
|
||||
return BasetenEmbeddingFunction(
|
||||
api_key=None, # Pass None if not in config, __init__ will check env var
|
||||
api_base=api_base,
|
||||
api_key_env_var=api_key_env_var,
|
||||
)
|
||||
@@ -0,0 +1,234 @@
|
||||
from chromadb.api.types import (
|
||||
SparseEmbeddingFunction,
|
||||
SparseVectors,
|
||||
Documents,
|
||||
)
|
||||
from typing import Dict, Any, TypedDict, Optional
|
||||
from typing import cast, Literal
|
||||
import warnings
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
|
||||
|
||||
TaskType = Literal["document", "query"]
|
||||
|
||||
|
||||
class Bm25EmbeddingFunctionQueryConfig(TypedDict):
|
||||
task: TaskType
|
||||
|
||||
|
||||
class Bm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
avg_len: Optional[float] = None,
|
||||
task: Optional[TaskType] = "document",
|
||||
cache_dir: Optional[str] = None,
|
||||
k: Optional[float] = None,
|
||||
b: Optional[float] = None,
|
||||
language: Optional[str] = None,
|
||||
token_max_length: Optional[int] = None,
|
||||
disable_stemmer: Optional[bool] = None,
|
||||
specific_model_path: Optional[str] = None,
|
||||
query_config: Optional[Bm25EmbeddingFunctionQueryConfig] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize SparseEncoderEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
avg_len(float, optional): The average length of the documents in the corpus.
|
||||
task (str, optional): Task to perform, can be "document" or "query"
|
||||
cache_dir (str, optional): The path to the cache directory.
|
||||
k (float, optional): The k parameter in the BM25 formula. Defines the saturation of the term frequency.
|
||||
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
|
||||
language (str, optional): Specifies the language for the stemmer.
|
||||
token_max_length (int, optional): The maximum length of the tokens.
|
||||
disable_stemmer (bool, optional): Disable the stemmer.
|
||||
specific_model_path (str, optional): The path to the specific model.
|
||||
query_config (dict, optional): Configuration for the query, can be "task"
|
||||
**kwargs: Additional arguments to pass to the Bm25 model.
|
||||
"""
|
||||
warnings.warn(
|
||||
"Bm25EmbeddingFunction is deprecated. Please use ChromaBm25EmbeddingFunction instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
try:
|
||||
from fastembed.sparse.bm25 import Bm25
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
|
||||
)
|
||||
|
||||
self.task = task
|
||||
self.query_config = query_config
|
||||
self.cache_dir = cache_dir
|
||||
self.k = k
|
||||
self.b = b
|
||||
self.avg_len = avg_len
|
||||
self.language = language
|
||||
self.token_max_length = token_max_length
|
||||
self.disable_stemmer = disable_stemmer
|
||||
self.specific_model_path = specific_model_path
|
||||
for key, value in kwargs.items():
|
||||
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
|
||||
raise ValueError(f"Keyword argument {key} is not a primitive type")
|
||||
self.kwargs = kwargs
|
||||
bm25_kwargs = {
|
||||
"model_name": "Qdrant/bm25",
|
||||
}
|
||||
optional_params = {
|
||||
"cache_dir": cache_dir,
|
||||
"k": k,
|
||||
"b": b,
|
||||
"avg_len": avg_len,
|
||||
"language": language,
|
||||
"token_max_length": token_max_length,
|
||||
"disable_stemmer": disable_stemmer,
|
||||
"specific_model_path": specific_model_path,
|
||||
}
|
||||
for key, value in optional_params.items():
|
||||
if value is not None:
|
||||
bm25_kwargs[key] = value
|
||||
bm25_kwargs.update({k: v for k, v in kwargs.items() if v is not None})
|
||||
self._model = Bm25(**bm25_kwargs)
|
||||
|
||||
def __call__(self, input: Documents) -> SparseVectors:
|
||||
"""Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
try:
|
||||
from fastembed.sparse.bm25 import Bm25
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
|
||||
)
|
||||
model = cast(Bm25, self._model)
|
||||
if self.task == "document":
|
||||
embeddings = model.embed(
|
||||
list(input),
|
||||
)
|
||||
elif self.task == "query":
|
||||
embeddings = model.query_embed(
|
||||
list(input),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid task: {self.task}")
|
||||
|
||||
sparse_vectors: SparseVectors = []
|
||||
|
||||
for vec in embeddings:
|
||||
sparse_vectors.append(
|
||||
normalize_sparse_vector(
|
||||
indices=vec.indices.tolist(), values=vec.values.tolist()
|
||||
)
|
||||
)
|
||||
|
||||
return sparse_vectors
|
||||
|
||||
def embed_query(self, input: Documents) -> SparseVectors:
|
||||
try:
|
||||
from fastembed.sparse.bm25 import Bm25
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
|
||||
)
|
||||
model = cast(Bm25, self._model)
|
||||
if self.query_config is not None:
|
||||
task = self.query_config.get("task")
|
||||
if task == "document":
|
||||
embeddings = model.embed(
|
||||
list(input),
|
||||
)
|
||||
elif task == "query":
|
||||
embeddings = model.query_embed(
|
||||
list(input),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid task: {task}")
|
||||
|
||||
sparse_vectors: SparseVectors = []
|
||||
|
||||
for vec in embeddings:
|
||||
sparse_vectors.append(
|
||||
normalize_sparse_vector(
|
||||
indices=vec.indices.tolist(), values=vec.values.tolist()
|
||||
)
|
||||
)
|
||||
|
||||
return sparse_vectors
|
||||
|
||||
else:
|
||||
return self.__call__(input)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "bm25"
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
) -> "SparseEmbeddingFunction[Documents]":
|
||||
task = config.get("task")
|
||||
query_config = config.get("query_config")
|
||||
cache_dir = config.get("cache_dir")
|
||||
k = config.get("k")
|
||||
b = config.get("b")
|
||||
avg_len = config.get("avg_len")
|
||||
language = config.get("language")
|
||||
token_max_length = config.get("token_max_length")
|
||||
disable_stemmer = config.get("disable_stemmer")
|
||||
specific_model_path = config.get("specific_model_path")
|
||||
kwargs = config.get("kwargs", {})
|
||||
|
||||
return Bm25EmbeddingFunction(
|
||||
task=task,
|
||||
query_config=query_config,
|
||||
cache_dir=cache_dir,
|
||||
k=k,
|
||||
b=b,
|
||||
avg_len=avg_len,
|
||||
language=language,
|
||||
token_max_length=token_max_length,
|
||||
disable_stemmer=disable_stemmer,
|
||||
specific_model_path=specific_model_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"task": self.task,
|
||||
"query_config": self.query_config,
|
||||
"cache_dir": self.cache_dir,
|
||||
"k": self.k,
|
||||
"b": self.b,
|
||||
"avg_len": self.avg_len,
|
||||
"language": self.language,
|
||||
"token_max_length": self.token_max_length,
|
||||
"disable_stemmer": self.disable_stemmer,
|
||||
"specific_model_path": self.specific_model_path,
|
||||
"kwargs": self.kwargs,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
# Users should be able to change the path if needed, so we should not validate that.
|
||||
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "bm25")
|
||||
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, Iterable, List, Optional, TypedDict
|
||||
|
||||
from chromadb.api.types import Documents, SparseEmbeddingFunction, SparseVectors
|
||||
from chromadb.base_types import SparseVector
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.utils.embedding_functions.schemas.bm25_tokenizer import (
|
||||
Bm25Tokenizer,
|
||||
DEFAULT_CHROMA_BM25_STOPWORDS as _DEFAULT_STOPWORDS,
|
||||
get_english_stemmer,
|
||||
Murmur3AbsHasher,
|
||||
)
|
||||
|
||||
NAME = "chroma_bm25"
|
||||
|
||||
DEFAULT_K = 1.2
|
||||
DEFAULT_B = 0.75
|
||||
DEFAULT_AVG_DOC_LENGTH = 256.0
|
||||
DEFAULT_TOKEN_MAX_LENGTH = 40
|
||||
|
||||
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)
|
||||
|
||||
|
||||
class ChromaBm25Config(TypedDict, total=False):
|
||||
k: float
|
||||
b: float
|
||||
avg_doc_length: float
|
||||
token_max_length: int
|
||||
stopwords: List[str]
|
||||
|
||||
|
||||
class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
k: float = DEFAULT_K,
|
||||
b: float = DEFAULT_B,
|
||||
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
|
||||
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
|
||||
stopwords: Optional[Iterable[str]] = None,
|
||||
) -> None:
|
||||
"""Initialize the BM25 sparse embedding function."""
|
||||
|
||||
self.k = float(k)
|
||||
self.b = float(b)
|
||||
self.avg_doc_length = float(avg_doc_length)
|
||||
self.token_max_length = int(token_max_length)
|
||||
|
||||
if stopwords is not None:
|
||||
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
|
||||
stopword_list: Iterable[str] = self.stopwords
|
||||
else:
|
||||
self.stopwords = None
|
||||
stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
|
||||
|
||||
stemmer = get_english_stemmer()
|
||||
self._tokenizer = Bm25Tokenizer(stemmer, stopword_list, self.token_max_length)
|
||||
self._hasher = Murmur3AbsHasher()
|
||||
|
||||
def _encode(self, text: str) -> SparseVector:
|
||||
tokens = self._tokenizer.tokenize(text)
|
||||
|
||||
if not tokens:
|
||||
return SparseVector(indices=[], values=[])
|
||||
|
||||
doc_len = float(len(tokens))
|
||||
counts = Counter(self._hasher.hash(token) for token in tokens)
|
||||
|
||||
indices = sorted(counts.keys())
|
||||
values: List[float] = []
|
||||
for idx in indices:
|
||||
tf = float(counts[idx])
|
||||
denominator = tf + self.k * (
|
||||
1 - self.b + (self.b * doc_len) / self.avg_doc_length
|
||||
)
|
||||
score = tf * (self.k + 1) / denominator
|
||||
values.append(score)
|
||||
|
||||
return SparseVector(indices=indices, values=values)
|
||||
|
||||
def __call__(self, input: Documents) -> SparseVectors:
|
||||
sparse_vectors: SparseVectors = []
|
||||
|
||||
if not input:
|
||||
return sparse_vectors
|
||||
|
||||
for document in input:
|
||||
sparse_vectors.append(self._encode(document))
|
||||
|
||||
return sparse_vectors
|
||||
|
||||
def embed_query(self, input: Documents) -> SparseVectors:
|
||||
return self.__call__(input)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return NAME
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
) -> "SparseEmbeddingFunction[Documents]":
|
||||
return ChromaBm25EmbeddingFunction(
|
||||
k=config.get("k", DEFAULT_K),
|
||||
b=config.get("b", DEFAULT_B),
|
||||
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
|
||||
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
|
||||
stopwords=config.get("stopwords"),
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
config: Dict[str, Any] = {
|
||||
"k": self.k,
|
||||
"b": self.b,
|
||||
"avg_doc_length": self.avg_doc_length,
|
||||
"token_max_length": self.token_max_length,
|
||||
}
|
||||
|
||||
if self.stopwords is not None:
|
||||
config["stopwords"] = list(self.stopwords)
|
||||
|
||||
return config
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords"}
|
||||
for key in new_config:
|
||||
if key not in mutable_keys:
|
||||
raise ValueError(f"Updating '{key}' is not supported for {NAME}")
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
validate_config_schema(config, NAME)
|
||||
@@ -0,0 +1,213 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from typing import List, Dict, Any, Union
|
||||
import os
|
||||
import numpy as np
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ChromaCloudQwenEmbeddingModel(Enum):
|
||||
QWEN3_EMBEDDING_0p6B = "Qwen/Qwen3-Embedding-0.6B"
|
||||
|
||||
|
||||
class ChromaCloudQwenEmbeddingTarget(Enum):
|
||||
DOCUMENTS = "documents"
|
||||
QUERY = "query"
|
||||
|
||||
|
||||
ChromaCloudQwenEmbeddingInstructions = Dict[
|
||||
str, Dict[ChromaCloudQwenEmbeddingTarget, str]
|
||||
]
|
||||
|
||||
CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS: ChromaCloudQwenEmbeddingInstructions = {
|
||||
"nl_to_code": {
|
||||
ChromaCloudQwenEmbeddingTarget.DOCUMENTS: "",
|
||||
# Taken from https://github.com/QwenLM/Qwen3-Embedding/blob/main/evaluation/task_prompts.json
|
||||
ChromaCloudQwenEmbeddingTarget.QUERY: "Given a question about coding, retrieval code or passage that can solve user's question",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
model: ChromaCloudQwenEmbeddingModel,
|
||||
task: str,
|
||||
instructions: ChromaCloudQwenEmbeddingInstructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
|
||||
api_key_env_var: str = "CHROMA_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the ChromaCloudQwenEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model (ChromaCloudQwenEmbeddingModel): The specific Qwen model to use for embeddings.
|
||||
task (str): The task for which embeddings are being generated.
|
||||
instructions (ChromaCloudQwenEmbeddingInstructions, optional): A dictionary containing
|
||||
custom instructions to use for the specified Qwen model. Defaults to CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS.
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key.
|
||||
Defaults to "CHROMA_API_KEY".
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The httpx python package is not installed. Please install it with `pip install httpx`"
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model = model
|
||||
self.task = task
|
||||
self.instructions = instructions
|
||||
|
||||
self._api_url = "https://embed.trychroma.com"
|
||||
self._session = httpx.Client()
|
||||
self._session.headers.update(
|
||||
{
|
||||
"x-chroma-token": self.api_key,
|
||||
"x-chroma-embedding-model": self.model.value,
|
||||
}
|
||||
)
|
||||
|
||||
def _parse_response(self, response: Any) -> Embeddings:
|
||||
"""
|
||||
Convert the response from the Chroma Embedding API to a list of numpy arrays.
|
||||
|
||||
Args:
|
||||
response (Any): The response from the Chroma Embedding API.
|
||||
|
||||
Returns:
|
||||
Embeddings: A list of numpy arrays representing the embeddings.
|
||||
"""
|
||||
if "embeddings" not in response:
|
||||
raise RuntimeError(response.get("error", "Unknown error"))
|
||||
|
||||
embeddings: List[List[float]] = response["embeddings"]
|
||||
|
||||
return np.array(embeddings, dtype=np.float32)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
if not input:
|
||||
return []
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": self.instructions[self.task][
|
||||
ChromaCloudQwenEmbeddingTarget.DOCUMENTS
|
||||
],
|
||||
"texts": input,
|
||||
}
|
||||
|
||||
response = self._session.post(self._api_url, json=payload, timeout=60).json()
|
||||
|
||||
return self._parse_response(response)
|
||||
|
||||
def embed_query(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a query input.
|
||||
"""
|
||||
if not input:
|
||||
return []
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": self.instructions[self.task][
|
||||
ChromaCloudQwenEmbeddingTarget.QUERY
|
||||
],
|
||||
"texts": input,
|
||||
}
|
||||
|
||||
response = self._session.post(self._api_url, json=payload, timeout=60).json()
|
||||
|
||||
return self._parse_response(response)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "chroma-cloud-qwen"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
model = config.get("model")
|
||||
task = config.get("task")
|
||||
instructions = config.get("instructions")
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
|
||||
if model is None or task is None:
|
||||
assert False, "Config is missing a required field"
|
||||
|
||||
# Deserialize instructions dict from string keys to enum keys
|
||||
deserialized_instructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS
|
||||
if instructions is not None:
|
||||
deserialized_instructions = {}
|
||||
for task_key, targets in instructions.items():
|
||||
deserialized_instructions[task_key] = {}
|
||||
for target_key, instruction in targets.items():
|
||||
# Convert string key to enum
|
||||
target_enum = ChromaCloudQwenEmbeddingTarget(target_key)
|
||||
deserialized_instructions[task_key][target_enum] = instruction
|
||||
deserialized_instructions[task_key][target_enum] = instruction
|
||||
|
||||
return ChromaCloudQwenEmbeddingFunction(
|
||||
model=ChromaCloudQwenEmbeddingModel(model),
|
||||
task=task,
|
||||
instructions=deserialized_instructions,
|
||||
api_key_env_var=api_key_env_var or "CHROMA_API_KEY",
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
# Serialize instructions dict with enum keys to string keys for JSON compatibility
|
||||
serialized_instructions = {
|
||||
task: {target.value: instruction for target, instruction in targets.items()}
|
||||
for task, targets in self.instructions.items()
|
||||
}
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model": self.model.value,
|
||||
"task": self.task,
|
||||
"instructions": serialized_instructions,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model" in new_config:
|
||||
raise ValueError(
|
||||
"The model cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
elif "task" in new_config:
|
||||
raise ValueError(
|
||||
"The task cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
elif "instructions" in new_config:
|
||||
raise ValueError(
|
||||
"The instructions cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "chroma-cloud-qwen")
|
||||
@@ -0,0 +1,159 @@
|
||||
from chromadb.api.types import (
|
||||
SparseEmbeddingFunction,
|
||||
SparseVectors,
|
||||
Documents,
|
||||
)
|
||||
from typing import Dict, Any
|
||||
from enum import Enum
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
|
||||
from chromadb.base_types import SparseVector
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
|
||||
class ChromaCloudSpladeEmbeddingModel(Enum):
|
||||
SPLADE_PP_EN_V1 = "prithivida/Splade_PP_en_v1"
|
||||
|
||||
|
||||
class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
api_key_env_var: str = "CHROMA_API_KEY",
|
||||
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
|
||||
):
|
||||
"""
|
||||
Initialize the ChromaCloudSpladeEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key.
|
||||
Defaults to "CHROMA_API_KEY".
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The httpx python package is not installed. Please install it with `pip install httpx`"
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
f"API key not found in environment variable {self.api_key_env_var}"
|
||||
)
|
||||
self.model = model
|
||||
self._api_url = "https://embed.trychroma.com/embed_sparse"
|
||||
self._session = httpx.Client()
|
||||
self._session.headers.update(
|
||||
{
|
||||
"x-chroma-token": self.api_key,
|
||||
"x-chroma-embedding-model": self.model.value,
|
||||
}
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Cleanup the HTTP client session when the object is destroyed.
|
||||
"""
|
||||
if hasattr(self, "_session"):
|
||||
self._session.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Explicitly close the HTTP client session.
|
||||
Call this method when you're done using the embedding function.
|
||||
"""
|
||||
if hasattr(self, "_session"):
|
||||
self._session.close()
|
||||
|
||||
def __call__(self, input: Documents) -> SparseVectors:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input (Documents): The documents to generate embeddings for.
|
||||
"""
|
||||
if not input:
|
||||
return []
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"texts": list(input),
|
||||
"task": "",
|
||||
"target": "",
|
||||
}
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
response = self._session.post(self._api_url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
return self._parse_response(json_response)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to get embeddings from Chroma Cloud API: HTTP {e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise RuntimeError("Request to Chroma Cloud API timed out after 60 seconds")
|
||||
except httpx.HTTPError as e:
|
||||
raise RuntimeError(f"Failed to get embeddings from Chroma Cloud API: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error calling Chroma Cloud API: {e}")
|
||||
|
||||
def _parse_response(self, response: Any) -> SparseVectors:
|
||||
"""
|
||||
Parse the response from the Chroma Cloud Sparse Embedding API.
|
||||
"""
|
||||
raw_embeddings = response["embeddings"]
|
||||
|
||||
# Normalize each sparse vector (sort indices and validate)
|
||||
normalized_vectors: SparseVectors = []
|
||||
for emb in raw_embeddings:
|
||||
# Handle both dict format and SparseVector format
|
||||
if isinstance(emb, dict):
|
||||
indices = emb.get("indices", [])
|
||||
values = emb.get("values", [])
|
||||
else:
|
||||
# Already a SparseVector, extract its data
|
||||
indices = emb.indices
|
||||
values = emb.values
|
||||
|
||||
normalized_vectors.append(
|
||||
normalize_sparse_vector(indices=indices, values=values)
|
||||
)
|
||||
|
||||
return normalized_vectors
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "chroma-cloud-splade"
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
) -> "SparseEmbeddingFunction[Documents]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model = config.get("model")
|
||||
if model is None:
|
||||
raise ValueError("model must be provided in config")
|
||||
if not api_key_env_var:
|
||||
raise ValueError("api_key_env_var must be provided in config")
|
||||
return ChromaCloudSpladeEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var,
|
||||
model=ChromaCloudSpladeEmbeddingModel(model),
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"api_key_env_var": self.api_key_env_var, "model": self.model.value}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model" in new_config:
|
||||
raise ValueError(
|
||||
"model cannot be changed after the embedding function has been initialized"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
validate_config_schema(config, "chroma-cloud-splade")
|
||||
@@ -0,0 +1,171 @@
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Embeddings,
|
||||
Images,
|
||||
Embeddable,
|
||||
EmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import List, Dict, Any, Union, cast, Sequence
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_langchain_embedding(
|
||||
langchain_embedding_fn: Any,
|
||||
) -> "ChromaLangchainEmbeddingFunction":
|
||||
"""
|
||||
Create a ChromaLangchainEmbeddingFunction from a langchain embedding function.
|
||||
|
||||
Args:
|
||||
langchain_embedding_fn: The langchain embedding function to use.
|
||||
|
||||
Returns:
|
||||
A ChromaLangchainEmbeddingFunction that wraps the langchain embedding function.
|
||||
"""
|
||||
|
||||
return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embedding_fn)
|
||||
|
||||
|
||||
class ChromaLangchainEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"""
|
||||
This class is used as bridge between langchain embedding functions and custom chroma embedding functions.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_function: Any) -> None:
|
||||
"""
|
||||
Initialize the ChromaLangchainEmbeddingFunction
|
||||
|
||||
Args:
|
||||
embedding_function: The embedding function implementing Embeddings from langchain_core.
|
||||
"""
|
||||
try:
|
||||
import langchain_core.embeddings
|
||||
|
||||
LangchainEmbeddings = langchain_core.embeddings.Embeddings
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The langchain_core python package is not installed. Please install it with `pip install langchain-core`"
|
||||
)
|
||||
|
||||
if not isinstance(embedding_function, LangchainEmbeddings):
|
||||
raise ValueError(
|
||||
"The embedding_function must implement the Embeddings interface from langchain_core."
|
||||
)
|
||||
|
||||
self.embedding_function = embedding_function
|
||||
|
||||
# Store the class name for serialization
|
||||
self._embedding_function_class = embedding_function.__class__.__name__
|
||||
|
||||
def embed_documents(self, documents: Sequence[str]) -> List[List[float]]:
|
||||
"""
|
||||
Embed documents using the langchain embedding function.
|
||||
|
||||
Args:
|
||||
documents: The documents to embed.
|
||||
|
||||
Returns:
|
||||
The embeddings for the documents.
|
||||
"""
|
||||
return cast(
|
||||
List[List[float]], self.embedding_function.embed_documents(list(documents))
|
||||
)
|
||||
|
||||
def embed_query(self, query: str) -> List[float]:
|
||||
"""
|
||||
Embed a query using the langchain embedding function.
|
||||
|
||||
Args:
|
||||
query: The query to embed.
|
||||
|
||||
Returns:
|
||||
The embedding for the query.
|
||||
"""
|
||||
return cast(List[float], self.embedding_function.embed_query(query))
|
||||
|
||||
def embed_image(self, uris: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Embed images using the langchain embedding function.
|
||||
|
||||
Args:
|
||||
uris: The URIs of the images to embed.
|
||||
|
||||
Returns:
|
||||
The embeddings for the images.
|
||||
"""
|
||||
if hasattr(self.embedding_function, "embed_image"):
|
||||
return cast(List[List[float]], self.embedding_function.embed_image(uris))
|
||||
else:
|
||||
raise ValueError(
|
||||
"The provided embedding function does not support image embeddings."
|
||||
)
|
||||
|
||||
def __call__(self, input: Union[Documents, Images]) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a list of texts or images.
|
||||
|
||||
Args:
|
||||
input: A list of texts or images to get embeddings for.
|
||||
Images should be provided as a list of URIs passed through the langchain data loader
|
||||
|
||||
Returns:
|
||||
The embeddings for the texts or images.
|
||||
|
||||
Example:
|
||||
>>> from langchain_openai import OpenAIEmbeddings
|
||||
>>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large"))
|
||||
>>> texts = ["Hello, world!", "How are you?"]
|
||||
>>> embeddings = langchain_embedding(texts)
|
||||
"""
|
||||
# Due to langchain quirks, the dataloader returns a tuple if the input is uris of images
|
||||
if isinstance(input, tuple) and len(input) == 2 and input[0] == "images":
|
||||
embeddings = self.embed_image(list(input[1]))
|
||||
else:
|
||||
# Cast to Sequence[str] to satisfy the type checker
|
||||
embeddings = self.embed_documents(cast(Sequence[str], input))
|
||||
|
||||
# Convert to numpy arrays
|
||||
return [np.array(embedding, dtype=np.float32) for embedding in embeddings]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "langchain"
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
) -> "EmbeddingFunction[Union[Documents, Images]]":
|
||||
# This is a placeholder implementation since we can't easily serialize and deserialize
|
||||
# langchain embedding functions. Users will need to recreate the langchain embedding function
|
||||
# and pass it to create_langchain_embedding.
|
||||
raise NotImplementedError(
|
||||
"Building a ChromaLangchainEmbeddingFunction from config is not supported. "
|
||||
"Please recreate the langchain embedding function and pass it to create_langchain_embedding."
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"embedding_function_class": self._embedding_function_class,
|
||||
"note": "This is a placeholder config. You will need to recreate the langchain embedding function.",
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
raise NotImplementedError(
|
||||
"Updating a ChromaLangchainEmbeddingFunction config is not supported. "
|
||||
"Please recreate the langchain embedding function and pass it to create_langchain_embedding."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "chroma_langchain")
|
||||
@@ -0,0 +1,152 @@
|
||||
from chromadb.api.types import (
|
||||
Embeddings,
|
||||
Documents,
|
||||
EmbeddingFunction,
|
||||
Space,
|
||||
)
|
||||
from typing import List, Dict, Any, Optional
|
||||
import os
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import cast
|
||||
import warnings
|
||||
|
||||
BASE_URL = "https://api.cloudflare.com/client/v4/accounts"
|
||||
GATEWAY_BASE_URL = "https://gateway.ai.cloudflare.com/v1"
|
||||
|
||||
|
||||
class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""
|
||||
This class is used to get embeddings for a list of texts using the Cloudflare Workers AI API.
|
||||
It requires an API key and a model name.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
account_id: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_key_env_var: str = "CHROMA_CLOUDFLARE_API_KEY",
|
||||
gateway_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the CloudflareWorkersAIEmbeddingFunction. See the docs for supported models here:
|
||||
https://developers.cloudflare.com/workers-ai/models/
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to use for text embeddings.
|
||||
account_id: The account ID for the Cloudflare Workers AI API.
|
||||
api_key: The API key for the Cloudflare Workers AI API.
|
||||
api_key_env_var: The environment variable name for the Cloudflare Workers AI API key.
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The httpx python package is not installed. Please install it with `pip install httpx`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.account_id = account_id
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
self.gateway_id = gateway_id
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
if self.gateway_id:
|
||||
self._api_url = f"{GATEWAY_BASE_URL}/{self.account_id}/{self.gateway_id}/workers-ai/{self.model_name}"
|
||||
else:
|
||||
self._api_url = f"{BASE_URL}/{self.account_id}/ai/run/{self.model_name}"
|
||||
|
||||
self._session = httpx.Client()
|
||||
self._session.headers.update(
|
||||
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
raise ValueError(
|
||||
"Cloudflare Workers AI only supports text documents, not images"
|
||||
)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"text": input,
|
||||
}
|
||||
|
||||
resp = self._session.post(self._api_url, json=payload).json()
|
||||
|
||||
if "result" not in resp and "data" not in resp["result"]:
|
||||
raise RuntimeError(resp.get("detail", "Unknown error"))
|
||||
|
||||
return cast(Embeddings, resp["result"]["data"])
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "cloudflare_workers_ai"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
account_id = config.get("account_id")
|
||||
gateway_id = config.get("gateway_id", None)
|
||||
if api_key_env_var is None or model_name is None or account_id is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return CloudflareWorkersAIEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var,
|
||||
model_name=model_name,
|
||||
account_id=account_id,
|
||||
gateway_id=gateway_id,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model_name": self.model_name,
|
||||
"account_id": self.account_id,
|
||||
"gateway_id": self.gateway_id,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "cloudflare_workers_ai")
|
||||
@@ -0,0 +1,176 @@
|
||||
from chromadb.api.types import (
|
||||
Embeddings,
|
||||
Embeddable,
|
||||
EmbeddingFunction,
|
||||
Space,
|
||||
is_image,
|
||||
is_document,
|
||||
)
|
||||
from typing import List, Dict, Any, Optional
|
||||
import os
|
||||
import numpy as np
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
import base64
|
||||
import io
|
||||
import importlib
|
||||
import warnings
|
||||
|
||||
|
||||
class CohereEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "large",
|
||||
api_key_env_var: str = "CHROMA_COHERE_API_KEY",
|
||||
):
|
||||
try:
|
||||
import cohere
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The cohere python package is not installed. Please install it with `pip install cohere`"
|
||||
)
|
||||
|
||||
try:
|
||||
self._PILImage = importlib.import_module("PIL.Image")
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The PIL python package is not installed. Please install it with `pip install pillow`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.client = cohere.Client(self.api_key)
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
|
||||
# Cohere works with images. if all are texts, return the embeddings for the texts
|
||||
if all(is_document(item) for item in input):
|
||||
return [
|
||||
np.array(embeddings, dtype=np.float32)
|
||||
for embeddings in self.client.embed(
|
||||
texts=[str(item) for item in input],
|
||||
model=self.model_name,
|
||||
input_type="search_document",
|
||||
).embeddings
|
||||
]
|
||||
|
||||
elif all(is_image(item) for item in input):
|
||||
base64_images = []
|
||||
for image_np in input:
|
||||
if not isinstance(image_np, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Expected image input to be a numpy array, got {type(image_np)}"
|
||||
)
|
||||
|
||||
try:
|
||||
pil_image = self._PILImage.fromarray(image_np)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
img_bytes = buffer.getvalue()
|
||||
|
||||
# Encode bytes to base64 string
|
||||
base64_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
data_uri = f"data:image/png;base64,{base64_string}"
|
||||
base64_images.append(data_uri)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to convert image numpy array to base64 data URI: {e}"
|
||||
) from e
|
||||
|
||||
return [
|
||||
np.array(embeddings, dtype=np.float32)
|
||||
for embeddings in self.client.embed(
|
||||
images=base64_images,
|
||||
model=self.model_name,
|
||||
input_type="image",
|
||||
).embeddings
|
||||
]
|
||||
else:
|
||||
# Check if it's a mix or neither
|
||||
has_texts = any(is_document(item) for item in input)
|
||||
has_images = any(is_image(item) for item in input)
|
||||
if has_texts and has_images:
|
||||
raise ValueError(
|
||||
"Input contains a mix of text documents and images, which is not supported. Provide either all texts or all images."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input must be a list of text documents (str) or a list of images (numpy arrays)."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "cohere"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
if self.model_name == "embed-multilingual-v2.0":
|
||||
return "ip"
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
if self.model_name == "embed-english-v2.0":
|
||||
return ["cosine"]
|
||||
elif self.model_name == "embed-english-light-v2.0":
|
||||
return ["cosine"]
|
||||
elif self.model_name == "embed-multilingual-v2.0":
|
||||
return ["ip"]
|
||||
else:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
if api_key_env_var is None or model_name is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return CohereEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var, model_name=model_name
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"api_key_env_var": self.api_key_env_var, "model_name": self.model_name}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "cohere")
|
||||
@@ -0,0 +1,205 @@
|
||||
from chromadb.api.types import (
|
||||
SparseEmbeddingFunction,
|
||||
SparseVectors,
|
||||
Documents,
|
||||
)
|
||||
from typing import Dict, Any, TypedDict, Optional
|
||||
from typing import cast, Literal
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
|
||||
|
||||
TaskType = Literal["document", "query"]
|
||||
|
||||
|
||||
class FastembedSparseEmbeddingFunctionQueryConfig(TypedDict):
|
||||
task: TaskType
|
||||
|
||||
|
||||
class FastembedSparseEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
task: Optional[TaskType] = "document",
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
cuda: Optional[bool] = None,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
lazy_load: Optional[bool] = None,
|
||||
query_config: Optional[FastembedSparseEmbeddingFunctionQueryConfig] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize SparseEncoderEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): Identifier of the Fastembed model
|
||||
List of commonly used models: Qdrant/bm25, prithivida/Splade_PP_en_v1, Qdrant/minicoil-v1
|
||||
task (str, optional): Task to perform, can be "document" or "query"
|
||||
cache_dir (str, optional): The path to the cache directory.
|
||||
threads (int, optional): The number of threads to use for the model.
|
||||
cuda (bool, optional): Whether to use CUDA.
|
||||
device_ids (list[int], optional): The device IDs to use for the model.
|
||||
lazy_load (bool, optional): Whether to lazy load the model.
|
||||
query_config (dict, optional): Configuration for the query, can be "task"
|
||||
**kwargs: Additional arguments to pass to the model.
|
||||
"""
|
||||
try:
|
||||
from fastembed import SparseTextEmbedding
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
|
||||
)
|
||||
|
||||
self.task = task
|
||||
self.query_config = query_config
|
||||
self.model_name = model_name
|
||||
self.cache_dir = cache_dir
|
||||
self.threads = threads
|
||||
self.cuda = cuda
|
||||
self.device_ids = device_ids
|
||||
self.lazy_load = lazy_load
|
||||
for key, value in kwargs.items():
|
||||
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
|
||||
raise ValueError(f"Keyword argument {key} is not a primitive type")
|
||||
self.kwargs = kwargs
|
||||
self._model = SparseTextEmbedding(
|
||||
model_name, cache_dir, threads, cuda, device_ids, lazy_load, **kwargs
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> SparseVectors:
|
||||
"""Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
try:
|
||||
from fastembed import SparseTextEmbedding
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
|
||||
)
|
||||
model = cast(SparseTextEmbedding, self._model)
|
||||
if self.task == "document":
|
||||
embeddings = model.embed(
|
||||
list(input),
|
||||
)
|
||||
elif self.task == "query":
|
||||
embeddings = model.query_embed(
|
||||
list(input),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid task: {self.task}")
|
||||
|
||||
sparse_vectors: SparseVectors = []
|
||||
|
||||
for vec in embeddings:
|
||||
sparse_vectors.append(
|
||||
normalize_sparse_vector(
|
||||
indices=vec.indices.tolist(), values=vec.values.tolist()
|
||||
)
|
||||
)
|
||||
|
||||
return sparse_vectors
|
||||
|
||||
def embed_query(self, input: Documents) -> SparseVectors:
|
||||
try:
|
||||
from fastembed import SparseTextEmbedding
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
|
||||
)
|
||||
model = cast(SparseTextEmbedding, self._model)
|
||||
if self.query_config is not None:
|
||||
task = self.query_config.get("task")
|
||||
if task == "document":
|
||||
embeddings = model.embed(
|
||||
list(input),
|
||||
)
|
||||
elif task == "query":
|
||||
embeddings = model.query_embed(
|
||||
list(input),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid task: {task}")
|
||||
|
||||
sparse_vectors: SparseVectors = []
|
||||
|
||||
for vec in embeddings:
|
||||
sparse_vectors.append(
|
||||
normalize_sparse_vector(
|
||||
indices=vec.indices.tolist(), values=vec.values.tolist()
|
||||
)
|
||||
)
|
||||
|
||||
return sparse_vectors
|
||||
|
||||
else:
|
||||
return self.__call__(input)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "fastembed_sparse"
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
) -> "SparseEmbeddingFunction[Documents]":
|
||||
model_name = config.get("model_name")
|
||||
task = config.get("task")
|
||||
query_config = config.get("query_config")
|
||||
cache_dir = config.get("cache_dir")
|
||||
threads = config.get("threads")
|
||||
cuda = config.get("cuda")
|
||||
device_ids = config.get("device_ids")
|
||||
lazy_load = config.get("lazy_load")
|
||||
kwargs = config.get("kwargs", {})
|
||||
if model_name is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return FastembedSparseEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
task=task,
|
||||
query_config=query_config,
|
||||
cache_dir=cache_dir,
|
||||
threads=threads,
|
||||
cuda=cuda,
|
||||
device_ids=device_ids,
|
||||
lazy_load=lazy_load,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"task": self.task,
|
||||
"query_config": self.query_config,
|
||||
"cache_dir": self.cache_dir,
|
||||
"threads": self.threads,
|
||||
"cuda": self.cuda,
|
||||
"device_ids": self.device_ids,
|
||||
"lazy_load": self.lazy_load,
|
||||
"kwargs": self.kwargs,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
# model_name is also used as the identifier for model path if stored locally.
|
||||
# Users should be able to change the path if needed, so we should not validate that.
|
||||
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "fastembed_sparse")
|
||||
@@ -0,0 +1,395 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from typing import List, Dict, Any, cast, Optional
|
||||
import os
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
import warnings
|
||||
|
||||
|
||||
class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "models/embedding-gecko-001",
|
||||
api_key_env_var: str = "CHROMA_GOOGLE_PALM_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the GooglePalmEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google PaLM API.
|
||||
Defaults to "CHROMA_GOOGLE_PALM_API_KEY".
|
||||
model_name (str, optional): The name of the model to use for text embeddings.
|
||||
Defaults to "models/embedding-gecko-001".
|
||||
"""
|
||||
try:
|
||||
import google.generativeai as palm
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
palm.configure(api_key=self.api_key)
|
||||
self._palm = palm
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
# Google PaLM only works with text documents
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
raise ValueError("Google PaLM only supports text documents, not images")
|
||||
|
||||
return [
|
||||
np.array(
|
||||
self._palm.generate_embeddings(model=self.model_name, text=text)[
|
||||
"embedding"
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
for text in input
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "google_palm"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
|
||||
if api_key_env_var is None or model_name is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return GooglePalmEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var, model_name=model_name
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"api_key_env_var": self.api_key_env_var, "model_name": self.model_name}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "google_palm")
|
||||
|
||||
|
||||
class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "models/embedding-001",
|
||||
task_type: str = "RETRIEVAL_DOCUMENT",
|
||||
api_key_env_var: str = "CHROMA_GOOGLE_GENAI_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the GoogleGenerativeAiEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google Generative AI API.
|
||||
Defaults to "CHROMA_GOOGLE_GENAI_API_KEY".
|
||||
model_name (str, optional): The name of the model to use for text embeddings.
|
||||
Defaults to "models/embedding-001".
|
||||
task_type (str, optional): The task type for the embeddings.
|
||||
Use "RETRIEVAL_DOCUMENT" for embedding documents and "RETRIEVAL_QUERY" for embedding queries.
|
||||
Defaults to "RETRIEVAL_DOCUMENT".
|
||||
"""
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model_name = model_name
|
||||
self.task_type = task_type
|
||||
|
||||
genai.configure(api_key=self.api_key)
|
||||
self._genai = genai
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
# Google Generative AI only works with text documents
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
raise ValueError(
|
||||
"Google Generative AI only supports text documents, not images"
|
||||
)
|
||||
|
||||
embeddings_list: List[npt.NDArray[np.float32]] = []
|
||||
for text in input:
|
||||
embedding_result = self._genai.embed_content(
|
||||
model=self.model_name,
|
||||
content=text,
|
||||
task_type=self.task_type,
|
||||
)
|
||||
embeddings_list.append(
|
||||
np.array(embedding_result["embedding"], dtype=np.float32)
|
||||
)
|
||||
|
||||
# Convert to the expected Embeddings type (List[Vector])
|
||||
return cast(Embeddings, embeddings_list)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "google_generative_ai"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
task_type = config.get("task_type")
|
||||
|
||||
if api_key_env_var is None or model_name is None or task_type is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return GoogleGenerativeAiEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var, model_name=model_name, task_type=task_type
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model_name": self.model_name,
|
||||
"task_type": self.task_type,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "task_type" in new_config:
|
||||
raise ValueError(
|
||||
"The task type cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "google_generative_ai")
|
||||
|
||||
|
||||
class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""To use this EmbeddingFunction, you must have the vertexai Python package installed and have Google Cloud credentials configured."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "textembedding-gecko",
|
||||
project_id: str = "cloud-large-language-models",
|
||||
region: str = "us-central1",
|
||||
api_key_env_var: str = "CHROMA_GOOGLE_VERTEX_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the GoogleVertexEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google Vertex AI API.
|
||||
Defaults to "CHROMA_GOOGLE_VERTEX_API_KEY".
|
||||
model_name (str, optional): The name of the model to use for text embeddings.
|
||||
Defaults to "textembedding-gecko".
|
||||
project_id (str, optional): The Google Cloud project ID.
|
||||
Defaults to "cloud-large-language-models".
|
||||
region (str, optional): The Google Cloud region.
|
||||
Defaults to "us-central1".
|
||||
"""
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The vertexai python package is not installed. Please install it with `pip install google-cloud-aiplatform`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model_name = model_name
|
||||
self.project_id = project_id
|
||||
self.region = region
|
||||
|
||||
vertexai.init(project=project_id, location=region)
|
||||
self._model = TextEmbeddingModel.from_pretrained(model_name)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
# Google Vertex only works with text documents
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
raise ValueError("Google Vertex only supports text documents, not images")
|
||||
|
||||
embeddings_list: List[npt.NDArray[np.float32]] = []
|
||||
for text in input:
|
||||
embedding_result = self._model.get_embeddings([text])
|
||||
embeddings_list.append(
|
||||
np.array(embedding_result[0].values, dtype=np.float32)
|
||||
)
|
||||
|
||||
# Convert to the expected Embeddings type (List[Vector])
|
||||
return cast(Embeddings, embeddings_list)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "google_vertex"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
project_id = config.get("project_id")
|
||||
region = config.get("region")
|
||||
|
||||
if (
|
||||
api_key_env_var is None
|
||||
or model_name is None
|
||||
or project_id is None
|
||||
or region is None
|
||||
):
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var,
|
||||
model_name=model_name,
|
||||
project_id=project_id,
|
||||
region=region,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model_name": self.model_name,
|
||||
"project_id": self.project_id,
|
||||
"region": self.region,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "project_id" in new_config:
|
||||
raise ValueError(
|
||||
"The project ID cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "region" in new_config:
|
||||
raise ValueError(
|
||||
"The region cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "google_vertex")
|
||||
@@ -0,0 +1,236 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from typing import List, Dict, Any, Optional
|
||||
import os
|
||||
import numpy as np
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
import warnings
|
||||
|
||||
|
||||
class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""
|
||||
This class is used to get embeddings for a list of texts using the HuggingFace API.
|
||||
It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
||||
api_key_env_var: str = "CHROMA_HUGGINGFACE_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the HuggingFaceEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the HuggingFace API.
|
||||
Defaults to "CHROMA_HUGGINGFACE_API_KEY".
|
||||
model_name (str, optional): The name of the model to use for text embeddings.
|
||||
Defaults to "sentence-transformers/all-MiniLM-L6-v2".
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The httpx python package is not installed. Please install it with `pip install httpx`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self._api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
|
||||
self._session = httpx.Client()
|
||||
self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
input (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings for the texts.
|
||||
|
||||
Example:
|
||||
>>> hugging_face = HuggingFaceEmbeddingFunction(api_key_env_var="CHROMA_HUGGINGFACE_API_KEY")
|
||||
>>> texts = ["Hello, world!", "How are you?"]
|
||||
>>> embeddings = hugging_face(texts)
|
||||
"""
|
||||
# Call HuggingFace Embedding API for each document
|
||||
response = self._session.post(
|
||||
self._api_url,
|
||||
json={"inputs": input, "options": {"wait_for_model": True}},
|
||||
).json()
|
||||
|
||||
# Convert to numpy arrays
|
||||
return [np.array(embedding, dtype=np.float32) for embedding in response]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "huggingface"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
|
||||
if api_key_env_var is None or model_name is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return HuggingFaceEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var, model_name=model_name
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"api_key_env_var": self.api_key_env_var, "model_name": self.model_name}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "huggingface")
|
||||
|
||||
|
||||
class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]):
|
||||
"""
|
||||
This class is used to get embeddings for a list of texts using the HuggingFace Embedding server
|
||||
(https://github.com/huggingface/text-embeddings-inference).
|
||||
The embedding model is configured in the server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
api_key_env_var: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the HuggingFaceEmbeddingServer.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the HuggingFace Embedding Server.
|
||||
api_key (Optional[str]): The API key for the HuggingFace Embedding Server.
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the HuggingFace API.
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The httpx python package is not installed. Please install it with `pip install httpx`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.url = url
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
if self.api_key_env_var is not None:
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
else:
|
||||
self.api_key = api_key
|
||||
|
||||
self._api_url = f"{url}"
|
||||
self._session = httpx.Client()
|
||||
|
||||
if self.api_key is not None:
|
||||
self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
input (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings for the texts.
|
||||
|
||||
Example:
|
||||
>>> hugging_face = HuggingFaceEmbeddingServer(url="http://localhost:8080/embed")
|
||||
>>> texts = ["Hello, world!", "How are you?"]
|
||||
>>> embeddings = hugging_face(texts)
|
||||
"""
|
||||
# Call HuggingFace Embedding Server API for each document
|
||||
response = self._session.post(self._api_url, json={"inputs": input}).json()
|
||||
|
||||
# Convert to numpy arrays
|
||||
return [np.array(embedding, dtype=np.float32) for embedding in response]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "huggingface_server"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
url = config.get("url")
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
if url is None:
|
||||
raise ValueError("URL must be provided for HuggingFaceEmbeddingServer")
|
||||
|
||||
return HuggingFaceEmbeddingServer(url=url, api_key_env_var=api_key_env_var)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"url": self.url, "api_key_env_var": self.api_key_env_var}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "url" in new_config and new_config["url"] != self.url:
|
||||
raise ValueError(
|
||||
"The URL cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "huggingface_server")
|
||||
@@ -0,0 +1,200 @@
|
||||
from chromadb.api.types import (
|
||||
SparseEmbeddingFunction,
|
||||
SparseVectors,
|
||||
Documents,
|
||||
)
|
||||
from typing import Dict, Any, TypedDict, Optional
|
||||
import numpy as np
|
||||
from typing import cast, Literal
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
|
||||
|
||||
TaskType = Literal["document", "query"]
|
||||
|
||||
|
||||
class HuggingFaceSparseEmbeddingFunctionQueryConfig(TypedDict):
|
||||
task: TaskType
|
||||
|
||||
|
||||
class HuggingFaceSparseEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
# Since we do dynamic imports we have to type this as Any
|
||||
models: Dict[str, Any] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
device: str,
|
||||
task: Optional[TaskType] = "document",
|
||||
query_config: Optional[HuggingFaceSparseEmbeddingFunctionQueryConfig] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize SparseEncoderEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): Identifier of the Huggingface SparseEncoder model
|
||||
Some common models: prithivida/Splade_PP_en_v1, naver/splade-cocondenser-ensembledistil, naver/splade-v3
|
||||
device (str, optional): Device used for computation
|
||||
**kwargs: Additional arguments to pass to the Splade model.
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SparseEncoder
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.task = task
|
||||
self.query_config = query_config
|
||||
for key, value in kwargs.items():
|
||||
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
|
||||
raise ValueError(f"Keyword argument {key} is not a primitive type")
|
||||
self.kwargs = kwargs
|
||||
|
||||
if model_name not in self.models:
|
||||
self.models[model_name] = SparseEncoder(
|
||||
model_name_or_path=model_name, device=device, **kwargs
|
||||
)
|
||||
self._model = self.models[model_name]
|
||||
|
||||
def __call__(self, input: Documents) -> SparseVectors:
|
||||
"""Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SparseEncoder
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
|
||||
)
|
||||
model = cast(SparseEncoder, self._model)
|
||||
if self.task == "document":
|
||||
embeddings = model.encode_document(
|
||||
list(input),
|
||||
)
|
||||
elif self.task == "query":
|
||||
embeddings = model.encode_query(
|
||||
list(input),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid task: {self.task}")
|
||||
|
||||
sparse_vectors: SparseVectors = []
|
||||
|
||||
for vec in embeddings:
|
||||
# Convert sparse tensor to dense array if needed
|
||||
if hasattr(vec, "to_dense"):
|
||||
vec_dense = vec.to_dense().numpy()
|
||||
else:
|
||||
vec_dense = vec.numpy() if hasattr(vec, "numpy") else np.array(vec)
|
||||
|
||||
nz = np.where(vec_dense != 0)[0]
|
||||
sparse_vectors.append(
|
||||
normalize_sparse_vector(
|
||||
indices=nz.tolist(), values=vec_dense[nz].tolist()
|
||||
)
|
||||
)
|
||||
|
||||
return sparse_vectors
|
||||
|
||||
def embed_query(self, input: Documents) -> SparseVectors:
|
||||
try:
|
||||
from sentence_transformers import SparseEncoder
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
|
||||
)
|
||||
model = cast(SparseEncoder, self._model)
|
||||
if self.query_config is not None:
|
||||
if self.query_config.get("task") == "document":
|
||||
embeddings = model.encode_document(
|
||||
list(input),
|
||||
)
|
||||
elif self.query_config.get("task") == "query":
|
||||
embeddings = model.encode_query(
|
||||
list(input),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid task: {self.query_config.get('task')}")
|
||||
|
||||
sparse_vectors: SparseVectors = []
|
||||
|
||||
for vec in embeddings:
|
||||
# Convert sparse tensor to dense array if needed
|
||||
if hasattr(vec, "to_dense"):
|
||||
vec_dense = vec.to_dense().numpy()
|
||||
else:
|
||||
vec_dense = vec.numpy() if hasattr(vec, "numpy") else np.array(vec)
|
||||
|
||||
nz = np.where(vec_dense != 0)[0]
|
||||
sparse_vectors.append(
|
||||
normalize_sparse_vector(
|
||||
indices=nz.tolist(), values=vec_dense[nz].tolist()
|
||||
)
|
||||
)
|
||||
|
||||
return sparse_vectors
|
||||
|
||||
else:
|
||||
return self.__call__(input)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "huggingface_sparse"
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
) -> "SparseEmbeddingFunction[Documents]":
|
||||
model_name = config.get("model_name")
|
||||
device = config.get("device")
|
||||
task = config.get("task")
|
||||
query_config = config.get("query_config")
|
||||
kwargs = config.get("kwargs", {})
|
||||
|
||||
if model_name is None or device is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return HuggingFaceSparseEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
task=task,
|
||||
query_config=query_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"device": self.device,
|
||||
"task": self.task,
|
||||
"query_config": self.query_config,
|
||||
"kwargs": self.kwargs,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
# model_name is also used as the identifier for model path if stored locally.
|
||||
# Users should be able to change the path if needed, so we should not validate that.
|
||||
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "huggingface_sparse")
|
||||
@@ -0,0 +1,118 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
|
||||
|
||||
class InstructorEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""
|
||||
This class is used to generate embeddings for a list of texts using the Instructor embedding model.
|
||||
"""
|
||||
|
||||
# If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda"
|
||||
# for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "hkunlp/instructor-base",
|
||||
device: str = "cpu",
|
||||
instruction: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the InstructorEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The name of the model to use for text embeddings.
|
||||
Defaults to "hkunlp/instructor-base".
|
||||
device (str, optional): The device to use for computation.
|
||||
Defaults to "cpu".
|
||||
instruction (str, optional): The instruction to use for the embeddings.
|
||||
Defaults to None.
|
||||
"""
|
||||
try:
|
||||
from InstructorEmbedding import INSTRUCTOR
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The InstructorEmbedding python package is not installed. Please install it with `pip install InstructorEmbedding`"
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.instruction = instruction
|
||||
|
||||
self._model = INSTRUCTOR(model_name_or_path=model_name, device=device)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
# Instructor only works with text documents
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
raise ValueError("Instructor only supports text documents, not images")
|
||||
|
||||
if self.instruction is None:
|
||||
embeddings = self._model.encode(input, convert_to_numpy=True)
|
||||
else:
|
||||
texts_with_instructions = [[self.instruction, text] for text in input]
|
||||
embeddings = self._model.encode(
|
||||
texts_with_instructions, convert_to_numpy=True
|
||||
)
|
||||
|
||||
# Convert to numpy arrays
|
||||
return [np.array(embedding, dtype=np.float32) for embedding in embeddings]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "instructor"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
model_name = config.get("model_name")
|
||||
device = config.get("device")
|
||||
instruction = config.get("instruction")
|
||||
|
||||
if model_name is None or device is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return InstructorEmbeddingFunction(
|
||||
model_name=model_name, device=device, instruction=instruction
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"device": self.device,
|
||||
"instruction": self.instruction,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
# model_name is also used as the identifier for model path if stored locally.
|
||||
# Users should be able to change the path if needed, so we should not validate that.
|
||||
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "instructor")
|
||||
@@ -0,0 +1,277 @@
|
||||
from chromadb.api.types import (
|
||||
Embeddings,
|
||||
EmbeddingFunction,
|
||||
Space,
|
||||
Embeddable,
|
||||
is_image,
|
||||
is_document,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import List, Dict, Any, Union, Optional, TypedDict
|
||||
import os
|
||||
import numpy as np
|
||||
import warnings
|
||||
import importlib
|
||||
import base64
|
||||
import io
|
||||
|
||||
|
||||
class JinaQueryConfig(TypedDict):
|
||||
task: str
|
||||
|
||||
|
||||
class JinaEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"""
|
||||
This class is used to get embeddings for a list of texts using the Jina AI API.
|
||||
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "jina-embeddings-v2-base-en",
|
||||
api_key_env_var: str = "CHROMA_JINA_API_KEY",
|
||||
task: Optional[str] = None,
|
||||
late_chunking: Optional[bool] = None,
|
||||
truncate: Optional[bool] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
embedding_type: Optional[str] = None,
|
||||
normalized: Optional[bool] = None,
|
||||
query_config: Optional[JinaQueryConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the JinaEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the Jina AI API.
|
||||
Defaults to "CHROMA_JINA_API_KEY".
|
||||
model_name (str, optional): The name of the model to use for text embeddings.
|
||||
Defaults to "jina-embeddings-v2-base-en".
|
||||
task (str, optional): The task to use for the Jina AI API.
|
||||
Defaults to None.
|
||||
late_chunking (bool, optional): Whether to use late chunking for the Jina AI API.
|
||||
Defaults to None.
|
||||
truncate (bool, optional): Whether to truncate the Jina AI API.
|
||||
Defaults to None.
|
||||
dimensions (int, optional): The number of dimensions to use for the Jina AI API.
|
||||
Defaults to None.
|
||||
embedding_type (str, optional): The type of embedding to use for the Jina AI API.
|
||||
Defaults to None.
|
||||
normalized (bool, optional): Whether to normalize the Jina AI API.
|
||||
Defaults to None.
|
||||
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The httpx python package is not installed. Please install it with `pip install httpx`"
|
||||
)
|
||||
try:
|
||||
self._PILImage = importlib.import_module("PIL.Image")
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The PIL python package is not installed. Please install it with `pip install pillow`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
# Initialize optional attributes to None
|
||||
self.task = task
|
||||
self.late_chunking = late_chunking
|
||||
self.truncate = truncate
|
||||
self.dimensions = dimensions
|
||||
self.embedding_type = embedding_type
|
||||
self.normalized = normalized
|
||||
self.query_config = query_config
|
||||
|
||||
self._api_url = "https://api.jina.ai/v1/embeddings"
|
||||
self._session = httpx.Client()
|
||||
self._session.headers.update(
|
||||
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
|
||||
)
|
||||
|
||||
def _build_payload(self, input: Embeddable, is_query: bool) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"input": [],
|
||||
"model": self.model_name,
|
||||
}
|
||||
if all(is_document(item) for item in input):
|
||||
payload["input"] = input
|
||||
else:
|
||||
for item in input:
|
||||
if is_document(item):
|
||||
payload["input"].append({"text": item})
|
||||
elif is_image(item):
|
||||
try:
|
||||
pil_image = self._PILImage.fromarray(item)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
img_bytes = buffer.getvalue()
|
||||
|
||||
# Encode bytes to base64 string
|
||||
base64_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to convert image numpy array to base64 data URI: {e}"
|
||||
) from e
|
||||
payload["input"].append({"image": base64_string})
|
||||
|
||||
if self.task is not None:
|
||||
payload["task"] = self.task
|
||||
if self.late_chunking is not None:
|
||||
payload["late_chunking"] = self.late_chunking
|
||||
if self.truncate is not None:
|
||||
payload["truncate"] = self.truncate
|
||||
if self.dimensions is not None:
|
||||
payload["dimensions"] = self.dimensions
|
||||
if self.embedding_type is not None:
|
||||
payload["embedding_type"] = self.embedding_type
|
||||
if self.normalized is not None:
|
||||
payload["normalized"] = self.normalized
|
||||
|
||||
# overwrite parameteres when query payload is used
|
||||
if is_query and self.query_config is not None:
|
||||
for key, value in self.query_config.items():
|
||||
payload[key] = value
|
||||
|
||||
return payload
|
||||
|
||||
def _convert_resp(self, resp: Any, is_query: bool = False) -> Embeddings:
|
||||
"""
|
||||
Convert the response from the Jina AI API to a list of numpy arrays.
|
||||
|
||||
Args:
|
||||
resp (Any): The response from the Jina AI API.
|
||||
|
||||
Returns:
|
||||
Embeddings: A list of numpy arrays representing the embeddings.
|
||||
"""
|
||||
if "data" not in resp:
|
||||
raise RuntimeError(resp.get("detail", "Unknown error"))
|
||||
|
||||
embeddings_data: List[Dict[str, Union[int, List[float]]]] = resp["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings_data, key=lambda e: e["index"])
|
||||
|
||||
# Return embeddings as numpy arrays
|
||||
return [
|
||||
np.array(result["embedding"], dtype=np.float32)
|
||||
for result in sorted_embeddings
|
||||
]
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
input (Embeddable): A list of texts and/or images to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings for the texts.
|
||||
|
||||
Example:
|
||||
>>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
|
||||
>>> input = ["Hello, world!", "How are you?"]
|
||||
"""
|
||||
|
||||
payload = self._build_payload(input, is_query=False)
|
||||
|
||||
# Call Jina AI Embedding API
|
||||
resp = self._session.post(self._api_url, json=payload, timeout=60).json()
|
||||
|
||||
return self._convert_resp(resp)
|
||||
|
||||
def embed_query(self, input: Embeddable) -> Embeddings:
|
||||
payload = self._build_payload(input, is_query=True)
|
||||
|
||||
# Call Jina AI Embedding API
|
||||
resp = self._session.post(self._api_url, json=payload, timeout=60).json()
|
||||
|
||||
return self._convert_resp(resp, is_query=True)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "jina"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
task = config.get("task")
|
||||
late_chunking = config.get("late_chunking")
|
||||
truncate = config.get("truncate")
|
||||
dimensions = config.get("dimensions")
|
||||
embedding_type = config.get("embedding_type")
|
||||
normalized = config.get("normalized")
|
||||
query_config = config.get("query_config")
|
||||
|
||||
if api_key_env_var is None or model_name is None:
|
||||
assert False, "This code should not be reached" # this is for type checking
|
||||
|
||||
return JinaEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var,
|
||||
model_name=model_name,
|
||||
task=task,
|
||||
late_chunking=late_chunking,
|
||||
truncate=truncate,
|
||||
dimensions=dimensions,
|
||||
embedding_type=embedding_type,
|
||||
normalized=normalized,
|
||||
query_config=query_config,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model_name": self.model_name,
|
||||
"task": self.task,
|
||||
"late_chunking": self.late_chunking,
|
||||
"truncate": self.truncate,
|
||||
"dimensions": self.dimensions,
|
||||
"embedding_type": self.embedding_type,
|
||||
"normalized": self.normalized,
|
||||
"query_config": self.query_config,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "jina")
|
||||
@@ -0,0 +1,92 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import List, Dict, Any
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MistralEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key_env_var: str = "MISTRAL_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the MistralEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model (str): The name of the model to use for text embeddings.
|
||||
api_key_env_var (str): The environment variable name for the Mistral API key.
|
||||
"""
|
||||
try:
|
||||
from mistralai import Mistral
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The mistralai python package is not installed. Please install it with `pip install mistralai`"
|
||||
)
|
||||
self.model = model
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
self.client = Mistral(api_key=self.api_key)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
input (Documents): A list of texts to get embeddings for.
|
||||
"""
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
raise ValueError("Mistral only supports text documents, not images")
|
||||
output = self.client.embeddings.create(
|
||||
model=self.model,
|
||||
inputs=input,
|
||||
)
|
||||
|
||||
# Extract embeddings from the response
|
||||
return [np.array(data.embedding) for data in output.data]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "mistral"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
model = config.get("model")
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
|
||||
if model is None or api_key_env_var is None:
|
||||
assert False, "This code should not be reached" # this is for type checking
|
||||
return MistralEmbeddingFunction(model=model, api_key_env_var=api_key_env_var)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model" in new_config:
|
||||
raise ValueError(
|
||||
"The model cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
"""
|
||||
validate_config_schema(config, "mistral")
|
||||
@@ -0,0 +1,147 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from typing import List, Dict, Any, Optional
|
||||
import os
|
||||
import numpy as np
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
import warnings
|
||||
|
||||
|
||||
class MorphEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "morph-embedding-v2",
|
||||
api_base: str = "https://api.morphllm.com/v1",
|
||||
encoding_format: str = "float",
|
||||
api_key_env_var: str = "MORPH_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the MorphEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
api_key (str, optional): The API key for the Morph API. If not provided,
|
||||
it will be read from the environment variable specified by api_key_env_var.
|
||||
model_name (str, optional): The name of the model to use for embeddings.
|
||||
Defaults to "morph-embedding-v2".
|
||||
api_base (str, optional): The base URL for the Morph API.
|
||||
Defaults to "https://api.morphllm.com/v1".
|
||||
encoding_format (str, optional): The format for embeddings (float or base64).
|
||||
Defaults to "float".
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key.
|
||||
Defaults to "MORPH_API_KEY".
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The openai python package is not installed. Please install it with `pip install openai`. "
|
||||
"Note: Morph uses the OpenAI client library for API communication."
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_base = api_base
|
||||
self.encoding_format = encoding_format
|
||||
|
||||
# Initialize the OpenAI client with Morph's base URL
|
||||
self.client = openai.OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
# Handle empty input
|
||||
if not input:
|
||||
return []
|
||||
|
||||
# Prepare embedding parameters
|
||||
embedding_params: Dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"input": input,
|
||||
"encoding_format": self.encoding_format,
|
||||
}
|
||||
|
||||
# Get embeddings from Morph API
|
||||
response = self.client.embeddings.create(**embedding_params)
|
||||
|
||||
# Extract embeddings from response
|
||||
return [np.array(data.embedding, dtype=np.float32) for data in response.data]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "morph"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
# Morph embeddings work best with cosine similarity
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
# Extract parameters from config
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
api_base = config.get("api_base")
|
||||
encoding_format = config.get("encoding_format")
|
||||
|
||||
if api_key_env_var is None or model_name is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
# Create and return the embedding function
|
||||
return MorphEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var,
|
||||
model_name=model_name,
|
||||
api_base=api_base if api_base is not None else "https://api.morphllm.com/v1",
|
||||
encoding_format=encoding_format if encoding_format is not None else "float",
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model_name": self.model_name,
|
||||
"api_base": self.api_base,
|
||||
"encoding_format": self.encoding_format,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "morph")
|
||||
@@ -0,0 +1,117 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import List, Dict, Any
|
||||
import numpy as np
|
||||
from urllib.parse import urlparse
|
||||
|
||||
DEFAULT_MODEL_NAME = "chroma/all-minilm-l6-v2-f32"
|
||||
|
||||
|
||||
class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""
|
||||
This class is used to generate embeddings for a list of texts using the Ollama Embedding API
|
||||
(https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "http://localhost:11434",
|
||||
model_name: str = DEFAULT_MODEL_NAME,
|
||||
timeout: int = 60,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Ollama Embedding Function.
|
||||
|
||||
Args:
|
||||
url (str): The Base URL of the Ollama Server (default: "http://localhost:11434").
|
||||
model_name (str): The name of the model to use for text embeddings.
|
||||
Defaults to "chroma/all-minilm-l6-v2-f32", for available models see https://ollama.com/library.
|
||||
timeout (int): The timeout for the API call in seconds. Defaults to 60.
|
||||
"""
|
||||
try:
|
||||
from ollama import Client
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The ollama python package is not installed. Please install it with `pip install ollama`"
|
||||
)
|
||||
|
||||
self.url = url
|
||||
self.model_name = model_name
|
||||
self.timeout = timeout
|
||||
|
||||
# Adding this for backwards compatibility with the old version of the EF
|
||||
self._base_url = url
|
||||
if self._base_url.endswith("/api/embeddings"):
|
||||
parsed_url = urlparse(url)
|
||||
self._base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
self._client = Client(host=self._base_url, timeout=timeout)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
input (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings for the texts.
|
||||
|
||||
Example:
|
||||
>>> ollama_ef = OllamaEmbeddingFunction()
|
||||
>>> texts = ["Hello, world!", "How are you?"]
|
||||
>>> embeddings = ollama_ef(texts)
|
||||
"""
|
||||
# Call Ollama client
|
||||
response = self._client.embed(model=self.model_name, input=input)
|
||||
|
||||
# Convert to numpy arrays
|
||||
return [
|
||||
np.array(embedding, dtype=np.float32)
|
||||
for embedding in response["embeddings"]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "ollama"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
url = config.get("url")
|
||||
model_name = config.get("model_name")
|
||||
timeout = config.get("timeout")
|
||||
|
||||
if url is None or model_name is None or timeout is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return OllamaEmbeddingFunction(url=url, model_name=model_name, timeout=timeout)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"url": self.url, "model_name": self.model_name, "timeout": self.timeout}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "ollama")
|
||||
@@ -0,0 +1,364 @@
|
||||
import hashlib
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
import sys
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import httpx
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random
|
||||
|
||||
from chromadb.api.types import Documents, Embeddings, EmbeddingFunction, Space
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _verify_sha256(fname: str, expected_sha256: str) -> bool:
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(fname, "rb") as f:
|
||||
# Read and update hash in chunks to avoid using too much memory
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
|
||||
return sha256_hash.hexdigest() == expected_sha256
|
||||
|
||||
|
||||
# In order to remove dependencies on sentence-transformers, which in turn depends on
|
||||
# pytorch and sentence-piece we have created a default ONNX embedding function that
|
||||
# implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers.
|
||||
# visit https://github.com/chroma-core/onnx-embedding for the source code to generate
|
||||
# and verify the ONNX model.
|
||||
class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]):
|
||||
MODEL_NAME = "all-MiniLM-L6-v2"
|
||||
DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME
|
||||
EXTRACTED_FOLDER_NAME = "onnx"
|
||||
ARCHIVE_FILENAME = "onnx.tar.gz"
|
||||
MODEL_DOWNLOAD_URL = (
|
||||
"https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz"
|
||||
)
|
||||
_MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3"
|
||||
|
||||
def __init__(self, preferred_providers: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Initialize the ONNXMiniLM_L6_V2 embedding function.
|
||||
|
||||
Args:
|
||||
preferred_providers (List[str], optional): The preferred ONNX runtime providers.
|
||||
Defaults to None.
|
||||
"""
|
||||
# convert the list to set for unique values
|
||||
if preferred_providers and not all(
|
||||
[isinstance(i, str) for i in preferred_providers]
|
||||
):
|
||||
raise ValueError("Preferred providers must be a list of strings")
|
||||
# check for duplicate providers
|
||||
if preferred_providers and len(preferred_providers) != len(
|
||||
set(preferred_providers)
|
||||
):
|
||||
raise ValueError("Preferred providers must be unique")
|
||||
|
||||
self._preferred_providers = preferred_providers
|
||||
|
||||
try:
|
||||
# Equivalent to import onnxruntime
|
||||
self.ort = importlib.import_module("onnxruntime")
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`"
|
||||
)
|
||||
try:
|
||||
# Equivalent to from tokenizers import Tokenizer
|
||||
self.Tokenizer = importlib.import_module("tokenizers").Tokenizer
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The tokenizers python package is not installed. Please install it with `pip install tokenizers`"
|
||||
)
|
||||
try:
|
||||
# Equivalent to from tqdm import tqdm
|
||||
self.tqdm = importlib.import_module("tqdm").tqdm
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The tqdm python package is not installed. Please install it with `pip install tqdm`"
|
||||
)
|
||||
|
||||
# Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
|
||||
# Download with tqdm to preserve the sentence-transformers experience
|
||||
@retry( # type: ignore
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random(min=1, max=3),
|
||||
retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)),
|
||||
)
|
||||
def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None:
|
||||
"""
|
||||
Download the onnx model from the URL and save it to the file path.
|
||||
|
||||
Args:
|
||||
url: The URL to download the model from.
|
||||
fname: The path to save the model to.
|
||||
chunk_size: The chunk size to use when downloading.
|
||||
"""
|
||||
with httpx.stream("GET", url) as resp:
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
with open(fname, "wb") as file, self.tqdm(
|
||||
desc=str(fname),
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in resp.iter_bytes(chunk_size=chunk_size):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
|
||||
if not _verify_sha256(fname, self._MODEL_SHA256):
|
||||
os.remove(fname)
|
||||
raise ValueError(
|
||||
f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file."
|
||||
)
|
||||
|
||||
# Use pytorches default epsilon for division by zero
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
def _normalize(self, v: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
||||
"""
|
||||
Normalize a vector.
|
||||
|
||||
Args:
|
||||
v: The vector to normalize.
|
||||
|
||||
Returns:
|
||||
The normalized vector.
|
||||
"""
|
||||
norm = np.linalg.norm(v, axis=1)
|
||||
# Handle division by zero
|
||||
norm[norm == 0] = 1e-12
|
||||
return cast(npt.NDArray[np.float32], v / norm[:, np.newaxis])
|
||||
|
||||
def _forward(
|
||||
self, documents: List[str], batch_size: int = 32
|
||||
) -> npt.NDArray[np.float32]:
|
||||
"""
|
||||
Generate embeddings for a list of documents.
|
||||
|
||||
Args:
|
||||
documents: The documents to generate embeddings for.
|
||||
batch_size: The batch size to use when generating embeddings.
|
||||
|
||||
Returns:
|
||||
The embeddings for the documents.
|
||||
"""
|
||||
all_embeddings = []
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch = documents[i : i + batch_size]
|
||||
|
||||
# Encode each document separately
|
||||
encoded = [self.tokenizer.encode(d) for d in batch]
|
||||
|
||||
# Check if any document exceeds the max tokens
|
||||
for doc_tokens in encoded:
|
||||
if len(doc_tokens.ids) > self.max_tokens():
|
||||
raise ValueError(
|
||||
f"Document length {len(doc_tokens.ids)} is greater than the max tokens {self.max_tokens()}"
|
||||
)
|
||||
|
||||
input_ids = np.array([e.ids for e in encoded])
|
||||
attention_mask = np.array([e.attention_mask for e in encoded])
|
||||
|
||||
onnx_input = {
|
||||
"input_ids": np.array(input_ids, dtype=np.int64),
|
||||
"attention_mask": np.array(attention_mask, dtype=np.int64),
|
||||
"token_type_ids": np.array(
|
||||
[np.zeros(len(e), dtype=np.int64) for e in input_ids],
|
||||
dtype=np.int64,
|
||||
),
|
||||
}
|
||||
|
||||
model_output = self.model.run(None, onnx_input)
|
||||
last_hidden_state = model_output[0]
|
||||
|
||||
# Perform mean pooling with attention weighting
|
||||
input_mask_expanded = np.broadcast_to(
|
||||
np.expand_dims(attention_mask, -1), last_hidden_state.shape
|
||||
)
|
||||
embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(
|
||||
input_mask_expanded.sum(1), a_min=1e-9, a_max=None
|
||||
)
|
||||
|
||||
embeddings = self._normalize(embeddings).astype(np.float32)
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
return np.concatenate(all_embeddings)
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self) -> Any:
|
||||
"""
|
||||
Get the tokenizer for the model.
|
||||
|
||||
Returns:
|
||||
The tokenizer for the model.
|
||||
"""
|
||||
tokenizer = self.Tokenizer.from_file(
|
||||
os.path.join(
|
||||
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json"
|
||||
)
|
||||
)
|
||||
# max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128
|
||||
# https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480
|
||||
tokenizer.enable_truncation(max_length=256)
|
||||
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256)
|
||||
return tokenizer
|
||||
|
||||
@cached_property
|
||||
def model(self) -> Any:
|
||||
"""
|
||||
Get the model.
|
||||
|
||||
Returns:
|
||||
The model.
|
||||
"""
|
||||
if self._preferred_providers is None or len(self._preferred_providers) == 0:
|
||||
if len(self.ort.get_available_providers()) > 0:
|
||||
logger.debug(
|
||||
f"WARNING: No ONNX providers provided, defaulting to available providers: "
|
||||
f"{self.ort.get_available_providers()}"
|
||||
)
|
||||
self._preferred_providers = self.ort.get_available_providers()
|
||||
elif not set(self._preferred_providers).issubset(
|
||||
set(self.ort.get_available_providers())
|
||||
):
|
||||
raise ValueError(
|
||||
f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}"
|
||||
)
|
||||
|
||||
# Suppress onnxruntime warnings
|
||||
so = self.ort.SessionOptions()
|
||||
so.log_severity_level = 3
|
||||
so.graph_optimization_level = self.ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
if (
|
||||
self._preferred_providers
|
||||
and "CoreMLExecutionProvider" in self._preferred_providers
|
||||
):
|
||||
# remove CoreMLExecutionProvider from the list, it is not as well optimized as CPU.
|
||||
self._preferred_providers.remove("CoreMLExecutionProvider")
|
||||
|
||||
return self.ort.InferenceSession(
|
||||
os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"),
|
||||
# Since 1.9 onnyx runtime requires providers to be specified when there are multiple available
|
||||
providers=self._preferred_providers,
|
||||
sess_options=so,
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
|
||||
# Only download the model when it is actually used
|
||||
self._download_model_if_not_exists()
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = self._forward(input)
|
||||
|
||||
# Convert to list of numpy arrays for the expected Embeddings type
|
||||
return cast(
|
||||
Embeddings,
|
||||
[np.array(embedding, dtype=np.float32) for embedding in embeddings],
|
||||
)
|
||||
|
||||
def _download_model_if_not_exists(self) -> None:
|
||||
"""
|
||||
Download the model if it doesn't exist.
|
||||
"""
|
||||
onnx_files = [
|
||||
"config.json",
|
||||
"model.onnx",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer_config.json",
|
||||
"tokenizer.json",
|
||||
"vocab.txt",
|
||||
]
|
||||
extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME)
|
||||
onnx_files_exist = True
|
||||
for f in onnx_files:
|
||||
if not os.path.exists(os.path.join(extracted_folder, f)):
|
||||
onnx_files_exist = False
|
||||
break
|
||||
# Model is not downloaded yet
|
||||
if not onnx_files_exist:
|
||||
os.makedirs(self.DOWNLOAD_PATH, exist_ok=True)
|
||||
if not os.path.exists(
|
||||
os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME)
|
||||
) or not _verify_sha256(
|
||||
os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
|
||||
self._MODEL_SHA256,
|
||||
):
|
||||
self._download(
|
||||
url=self.MODEL_DOWNLOAD_URL,
|
||||
fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
|
||||
)
|
||||
with tarfile.open(
|
||||
name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
|
||||
mode="r:gz",
|
||||
) as tar:
|
||||
if sys.version_info >= (3, 12):
|
||||
tar.extractall(path=self.DOWNLOAD_PATH, filter="data")
|
||||
else:
|
||||
# filter argument was added in Python 3.12
|
||||
# https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractall
|
||||
# In versions prior to 3.12, this provides the same behavior as filter="data"
|
||||
tar.extractall(path=self.DOWNLOAD_PATH)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "onnx_mini_lm_l6_v2"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
def max_tokens(self) -> int:
|
||||
# Default token limit for ONNX Mini LM L6 V2 model
|
||||
return 256
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
preferred_providers = config.get("preferred_providers")
|
||||
|
||||
return ONNXMiniLM_L6_V2(preferred_providers=preferred_providers)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"preferred_providers": self._preferred_providers}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
# Preferred providers can be changed, so no validation needed
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "onnx_mini_lm_l6_v2")
|
||||
@@ -0,0 +1,187 @@
|
||||
from chromadb.api.types import EmbeddingFunction, Space
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.api.types import (
|
||||
Document,
|
||||
Documents,
|
||||
Embedding,
|
||||
Embeddings,
|
||||
Image,
|
||||
Images,
|
||||
is_document,
|
||||
is_image,
|
||||
Embeddable,
|
||||
)
|
||||
from typing import List, Dict, Any, Union, Optional, cast
|
||||
import numpy as np
|
||||
import importlib
|
||||
|
||||
|
||||
class OpenCLIPEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"""
|
||||
This class is used to generate embeddings for a list of texts or images using the Open CLIP model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "ViT-B-32",
|
||||
checkpoint: str = "laion2b_s34b_b79k",
|
||||
device: Optional[str] = "cpu",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the OpenCLIPEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The name of the model to use for embeddings.
|
||||
Defaults to "ViT-B-32".
|
||||
checkpoint (str, optional): The checkpoint to use for the model.
|
||||
Defaults to "laion2b_s34b_b79k".
|
||||
device (str, optional): The device to use for computation.
|
||||
Defaults to "cpu".
|
||||
"""
|
||||
try:
|
||||
import open_clip
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip"
|
||||
)
|
||||
|
||||
try:
|
||||
self._torch = importlib.import_module("torch")
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The torch python package is not installed. Please install it with `pip install torch`"
|
||||
)
|
||||
|
||||
try:
|
||||
self._PILImage = importlib.import_module("PIL.Image")
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The PIL python package is not installed. Please install it with `pip install pillow`"
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.checkpoint = checkpoint
|
||||
self.device = device
|
||||
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
model_name=model_name, pretrained=checkpoint
|
||||
)
|
||||
self._model = model
|
||||
self._model.to(device)
|
||||
self._preprocess = preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(model_name=model_name)
|
||||
|
||||
def _encode_image(self, image: Image) -> Embedding:
|
||||
"""
|
||||
Encode an image using the Open CLIP model.
|
||||
|
||||
Args:
|
||||
image: The image to encode.
|
||||
|
||||
Returns:
|
||||
The embedding for the image.
|
||||
"""
|
||||
pil_image = self._PILImage.fromarray(image)
|
||||
with self._torch.no_grad():
|
||||
image_features = self._model.encode_image(
|
||||
self._preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||
)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return cast(Embedding, image_features.squeeze().cpu().numpy())
|
||||
|
||||
def _encode_text(self, text: Document) -> Embedding:
|
||||
"""
|
||||
Encode a text using the Open CLIP model.
|
||||
|
||||
Args:
|
||||
text: The text to encode.
|
||||
|
||||
Returns:
|
||||
The embedding for the text.
|
||||
"""
|
||||
with self._torch.no_grad():
|
||||
text_features = self._model.encode_text(
|
||||
self._tokenizer(text).to(self.device)
|
||||
)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return cast(Embedding, text_features.squeeze().cpu().numpy())
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents or images.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents or images.
|
||||
"""
|
||||
embeddings: Embeddings = []
|
||||
for item in input:
|
||||
if is_image(item):
|
||||
embeddings.append(
|
||||
np.array(self._encode_image(cast(Image, item)), dtype=np.float32)
|
||||
)
|
||||
elif is_document(item):
|
||||
embeddings.append(
|
||||
np.array(self._encode_text(cast(Document, item)), dtype=np.float32)
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "open_clip"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
) -> "EmbeddingFunction[Union[Documents, Images]]":
|
||||
model_name = config.get("model_name")
|
||||
checkpoint = config.get("checkpoint")
|
||||
device = config.get("device")
|
||||
|
||||
if model_name is None or checkpoint is None or device is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return OpenCLIPEmbeddingFunction(
|
||||
model_name=model_name, checkpoint=checkpoint, device=device
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"checkpoint": self.checkpoint,
|
||||
"device": self.device,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "checkpoint" in new_config:
|
||||
raise ValueError(
|
||||
"The checkpoint cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "open_clip")
|
||||
@@ -0,0 +1,204 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from typing import List, Dict, Any, Optional
|
||||
import os
|
||||
import numpy as np
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
import warnings
|
||||
|
||||
|
||||
class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
organization_id: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
deployment_id: Optional[str] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
api_key_env_var: str = "CHROMA_OPENAI_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the OpenAIEmbeddingFunction.
|
||||
Args:
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the OpenAI API.
|
||||
Defaults to "CHROMA_OPENAI_API_KEY".
|
||||
model_name (str, optional): The name of the model to use for text
|
||||
embeddings. Defaults to "text-embedding-ada-002".
|
||||
organization_id(str, optional): The OpenAI organization ID if applicable
|
||||
api_base (str, optional): The base path for the API. If not provided,
|
||||
it will use the base path for the OpenAI API. This can be used to
|
||||
point to a different deployment, such as an Azure deployment.
|
||||
api_type (str, optional): The type of the API deployment. This can be
|
||||
used to specify a different deployment, such as 'azure'. If not
|
||||
provided, it will use the default OpenAI deployment.
|
||||
api_version (str, optional): The api version for the API. If not provided,
|
||||
it will use the api version for the OpenAI API. This can be used to
|
||||
point to a different deployment, such as an Azure deployment.
|
||||
deployment_id (str, optional): Deployment ID for Azure OpenAI.
|
||||
default_headers (Dict[str, str], optional): A mapping of default headers to be sent with each API request.
|
||||
dimensions (int, optional): The number of dimensions for the embeddings.
|
||||
Only supported for `text-embedding-3` or later models from OpenAI.
|
||||
https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The openai python package is not installed. Please install it with `pip install openai`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model_name = model_name
|
||||
self.organization_id = organization_id
|
||||
self.api_base = api_base
|
||||
self.api_type = api_type
|
||||
self.api_version = api_version
|
||||
self.deployment_id = deployment_id
|
||||
self.default_headers = default_headers
|
||||
self.dimensions = dimensions
|
||||
|
||||
# Initialize the OpenAI client
|
||||
client_params: Dict[str, Any] = {"api_key": self.api_key}
|
||||
|
||||
if self.organization_id is not None:
|
||||
client_params["organization"] = self.organization_id
|
||||
if self.api_base is not None:
|
||||
client_params["base_url"] = self.api_base
|
||||
if self.default_headers is not None:
|
||||
client_params["default_headers"] = self.default_headers
|
||||
|
||||
self.client = openai.OpenAI(**client_params)
|
||||
|
||||
# For Azure OpenAI
|
||||
if self.api_type == "azure":
|
||||
if self.api_version is None:
|
||||
raise ValueError("api_version must be specified for Azure OpenAI")
|
||||
if self.deployment_id is None:
|
||||
raise ValueError("deployment_id must be specified for Azure OpenAI")
|
||||
if self.api_base is None:
|
||||
raise ValueError("api_base must be specified for Azure OpenAI")
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version,
|
||||
azure_endpoint=self.api_base,
|
||||
azure_deployment=self.deployment_id,
|
||||
default_headers=self.default_headers,
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
# Handle batching
|
||||
if not input:
|
||||
return []
|
||||
|
||||
# Prepare embedding parameters
|
||||
embedding_params: Dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"input": input,
|
||||
}
|
||||
|
||||
if self.dimensions is not None and "text-embedding-3" in self.model_name:
|
||||
embedding_params["dimensions"] = self.dimensions
|
||||
|
||||
# Get embeddings
|
||||
response = self.client.embeddings.create(**embedding_params)
|
||||
|
||||
# Extract embeddings from response
|
||||
return [np.array(data.embedding, dtype=np.float32) for data in response.data]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "openai"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
# OpenAI embeddings work best with cosine similarity
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
# Extract parameters from config
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
organization_id = config.get("organization_id")
|
||||
api_base = config.get("api_base")
|
||||
api_type = config.get("api_type")
|
||||
api_version = config.get("api_version")
|
||||
deployment_id = config.get("deployment_id")
|
||||
default_headers = config.get("default_headers")
|
||||
dimensions = config.get("dimensions")
|
||||
|
||||
if api_key_env_var is None or model_name is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
# Create and return the embedding function
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var,
|
||||
model_name=model_name,
|
||||
organization_id=organization_id,
|
||||
api_base=api_base,
|
||||
api_type=api_type,
|
||||
api_version=api_version,
|
||||
deployment_id=deployment_id,
|
||||
default_headers=default_headers,
|
||||
dimensions=dimensions,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model_name": self.model_name,
|
||||
"organization_id": self.organization_id,
|
||||
"api_base": self.api_base,
|
||||
"api_type": self.api_type,
|
||||
"api_version": self.api_version,
|
||||
"deployment_id": self.deployment_id,
|
||||
"default_headers": self.default_headers,
|
||||
"dimensions": self.dimensions,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "openai")
|
||||
@@ -0,0 +1,159 @@
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Embeddings,
|
||||
Images,
|
||||
is_document,
|
||||
is_image,
|
||||
Embeddable,
|
||||
EmbeddingFunction,
|
||||
Space,
|
||||
)
|
||||
from typing import List, Dict, Any, Union, cast, Optional
|
||||
import os
|
||||
import importlib
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
|
||||
class RoboflowEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"""
|
||||
This class is used to generate embeddings for a list of texts or images using the Roboflow API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: str = "https://infer.roboflow.com",
|
||||
api_key_env_var: str = "CHROMA_ROBOFLOW_API_KEY",
|
||||
) -> None:
|
||||
"""
|
||||
Create a RoboflowEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the Roboflow API.
|
||||
Defaults to "CHROMA_ROBOFLOW_API_KEY".
|
||||
api_url (str, optional): The URL of the Roboflow API.
|
||||
Defaults to "https://infer.roboflow.com".
|
||||
"""
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.api_url = api_url
|
||||
|
||||
try:
|
||||
self._PILImage = importlib.import_module("PIL.Image")
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The PIL python package is not installed. Please install it with `pip install pillow`"
|
||||
)
|
||||
|
||||
self._httpx = importlib.import_module("httpx")
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents or images.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents or images.
|
||||
"""
|
||||
embeddings = []
|
||||
|
||||
for item in input:
|
||||
if is_image(item):
|
||||
image = self._PILImage.fromarray(item)
|
||||
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="JPEG")
|
||||
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
infer_clip_payload_image = {
|
||||
"image": {
|
||||
"type": "base64",
|
||||
"value": base64_image,
|
||||
},
|
||||
}
|
||||
|
||||
res = self._httpx.post(
|
||||
f"{self.api_url}/clip/embed_image?api_key={self.api_key}",
|
||||
json=infer_clip_payload_image,
|
||||
)
|
||||
|
||||
result = res.json()["embeddings"]
|
||||
embeddings.append(np.array(result[0], dtype=np.float32))
|
||||
|
||||
elif is_document(item):
|
||||
infer_clip_payload_text = {
|
||||
"text": item,
|
||||
}
|
||||
|
||||
res = self._httpx.post(
|
||||
f"{self.api_url}/clip/embed_text?api_key={self.api_key}",
|
||||
json=infer_clip_payload_text,
|
||||
)
|
||||
|
||||
result = res.json()["embeddings"]
|
||||
embeddings.append(np.array(result[0], dtype=np.float32))
|
||||
|
||||
# Cast to the expected Embeddings type
|
||||
return cast(Embeddings, embeddings)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "roboflow"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
) -> "EmbeddingFunction[Union[Documents, Images]]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
api_url = config.get("api_url")
|
||||
|
||||
if api_key_env_var is None or api_url is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return RoboflowEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var, api_url=api_url
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"api_key_env_var": self.api_key_env_var, "api_url": self.api_url}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
# API URL can be changed, so no validation needed
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "roboflow")
|
||||
@@ -0,0 +1,41 @@
|
||||
# Embedding Function Schemas
|
||||
|
||||
This directory contains JSON schemas for all embedding functions in Chroma. The purpose of having this schema is to support cross language compatibility, and to validate that changes in one client library do not accidentally diverge from others.
|
||||
|
||||
## Schema Structure
|
||||
|
||||
Each schema follows the JSON Schema Draft-07 specification and includes:
|
||||
|
||||
- `version`: The version of the schema
|
||||
- `title`: The title of the schema
|
||||
- `description`: A description of the schema
|
||||
- `properties`: The properties that can be configured for the embedding function
|
||||
- `required`: The properties that are required for the embedding function
|
||||
- `additionalProperties`: Whether additional properties are allowed (always set to `false` to ensure strict validation)
|
||||
|
||||
## Usage
|
||||
|
||||
The schemas can be used to validate the configuration of embedding functions using the `validate_config` function:
|
||||
|
||||
```python
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config
|
||||
|
||||
# Validate a configuration
|
||||
config = {
|
||||
"api_key_env_var": "CHROMA_OPENAI_API_KEY",
|
||||
"model_name": "text-embedding-ada-002"
|
||||
}
|
||||
validate_config(config, "openai")
|
||||
```
|
||||
|
||||
## Adding New Schemas
|
||||
|
||||
To add a new schema:
|
||||
|
||||
1. Create a new JSON file in this directory with the name of the embedding function (e.g., `new_function.json`)
|
||||
2. Define the schema following the JSON Schema Draft-07 specification
|
||||
3. Update the embedding function to use the schema for validation
|
||||
|
||||
## Schema Versioning
|
||||
|
||||
Each schema includes a version number to support future changes to embedding function configurations. When making changes to a schema, increment the version number to ensure backward compatibility.
|
||||
@@ -0,0 +1,19 @@
|
||||
from chromadb.utils.embedding_functions.schemas.schema_utils import (
|
||||
validate_config_schema,
|
||||
load_schema,
|
||||
get_schema_version,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.schemas.registry import (
|
||||
get_available_schemas,
|
||||
get_schema_info,
|
||||
get_embedding_function_names,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"validate_config_schema",
|
||||
"load_schema",
|
||||
"get_schema_version",
|
||||
"get_available_schemas",
|
||||
"get_schema_info",
|
||||
"get_embedding_function_names",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Protocol, cast
|
||||
|
||||
|
||||
DEFAULT_ENGLISH_STOPWORDS: List[str] = [
|
||||
"a",
|
||||
"about",
|
||||
"above",
|
||||
"after",
|
||||
"again",
|
||||
"against",
|
||||
"ain",
|
||||
"all",
|
||||
"am",
|
||||
"an",
|
||||
"and",
|
||||
"any",
|
||||
"are",
|
||||
"aren",
|
||||
"aren't",
|
||||
"as",
|
||||
"at",
|
||||
"be",
|
||||
"because",
|
||||
"been",
|
||||
"before",
|
||||
"being",
|
||||
"below",
|
||||
"between",
|
||||
"both",
|
||||
"but",
|
||||
"by",
|
||||
"can",
|
||||
"couldn",
|
||||
"couldn't",
|
||||
"d",
|
||||
"did",
|
||||
"didn",
|
||||
"didn't",
|
||||
"do",
|
||||
"does",
|
||||
"doesn",
|
||||
"doesn't",
|
||||
"doing",
|
||||
"don",
|
||||
"don't",
|
||||
"down",
|
||||
"during",
|
||||
"each",
|
||||
"few",
|
||||
"for",
|
||||
"from",
|
||||
"further",
|
||||
"had",
|
||||
"hadn",
|
||||
"hadn't",
|
||||
"has",
|
||||
"hasn",
|
||||
"hasn't",
|
||||
"have",
|
||||
"haven",
|
||||
"haven't",
|
||||
"having",
|
||||
"he",
|
||||
"her",
|
||||
"here",
|
||||
"hers",
|
||||
"herself",
|
||||
"him",
|
||||
"himself",
|
||||
"his",
|
||||
"how",
|
||||
"i",
|
||||
"if",
|
||||
"in",
|
||||
"into",
|
||||
"is",
|
||||
"isn",
|
||||
"isn't",
|
||||
"it",
|
||||
"it's",
|
||||
"its",
|
||||
"itself",
|
||||
"just",
|
||||
"ll",
|
||||
"m",
|
||||
"ma",
|
||||
"me",
|
||||
"mightn",
|
||||
"mightn't",
|
||||
"more",
|
||||
"most",
|
||||
"mustn",
|
||||
"mustn't",
|
||||
"my",
|
||||
"myself",
|
||||
"needn",
|
||||
"needn't",
|
||||
"no",
|
||||
"nor",
|
||||
"not",
|
||||
"now",
|
||||
"o",
|
||||
"of",
|
||||
"off",
|
||||
"on",
|
||||
"once",
|
||||
"only",
|
||||
"or",
|
||||
"other",
|
||||
"our",
|
||||
"ours",
|
||||
"ourselves",
|
||||
"out",
|
||||
"over",
|
||||
"own",
|
||||
"re",
|
||||
"s",
|
||||
"same",
|
||||
"shan",
|
||||
"shan't",
|
||||
"she",
|
||||
"she's",
|
||||
"should",
|
||||
"should've",
|
||||
"shouldn",
|
||||
"shouldn't",
|
||||
"so",
|
||||
"some",
|
||||
"such",
|
||||
"t",
|
||||
"than",
|
||||
"that",
|
||||
"that'll",
|
||||
"the",
|
||||
"their",
|
||||
"theirs",
|
||||
"them",
|
||||
"themselves",
|
||||
"then",
|
||||
"there",
|
||||
"these",
|
||||
"they",
|
||||
"this",
|
||||
"those",
|
||||
"through",
|
||||
"to",
|
||||
"too",
|
||||
"under",
|
||||
"until",
|
||||
"up",
|
||||
"ve",
|
||||
"very",
|
||||
"was",
|
||||
"wasn",
|
||||
"wasn't",
|
||||
"we",
|
||||
"were",
|
||||
"weren",
|
||||
"weren't",
|
||||
"what",
|
||||
"when",
|
||||
"where",
|
||||
"which",
|
||||
"while",
|
||||
"who",
|
||||
"whom",
|
||||
"why",
|
||||
"will",
|
||||
"with",
|
||||
"won",
|
||||
"won't",
|
||||
"wouldn",
|
||||
"wouldn't",
|
||||
"y",
|
||||
"you",
|
||||
"you'd",
|
||||
"you'll",
|
||||
"you're",
|
||||
"you've",
|
||||
"your",
|
||||
"yours",
|
||||
"yourself",
|
||||
"yourselves",
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(DEFAULT_ENGLISH_STOPWORDS)
|
||||
|
||||
|
||||
class SnowballStemmer(Protocol):
|
||||
def stem(self, token: str) -> str: # pragma: no cover - protocol definition
|
||||
...
|
||||
|
||||
|
||||
class _SnowballStemmerAdapter:
|
||||
"""Adapter that provides the uniform `stem` API used across languages."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
import snowballstemmer
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The snowballstemmer python package is not installed. Please install it with `pip install snowballstemmer`"
|
||||
)
|
||||
|
||||
self._stemmer = snowballstemmer.stemmer("english")
|
||||
|
||||
def stem(self, token: str) -> str:
|
||||
return cast(str, self._stemmer.stemWord(token))
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_english_stemmer() -> SnowballStemmer:
|
||||
"""Return a cached Snowball stemmer for English."""
|
||||
|
||||
return _SnowballStemmerAdapter()
|
||||
|
||||
|
||||
class Bm25Tokenizer:
|
||||
"""Tokenizer with stopword filtering and stemming used by BM25 embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stemmer: SnowballStemmer,
|
||||
stopwords: Iterable[str],
|
||||
token_max_length: int,
|
||||
) -> None:
|
||||
self._stemmer = stemmer
|
||||
self._stopwords = {word.lower() for word in stopwords}
|
||||
self._token_max_length = token_max_length
|
||||
self._non_alphanumeric_pattern = re.compile(r"[^\w\s]+", flags=re.UNICODE)
|
||||
|
||||
def _remove_non_alphanumeric(self, text: str) -> str:
|
||||
return self._non_alphanumeric_pattern.sub(" ", text)
|
||||
|
||||
@staticmethod
|
||||
def _simple_tokenize(text: str) -> List[str]:
|
||||
return [token for token in text.lower().split() if token]
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
cleaned = self._remove_non_alphanumeric(text)
|
||||
raw_tokens = self._simple_tokenize(cleaned)
|
||||
|
||||
tokens: List[str] = []
|
||||
for token in raw_tokens:
|
||||
if token in self._stopwords:
|
||||
continue
|
||||
|
||||
if len(token) > self._token_max_length:
|
||||
continue
|
||||
|
||||
stemmed = self._stemmer.stem(token).strip()
|
||||
if stemmed:
|
||||
tokens.append(stemmed)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
class Murmur3AbsHasher:
|
||||
def __init__(self, seed: int = 0) -> None:
|
||||
try:
|
||||
import mmh3
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The murmurhash3 python package is not installed. Please install it with `pip install murmurhash3`"
|
||||
)
|
||||
self.hasher = mmh3.hash
|
||||
self.seed = seed
|
||||
|
||||
def hash(self, token: str) -> int:
|
||||
return cast(int, abs(self.hasher(token, seed=self.seed)))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Bm25Tokenizer",
|
||||
"DEFAULT_CHROMA_BM25_STOPWORDS",
|
||||
"DEFAULT_ENGLISH_STOPWORDS",
|
||||
"SnowballStemmer",
|
||||
"get_english_stemmer",
|
||||
"Murmur3AbsHasher",
|
||||
]
|
||||
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Schema Registry for Embedding Functions
|
||||
|
||||
This module provides a registry of all available schemas for embedding functions.
|
||||
It can be used to get information about available schemas and their versions.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set
|
||||
import os
|
||||
import json
|
||||
from .schema_utils import SCHEMAS_DIR
|
||||
|
||||
|
||||
def get_available_schemas() -> List[str]:
|
||||
"""
|
||||
Get a list of all available schemas.
|
||||
|
||||
Returns:
|
||||
A list of schema names (without .json extension)
|
||||
"""
|
||||
schemas = []
|
||||
for filename in os.listdir(SCHEMAS_DIR):
|
||||
if filename.endswith(".json") and filename != "base_schema.json":
|
||||
schemas.append(filename[:-5]) # Remove .json extension
|
||||
return schemas
|
||||
|
||||
|
||||
def get_schema_info() -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Get information about all available schemas.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping schema names to information about the schema
|
||||
"""
|
||||
schema_info = {}
|
||||
for schema_name in get_available_schemas():
|
||||
schema_path = os.path.join(SCHEMAS_DIR, f"{schema_name}.json")
|
||||
with open(schema_path, "r") as f:
|
||||
schema = json.load(f)
|
||||
schema_info[schema_name] = {
|
||||
"version": schema.get("version", "1.0.0"),
|
||||
"title": schema.get("title", ""),
|
||||
"description": schema.get("description", ""),
|
||||
}
|
||||
return schema_info
|
||||
|
||||
|
||||
def get_embedding_function_names() -> Set[str]:
|
||||
"""
|
||||
Get a set of all embedding function names that have schemas.
|
||||
|
||||
Returns:
|
||||
A set of embedding function names
|
||||
"""
|
||||
return set(get_available_schemas())
|
||||
@@ -0,0 +1,85 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Any, cast
|
||||
import jsonschema
|
||||
from jsonschema import ValidationError
|
||||
|
||||
# Path to the schemas directory
|
||||
SCHEMAS_DIR = os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
),
|
||||
"schemas",
|
||||
"embedding_functions",
|
||||
)
|
||||
|
||||
cached_schemas: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def load_schema(schema_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load a JSON schema from the schemas directory.
|
||||
|
||||
Args:
|
||||
schema_name: Name of the schema file (without .json extension)
|
||||
|
||||
Returns:
|
||||
The loaded schema as a dictionary
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the schema file does not exist
|
||||
json.JSONDecodeError: If the schema file is not valid JSON
|
||||
"""
|
||||
if schema_name in cached_schemas:
|
||||
return cached_schemas[schema_name]
|
||||
schema_path = os.path.join(SCHEMAS_DIR, f"{schema_name}.json")
|
||||
with open(schema_path, "r") as f:
|
||||
schema = cast(Dict[str, Any], json.load(f))
|
||||
cached_schemas[schema_name] = schema
|
||||
return schema
|
||||
|
||||
|
||||
def validate_config_schema(config: Dict[str, Any], schema_name: str) -> None:
|
||||
"""
|
||||
Validate a configuration against a schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
schema_name: Name of the schema file (without .json extension)
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
FileNotFoundError: If the schema file does not exist
|
||||
json.JSONDecodeError: If the schema file is not valid JSON
|
||||
"""
|
||||
schema = load_schema(schema_name)
|
||||
try:
|
||||
jsonschema.validate(instance=config, schema=schema)
|
||||
except ValidationError as e:
|
||||
# Enhance the error message with more context
|
||||
error_path = "/".join(str(path) for path in e.path)
|
||||
error_message = (
|
||||
f"Config validation failed for schema '{schema_name}': {e.message}"
|
||||
)
|
||||
if error_path:
|
||||
error_message += f" at path '{error_path}'"
|
||||
raise ValidationError(error_message) from e
|
||||
|
||||
|
||||
def get_schema_version(schema_name: str) -> str:
|
||||
"""
|
||||
Get the version of a schema.
|
||||
|
||||
Args:
|
||||
schema_name: Name of the schema file (without .json extension)
|
||||
|
||||
Returns:
|
||||
The schema version as a string
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the schema file does not exist
|
||||
json.JSONDecodeError: If the schema file is not valid JSON
|
||||
KeyError: If the schema does not have a version
|
||||
"""
|
||||
schema = load_schema(schema_name)
|
||||
return cast(str, schema.get("version", "1.0.0"))
|
||||
@@ -0,0 +1,121 @@
|
||||
from chromadb.api.types import EmbeddingFunction, Space, Embeddings, Documents
|
||||
from typing import List, Dict, Any
|
||||
import numpy as np
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
# Since we do dynamic imports we have to type this as Any
|
||||
models: Dict[str, Any] = {}
|
||||
|
||||
# If you have a beefier machine, try "gtr-t5-large".
|
||||
# for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "all-MiniLM-L6-v2",
|
||||
device: str = "cpu",
|
||||
normalize_embeddings: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize SentenceTransformerEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2"
|
||||
device (str, optional): Device used for computation, defaults to "cpu"
|
||||
normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False
|
||||
**kwargs: Additional arguments to pass to the SentenceTransformer model.
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
for key, value in kwargs.items():
|
||||
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
|
||||
raise ValueError(f"Keyword argument {key} is not a primitive type")
|
||||
self.kwargs = kwargs
|
||||
|
||||
if model_name not in self.models:
|
||||
self.models[model_name] = SentenceTransformer(
|
||||
model_name_or_path=model_name, device=device, **kwargs
|
||||
)
|
||||
self._model = self.models[model_name]
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
embeddings = self._model.encode(
|
||||
list(input),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
)
|
||||
|
||||
return [np.array(embedding, dtype=np.float32) for embedding in embeddings]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "sentence_transformer"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
# If normalize_embeddings is True, cosine is equivalent to dot product
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
model_name = config.get("model_name")
|
||||
device = config.get("device")
|
||||
normalize_embeddings = config.get("normalize_embeddings")
|
||||
kwargs = config.get("kwargs", {})
|
||||
|
||||
if model_name is None or device is None or normalize_embeddings is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return SentenceTransformerEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
normalize_embeddings=normalize_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"device": self.device,
|
||||
"normalize_embeddings": self.normalize_embeddings,
|
||||
"kwargs": self.kwargs,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
# model_name is also used as the identifier for model path if stored locally.
|
||||
# Users should be able to change the path if needed, so we should not validate that.
|
||||
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "sentence_transformer")
|
||||
@@ -0,0 +1,90 @@
|
||||
from chromadb.api.types import EmbeddingFunction, Space, Embeddings, Documents
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import List, Dict, Any
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""
|
||||
This class is used to generate embeddings for a list of texts using the Text2Vec model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"):
|
||||
"""
|
||||
Initialize the Text2VecEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The name of the model to use for text embeddings.
|
||||
Defaults to "shibing624/text2vec-base-chinese".
|
||||
"""
|
||||
try:
|
||||
from text2vec import SentenceModel
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The text2vec python package is not installed. Please install it with `pip install text2vec`"
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self._model = SentenceModel(model_name_or_path=model_name)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
# Text2Vec only works with text documents
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
raise ValueError("Text2Vec only supports text documents, not images")
|
||||
|
||||
embeddings = self._model.encode(list(input), convert_to_numpy=True)
|
||||
|
||||
# Convert to numpy arrays
|
||||
return [np.array(embedding, dtype=np.float32) for embedding in embeddings]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "text2vec"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
model_name = config.get("model_name")
|
||||
|
||||
if model_name is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return Text2VecEmbeddingFunction(model_name=model_name)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"model_name": self.model_name}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
# model_name is also used as the identifier for model path if stored locally.
|
||||
# Users should be able to change the path if needed, so we should not validate that.
|
||||
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "text2vec")
|
||||
@@ -0,0 +1,145 @@
|
||||
from chromadb.api.types import (
|
||||
Embeddings,
|
||||
Documents,
|
||||
EmbeddingFunction,
|
||||
Space,
|
||||
)
|
||||
from typing import List, Dict, Any, Optional
|
||||
import os
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import cast
|
||||
import warnings
|
||||
|
||||
ENDPOINT = "https://api.together.xyz/v1/embeddings"
|
||||
|
||||
|
||||
class TogetherAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""
|
||||
This class is used to get embeddings for a list of texts using the Together AI API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_key_env_var: str = "CHROMA_TOGETHER_AI_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the TogetherAIEmbeddingFunction. See the docs for supported models here:
|
||||
https://docs.together.ai/docs/serverless-models#embedding-models
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to use for text embeddings.
|
||||
api_key: The API key to use for the Together AI API.
|
||||
api_key_env_var: The environment variable to use for the Together AI API key.
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The httpx python package is not installed. Please install it with `pip install httpx`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv(self.api_key_env_var)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
f"API key not found in environment variable {self.api_key_env_var}"
|
||||
)
|
||||
|
||||
self._session = httpx.Client()
|
||||
self._session.headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"accept": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Embed a list of texts using the Together AI API.
|
||||
|
||||
Args:
|
||||
input: A list of texts to embed.
|
||||
"""
|
||||
|
||||
if not input:
|
||||
raise ValueError("Input is required")
|
||||
|
||||
if not isinstance(input, list):
|
||||
raise ValueError("Input must be a list")
|
||||
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
raise ValueError("All items in input must be strings")
|
||||
|
||||
response = self._session.post(
|
||||
ENDPOINT,
|
||||
json={"model": self.model_name, "input": input},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
embeddings = [item["embedding"] for item in data["data"]]
|
||||
|
||||
return cast(Embeddings, embeddings)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "together_ai"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
|
||||
if api_key_env_var is None or model_name is None:
|
||||
raise ValueError("api_key_env_var and model_name must be provided")
|
||||
|
||||
return TogetherAIEmbeddingFunction(
|
||||
model_name=model_name, api_key_env_var=api_key_env_var
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
"""
|
||||
validate_config_schema(config, "together_ai")
|
||||
@@ -0,0 +1,136 @@
|
||||
from chromadb.api.types import EmbeddingFunction, Space, Embeddings, Documents
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import List, Dict, Any, Optional
|
||||
import os
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
|
||||
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""
|
||||
This class is used to generate embeddings for a list of texts using the VoyageAI API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "voyage-large-2",
|
||||
api_key_env_var: str = "CHROMA_VOYAGE_API_KEY",
|
||||
input_type: Optional[str] = None,
|
||||
truncation: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the VoyageAIEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the VoyageAI API.
|
||||
Defaults to "CHROMA_VOYAGE_API_KEY".
|
||||
model_name (str, optional): The name of the model to use for text embeddings.
|
||||
Defaults to "voyage-large-2".
|
||||
api_key (str, optional): API key for the VoyageAI API. If not provided, will look for it in the environment variable.
|
||||
input_type (str, optional): The type of input to use for the VoyageAI API.
|
||||
Defaults to None.
|
||||
truncation (bool): Whether to truncate the input text.
|
||||
Defaults to True.
|
||||
"""
|
||||
try:
|
||||
import voyageai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The voyageai python package is not installed. Please install it with `pip install voyageai`"
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"Direct api_key configuration will not be persisted. "
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
|
||||
self.model_name = model_name
|
||||
self.input_type = input_type
|
||||
self.truncation = truncation
|
||||
self._client = voyageai.Client(api_key=self.api_key)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
embeddings = self._client.embed(
|
||||
texts=input,
|
||||
model=self.model_name,
|
||||
input_type=self.input_type,
|
||||
truncation=self.truncation,
|
||||
)
|
||||
|
||||
# Convert to numpy arrays
|
||||
return [
|
||||
np.array(embedding, dtype=np.float32) for embedding in embeddings.embeddings
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "voyageai"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
api_key_env_var = config.get("api_key_env_var")
|
||||
model_name = config.get("model_name")
|
||||
input_type = config.get("input_type")
|
||||
truncation = config.get("truncation")
|
||||
|
||||
if api_key_env_var is None or model_name is None:
|
||||
assert False, "This code should not be reached"
|
||||
|
||||
return VoyageAIEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var,
|
||||
model_name=model_name,
|
||||
input_type=input_type,
|
||||
truncation=truncation if truncation is not None else True,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model_name": self.model_name,
|
||||
"input_type": self.input_type,
|
||||
"truncation": self.truncation,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "voyageai")
|
||||
@@ -0,0 +1,18 @@
|
||||
from uuid import UUID
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from chromadb.errors import ChromaError, InvalidUUIDError
|
||||
|
||||
|
||||
def fastapi_json_response(error: ChromaError) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
content={"error": error.name(), "message": error.message()},
|
||||
status_code=error.code(),
|
||||
)
|
||||
|
||||
|
||||
def string_to_uuid(uuid_str: str) -> UUID:
|
||||
try:
|
||||
return UUID(uuid_str)
|
||||
except ValueError:
|
||||
raise InvalidUUIDError(f"Could not parse {uuid_str} as a UUID")
|
||||
@@ -0,0 +1,32 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar
|
||||
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class LRUCache(Generic[K, V]):
|
||||
"""A simple LRU cache implementation, based on the OrderedDict class, which allows
|
||||
for a callback to be invoked when an item is evicted from the cache."""
|
||||
|
||||
def __init__(self, capacity: int, callback: Optional[Callable[[K, V], Any]] = None):
|
||||
self.capacity = capacity
|
||||
self.cache: OrderedDict[K, V] = OrderedDict()
|
||||
self.callback = callback
|
||||
|
||||
def get(self, key: K) -> Optional[V]:
|
||||
if key not in self.cache:
|
||||
return None
|
||||
value = self.cache.pop(key)
|
||||
self.cache[key] = value
|
||||
return value
|
||||
|
||||
def set(self, key: K, value: V) -> None:
|
||||
if key in self.cache:
|
||||
self.cache.pop(key)
|
||||
elif len(self.cache) == self.capacity:
|
||||
evicted_key, evicted_value = self.cache.popitem(last=False)
|
||||
if self.callback:
|
||||
self.callback(evicted_key, evicted_value)
|
||||
self.cache[key] = value
|
||||
@@ -0,0 +1,8 @@
|
||||
def int_to_bytes(int: int) -> bytes:
|
||||
"""Convert int to a 24 byte big endian byte string"""
|
||||
return int.to_bytes(24, "big")
|
||||
|
||||
|
||||
def bytes_to_int(bytes: bytes) -> int:
|
||||
"""Convert a 24 byte big endian byte string to an int"""
|
||||
return int.from_bytes(bytes, "big")
|
||||
@@ -0,0 +1,74 @@
|
||||
import threading
|
||||
from types import TracebackType
|
||||
from typing import Optional, Type
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""A lock object that allows many simultaneous "read locks", but
|
||||
only one "write lock." """
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._read_ready = threading.Condition(threading.RLock())
|
||||
self._readers = 0
|
||||
|
||||
def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Blocks only if a thread has
|
||||
acquired the write lock."""
|
||||
self._read_ready.acquire()
|
||||
try:
|
||||
self._readers += 1
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
self._read_ready.acquire()
|
||||
try:
|
||||
self._readers -= 1
|
||||
if not self._readers:
|
||||
self._read_ready.notify_all()
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def acquire_write(self) -> None:
|
||||
"""Acquire a write lock. Blocks until there are no
|
||||
acquired read or write locks."""
|
||||
self._read_ready.acquire()
|
||||
while self._readers > 0:
|
||||
self._read_ready.wait()
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
|
||||
class ReadRWLock:
|
||||
def __init__(self, rwLock: ReadWriteLock):
|
||||
self.rwLock = rwLock
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self.rwLock.acquire_read()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.rwLock.release_read()
|
||||
|
||||
|
||||
class WriteRWLock:
|
||||
def __init__(self, rwLock: ReadWriteLock):
|
||||
self.rwLock = rwLock
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self.rwLock.acquire_write()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.rwLock.release_write()
|
||||
@@ -0,0 +1,68 @@
|
||||
# An implementation of https://en.wikipedia.org/wiki/Rendezvous_hashing
|
||||
from typing import Callable, List, Tuple
|
||||
import mmh3
|
||||
import heapq
|
||||
|
||||
Hasher = Callable[[str, str], int]
|
||||
Member = str
|
||||
Members = List[str]
|
||||
Key = str
|
||||
|
||||
|
||||
def assign(
|
||||
key: Key, members: Members, hasher: Hasher, replication: int
|
||||
) -> List[Member]:
|
||||
"""Assigns a key to a member using the rendezvous hashing algorithm
|
||||
Args:
|
||||
key: The key to assign
|
||||
members: The list of members to assign the key to
|
||||
hasher: The hashing function to use
|
||||
replication: The number of members to assign the key to
|
||||
Returns:
|
||||
A list of members that the key has been assigned to
|
||||
"""
|
||||
|
||||
if replication > len(members):
|
||||
raise ValueError(
|
||||
"Replication factor cannot be greater than the number of members"
|
||||
)
|
||||
if len(members) == 0:
|
||||
raise ValueError("Cannot assign key to empty memberlist")
|
||||
if len(members) == 1:
|
||||
# Don't copy the input list for some safety
|
||||
return [members[0]]
|
||||
if key == "":
|
||||
raise ValueError("Cannot assign empty key")
|
||||
|
||||
member_score_heap: List[Tuple[int, Member]] = []
|
||||
for member in members:
|
||||
score = -hasher(member, key)
|
||||
# Invert the score since heapq is a min heap
|
||||
heapq.heappush(member_score_heap, (score, member))
|
||||
|
||||
output_members: List[Member] = []
|
||||
for _ in range(replication):
|
||||
member_and_score = heapq.heappop(member_score_heap)
|
||||
output_members.append(member_and_score[1])
|
||||
|
||||
return output_members
|
||||
|
||||
|
||||
def merge_hashes(x: int, y: int) -> int:
|
||||
"""murmurhash3 mix 64-bit"""
|
||||
acc = x ^ y
|
||||
acc ^= acc >> 33
|
||||
acc = (
|
||||
acc * 0xFF51AFD7ED558CCD
|
||||
) % 2**64 # We need to mod here to prevent python from using arbitrary size int
|
||||
acc ^= acc >> 33
|
||||
acc = (acc * 0xC4CEB9FE1A85EC53) % 2**64
|
||||
acc ^= acc >> 33
|
||||
return acc
|
||||
|
||||
|
||||
def murmur3hasher(member: Member, key: Key) -> int:
|
||||
"""Hashes the key and member using the murmur3 hashing algorithm"""
|
||||
member_hash = mmh3.hash64(member, signed=False)[0]
|
||||
key_hash = mmh3.hash64(key, signed=False)[0]
|
||||
return merge_hashes(member_hash, key_hash)
|
||||
@@ -0,0 +1,125 @@
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from chromadb.api.types import QueryResult, GetResult
|
||||
|
||||
|
||||
def _transform_embeddings(
|
||||
embeddings: Optional[List[np.ndarray]], # type: ignore
|
||||
) -> Optional[Union[List[List[float]], List[np.ndarray]]]: # type: ignore
|
||||
"""
|
||||
Transform embeddings from numpy arrays to lists of floats.
|
||||
This is a shared helper function to avoid duplicating the transformation logic.
|
||||
"""
|
||||
if embeddings is None:
|
||||
return None
|
||||
return (
|
||||
[emb.tolist() for emb in embeddings]
|
||||
if isinstance(embeddings[0], np.ndarray)
|
||||
else embeddings
|
||||
)
|
||||
|
||||
|
||||
def _add_query_fields(
|
||||
data_dict: Dict[str, Any],
|
||||
query_result: QueryResult,
|
||||
query_idx: int,
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to add fields from a query result to a dictionary.
|
||||
Handles the nested array structure specific to query results.
|
||||
|
||||
Args:
|
||||
data_dict: Dictionary to add the fields to
|
||||
query_result: QueryResult containing the data
|
||||
query_idx: Index of the current query being processed
|
||||
"""
|
||||
for field in query_result["included"]:
|
||||
value = query_result.get(field)
|
||||
if value is not None:
|
||||
key = field.rstrip("s") # DF naming convention is not plural
|
||||
if field == "embeddings":
|
||||
value = _transform_embeddings(value) # type: ignore
|
||||
if isinstance(value, list) and len(value) > 0:
|
||||
value = value[query_idx] # type: ignore
|
||||
data_dict[key] = value
|
||||
|
||||
|
||||
def _add_get_fields(
|
||||
data_dict: Dict[str, Any],
|
||||
get_result: GetResult,
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to add fields from a get result to a dictionary.
|
||||
Handles the flat array structure specific to get results.
|
||||
|
||||
Args:
|
||||
data_dict: Dictionary to add the fields to
|
||||
get_result: GetResult containing the data
|
||||
"""
|
||||
for field in get_result["included"]:
|
||||
value = get_result.get(field)
|
||||
if value is not None:
|
||||
key = field.rstrip("s") # DF naming convention is not plural
|
||||
if field == "embeddings":
|
||||
value = _transform_embeddings(value) # type: ignore
|
||||
data_dict[key] = value
|
||||
|
||||
|
||||
def query_result_to_dfs(query_result: QueryResult) -> List["pd.DataFrame"]:
|
||||
"""
|
||||
Function to convert QueryResult to list of DataFrames.
|
||||
Handles the nested array structure specific to query results.
|
||||
Column order is defined by the order of the fields in the QueryResult.
|
||||
|
||||
Args:
|
||||
query_result: QueryResult to convert to DataFrames.
|
||||
|
||||
Returns:
|
||||
List of DataFrames.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError("pandas is required to convert query results to DataFrames.")
|
||||
|
||||
dfs = []
|
||||
num_queries = len(query_result["ids"])
|
||||
|
||||
for i in range(num_queries):
|
||||
data_for_df: Dict[str, Any] = {}
|
||||
data_for_df["id"] = query_result["ids"][i]
|
||||
|
||||
_add_query_fields(data_for_df, query_result, i)
|
||||
|
||||
df = pd.DataFrame(data_for_df)
|
||||
df.set_index("id", inplace=True)
|
||||
dfs.append(df)
|
||||
return dfs
|
||||
|
||||
|
||||
def get_result_to_df(get_result: GetResult) -> "pd.DataFrame":
|
||||
"""
|
||||
Function to convert GetResult to a DataFrame.
|
||||
Handles the flat array structure specific to get results.
|
||||
Column order is defined by the order of the fields in the GetResult.
|
||||
|
||||
Args:
|
||||
get_result: GetResult to convert to a DataFrame.
|
||||
|
||||
Returns:
|
||||
DataFrame.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError("pandas is required to convert get results to a DataFrame.")
|
||||
|
||||
data_for_df: Dict[str, Any] = {}
|
||||
data_for_df["id"] = get_result["ids"]
|
||||
|
||||
_add_get_fields(data_for_df, get_result)
|
||||
|
||||
df = pd.DataFrame(data_for_df)
|
||||
df.set_index("id", inplace=True)
|
||||
return df
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import List, Tuple
|
||||
from chromadb.base_types import SparseVector
|
||||
|
||||
|
||||
def normalize_sparse_vector(indices: List[int], values: List[float]) -> SparseVector:
|
||||
"""Normalize and create a SparseVector by sorting indices and values together.
|
||||
|
||||
This function takes raw indices and values (which may be unsorted or have duplicates)
|
||||
and returns a properly constructed SparseVector with sorted indices.
|
||||
|
||||
Args:
|
||||
indices: List of dimension indices (may be unsorted)
|
||||
values: List of values corresponding to each index
|
||||
|
||||
Returns:
|
||||
SparseVector with indices sorted in ascending order
|
||||
|
||||
Raises:
|
||||
ValueError: If indices and values have different lengths
|
||||
ValueError: If there are duplicate indices (after sorting)
|
||||
ValueError: If indices are negative
|
||||
ValueError: If values are not numeric
|
||||
"""
|
||||
if not indices:
|
||||
return SparseVector(indices=[], values=[])
|
||||
|
||||
# Sort indices and values together by index
|
||||
sorted_pairs = sorted(zip(indices, values), key=lambda x: x[0])
|
||||
sorted_indices, sorted_values = zip(*sorted_pairs)
|
||||
|
||||
# Create SparseVector which will validate:
|
||||
# - indices are sorted
|
||||
# - no duplicate indices
|
||||
# - indices are non-negative
|
||||
# - values are numeric
|
||||
return SparseVector(indices=list(sorted_indices), values=list(sorted_values))
|
||||
Reference in New Issue
Block a user