chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
import importlib
from typing import Type, TypeVar, cast
C = TypeVar("C")
def get_class(fqn: str, type: Type[C]) -> Type[C]:
"""Given a fully qualifed class name, import the module and return the class"""
module_name, class_name = fqn.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
return cast(Type[C], cls)

View File

@@ -0,0 +1,63 @@
import inspect
import asyncio
from typing import Any, Callable, Coroutine, TypeVar
from typing_extensions import ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
def async_to_sync(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, R]:
"""A function decorator that converts an async function to a sync function.
This should generally not be used in production code paths.
"""
def sync_wrapper(*args, **kwargs): # type: ignore
loop = None
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if loop.is_running():
return func(*args, **kwargs)
result = loop.run_until_complete(func(*args, **kwargs))
def convert_result(result: Any) -> Any:
if isinstance(result, list):
return [convert_result(r) for r in result]
if isinstance(result, object):
return async_class_to_sync(result)
if callable(result):
return async_to_sync(result)
return result
return convert_result(result)
return sync_wrapper
T = TypeVar("T")
def async_class_to_sync(cls: T) -> T:
"""A decorator that converts a class with async methods to a class with sync methods.
This should generally not be used in production code paths.
"""
for attr, value in inspect.getmembers(cls):
if (
callable(value)
and inspect.iscoroutinefunction(value)
and not attr.startswith("__")
):
setattr(cls, attr, async_to_sync(value))
return cls

View File

@@ -0,0 +1,36 @@
from typing import Optional, Tuple, List
from chromadb.api import BaseAPI
from chromadb.api.types import (
Documents,
Embeddings,
IDs,
Metadatas,
)
def create_batches(
api: BaseAPI,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
) -> List[Tuple[IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]]]:
_batches: List[
Tuple[IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]]
] = []
if len(ids) > api.get_max_batch_size():
# create split batches
for i in range(0, len(ids), api.get_max_batch_size()):
_batches.append(
(
ids[i : i + api.get_max_batch_size()],
embeddings[i : i + api.get_max_batch_size()]
if embeddings is not None
else None,
metadatas[i : i + api.get_max_batch_size()] if metadatas else None,
documents[i : i + api.get_max_batch_size()] if documents else None,
)
)
else:
_batches.append((ids, embeddings, metadatas, documents))
return _batches

View File

@@ -0,0 +1,31 @@
import importlib
import multiprocessing
from typing import Optional, Sequence, List, Tuple
import numpy as np
from chromadb.api.types import URI, DataLoader, Image, URIs
from concurrent.futures import ThreadPoolExecutor
class ImageLoader(DataLoader[List[Optional[Image]]]):
def __init__(self, max_workers: int = multiprocessing.cpu_count()) -> None:
try:
self._PILImage = importlib.import_module("PIL.Image")
self._max_workers = max_workers
except ImportError:
raise ValueError(
"The PIL python package is not installed. Please install it with `pip install pillow`"
)
def _load_image(self, uri: Optional[URI]) -> Optional[Image]:
return np.array(self._PILImage.open(uri)) if uri is not None else None
def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]:
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
return list(executor.map(self._load_image, uris))
class ChromaLangchainPassthroughDataLoader(DataLoader[List[Optional[Image]]]):
# This is a simple pass through data loader that just returns the input data with "images"
# flag which lets the langchain embedding function know that the data is image uris
def __call__(self, uris: URIs) -> Tuple[str, URIs]: # type: ignore
return ("images", uris)

View File

@@ -0,0 +1,38 @@
import os
import random
import gc
import time
# Borrowed from https://github.com/rogerbinns/apsw/blob/master/apsw/tests.py#L224
# Used to delete sqlite files on Windows, since Windows file locking
# behaves differently to other operating systems
# This should only be used for test or non-production code, such as in reset_state.
def delete_file(name: str) -> None:
try:
os.remove(name)
except Exception:
pass
chars = list("abcdefghijklmn")
random.shuffle(chars)
newname = name + "-n-" + "".join(chars)
count = 0
while os.path.exists(name):
count += 1
try:
os.rename(name, newname)
except Exception:
if count > 30:
n = list("abcdefghijklmnopqrstuvwxyz")
random.shuffle(n)
final_name = "".join(n)
try:
os.rename(
name, "chroma-to-clean" + final_name + ".deletememanually"
)
except Exception:
pass
break
time.sleep(0.1)
gc.collect()

View File

@@ -0,0 +1,22 @@
import os
def get_directory_size(directory: str) -> int:
"""
Calculate the total size of the directory by walking through each file.
Parameters:
directory (str): The path of the directory for which to calculate the size.
Returns:
total_size (int): The total size of the directory in bytes.
"""
total_size = 0
for dirpath, _, filenames in os.walk(directory):
for f in filenames:
fp = os.path.join(dirpath, f)
# skip if it is symbolic link
if not os.path.islink(fp):
total_size += os.path.getsize(fp)
return total_size

View File

@@ -0,0 +1,32 @@
"""
These functions match what the spec of hnswlib is.
"""
from typing import Union, cast
import numpy as np
from numpy.typing import NDArray
Vector = NDArray[Union[np.int32, np.float32, np.int16, np.float16]]
def l2(x: Vector, y: Vector) -> float:
return (np.linalg.norm(x - y) ** 2).item()
def cosine(x: Vector, y: Vector) -> float:
# This epsilon is used to prevent division by zero, and the value is the same
# https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238
# We need to adapt the epsilon to the precision of the input
NORM_EPS = 1e-30
if x.dtype == np.float16 or y.dtype == np.float16:
NORM_EPS = 1e-7
return cast(
float,
(
1.0 - np.dot(x, y) / ((np.linalg.norm(x) * np.linalg.norm(y)) + NORM_EPS)
).item(),
)
def ip(x: Vector, y: Vector) -> float:
return cast(float, (1.0 - np.dot(x, y)).item())

View File

@@ -0,0 +1,285 @@
from typing import Dict, Any, Type, Set
from chromadb.api.types import (
EmbeddingFunction,
DefaultEmbeddingFunction,
SparseEmbeddingFunction,
)
# Import all embedding functions
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
HuggingFaceEmbeddingServer,
)
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GooglePalmEmbeddingFunction,
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
JinaQueryConfig,
)
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
VoyageAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import (
ChromaLangchainEmbeddingFunction,
)
from chromadb.utils.embedding_functions.baseten_embedding_function import (
BasetenEmbeddingFunction,
)
from chromadb.utils.embedding_functions.cloudflare_workers_ai_embedding_function import (
CloudflareWorkersAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.together_ai_embedding_function import (
TogetherAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.mistral_embedding_function import (
MistralEmbeddingFunction,
)
from chromadb.utils.embedding_functions.morph_embedding_function import (
MorphEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_sparse_embedding_function import (
HuggingFaceSparseEmbeddingFunction,
)
from chromadb.utils.embedding_functions.fastembed_sparse_embedding_function import (
FastembedSparseEmbeddingFunction,
)
from chromadb.utils.embedding_functions.bm25_embedding_function import (
Bm25EmbeddingFunction,
)
from chromadb.utils.embedding_functions.chroma_cloud_qwen_embedding_function import (
ChromaCloudQwenEmbeddingFunction,
)
from chromadb.utils.embedding_functions.chroma_cloud_splade_embedding_function import (
ChromaCloudSpladeEmbeddingFunction,
)
from chromadb.utils.embedding_functions.chroma_bm25_embedding_function import (
ChromaBm25EmbeddingFunction,
)
# Get all the class names for backward compatibility
_all_classes: Set[str] = {
"CohereEmbeddingFunction",
"OpenAIEmbeddingFunction",
"HuggingFaceEmbeddingFunction",
"HuggingFaceEmbeddingServer",
"SentenceTransformerEmbeddingFunction",
"GooglePalmEmbeddingFunction",
"GoogleGenerativeAiEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"OllamaEmbeddingFunction",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"MistralEmbeddingFunction",
"MorphEmbeddingFunction",
"VoyageAIEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OpenCLIPEmbeddingFunction",
"RoboflowEmbeddingFunction",
"Text2VecEmbeddingFunction",
"AmazonBedrockEmbeddingFunction",
"ChromaLangchainEmbeddingFunction",
"BasetenEmbeddingFunction",
"CloudflareWorkersAIEmbeddingFunction",
"TogetherAIEmbeddingFunction",
"DefaultEmbeddingFunction",
"HuggingFaceSparseEmbeddingFunction",
"FastembedSparseEmbeddingFunction",
"Bm25EmbeddingFunction",
"ChromaCloudQwenEmbeddingFunction",
"ChromaCloudSpladeEmbeddingFunction",
"ChromaBm25EmbeddingFunction",
}
def get_builtins() -> Set[str]:
return _all_classes
# Dictionary of supported embedding functions
known_embedding_functions: Dict[str, Type[EmbeddingFunction]] = { # type: ignore
"cohere": CohereEmbeddingFunction,
"openai": OpenAIEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"huggingface_server": HuggingFaceEmbeddingServer,
"sentence_transformer": SentenceTransformerEmbeddingFunction,
"google_palm": GooglePalmEmbeddingFunction,
"google_generative_ai": GoogleGenerativeAiEmbeddingFunction,
"google_vertex": GoogleVertexEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"mistral": MistralEmbeddingFunction,
"morph": MorphEmbeddingFunction,
"voyageai": VoyageAIEmbeddingFunction,
"onnx_mini_lm_l6_v2": ONNXMiniLM_L6_V2,
"open_clip": OpenCLIPEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"amazon_bedrock": AmazonBedrockEmbeddingFunction,
"chroma_langchain": ChromaLangchainEmbeddingFunction,
"baseten": BasetenEmbeddingFunction,
"default": DefaultEmbeddingFunction,
"cloudflare_workers_ai": CloudflareWorkersAIEmbeddingFunction,
"together_ai": TogetherAIEmbeddingFunction,
"chroma-cloud-qwen": ChromaCloudQwenEmbeddingFunction,
}
sparse_known_embedding_functions: Dict[str, Type[SparseEmbeddingFunction]] = { # type: ignore
"huggingface_sparse": HuggingFaceSparseEmbeddingFunction,
"fastembed_sparse": FastembedSparseEmbeddingFunction,
"bm25": Bm25EmbeddingFunction,
"chroma-cloud-splade": ChromaCloudSpladeEmbeddingFunction,
"chroma_bm25": ChromaBm25EmbeddingFunction,
}
def register_embedding_function(ef_class=None): # type: ignore
"""Register a custom embedding function.
Can be used as a decorator:
@register_embedding_function
class MyEmbedding(EmbeddingFunction):
@classmethod
def name(cls): return "my_embedding"
Or directly:
register_embedding_function(MyEmbedding)
Args:
ef_class: The embedding function class to register.
"""
def _register(cls): # type: ignore
try:
name = cls.name()
known_embedding_functions[name] = cls
except Exception as e:
raise ValueError(f"Failed to register embedding function: {e}")
return cls # Return the class unchanged
# If called with a class, register it immediately
if ef_class is not None:
return _register(ef_class) # type: ignore
# If called without arguments, return a decorator
return _register
def register_sparse_embedding_function(ef_class=None): # type: ignore
"""Register a custom sparse embedding function.
Can be used as a decorator:
@register_sparse_embedding_function
class MySparseEmbeddingFunction(SparseEmbeddingFunction):
@classmethod
def name(cls): return "my_sparse_embedding"
"""
def _register(cls): # type: ignore
try:
name = cls.name()
sparse_known_embedding_functions[name] = cls
except Exception as e:
raise ValueError(f"Failed to register sparse embedding function: {e}")
return cls # Return the class unchanged
if ef_class is not None:
return _register(ef_class) # type: ignore
return _register
# Function to convert config to embedding function
def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction: # type: ignore
"""Convert a config dictionary to an embedding function.
Args:
config: The config dictionary.
Returns:
The embedding function.
"""
if "name" not in config:
raise ValueError("Config must contain a 'name' field.")
name = config["name"]
if name not in known_embedding_functions:
raise ValueError(f"Unsupported embedding function: {name}")
ef_config = config.get("config", {})
if known_embedding_functions[name] is None:
raise ValueError(f"Unsupported embedding function: {name}")
return known_embedding_functions[name].build_from_config(ef_config)
__all__ = [
"EmbeddingFunction",
"DefaultEmbeddingFunction",
"CohereEmbeddingFunction",
"OpenAIEmbeddingFunction",
"BasetenEmbeddingFunction",
"CloudflareWorkersAIEmbeddingFunction",
"HuggingFaceEmbeddingFunction",
"HuggingFaceEmbeddingServer",
"SentenceTransformerEmbeddingFunction",
"GooglePalmEmbeddingFunction",
"GoogleGenerativeAiEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"OllamaEmbeddingFunction",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"JinaQueryConfig",
"MistralEmbeddingFunction",
"MorphEmbeddingFunction",
"VoyageAIEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OpenCLIPEmbeddingFunction",
"RoboflowEmbeddingFunction",
"Text2VecEmbeddingFunction",
"AmazonBedrockEmbeddingFunction",
"ChromaLangchainEmbeddingFunction",
"TogetherAIEmbeddingFunction",
"HuggingFaceSparseEmbeddingFunction",
"FastembedSparseEmbeddingFunction",
"Bm25EmbeddingFunction",
"ChromaCloudQwenEmbeddingFunction",
"ChromaCloudSpladeEmbeddingFunction",
"ChromaBm25EmbeddingFunction",
"register_embedding_function",
"config_to_embedding_function",
"known_embedding_functions",
]

View File

@@ -0,0 +1,138 @@
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction
from typing import Dict, Any, cast
import json
import numpy as np
class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using Amazon Bedrock.
"""
def __init__(
self,
session: Any,
model_name: str = "amazon.titan-embed-text-v1",
**kwargs: Any,
):
"""Initialize AmazonBedrockEmbeddingFunction.
Args:
session (boto3.Session): The boto3 session to use. You need to have boto3
installed, `pip install boto3`. Access & secret key are not supported.
model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1"
**kwargs: Additional arguments to pass to the boto3 client.
Example:
>>> import boto3
>>> session = boto3.Session(profile_name="profile", region_name="us-east-1")
>>> bedrock = AmazonBedrockEmbeddingFunction(session=session)
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = bedrock(texts)
"""
self.model_name = model_name
# check kwargs are primitives only
for key, value in kwargs.items():
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
raise ValueError(f"Keyword argument {key} is not a primitive type")
self.kwargs = kwargs
# Store the session for serialization
self._session_args = {}
if hasattr(session, "region_name") and session.region_name:
self._session_args["region_name"] = session.region_name
if hasattr(session, "profile_name") and session.profile_name:
self._session_args["profile_name"] = session.profile_name
self._client = session.client(
service_name="bedrock-runtime",
**kwargs,
)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
accept = "application/json"
content_type = "application/json"
embeddings = []
for text in input:
input_body = {"inputText": text}
body = json.dumps(input_body)
response = self._client.invoke_model(
body=body,
modelId=self.model_name,
accept=accept,
contentType=content_type,
)
response_body = json.loads(response.get("body").read())
embedding = response_body.get("embedding")
embeddings.append(np.array(embedding, dtype=np.float32))
# Convert to the expected Embeddings type
return cast(Embeddings, embeddings)
@staticmethod
def name() -> str:
return "amazon_bedrock"
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
try:
import boto3
except ImportError:
raise ValueError(
"The boto3 python package is not installed. Please install it with `pip install boto3`"
)
model_name = config.get("model_name")
session_args = config.get("session_args")
if model_name is None:
assert False, "This code should not be reached"
kwargs = config.get("kwargs", {})
if session_args is None:
session = boto3.Session()
else:
session = boto3.Session(**session_args)
return AmazonBedrockEmbeddingFunction(
session=session, model_name=model_name, **kwargs
)
def get_config(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"session_args": self._session_args,
"kwargs": self.kwargs,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "amazon_bedrock")

View File

@@ -0,0 +1,98 @@
import os
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from typing import Dict, Any, Optional, List
from chromadb.api.types import Space
import warnings
class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
def __init__(
self,
api_key: Optional[str],
api_base: str,
api_key_env_var: str = "CHROMA_BASETEN_API_KEY",
):
"""
Initialize the BasetenEmbeddingFunction.
Args:
api_key (str, optional): The API key for your Baseten account
api_base (str, required): The Baseten URL of the deployment
api_key_env_var (str, optional): The environment variable to use for the API key. Defaults to "CHROMA_BASETEN_API_KEY".
"""
try:
import openai
except ImportError:
raise ValueError(
"The openai python package is not installed. Please install it with `pip install openai`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
# Prioritize api_key argument, then environment variable
resolved_api_key = api_key or os.getenv(api_key_env_var)
if not resolved_api_key:
raise ValueError(
f"API key not provided and {api_key_env_var} environment variable is not set."
)
self.api_key = resolved_api_key
if not api_base:
raise ValueError("The api_base argument must be provided.")
self.api_base = api_base
self.model_name = "baseten-embedding-model"
self.dimensions = None
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.api_base)
@staticmethod
def name() -> str:
return "baseten"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
def get_config(self) -> Dict[str, Any]:
return {"api_base": self.api_base, "api_key_env_var": self.api_key_env_var}
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "BasetenEmbeddingFunction":
"""
Build the BasetenEmbeddingFunction from a configuration dictionary.
Args:
config (Dict[str, Any]): A dictionary containing the configuration parameters.
Expected keys: 'api_key', 'api_base', 'api_key_env_var'.
Returns:
BasetenEmbeddingFunction: An instance of BasetenEmbeddingFunction.
"""
api_key_env_var = config.get("api_key_env_var")
api_base = config.get("api_base")
if api_key_env_var is None or api_base is None:
raise ValueError(
"Missing 'api_key_env_var' or 'api_base' in configuration for BasetenEmbeddingFunction."
)
# Note: We rely on the __init__ method to handle potential missing api_key
# by checking the environment variable if the config value is None.
# However, api_base must be present either in config or have a default.
if api_base is None:
raise ValueError(
"Missing 'api_base' in configuration for BasetenEmbeddingFunction."
)
return BasetenEmbeddingFunction(
api_key=None, # Pass None if not in config, __init__ will check env var
api_base=api_base,
api_key_env_var=api_key_env_var,
)

View File

@@ -0,0 +1,234 @@
from chromadb.api.types import (
SparseEmbeddingFunction,
SparseVectors,
Documents,
)
from typing import Dict, Any, TypedDict, Optional
from typing import cast, Literal
import warnings
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
TaskType = Literal["document", "query"]
class Bm25EmbeddingFunctionQueryConfig(TypedDict):
task: TaskType
class Bm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
def __init__(
self,
avg_len: Optional[float] = None,
task: Optional[TaskType] = "document",
cache_dir: Optional[str] = None,
k: Optional[float] = None,
b: Optional[float] = None,
language: Optional[str] = None,
token_max_length: Optional[int] = None,
disable_stemmer: Optional[bool] = None,
specific_model_path: Optional[str] = None,
query_config: Optional[Bm25EmbeddingFunctionQueryConfig] = None,
**kwargs: Any,
):
"""Initialize SparseEncoderEmbeddingFunction.
Args:
avg_len(float, optional): The average length of the documents in the corpus.
task (str, optional): Task to perform, can be "document" or "query"
cache_dir (str, optional): The path to the cache directory.
k (float, optional): The k parameter in the BM25 formula. Defines the saturation of the term frequency.
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
language (str, optional): Specifies the language for the stemmer.
token_max_length (int, optional): The maximum length of the tokens.
disable_stemmer (bool, optional): Disable the stemmer.
specific_model_path (str, optional): The path to the specific model.
query_config (dict, optional): Configuration for the query, can be "task"
**kwargs: Additional arguments to pass to the Bm25 model.
"""
warnings.warn(
"Bm25EmbeddingFunction is deprecated. Please use ChromaBm25EmbeddingFunction instead.",
DeprecationWarning,
stacklevel=2,
)
try:
from fastembed.sparse.bm25 import Bm25
except ImportError:
raise ValueError(
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
)
self.task = task
self.query_config = query_config
self.cache_dir = cache_dir
self.k = k
self.b = b
self.avg_len = avg_len
self.language = language
self.token_max_length = token_max_length
self.disable_stemmer = disable_stemmer
self.specific_model_path = specific_model_path
for key, value in kwargs.items():
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
raise ValueError(f"Keyword argument {key} is not a primitive type")
self.kwargs = kwargs
bm25_kwargs = {
"model_name": "Qdrant/bm25",
}
optional_params = {
"cache_dir": cache_dir,
"k": k,
"b": b,
"avg_len": avg_len,
"language": language,
"token_max_length": token_max_length,
"disable_stemmer": disable_stemmer,
"specific_model_path": specific_model_path,
}
for key, value in optional_params.items():
if value is not None:
bm25_kwargs[key] = value
bm25_kwargs.update({k: v for k, v in kwargs.items() if v is not None})
self._model = Bm25(**bm25_kwargs)
def __call__(self, input: Documents) -> SparseVectors:
"""Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
try:
from fastembed.sparse.bm25 import Bm25
except ImportError:
raise ValueError(
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
)
model = cast(Bm25, self._model)
if self.task == "document":
embeddings = model.embed(
list(input),
)
elif self.task == "query":
embeddings = model.query_embed(
list(input),
)
else:
raise ValueError(f"Invalid task: {self.task}")
sparse_vectors: SparseVectors = []
for vec in embeddings:
sparse_vectors.append(
normalize_sparse_vector(
indices=vec.indices.tolist(), values=vec.values.tolist()
)
)
return sparse_vectors
def embed_query(self, input: Documents) -> SparseVectors:
try:
from fastembed.sparse.bm25 import Bm25
except ImportError:
raise ValueError(
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
)
model = cast(Bm25, self._model)
if self.query_config is not None:
task = self.query_config.get("task")
if task == "document":
embeddings = model.embed(
list(input),
)
elif task == "query":
embeddings = model.query_embed(
list(input),
)
else:
raise ValueError(f"Invalid task: {task}")
sparse_vectors: SparseVectors = []
for vec in embeddings:
sparse_vectors.append(
normalize_sparse_vector(
indices=vec.indices.tolist(), values=vec.values.tolist()
)
)
return sparse_vectors
else:
return self.__call__(input)
@staticmethod
def name() -> str:
return "bm25"
@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "SparseEmbeddingFunction[Documents]":
task = config.get("task")
query_config = config.get("query_config")
cache_dir = config.get("cache_dir")
k = config.get("k")
b = config.get("b")
avg_len = config.get("avg_len")
language = config.get("language")
token_max_length = config.get("token_max_length")
disable_stemmer = config.get("disable_stemmer")
specific_model_path = config.get("specific_model_path")
kwargs = config.get("kwargs", {})
return Bm25EmbeddingFunction(
task=task,
query_config=query_config,
cache_dir=cache_dir,
k=k,
b=b,
avg_len=avg_len,
language=language,
token_max_length=token_max_length,
disable_stemmer=disable_stemmer,
specific_model_path=specific_model_path,
**kwargs,
)
def get_config(self) -> Dict[str, Any]:
return {
"task": self.task,
"query_config": self.query_config,
"cache_dir": self.cache_dir,
"k": self.k,
"b": self.b,
"avg_len": self.avg_len,
"language": self.language,
"token_max_length": self.token_max_length,
"disable_stemmer": self.disable_stemmer,
"specific_model_path": self.specific_model_path,
"kwargs": self.kwargs,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
# Users should be able to change the path if needed, so we should not validate that.
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
return
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "bm25")

View File

@@ -0,0 +1,135 @@
from __future__ import annotations
from collections import Counter
from typing import Any, Dict, Iterable, List, Optional, TypedDict
from chromadb.api.types import Documents, SparseEmbeddingFunction, SparseVectors
from chromadb.base_types import SparseVector
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.embedding_functions.schemas.bm25_tokenizer import (
Bm25Tokenizer,
DEFAULT_CHROMA_BM25_STOPWORDS as _DEFAULT_STOPWORDS,
get_english_stemmer,
Murmur3AbsHasher,
)
NAME = "chroma_bm25"
DEFAULT_K = 1.2
DEFAULT_B = 0.75
DEFAULT_AVG_DOC_LENGTH = 256.0
DEFAULT_TOKEN_MAX_LENGTH = 40
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)
class ChromaBm25Config(TypedDict, total=False):
k: float
b: float
avg_doc_length: float
token_max_length: int
stopwords: List[str]
class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
def __init__(
self,
k: float = DEFAULT_K,
b: float = DEFAULT_B,
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
stopwords: Optional[Iterable[str]] = None,
) -> None:
"""Initialize the BM25 sparse embedding function."""
self.k = float(k)
self.b = float(b)
self.avg_doc_length = float(avg_doc_length)
self.token_max_length = int(token_max_length)
if stopwords is not None:
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
stopword_list: Iterable[str] = self.stopwords
else:
self.stopwords = None
stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
stemmer = get_english_stemmer()
self._tokenizer = Bm25Tokenizer(stemmer, stopword_list, self.token_max_length)
self._hasher = Murmur3AbsHasher()
def _encode(self, text: str) -> SparseVector:
tokens = self._tokenizer.tokenize(text)
if not tokens:
return SparseVector(indices=[], values=[])
doc_len = float(len(tokens))
counts = Counter(self._hasher.hash(token) for token in tokens)
indices = sorted(counts.keys())
values: List[float] = []
for idx in indices:
tf = float(counts[idx])
denominator = tf + self.k * (
1 - self.b + (self.b * doc_len) / self.avg_doc_length
)
score = tf * (self.k + 1) / denominator
values.append(score)
return SparseVector(indices=indices, values=values)
def __call__(self, input: Documents) -> SparseVectors:
sparse_vectors: SparseVectors = []
if not input:
return sparse_vectors
for document in input:
sparse_vectors.append(self._encode(document))
return sparse_vectors
def embed_query(self, input: Documents) -> SparseVectors:
return self.__call__(input)
@staticmethod
def name() -> str:
return NAME
@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "SparseEmbeddingFunction[Documents]":
return ChromaBm25EmbeddingFunction(
k=config.get("k", DEFAULT_K),
b=config.get("b", DEFAULT_B),
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
stopwords=config.get("stopwords"),
)
def get_config(self) -> Dict[str, Any]:
config: Dict[str, Any] = {
"k": self.k,
"b": self.b,
"avg_doc_length": self.avg_doc_length,
"token_max_length": self.token_max_length,
}
if self.stopwords is not None:
config["stopwords"] = list(self.stopwords)
return config
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords"}
for key in new_config:
if key not in mutable_keys:
raise ValueError(f"Updating '{key}' is not supported for {NAME}")
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
validate_config_schema(config, NAME)

View File

@@ -0,0 +1,213 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from typing import List, Dict, Any, Union
import os
import numpy as np
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from enum import Enum
class ChromaCloudQwenEmbeddingModel(Enum):
QWEN3_EMBEDDING_0p6B = "Qwen/Qwen3-Embedding-0.6B"
class ChromaCloudQwenEmbeddingTarget(Enum):
DOCUMENTS = "documents"
QUERY = "query"
ChromaCloudQwenEmbeddingInstructions = Dict[
str, Dict[ChromaCloudQwenEmbeddingTarget, str]
]
CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS: ChromaCloudQwenEmbeddingInstructions = {
"nl_to_code": {
ChromaCloudQwenEmbeddingTarget.DOCUMENTS: "",
# Taken from https://github.com/QwenLM/Qwen3-Embedding/blob/main/evaluation/task_prompts.json
ChromaCloudQwenEmbeddingTarget.QUERY: "Given a question about coding, retrieval code or passage that can solve user's question",
}
}
class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
model: ChromaCloudQwenEmbeddingModel,
task: str,
instructions: ChromaCloudQwenEmbeddingInstructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
api_key_env_var: str = "CHROMA_API_KEY",
):
"""
Initialize the ChromaCloudQwenEmbeddingFunction.
Args:
model (ChromaCloudQwenEmbeddingModel): The specific Qwen model to use for embeddings.
task (str): The task for which embeddings are being generated.
instructions (ChromaCloudQwenEmbeddingInstructions, optional): A dictionary containing
custom instructions to use for the specified Qwen model. Defaults to CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS.
api_key_env_var (str, optional): Environment variable name that contains your API key.
Defaults to "CHROMA_API_KEY".
"""
try:
import httpx
except ImportError:
raise ValueError(
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
self.api_key_env_var = api_key_env_var
self.api_key = os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model = model
self.task = task
self.instructions = instructions
self._api_url = "https://embed.trychroma.com"
self._session = httpx.Client()
self._session.headers.update(
{
"x-chroma-token": self.api_key,
"x-chroma-embedding-model": self.model.value,
}
)
def _parse_response(self, response: Any) -> Embeddings:
"""
Convert the response from the Chroma Embedding API to a list of numpy arrays.
Args:
response (Any): The response from the Chroma Embedding API.
Returns:
Embeddings: A list of numpy arrays representing the embeddings.
"""
if "embeddings" not in response:
raise RuntimeError(response.get("error", "Unknown error"))
embeddings: List[List[float]] = response["embeddings"]
return np.array(embeddings, dtype=np.float32)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
if not input:
return []
payload: Dict[str, Union[str, Documents]] = {
"instructions": self.instructions[self.task][
ChromaCloudQwenEmbeddingTarget.DOCUMENTS
],
"texts": input,
}
response = self._session.post(self._api_url, json=payload, timeout=60).json()
return self._parse_response(response)
def embed_query(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a query input.
"""
if not input:
return []
payload: Dict[str, Union[str, Documents]] = {
"instructions": self.instructions[self.task][
ChromaCloudQwenEmbeddingTarget.QUERY
],
"texts": input,
}
response = self._session.post(self._api_url, json=payload, timeout=60).json()
return self._parse_response(response)
@staticmethod
def name() -> str:
return "chroma-cloud-qwen"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
model = config.get("model")
task = config.get("task")
instructions = config.get("instructions")
api_key_env_var = config.get("api_key_env_var")
if model is None or task is None:
assert False, "Config is missing a required field"
# Deserialize instructions dict from string keys to enum keys
deserialized_instructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS
if instructions is not None:
deserialized_instructions = {}
for task_key, targets in instructions.items():
deserialized_instructions[task_key] = {}
for target_key, instruction in targets.items():
# Convert string key to enum
target_enum = ChromaCloudQwenEmbeddingTarget(target_key)
deserialized_instructions[task_key][target_enum] = instruction
deserialized_instructions[task_key][target_enum] = instruction
return ChromaCloudQwenEmbeddingFunction(
model=ChromaCloudQwenEmbeddingModel(model),
task=task,
instructions=deserialized_instructions,
api_key_env_var=api_key_env_var or "CHROMA_API_KEY",
)
def get_config(self) -> Dict[str, Any]:
# Serialize instructions dict with enum keys to string keys for JSON compatibility
serialized_instructions = {
task: {target.value: instruction for target, instruction in targets.items()}
for task, targets in self.instructions.items()
}
return {
"api_key_env_var": self.api_key_env_var,
"model": self.model.value,
"task": self.task,
"instructions": serialized_instructions,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model" in new_config:
raise ValueError(
"The model cannot be changed after the embedding function has been initialized."
)
elif "task" in new_config:
raise ValueError(
"The task cannot be changed after the embedding function has been initialized."
)
elif "instructions" in new_config:
raise ValueError(
"The instructions cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "chroma-cloud-qwen")

View File

@@ -0,0 +1,159 @@
from chromadb.api.types import (
SparseEmbeddingFunction,
SparseVectors,
Documents,
)
from typing import Dict, Any
from enum import Enum
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
from chromadb.base_types import SparseVector
import os
from typing import Union
class ChromaCloudSpladeEmbeddingModel(Enum):
SPLADE_PP_EN_V1 = "prithivida/Splade_PP_en_v1"
class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
def __init__(
self,
api_key_env_var: str = "CHROMA_API_KEY",
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
):
"""
Initialize the ChromaCloudSpladeEmbeddingFunction.
Args:
api_key_env_var (str, optional): Environment variable name that contains your API key.
Defaults to "CHROMA_API_KEY".
"""
try:
import httpx
except ImportError:
raise ValueError(
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
self.api_key_env_var = api_key_env_var
self.api_key = os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(
f"API key not found in environment variable {self.api_key_env_var}"
)
self.model = model
self._api_url = "https://embed.trychroma.com/embed_sparse"
self._session = httpx.Client()
self._session.headers.update(
{
"x-chroma-token": self.api_key,
"x-chroma-embedding-model": self.model.value,
}
)
def __del__(self) -> None:
"""
Cleanup the HTTP client session when the object is destroyed.
"""
if hasattr(self, "_session"):
self._session.close()
def close(self) -> None:
"""
Explicitly close the HTTP client session.
Call this method when you're done using the embedding function.
"""
if hasattr(self, "_session"):
self._session.close()
def __call__(self, input: Documents) -> SparseVectors:
"""
Generate embeddings for the given documents.
Args:
input (Documents): The documents to generate embeddings for.
"""
if not input:
return []
payload: Dict[str, Union[str, Documents]] = {
"texts": list(input),
"task": "",
"target": "",
}
try:
import httpx
response = self._session.post(self._api_url, json=payload, timeout=60)
response.raise_for_status()
json_response = response.json()
return self._parse_response(json_response)
except httpx.HTTPStatusError as e:
raise RuntimeError(
f"Failed to get embeddings from Chroma Cloud API: HTTP {e.response.status_code} - {e.response.text}"
)
except httpx.TimeoutException:
raise RuntimeError("Request to Chroma Cloud API timed out after 60 seconds")
except httpx.HTTPError as e:
raise RuntimeError(f"Failed to get embeddings from Chroma Cloud API: {e}")
except Exception as e:
raise RuntimeError(f"Unexpected error calling Chroma Cloud API: {e}")
def _parse_response(self, response: Any) -> SparseVectors:
"""
Parse the response from the Chroma Cloud Sparse Embedding API.
"""
raw_embeddings = response["embeddings"]
# Normalize each sparse vector (sort indices and validate)
normalized_vectors: SparseVectors = []
for emb in raw_embeddings:
# Handle both dict format and SparseVector format
if isinstance(emb, dict):
indices = emb.get("indices", [])
values = emb.get("values", [])
else:
# Already a SparseVector, extract its data
indices = emb.indices
values = emb.values
normalized_vectors.append(
normalize_sparse_vector(indices=indices, values=values)
)
return normalized_vectors
@staticmethod
def name() -> str:
return "chroma-cloud-splade"
@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "SparseEmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model = config.get("model")
if model is None:
raise ValueError("model must be provided in config")
if not api_key_env_var:
raise ValueError("api_key_env_var must be provided in config")
return ChromaCloudSpladeEmbeddingFunction(
api_key_env_var=api_key_env_var,
model=ChromaCloudSpladeEmbeddingModel(model),
)
def get_config(self) -> Dict[str, Any]:
return {"api_key_env_var": self.api_key_env_var, "model": self.model.value}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model" in new_config:
raise ValueError(
"model cannot be changed after the embedding function has been initialized"
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
validate_config_schema(config, "chroma-cloud-splade")

View File

@@ -0,0 +1,171 @@
from chromadb.api.types import (
Documents,
Embeddings,
Images,
Embeddable,
EmbeddingFunction,
)
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import List, Dict, Any, Union, cast, Sequence
import numpy as np
def create_langchain_embedding(
langchain_embedding_fn: Any,
) -> "ChromaLangchainEmbeddingFunction":
"""
Create a ChromaLangchainEmbeddingFunction from a langchain embedding function.
Args:
langchain_embedding_fn: The langchain embedding function to use.
Returns:
A ChromaLangchainEmbeddingFunction that wraps the langchain embedding function.
"""
return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embedding_fn)
class ChromaLangchainEmbeddingFunction(EmbeddingFunction[Embeddable]):
"""
This class is used as bridge between langchain embedding functions and custom chroma embedding functions.
"""
def __init__(self, embedding_function: Any) -> None:
"""
Initialize the ChromaLangchainEmbeddingFunction
Args:
embedding_function: The embedding function implementing Embeddings from langchain_core.
"""
try:
import langchain_core.embeddings
LangchainEmbeddings = langchain_core.embeddings.Embeddings
except ImportError:
raise ValueError(
"The langchain_core python package is not installed. Please install it with `pip install langchain-core`"
)
if not isinstance(embedding_function, LangchainEmbeddings):
raise ValueError(
"The embedding_function must implement the Embeddings interface from langchain_core."
)
self.embedding_function = embedding_function
# Store the class name for serialization
self._embedding_function_class = embedding_function.__class__.__name__
def embed_documents(self, documents: Sequence[str]) -> List[List[float]]:
"""
Embed documents using the langchain embedding function.
Args:
documents: The documents to embed.
Returns:
The embeddings for the documents.
"""
return cast(
List[List[float]], self.embedding_function.embed_documents(list(documents))
)
def embed_query(self, query: str) -> List[float]:
"""
Embed a query using the langchain embedding function.
Args:
query: The query to embed.
Returns:
The embedding for the query.
"""
return cast(List[float], self.embedding_function.embed_query(query))
def embed_image(self, uris: List[str]) -> List[List[float]]:
"""
Embed images using the langchain embedding function.
Args:
uris: The URIs of the images to embed.
Returns:
The embeddings for the images.
"""
if hasattr(self.embedding_function, "embed_image"):
return cast(List[List[float]], self.embedding_function.embed_image(uris))
else:
raise ValueError(
"The provided embedding function does not support image embeddings."
)
def __call__(self, input: Union[Documents, Images]) -> Embeddings:
"""
Get the embeddings for a list of texts or images.
Args:
input: A list of texts or images to get embeddings for.
Images should be provided as a list of URIs passed through the langchain data loader
Returns:
The embeddings for the texts or images.
Example:
>>> from langchain_openai import OpenAIEmbeddings
>>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large"))
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = langchain_embedding(texts)
"""
# Due to langchain quirks, the dataloader returns a tuple if the input is uris of images
if isinstance(input, tuple) and len(input) == 2 and input[0] == "images":
embeddings = self.embed_image(list(input[1]))
else:
# Cast to Sequence[str] to satisfy the type checker
embeddings = self.embed_documents(cast(Sequence[str], input))
# Convert to numpy arrays
return [np.array(embedding, dtype=np.float32) for embedding in embeddings]
@staticmethod
def name() -> str:
return "langchain"
@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "EmbeddingFunction[Union[Documents, Images]]":
# This is a placeholder implementation since we can't easily serialize and deserialize
# langchain embedding functions. Users will need to recreate the langchain embedding function
# and pass it to create_langchain_embedding.
raise NotImplementedError(
"Building a ChromaLangchainEmbeddingFunction from config is not supported. "
"Please recreate the langchain embedding function and pass it to create_langchain_embedding."
)
def get_config(self) -> Dict[str, Any]:
return {
"embedding_function_class": self._embedding_function_class,
"note": "This is a placeholder config. You will need to recreate the langchain embedding function.",
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
raise NotImplementedError(
"Updating a ChromaLangchainEmbeddingFunction config is not supported. "
"Please recreate the langchain embedding function and pass it to create_langchain_embedding."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "chroma_langchain")

View File

@@ -0,0 +1,152 @@
from chromadb.api.types import (
Embeddings,
Documents,
EmbeddingFunction,
Space,
)
from typing import List, Dict, Any, Optional
import os
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import cast
import warnings
BASE_URL = "https://api.cloudflare.com/client/v4/accounts"
GATEWAY_BASE_URL = "https://gateway.ai.cloudflare.com/v1"
class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the Cloudflare Workers AI API.
It requires an API key and a model name.
"""
def __init__(
self,
model_name: str,
account_id: str,
api_key: Optional[str] = None,
api_key_env_var: str = "CHROMA_CLOUDFLARE_API_KEY",
gateway_id: Optional[str] = None,
):
"""
Initialize the CloudflareWorkersAIEmbeddingFunction. See the docs for supported models here:
https://developers.cloudflare.com/workers-ai/models/
Args:
model_name: The name of the model to use for text embeddings.
account_id: The account ID for the Cloudflare Workers AI API.
api_key: The API key for the Cloudflare Workers AI API.
api_key_env_var: The environment variable name for the Cloudflare Workers AI API key.
"""
try:
import httpx
except ImportError:
raise ValueError(
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.model_name = model_name
self.account_id = account_id
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
self.gateway_id = gateway_id
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
if self.gateway_id:
self._api_url = f"{GATEWAY_BASE_URL}/{self.account_id}/{self.gateway_id}/workers-ai/{self.model_name}"
else:
self._api_url = f"{BASE_URL}/{self.account_id}/ai/run/{self.model_name}"
self._session = httpx.Client()
self._session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
if not all(isinstance(item, str) for item in input):
raise ValueError(
"Cloudflare Workers AI only supports text documents, not images"
)
payload: Dict[str, Any] = {
"text": input,
}
resp = self._session.post(self._api_url, json=payload).json()
if "result" not in resp and "data" not in resp["result"]:
raise RuntimeError(resp.get("detail", "Unknown error"))
return cast(Embeddings, resp["result"]["data"])
@staticmethod
def name() -> str:
return "cloudflare_workers_ai"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
account_id = config.get("account_id")
gateway_id = config.get("gateway_id", None)
if api_key_env_var is None or model_name is None or account_id is None:
assert False, "This code should not be reached"
return CloudflareWorkersAIEmbeddingFunction(
api_key_env_var=api_key_env_var,
model_name=model_name,
account_id=account_id,
gateway_id=gateway_id,
)
def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
"account_id": self.account_id,
"gateway_id": self.gateway_id,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "cloudflare_workers_ai")

View File

@@ -0,0 +1,176 @@
from chromadb.api.types import (
Embeddings,
Embeddable,
EmbeddingFunction,
Space,
is_image,
is_document,
)
from typing import List, Dict, Any, Optional
import os
import numpy as np
from chromadb.utils.embedding_functions.schemas import validate_config_schema
import base64
import io
import importlib
import warnings
class CohereEmbeddingFunction(EmbeddingFunction[Embeddable]):
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "large",
api_key_env_var: str = "CHROMA_COHERE_API_KEY",
):
try:
import cohere
except ImportError:
raise ValueError(
"The cohere python package is not installed. Please install it with `pip install cohere`"
)
try:
self._PILImage = importlib.import_module("PIL.Image")
except ImportError:
raise ValueError(
"The PIL python package is not installed. Please install it with `pip install pillow`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model_name = model_name
self.client = cohere.Client(self.api_key)
def __call__(self, input: Embeddable) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents.
"""
# Cohere works with images. if all are texts, return the embeddings for the texts
if all(is_document(item) for item in input):
return [
np.array(embeddings, dtype=np.float32)
for embeddings in self.client.embed(
texts=[str(item) for item in input],
model=self.model_name,
input_type="search_document",
).embeddings
]
elif all(is_image(item) for item in input):
base64_images = []
for image_np in input:
if not isinstance(image_np, np.ndarray):
raise ValueError(
f"Expected image input to be a numpy array, got {type(image_np)}"
)
try:
pil_image = self._PILImage.fromarray(image_np)
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
img_bytes = buffer.getvalue()
# Encode bytes to base64 string
base64_string = base64.b64encode(img_bytes).decode("utf-8")
data_uri = f"data:image/png;base64,{base64_string}"
base64_images.append(data_uri)
except Exception as e:
raise ValueError(
f"Failed to convert image numpy array to base64 data URI: {e}"
) from e
return [
np.array(embeddings, dtype=np.float32)
for embeddings in self.client.embed(
images=base64_images,
model=self.model_name,
input_type="image",
).embeddings
]
else:
# Check if it's a mix or neither
has_texts = any(is_document(item) for item in input)
has_images = any(is_image(item) for item in input)
if has_texts and has_images:
raise ValueError(
"Input contains a mix of text documents and images, which is not supported. Provide either all texts or all images."
)
else:
raise ValueError(
"Input must be a list of text documents (str) or a list of images (numpy arrays)."
)
@staticmethod
def name() -> str:
return "cohere"
def default_space(self) -> Space:
if self.model_name == "embed-multilingual-v2.0":
return "ip"
return "cosine"
def supported_spaces(self) -> List[Space]:
if self.model_name == "embed-english-v2.0":
return ["cosine"]
elif self.model_name == "embed-english-light-v2.0":
return ["cosine"]
elif self.model_name == "embed-multilingual-v2.0":
return ["ip"]
else:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
if api_key_env_var is None or model_name is None:
assert False, "This code should not be reached"
return CohereEmbeddingFunction(
api_key_env_var=api_key_env_var, model_name=model_name
)
def get_config(self) -> Dict[str, Any]:
return {"api_key_env_var": self.api_key_env_var, "model_name": self.model_name}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "cohere")

View File

@@ -0,0 +1,205 @@
from chromadb.api.types import (
SparseEmbeddingFunction,
SparseVectors,
Documents,
)
from typing import Dict, Any, TypedDict, Optional
from typing import cast, Literal
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
TaskType = Literal["document", "query"]
class FastembedSparseEmbeddingFunctionQueryConfig(TypedDict):
task: TaskType
class FastembedSparseEmbeddingFunction(SparseEmbeddingFunction[Documents]):
def __init__(
self,
model_name: str,
task: Optional[TaskType] = "document",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
cuda: Optional[bool] = None,
device_ids: Optional[list[int]] = None,
lazy_load: Optional[bool] = None,
query_config: Optional[FastembedSparseEmbeddingFunctionQueryConfig] = None,
**kwargs: Any,
):
"""Initialize SparseEncoderEmbeddingFunction.
Args:
model_name (str, optional): Identifier of the Fastembed model
List of commonly used models: Qdrant/bm25, prithivida/Splade_PP_en_v1, Qdrant/minicoil-v1
task (str, optional): Task to perform, can be "document" or "query"
cache_dir (str, optional): The path to the cache directory.
threads (int, optional): The number of threads to use for the model.
cuda (bool, optional): Whether to use CUDA.
device_ids (list[int], optional): The device IDs to use for the model.
lazy_load (bool, optional): Whether to lazy load the model.
query_config (dict, optional): Configuration for the query, can be "task"
**kwargs: Additional arguments to pass to the model.
"""
try:
from fastembed import SparseTextEmbedding
except ImportError:
raise ValueError(
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
)
self.task = task
self.query_config = query_config
self.model_name = model_name
self.cache_dir = cache_dir
self.threads = threads
self.cuda = cuda
self.device_ids = device_ids
self.lazy_load = lazy_load
for key, value in kwargs.items():
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
raise ValueError(f"Keyword argument {key} is not a primitive type")
self.kwargs = kwargs
self._model = SparseTextEmbedding(
model_name, cache_dir, threads, cuda, device_ids, lazy_load, **kwargs
)
def __call__(self, input: Documents) -> SparseVectors:
"""Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
try:
from fastembed import SparseTextEmbedding
except ImportError:
raise ValueError(
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
)
model = cast(SparseTextEmbedding, self._model)
if self.task == "document":
embeddings = model.embed(
list(input),
)
elif self.task == "query":
embeddings = model.query_embed(
list(input),
)
else:
raise ValueError(f"Invalid task: {self.task}")
sparse_vectors: SparseVectors = []
for vec in embeddings:
sparse_vectors.append(
normalize_sparse_vector(
indices=vec.indices.tolist(), values=vec.values.tolist()
)
)
return sparse_vectors
def embed_query(self, input: Documents) -> SparseVectors:
try:
from fastembed import SparseTextEmbedding
except ImportError:
raise ValueError(
"The fastembed python package is not installed. Please install it with `pip install fastembed`"
)
model = cast(SparseTextEmbedding, self._model)
if self.query_config is not None:
task = self.query_config.get("task")
if task == "document":
embeddings = model.embed(
list(input),
)
elif task == "query":
embeddings = model.query_embed(
list(input),
)
else:
raise ValueError(f"Invalid task: {task}")
sparse_vectors: SparseVectors = []
for vec in embeddings:
sparse_vectors.append(
normalize_sparse_vector(
indices=vec.indices.tolist(), values=vec.values.tolist()
)
)
return sparse_vectors
else:
return self.__call__(input)
@staticmethod
def name() -> str:
return "fastembed_sparse"
@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "SparseEmbeddingFunction[Documents]":
model_name = config.get("model_name")
task = config.get("task")
query_config = config.get("query_config")
cache_dir = config.get("cache_dir")
threads = config.get("threads")
cuda = config.get("cuda")
device_ids = config.get("device_ids")
lazy_load = config.get("lazy_load")
kwargs = config.get("kwargs", {})
if model_name is None:
assert False, "This code should not be reached"
return FastembedSparseEmbeddingFunction(
model_name=model_name,
task=task,
query_config=query_config,
cache_dir=cache_dir,
threads=threads,
cuda=cuda,
device_ids=device_ids,
lazy_load=lazy_load,
**kwargs,
)
def get_config(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"task": self.task,
"query_config": self.query_config,
"cache_dir": self.cache_dir,
"threads": self.threads,
"cuda": self.cuda,
"device_ids": self.device_ids,
"lazy_load": self.lazy_load,
"kwargs": self.kwargs,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
# model_name is also used as the identifier for model path if stored locally.
# Users should be able to change the path if needed, so we should not validate that.
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
return
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "fastembed_sparse")

View File

@@ -0,0 +1,395 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from typing import List, Dict, Any, cast, Optional
import os
import numpy as np
import numpy.typing as npt
from chromadb.utils.embedding_functions.schemas import validate_config_schema
import warnings
class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "models/embedding-gecko-001",
api_key_env_var: str = "CHROMA_GOOGLE_PALM_API_KEY",
):
"""
Initialize the GooglePalmEmbeddingFunction.
Args:
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google PaLM API.
Defaults to "CHROMA_GOOGLE_PALM_API_KEY".
model_name (str, optional): The name of the model to use for text embeddings.
Defaults to "models/embedding-gecko-001".
"""
try:
import google.generativeai as palm
except ImportError:
raise ValueError(
"The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model_name = model_name
palm.configure(api_key=self.api_key)
self._palm = palm
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents.
"""
# Google PaLM only works with text documents
if not all(isinstance(item, str) for item in input):
raise ValueError("Google PaLM only supports text documents, not images")
return [
np.array(
self._palm.generate_embeddings(model=self.model_name, text=text)[
"embedding"
],
dtype=np.float32,
)
for text in input
]
@staticmethod
def name() -> str:
return "google_palm"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
if api_key_env_var is None or model_name is None:
assert False, "This code should not be reached"
return GooglePalmEmbeddingFunction(
api_key_env_var=api_key_env_var, model_name=model_name
)
def get_config(self) -> Dict[str, Any]:
return {"api_key_env_var": self.api_key_env_var, "model_name": self.model_name}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "google_palm")
class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key."""
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "models/embedding-001",
task_type: str = "RETRIEVAL_DOCUMENT",
api_key_env_var: str = "CHROMA_GOOGLE_GENAI_API_KEY",
):
"""
Initialize the GoogleGenerativeAiEmbeddingFunction.
Args:
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google Generative AI API.
Defaults to "CHROMA_GOOGLE_GENAI_API_KEY".
model_name (str, optional): The name of the model to use for text embeddings.
Defaults to "models/embedding-001".
task_type (str, optional): The task type for the embeddings.
Use "RETRIEVAL_DOCUMENT" for embedding documents and "RETRIEVAL_QUERY" for embedding queries.
Defaults to "RETRIEVAL_DOCUMENT".
"""
try:
import google.generativeai as genai
except ImportError:
raise ValueError(
"The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model_name = model_name
self.task_type = task_type
genai.configure(api_key=self.api_key)
self._genai = genai
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents.
"""
# Google Generative AI only works with text documents
if not all(isinstance(item, str) for item in input):
raise ValueError(
"Google Generative AI only supports text documents, not images"
)
embeddings_list: List[npt.NDArray[np.float32]] = []
for text in input:
embedding_result = self._genai.embed_content(
model=self.model_name,
content=text,
task_type=self.task_type,
)
embeddings_list.append(
np.array(embedding_result["embedding"], dtype=np.float32)
)
# Convert to the expected Embeddings type (List[Vector])
return cast(Embeddings, embeddings_list)
@staticmethod
def name() -> str:
return "google_generative_ai"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
task_type = config.get("task_type")
if api_key_env_var is None or model_name is None or task_type is None:
assert False, "This code should not be reached"
return GoogleGenerativeAiEmbeddingFunction(
api_key_env_var=api_key_env_var, model_name=model_name, task_type=task_type
)
def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
"task_type": self.task_type,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
if "task_type" in new_config:
raise ValueError(
"The task type cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "google_generative_ai")
class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the vertexai Python package installed and have Google Cloud credentials configured."""
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "textembedding-gecko",
project_id: str = "cloud-large-language-models",
region: str = "us-central1",
api_key_env_var: str = "CHROMA_GOOGLE_VERTEX_API_KEY",
):
"""
Initialize the GoogleVertexEmbeddingFunction.
Args:
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google Vertex AI API.
Defaults to "CHROMA_GOOGLE_VERTEX_API_KEY".
model_name (str, optional): The name of the model to use for text embeddings.
Defaults to "textembedding-gecko".
project_id (str, optional): The Google Cloud project ID.
Defaults to "cloud-large-language-models".
region (str, optional): The Google Cloud region.
Defaults to "us-central1".
"""
try:
import vertexai
from vertexai.language_models import TextEmbeddingModel
except ImportError:
raise ValueError(
"The vertexai python package is not installed. Please install it with `pip install google-cloud-aiplatform`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model_name = model_name
self.project_id = project_id
self.region = region
vertexai.init(project=project_id, location=region)
self._model = TextEmbeddingModel.from_pretrained(model_name)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents.
"""
# Google Vertex only works with text documents
if not all(isinstance(item, str) for item in input):
raise ValueError("Google Vertex only supports text documents, not images")
embeddings_list: List[npt.NDArray[np.float32]] = []
for text in input:
embedding_result = self._model.get_embeddings([text])
embeddings_list.append(
np.array(embedding_result[0].values, dtype=np.float32)
)
# Convert to the expected Embeddings type (List[Vector])
return cast(Embeddings, embeddings_list)
@staticmethod
def name() -> str:
return "google_vertex"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
project_id = config.get("project_id")
region = config.get("region")
if (
api_key_env_var is None
or model_name is None
or project_id is None
or region is None
):
assert False, "This code should not be reached"
return GoogleVertexEmbeddingFunction(
api_key_env_var=api_key_env_var,
model_name=model_name,
project_id=project_id,
region=region,
)
def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
"project_id": self.project_id,
"region": self.region,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
if "project_id" in new_config:
raise ValueError(
"The project ID cannot be changed after the embedding function has been initialized."
)
if "region" in new_config:
raise ValueError(
"The region cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "google_vertex")

View File

@@ -0,0 +1,236 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from typing import List, Dict, Any, Optional
import os
import numpy as np
from chromadb.utils.embedding_functions.schemas import validate_config_schema
import warnings
class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the HuggingFace API.
It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2".
"""
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
api_key_env_var: str = "CHROMA_HUGGINGFACE_API_KEY",
):
"""
Initialize the HuggingFaceEmbeddingFunction.
Args:
api_key_env_var (str, optional): Environment variable name that contains your API key for the HuggingFace API.
Defaults to "CHROMA_HUGGINGFACE_API_KEY".
model_name (str, optional): The name of the model to use for text embeddings.
Defaults to "sentence-transformers/all-MiniLM-L6-v2".
"""
try:
import httpx
except ImportError:
raise ValueError(
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model_name = model_name
self._api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
self._session = httpx.Client()
self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
input (Documents): A list of texts to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
Example:
>>> hugging_face = HuggingFaceEmbeddingFunction(api_key_env_var="CHROMA_HUGGINGFACE_API_KEY")
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = hugging_face(texts)
"""
# Call HuggingFace Embedding API for each document
response = self._session.post(
self._api_url,
json={"inputs": input, "options": {"wait_for_model": True}},
).json()
# Convert to numpy arrays
return [np.array(embedding, dtype=np.float32) for embedding in response]
@staticmethod
def name() -> str:
return "huggingface"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
if api_key_env_var is None or model_name is None:
assert False, "This code should not be reached"
return HuggingFaceEmbeddingFunction(
api_key_env_var=api_key_env_var, model_name=model_name
)
def get_config(self) -> Dict[str, Any]:
return {"api_key_env_var": self.api_key_env_var, "model_name": self.model_name}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "huggingface")
class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the HuggingFace Embedding server
(https://github.com/huggingface/text-embeddings-inference).
The embedding model is configured in the server.
"""
def __init__(
self,
url: str,
api_key_env_var: Optional[str] = None,
api_key: Optional[str] = None,
):
"""
Initialize the HuggingFaceEmbeddingServer.
Args:
url (str): The URL of the HuggingFace Embedding Server.
api_key (Optional[str]): The API key for the HuggingFace Embedding Server.
api_key_env_var (str, optional): Environment variable name that contains your API key for the HuggingFace API.
"""
try:
import httpx
except ImportError:
raise ValueError(
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.url = url
self.api_key_env_var = api_key_env_var
if self.api_key_env_var is not None:
self.api_key = api_key or os.getenv(self.api_key_env_var)
else:
self.api_key = api_key
self._api_url = f"{url}"
self._session = httpx.Client()
if self.api_key is not None:
self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
input (Documents): A list of texts to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
Example:
>>> hugging_face = HuggingFaceEmbeddingServer(url="http://localhost:8080/embed")
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = hugging_face(texts)
"""
# Call HuggingFace Embedding Server API for each document
response = self._session.post(self._api_url, json={"inputs": input}).json()
# Convert to numpy arrays
return [np.array(embedding, dtype=np.float32) for embedding in response]
@staticmethod
def name() -> str:
return "huggingface_server"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
url = config.get("url")
api_key_env_var = config.get("api_key_env_var")
if url is None:
raise ValueError("URL must be provided for HuggingFaceEmbeddingServer")
return HuggingFaceEmbeddingServer(url=url, api_key_env_var=api_key_env_var)
def get_config(self) -> Dict[str, Any]:
return {"url": self.url, "api_key_env_var": self.api_key_env_var}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "url" in new_config and new_config["url"] != self.url:
raise ValueError(
"The URL cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "huggingface_server")

View File

@@ -0,0 +1,200 @@
from chromadb.api.types import (
SparseEmbeddingFunction,
SparseVectors,
Documents,
)
from typing import Dict, Any, TypedDict, Optional
import numpy as np
from typing import cast, Literal
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
TaskType = Literal["document", "query"]
class HuggingFaceSparseEmbeddingFunctionQueryConfig(TypedDict):
task: TaskType
class HuggingFaceSparseEmbeddingFunction(SparseEmbeddingFunction[Documents]):
# Since we do dynamic imports we have to type this as Any
models: Dict[str, Any] = {}
def __init__(
self,
model_name: str,
device: str,
task: Optional[TaskType] = "document",
query_config: Optional[HuggingFaceSparseEmbeddingFunctionQueryConfig] = None,
**kwargs: Any,
):
"""Initialize SparseEncoderEmbeddingFunction.
Args:
model_name (str, optional): Identifier of the Huggingface SparseEncoder model
Some common models: prithivida/Splade_PP_en_v1, naver/splade-cocondenser-ensembledistil, naver/splade-v3
device (str, optional): Device used for computation
**kwargs: Additional arguments to pass to the Splade model.
"""
try:
from sentence_transformers import SparseEncoder
except ImportError:
raise ValueError(
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
)
self.model_name = model_name
self.device = device
self.task = task
self.query_config = query_config
for key, value in kwargs.items():
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
raise ValueError(f"Keyword argument {key} is not a primitive type")
self.kwargs = kwargs
if model_name not in self.models:
self.models[model_name] = SparseEncoder(
model_name_or_path=model_name, device=device, **kwargs
)
self._model = self.models[model_name]
def __call__(self, input: Documents) -> SparseVectors:
"""Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
try:
from sentence_transformers import SparseEncoder
except ImportError:
raise ValueError(
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
)
model = cast(SparseEncoder, self._model)
if self.task == "document":
embeddings = model.encode_document(
list(input),
)
elif self.task == "query":
embeddings = model.encode_query(
list(input),
)
else:
raise ValueError(f"Invalid task: {self.task}")
sparse_vectors: SparseVectors = []
for vec in embeddings:
# Convert sparse tensor to dense array if needed
if hasattr(vec, "to_dense"):
vec_dense = vec.to_dense().numpy()
else:
vec_dense = vec.numpy() if hasattr(vec, "numpy") else np.array(vec)
nz = np.where(vec_dense != 0)[0]
sparse_vectors.append(
normalize_sparse_vector(
indices=nz.tolist(), values=vec_dense[nz].tolist()
)
)
return sparse_vectors
def embed_query(self, input: Documents) -> SparseVectors:
try:
from sentence_transformers import SparseEncoder
except ImportError:
raise ValueError(
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
)
model = cast(SparseEncoder, self._model)
if self.query_config is not None:
if self.query_config.get("task") == "document":
embeddings = model.encode_document(
list(input),
)
elif self.query_config.get("task") == "query":
embeddings = model.encode_query(
list(input),
)
else:
raise ValueError(f"Invalid task: {self.query_config.get('task')}")
sparse_vectors: SparseVectors = []
for vec in embeddings:
# Convert sparse tensor to dense array if needed
if hasattr(vec, "to_dense"):
vec_dense = vec.to_dense().numpy()
else:
vec_dense = vec.numpy() if hasattr(vec, "numpy") else np.array(vec)
nz = np.where(vec_dense != 0)[0]
sparse_vectors.append(
normalize_sparse_vector(
indices=nz.tolist(), values=vec_dense[nz].tolist()
)
)
return sparse_vectors
else:
return self.__call__(input)
@staticmethod
def name() -> str:
return "huggingface_sparse"
@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "SparseEmbeddingFunction[Documents]":
model_name = config.get("model_name")
device = config.get("device")
task = config.get("task")
query_config = config.get("query_config")
kwargs = config.get("kwargs", {})
if model_name is None or device is None:
assert False, "This code should not be reached"
return HuggingFaceSparseEmbeddingFunction(
model_name=model_name,
device=device,
task=task,
query_config=query_config,
**kwargs,
)
def get_config(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"device": self.device,
"task": self.task,
"query_config": self.query_config,
"kwargs": self.kwargs,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
# model_name is also used as the identifier for model path if stored locally.
# Users should be able to change the path if needed, so we should not validate that.
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
return
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "huggingface_sparse")

View File

@@ -0,0 +1,118 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import List, Dict, Any, Optional
import numpy as np
class InstructorEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Instructor embedding model.
"""
# If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda"
# for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list
def __init__(
self,
model_name: str = "hkunlp/instructor-base",
device: str = "cpu",
instruction: Optional[str] = None,
):
"""
Initialize the InstructorEmbeddingFunction.
Args:
model_name (str, optional): The name of the model to use for text embeddings.
Defaults to "hkunlp/instructor-base".
device (str, optional): The device to use for computation.
Defaults to "cpu".
instruction (str, optional): The instruction to use for the embeddings.
Defaults to None.
"""
try:
from InstructorEmbedding import INSTRUCTOR
except ImportError:
raise ValueError(
"The InstructorEmbedding python package is not installed. Please install it with `pip install InstructorEmbedding`"
)
self.model_name = model_name
self.device = device
self.instruction = instruction
self._model = INSTRUCTOR(model_name_or_path=model_name, device=device)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents.
"""
# Instructor only works with text documents
if not all(isinstance(item, str) for item in input):
raise ValueError("Instructor only supports text documents, not images")
if self.instruction is None:
embeddings = self._model.encode(input, convert_to_numpy=True)
else:
texts_with_instructions = [[self.instruction, text] for text in input]
embeddings = self._model.encode(
texts_with_instructions, convert_to_numpy=True
)
# Convert to numpy arrays
return [np.array(embedding, dtype=np.float32) for embedding in embeddings]
@staticmethod
def name() -> str:
return "instructor"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
model_name = config.get("model_name")
device = config.get("device")
instruction = config.get("instruction")
if model_name is None or device is None:
assert False, "This code should not be reached"
return InstructorEmbeddingFunction(
model_name=model_name, device=device, instruction=instruction
)
def get_config(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"device": self.device,
"instruction": self.instruction,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
# model_name is also used as the identifier for model path if stored locally.
# Users should be able to change the path if needed, so we should not validate that.
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
return
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "instructor")

View File

@@ -0,0 +1,277 @@
from chromadb.api.types import (
Embeddings,
EmbeddingFunction,
Space,
Embeddable,
is_image,
is_document,
)
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import List, Dict, Any, Union, Optional, TypedDict
import os
import numpy as np
import warnings
import importlib
import base64
import io
class JinaQueryConfig(TypedDict):
task: str
class JinaEmbeddingFunction(EmbeddingFunction[Embeddable]):
"""
This class is used to get embeddings for a list of texts using the Jina AI API.
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
"""
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "jina-embeddings-v2-base-en",
api_key_env_var: str = "CHROMA_JINA_API_KEY",
task: Optional[str] = None,
late_chunking: Optional[bool] = None,
truncate: Optional[bool] = None,
dimensions: Optional[int] = None,
embedding_type: Optional[str] = None,
normalized: Optional[bool] = None,
query_config: Optional[JinaQueryConfig] = None,
):
"""
Initialize the JinaEmbeddingFunction.
Args:
api_key_env_var (str, optional): Environment variable name that contains your API key for the Jina AI API.
Defaults to "CHROMA_JINA_API_KEY".
model_name (str, optional): The name of the model to use for text embeddings.
Defaults to "jina-embeddings-v2-base-en".
task (str, optional): The task to use for the Jina AI API.
Defaults to None.
late_chunking (bool, optional): Whether to use late chunking for the Jina AI API.
Defaults to None.
truncate (bool, optional): Whether to truncate the Jina AI API.
Defaults to None.
dimensions (int, optional): The number of dimensions to use for the Jina AI API.
Defaults to None.
embedding_type (str, optional): The type of embedding to use for the Jina AI API.
Defaults to None.
normalized (bool, optional): Whether to normalize the Jina AI API.
Defaults to None.
"""
try:
import httpx
except ImportError:
raise ValueError(
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
try:
self._PILImage = importlib.import_module("PIL.Image")
except ImportError:
raise ValueError(
"The PIL python package is not installed. Please install it with `pip install pillow`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model_name = model_name
# Initialize optional attributes to None
self.task = task
self.late_chunking = late_chunking
self.truncate = truncate
self.dimensions = dimensions
self.embedding_type = embedding_type
self.normalized = normalized
self.query_config = query_config
self._api_url = "https://api.jina.ai/v1/embeddings"
self._session = httpx.Client()
self._session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)
def _build_payload(self, input: Embeddable, is_query: bool) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"input": [],
"model": self.model_name,
}
if all(is_document(item) for item in input):
payload["input"] = input
else:
for item in input:
if is_document(item):
payload["input"].append({"text": item})
elif is_image(item):
try:
pil_image = self._PILImage.fromarray(item)
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
img_bytes = buffer.getvalue()
# Encode bytes to base64 string
base64_string = base64.b64encode(img_bytes).decode("utf-8")
except Exception as e:
raise ValueError(
f"Failed to convert image numpy array to base64 data URI: {e}"
) from e
payload["input"].append({"image": base64_string})
if self.task is not None:
payload["task"] = self.task
if self.late_chunking is not None:
payload["late_chunking"] = self.late_chunking
if self.truncate is not None:
payload["truncate"] = self.truncate
if self.dimensions is not None:
payload["dimensions"] = self.dimensions
if self.embedding_type is not None:
payload["embedding_type"] = self.embedding_type
if self.normalized is not None:
payload["normalized"] = self.normalized
# overwrite parameteres when query payload is used
if is_query and self.query_config is not None:
for key, value in self.query_config.items():
payload[key] = value
return payload
def _convert_resp(self, resp: Any, is_query: bool = False) -> Embeddings:
"""
Convert the response from the Jina AI API to a list of numpy arrays.
Args:
resp (Any): The response from the Jina AI API.
Returns:
Embeddings: A list of numpy arrays representing the embeddings.
"""
if "data" not in resp:
raise RuntimeError(resp.get("detail", "Unknown error"))
embeddings_data: List[Dict[str, Union[int, List[float]]]] = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings_data, key=lambda e: e["index"])
# Return embeddings as numpy arrays
return [
np.array(result["embedding"], dtype=np.float32)
for result in sorted_embeddings
]
def __call__(self, input: Embeddable) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
input (Embeddable): A list of texts and/or images to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
Example:
>>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
>>> input = ["Hello, world!", "How are you?"]
"""
payload = self._build_payload(input, is_query=False)
# Call Jina AI Embedding API
resp = self._session.post(self._api_url, json=payload, timeout=60).json()
return self._convert_resp(resp)
def embed_query(self, input: Embeddable) -> Embeddings:
payload = self._build_payload(input, is_query=True)
# Call Jina AI Embedding API
resp = self._session.post(self._api_url, json=payload, timeout=60).json()
return self._convert_resp(resp, is_query=True)
@staticmethod
def name() -> str:
return "jina"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
task = config.get("task")
late_chunking = config.get("late_chunking")
truncate = config.get("truncate")
dimensions = config.get("dimensions")
embedding_type = config.get("embedding_type")
normalized = config.get("normalized")
query_config = config.get("query_config")
if api_key_env_var is None or model_name is None:
assert False, "This code should not be reached" # this is for type checking
return JinaEmbeddingFunction(
api_key_env_var=api_key_env_var,
model_name=model_name,
task=task,
late_chunking=late_chunking,
truncate=truncate,
dimensions=dimensions,
embedding_type=embedding_type,
normalized=normalized,
query_config=query_config,
)
def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
"task": self.task,
"late_chunking": self.late_chunking,
"truncate": self.truncate,
"dimensions": self.dimensions,
"embedding_type": self.embedding_type,
"normalized": self.normalized,
"query_config": self.query_config,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "jina")

View File

@@ -0,0 +1,92 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import List, Dict, Any
import os
import numpy as np
class MistralEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
model: str,
api_key_env_var: str = "MISTRAL_API_KEY",
):
"""
Initialize the MistralEmbeddingFunction.
Args:
model (str): The name of the model to use for text embeddings.
api_key_env_var (str): The environment variable name for the Mistral API key.
"""
try:
from mistralai import Mistral
except ImportError:
raise ValueError(
"The mistralai python package is not installed. Please install it with `pip install mistralai`"
)
self.model = model
self.api_key_env_var = api_key_env_var
self.api_key = os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.client = Mistral(api_key=self.api_key)
def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
input (Documents): A list of texts to get embeddings for.
"""
if not all(isinstance(item, str) for item in input):
raise ValueError("Mistral only supports text documents, not images")
output = self.client.embeddings.create(
model=self.model,
inputs=input,
)
# Extract embeddings from the response
return [np.array(data.embedding) for data in output.data]
@staticmethod
def name() -> str:
return "mistral"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
model = config.get("model")
api_key_env_var = config.get("api_key_env_var")
if model is None or api_key_env_var is None:
assert False, "This code should not be reached" # this is for type checking
return MistralEmbeddingFunction(model=model, api_key_env_var=api_key_env_var)
def get_config(self) -> Dict[str, Any]:
return {
"model": self.model,
"api_key_env_var": self.api_key_env_var,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model" in new_config:
raise ValueError(
"The model cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
"""
validate_config_schema(config, "mistral")

View File

@@ -0,0 +1,147 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from typing import List, Dict, Any, Optional
import os
import numpy as np
from chromadb.utils.embedding_functions.schemas import validate_config_schema
import warnings
class MorphEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "morph-embedding-v2",
api_base: str = "https://api.morphllm.com/v1",
encoding_format: str = "float",
api_key_env_var: str = "MORPH_API_KEY",
):
"""
Initialize the MorphEmbeddingFunction.
Args:
api_key (str, optional): The API key for the Morph API. If not provided,
it will be read from the environment variable specified by api_key_env_var.
model_name (str, optional): The name of the model to use for embeddings.
Defaults to "morph-embedding-v2".
api_base (str, optional): The base URL for the Morph API.
Defaults to "https://api.morphllm.com/v1".
encoding_format (str, optional): The format for embeddings (float or base64).
Defaults to "float".
api_key_env_var (str, optional): Environment variable name that contains your API key.
Defaults to "MORPH_API_KEY".
"""
try:
import openai
except ImportError:
raise ValueError(
"The openai python package is not installed. Please install it with `pip install openai`. "
"Note: Morph uses the OpenAI client library for API communication."
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model_name = model_name
self.api_base = api_base
self.encoding_format = encoding_format
# Initialize the OpenAI client with Morph's base URL
self.client = openai.OpenAI(
api_key=self.api_key,
base_url=self.api_base,
)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
# Handle empty input
if not input:
return []
# Prepare embedding parameters
embedding_params: Dict[str, Any] = {
"model": self.model_name,
"input": input,
"encoding_format": self.encoding_format,
}
# Get embeddings from Morph API
response = self.client.embeddings.create(**embedding_params)
# Extract embeddings from response
return [np.array(data.embedding, dtype=np.float32) for data in response.data]
@staticmethod
def name() -> str:
return "morph"
def default_space(self) -> Space:
# Morph embeddings work best with cosine similarity
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
# Extract parameters from config
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
api_base = config.get("api_base")
encoding_format = config.get("encoding_format")
if api_key_env_var is None or model_name is None:
assert False, "This code should not be reached"
# Create and return the embedding function
return MorphEmbeddingFunction(
api_key_env_var=api_key_env_var,
model_name=model_name,
api_base=api_base if api_base is not None else "https://api.morphllm.com/v1",
encoding_format=encoding_format if encoding_format is not None else "float",
)
def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
"api_base": self.api_base,
"encoding_format": self.encoding_format,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "morph")

View File

@@ -0,0 +1,117 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import List, Dict, Any
import numpy as np
from urllib.parse import urlparse
DEFAULT_MODEL_NAME = "chroma/all-minilm-l6-v2-f32"
class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API
(https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
"""
def __init__(
self,
url: str = "http://localhost:11434",
model_name: str = DEFAULT_MODEL_NAME,
timeout: int = 60,
) -> None:
"""
Initialize the Ollama Embedding Function.
Args:
url (str): The Base URL of the Ollama Server (default: "http://localhost:11434").
model_name (str): The name of the model to use for text embeddings.
Defaults to "chroma/all-minilm-l6-v2-f32", for available models see https://ollama.com/library.
timeout (int): The timeout for the API call in seconds. Defaults to 60.
"""
try:
from ollama import Client
except ImportError:
raise ValueError(
"The ollama python package is not installed. Please install it with `pip install ollama`"
)
self.url = url
self.model_name = model_name
self.timeout = timeout
# Adding this for backwards compatibility with the old version of the EF
self._base_url = url
if self._base_url.endswith("/api/embeddings"):
parsed_url = urlparse(url)
self._base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
self._client = Client(host=self._base_url, timeout=timeout)
def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
input (Documents): A list of texts to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
Example:
>>> ollama_ef = OllamaEmbeddingFunction()
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = ollama_ef(texts)
"""
# Call Ollama client
response = self._client.embed(model=self.model_name, input=input)
# Convert to numpy arrays
return [
np.array(embedding, dtype=np.float32)
for embedding in response["embeddings"]
]
@staticmethod
def name() -> str:
return "ollama"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
url = config.get("url")
model_name = config.get("model_name")
timeout = config.get("timeout")
if url is None or model_name is None or timeout is None:
assert False, "This code should not be reached"
return OllamaEmbeddingFunction(url=url, model_name=model_name, timeout=timeout)
def get_config(self) -> Dict[str, Any]:
return {"url": self.url, "model_name": self.model_name, "timeout": self.timeout}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "ollama")

View File

@@ -0,0 +1,364 @@
import hashlib
import importlib
import logging
import os
import tarfile
import sys
from functools import cached_property
from pathlib import Path
from typing import List, Dict, Any, Optional, cast
import numpy as np
import numpy.typing as npt
import httpx
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random
from chromadb.api.types import Documents, Embeddings, EmbeddingFunction, Space
from chromadb.utils.embedding_functions.schemas import validate_config_schema
logger = logging.getLogger(__name__)
def _verify_sha256(fname: str, expected_sha256: str) -> bool:
sha256_hash = hashlib.sha256()
with open(fname, "rb") as f:
# Read and update hash in chunks to avoid using too much memory
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest() == expected_sha256
# In order to remove dependencies on sentence-transformers, which in turn depends on
# pytorch and sentence-piece we have created a default ONNX embedding function that
# implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers.
# visit https://github.com/chroma-core/onnx-embedding for the source code to generate
# and verify the ONNX model.
class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]):
MODEL_NAME = "all-MiniLM-L6-v2"
DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME
EXTRACTED_FOLDER_NAME = "onnx"
ARCHIVE_FILENAME = "onnx.tar.gz"
MODEL_DOWNLOAD_URL = (
"https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz"
)
_MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3"
def __init__(self, preferred_providers: Optional[List[str]] = None) -> None:
"""
Initialize the ONNXMiniLM_L6_V2 embedding function.
Args:
preferred_providers (List[str], optional): The preferred ONNX runtime providers.
Defaults to None.
"""
# convert the list to set for unique values
if preferred_providers and not all(
[isinstance(i, str) for i in preferred_providers]
):
raise ValueError("Preferred providers must be a list of strings")
# check for duplicate providers
if preferred_providers and len(preferred_providers) != len(
set(preferred_providers)
):
raise ValueError("Preferred providers must be unique")
self._preferred_providers = preferred_providers
try:
# Equivalent to import onnxruntime
self.ort = importlib.import_module("onnxruntime")
except ImportError:
raise ValueError(
"The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`"
)
try:
# Equivalent to from tokenizers import Tokenizer
self.Tokenizer = importlib.import_module("tokenizers").Tokenizer
except ImportError:
raise ValueError(
"The tokenizers python package is not installed. Please install it with `pip install tokenizers`"
)
try:
# Equivalent to from tqdm import tqdm
self.tqdm = importlib.import_module("tqdm").tqdm
except ImportError:
raise ValueError(
"The tqdm python package is not installed. Please install it with `pip install tqdm`"
)
# Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
# Download with tqdm to preserve the sentence-transformers experience
@retry( # type: ignore
reraise=True,
stop=stop_after_attempt(3),
wait=wait_random(min=1, max=3),
retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)),
)
def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None:
"""
Download the onnx model from the URL and save it to the file path.
Args:
url: The URL to download the model from.
fname: The path to save the model to.
chunk_size: The chunk size to use when downloading.
"""
with httpx.stream("GET", url) as resp:
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, self.tqdm(
desc=str(fname),
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_bytes(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
if not _verify_sha256(fname, self._MODEL_SHA256):
os.remove(fname)
raise ValueError(
f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file."
)
# Use pytorches default epsilon for division by zero
# https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
def _normalize(self, v: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
"""
Normalize a vector.
Args:
v: The vector to normalize.
Returns:
The normalized vector.
"""
norm = np.linalg.norm(v, axis=1)
# Handle division by zero
norm[norm == 0] = 1e-12
return cast(npt.NDArray[np.float32], v / norm[:, np.newaxis])
def _forward(
self, documents: List[str], batch_size: int = 32
) -> npt.NDArray[np.float32]:
"""
Generate embeddings for a list of documents.
Args:
documents: The documents to generate embeddings for.
batch_size: The batch size to use when generating embeddings.
Returns:
The embeddings for the documents.
"""
all_embeddings = []
for i in range(0, len(documents), batch_size):
batch = documents[i : i + batch_size]
# Encode each document separately
encoded = [self.tokenizer.encode(d) for d in batch]
# Check if any document exceeds the max tokens
for doc_tokens in encoded:
if len(doc_tokens.ids) > self.max_tokens():
raise ValueError(
f"Document length {len(doc_tokens.ids)} is greater than the max tokens {self.max_tokens()}"
)
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
onnx_input = {
"input_ids": np.array(input_ids, dtype=np.int64),
"attention_mask": np.array(attention_mask, dtype=np.int64),
"token_type_ids": np.array(
[np.zeros(len(e), dtype=np.int64) for e in input_ids],
dtype=np.int64,
),
}
model_output = self.model.run(None, onnx_input)
last_hidden_state = model_output[0]
# Perform mean pooling with attention weighting
input_mask_expanded = np.broadcast_to(
np.expand_dims(attention_mask, -1), last_hidden_state.shape
)
embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(
input_mask_expanded.sum(1), a_min=1e-9, a_max=None
)
embeddings = self._normalize(embeddings).astype(np.float32)
all_embeddings.append(embeddings)
return np.concatenate(all_embeddings)
@cached_property
def tokenizer(self) -> Any:
"""
Get the tokenizer for the model.
Returns:
The tokenizer for the model.
"""
tokenizer = self.Tokenizer.from_file(
os.path.join(
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json"
)
)
# max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128
# https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480
tokenizer.enable_truncation(max_length=256)
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256)
return tokenizer
@cached_property
def model(self) -> Any:
"""
Get the model.
Returns:
The model.
"""
if self._preferred_providers is None or len(self._preferred_providers) == 0:
if len(self.ort.get_available_providers()) > 0:
logger.debug(
f"WARNING: No ONNX providers provided, defaulting to available providers: "
f"{self.ort.get_available_providers()}"
)
self._preferred_providers = self.ort.get_available_providers()
elif not set(self._preferred_providers).issubset(
set(self.ort.get_available_providers())
):
raise ValueError(
f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}"
)
# Suppress onnxruntime warnings
so = self.ort.SessionOptions()
so.log_severity_level = 3
so.graph_optimization_level = self.ort.GraphOptimizationLevel.ORT_ENABLE_ALL
if (
self._preferred_providers
and "CoreMLExecutionProvider" in self._preferred_providers
):
# remove CoreMLExecutionProvider from the list, it is not as well optimized as CPU.
self._preferred_providers.remove("CoreMLExecutionProvider")
return self.ort.InferenceSession(
os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"),
# Since 1.9 onnyx runtime requires providers to be specified when there are multiple available
providers=self._preferred_providers,
sess_options=so,
)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
# Only download the model when it is actually used
self._download_model_if_not_exists()
# Generate embeddings
embeddings = self._forward(input)
# Convert to list of numpy arrays for the expected Embeddings type
return cast(
Embeddings,
[np.array(embedding, dtype=np.float32) for embedding in embeddings],
)
def _download_model_if_not_exists(self) -> None:
"""
Download the model if it doesn't exist.
"""
onnx_files = [
"config.json",
"model.onnx",
"special_tokens_map.json",
"tokenizer_config.json",
"tokenizer.json",
"vocab.txt",
]
extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME)
onnx_files_exist = True
for f in onnx_files:
if not os.path.exists(os.path.join(extracted_folder, f)):
onnx_files_exist = False
break
# Model is not downloaded yet
if not onnx_files_exist:
os.makedirs(self.DOWNLOAD_PATH, exist_ok=True)
if not os.path.exists(
os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME)
) or not _verify_sha256(
os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
self._MODEL_SHA256,
):
self._download(
url=self.MODEL_DOWNLOAD_URL,
fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
)
with tarfile.open(
name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
mode="r:gz",
) as tar:
if sys.version_info >= (3, 12):
tar.extractall(path=self.DOWNLOAD_PATH, filter="data")
else:
# filter argument was added in Python 3.12
# https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractall
# In versions prior to 3.12, this provides the same behavior as filter="data"
tar.extractall(path=self.DOWNLOAD_PATH)
@staticmethod
def name() -> str:
return "onnx_mini_lm_l6_v2"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
def max_tokens(self) -> int:
# Default token limit for ONNX Mini LM L6 V2 model
return 256
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
preferred_providers = config.get("preferred_providers")
return ONNXMiniLM_L6_V2(preferred_providers=preferred_providers)
def get_config(self) -> Dict[str, Any]:
return {"preferred_providers": self._preferred_providers}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
# Preferred providers can be changed, so no validation needed
pass
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "onnx_mini_lm_l6_v2")

View File

@@ -0,0 +1,187 @@
from chromadb.api.types import EmbeddingFunction, Space
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.api.types import (
Document,
Documents,
Embedding,
Embeddings,
Image,
Images,
is_document,
is_image,
Embeddable,
)
from typing import List, Dict, Any, Union, Optional, cast
import numpy as np
import importlib
class OpenCLIPEmbeddingFunction(EmbeddingFunction[Embeddable]):
"""
This class is used to generate embeddings for a list of texts or images using the Open CLIP model.
"""
def __init__(
self,
model_name: str = "ViT-B-32",
checkpoint: str = "laion2b_s34b_b79k",
device: Optional[str] = "cpu",
) -> None:
"""
Initialize the OpenCLIPEmbeddingFunction.
Args:
model_name (str, optional): The name of the model to use for embeddings.
Defaults to "ViT-B-32".
checkpoint (str, optional): The checkpoint to use for the model.
Defaults to "laion2b_s34b_b79k".
device (str, optional): The device to use for computation.
Defaults to "cpu".
"""
try:
import open_clip
except ImportError:
raise ValueError(
"The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip"
)
try:
self._torch = importlib.import_module("torch")
except ImportError:
raise ValueError(
"The torch python package is not installed. Please install it with `pip install torch`"
)
try:
self._PILImage = importlib.import_module("PIL.Image")
except ImportError:
raise ValueError(
"The PIL python package is not installed. Please install it with `pip install pillow`"
)
self.model_name = model_name
self.checkpoint = checkpoint
self.device = device
model, _, preprocess = open_clip.create_model_and_transforms(
model_name=model_name, pretrained=checkpoint
)
self._model = model
self._model.to(device)
self._preprocess = preprocess
self._tokenizer = open_clip.get_tokenizer(model_name=model_name)
def _encode_image(self, image: Image) -> Embedding:
"""
Encode an image using the Open CLIP model.
Args:
image: The image to encode.
Returns:
The embedding for the image.
"""
pil_image = self._PILImage.fromarray(image)
with self._torch.no_grad():
image_features = self._model.encode_image(
self._preprocess(pil_image).unsqueeze(0).to(self.device)
)
image_features /= image_features.norm(dim=-1, keepdim=True)
return cast(Embedding, image_features.squeeze().cpu().numpy())
def _encode_text(self, text: Document) -> Embedding:
"""
Encode a text using the Open CLIP model.
Args:
text: The text to encode.
Returns:
The embedding for the text.
"""
with self._torch.no_grad():
text_features = self._model.encode_text(
self._tokenizer(text).to(self.device)
)
text_features /= text_features.norm(dim=-1, keepdim=True)
return cast(Embedding, text_features.squeeze().cpu().numpy())
def __call__(self, input: Embeddable) -> Embeddings:
"""
Generate embeddings for the given documents or images.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents or images.
"""
embeddings: Embeddings = []
for item in input:
if is_image(item):
embeddings.append(
np.array(self._encode_image(cast(Image, item)), dtype=np.float32)
)
elif is_document(item):
embeddings.append(
np.array(self._encode_text(cast(Document, item)), dtype=np.float32)
)
return embeddings
@staticmethod
def name() -> str:
return "open_clip"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "EmbeddingFunction[Union[Documents, Images]]":
model_name = config.get("model_name")
checkpoint = config.get("checkpoint")
device = config.get("device")
if model_name is None or checkpoint is None or device is None:
assert False, "This code should not be reached"
return OpenCLIPEmbeddingFunction(
model_name=model_name, checkpoint=checkpoint, device=device
)
def get_config(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"checkpoint": self.checkpoint,
"device": self.device,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
if "checkpoint" in new_config:
raise ValueError(
"The checkpoint cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "open_clip")

View File

@@ -0,0 +1,204 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from typing import List, Dict, Any, Optional
import os
import numpy as np
from chromadb.utils.embedding_functions.schemas import validate_config_schema
import warnings
class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "text-embedding-ada-002",
organization_id: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
deployment_id: Optional[str] = None,
default_headers: Optional[Dict[str, str]] = None,
dimensions: Optional[int] = None,
api_key_env_var: str = "CHROMA_OPENAI_API_KEY",
):
"""
Initialize the OpenAIEmbeddingFunction.
Args:
api_key_env_var (str, optional): Environment variable name that contains your API key for the OpenAI API.
Defaults to "CHROMA_OPENAI_API_KEY".
model_name (str, optional): The name of the model to use for text
embeddings. Defaults to "text-embedding-ada-002".
organization_id(str, optional): The OpenAI organization ID if applicable
api_base (str, optional): The base path for the API. If not provided,
it will use the base path for the OpenAI API. This can be used to
point to a different deployment, such as an Azure deployment.
api_type (str, optional): The type of the API deployment. This can be
used to specify a different deployment, such as 'azure'. If not
provided, it will use the default OpenAI deployment.
api_version (str, optional): The api version for the API. If not provided,
it will use the api version for the OpenAI API. This can be used to
point to a different deployment, such as an Azure deployment.
deployment_id (str, optional): Deployment ID for Azure OpenAI.
default_headers (Dict[str, str], optional): A mapping of default headers to be sent with each API request.
dimensions (int, optional): The number of dimensions for the embeddings.
Only supported for `text-embedding-3` or later models from OpenAI.
https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions
"""
try:
import openai
except ImportError:
raise ValueError(
"The openai python package is not installed. Please install it with `pip install openai`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model_name = model_name
self.organization_id = organization_id
self.api_base = api_base
self.api_type = api_type
self.api_version = api_version
self.deployment_id = deployment_id
self.default_headers = default_headers
self.dimensions = dimensions
# Initialize the OpenAI client
client_params: Dict[str, Any] = {"api_key": self.api_key}
if self.organization_id is not None:
client_params["organization"] = self.organization_id
if self.api_base is not None:
client_params["base_url"] = self.api_base
if self.default_headers is not None:
client_params["default_headers"] = self.default_headers
self.client = openai.OpenAI(**client_params)
# For Azure OpenAI
if self.api_type == "azure":
if self.api_version is None:
raise ValueError("api_version must be specified for Azure OpenAI")
if self.deployment_id is None:
raise ValueError("deployment_id must be specified for Azure OpenAI")
if self.api_base is None:
raise ValueError("api_base must be specified for Azure OpenAI")
from openai import AzureOpenAI
self.client = AzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.api_base,
azure_deployment=self.deployment_id,
default_headers=self.default_headers,
)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
# Handle batching
if not input:
return []
# Prepare embedding parameters
embedding_params: Dict[str, Any] = {
"model": self.model_name,
"input": input,
}
if self.dimensions is not None and "text-embedding-3" in self.model_name:
embedding_params["dimensions"] = self.dimensions
# Get embeddings
response = self.client.embeddings.create(**embedding_params)
# Extract embeddings from response
return [np.array(data.embedding, dtype=np.float32) for data in response.data]
@staticmethod
def name() -> str:
return "openai"
def default_space(self) -> Space:
# OpenAI embeddings work best with cosine similarity
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
# Extract parameters from config
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
organization_id = config.get("organization_id")
api_base = config.get("api_base")
api_type = config.get("api_type")
api_version = config.get("api_version")
deployment_id = config.get("deployment_id")
default_headers = config.get("default_headers")
dimensions = config.get("dimensions")
if api_key_env_var is None or model_name is None:
assert False, "This code should not be reached"
# Create and return the embedding function
return OpenAIEmbeddingFunction(
api_key_env_var=api_key_env_var,
model_name=model_name,
organization_id=organization_id,
api_base=api_base,
api_type=api_type,
api_version=api_version,
deployment_id=deployment_id,
default_headers=default_headers,
dimensions=dimensions,
)
def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
"organization_id": self.organization_id,
"api_base": self.api_base,
"api_type": self.api_type,
"api_version": self.api_version,
"deployment_id": self.deployment_id,
"default_headers": self.default_headers,
"dimensions": self.dimensions,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "openai")

View File

@@ -0,0 +1,159 @@
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.api.types import (
Documents,
Embeddings,
Images,
is_document,
is_image,
Embeddable,
EmbeddingFunction,
Space,
)
from typing import List, Dict, Any, Union, cast, Optional
import os
import importlib
import base64
from io import BytesIO
import numpy as np
import warnings
class RoboflowEmbeddingFunction(EmbeddingFunction[Embeddable]):
"""
This class is used to generate embeddings for a list of texts or images using the Roboflow API.
"""
def __init__(
self,
api_key: Optional[str] = None,
api_url: str = "https://infer.roboflow.com",
api_key_env_var: str = "CHROMA_ROBOFLOW_API_KEY",
) -> None:
"""
Create a RoboflowEmbeddingFunction.
Args:
api_key_env_var (str, optional): Environment variable name that contains your API key for the Roboflow API.
Defaults to "CHROMA_ROBOFLOW_API_KEY".
api_url (str, optional): The URL of the Roboflow API.
Defaults to "https://infer.roboflow.com".
"""
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.api_url = api_url
try:
self._PILImage = importlib.import_module("PIL.Image")
except ImportError:
raise ValueError(
"The PIL python package is not installed. Please install it with `pip install pillow`"
)
self._httpx = importlib.import_module("httpx")
def __call__(self, input: Embeddable) -> Embeddings:
"""
Generate embeddings for the given documents or images.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents or images.
"""
embeddings = []
for item in input:
if is_image(item):
image = self._PILImage.fromarray(item)
buffer = BytesIO()
image.save(buffer, format="JPEG")
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
infer_clip_payload_image = {
"image": {
"type": "base64",
"value": base64_image,
},
}
res = self._httpx.post(
f"{self.api_url}/clip/embed_image?api_key={self.api_key}",
json=infer_clip_payload_image,
)
result = res.json()["embeddings"]
embeddings.append(np.array(result[0], dtype=np.float32))
elif is_document(item):
infer_clip_payload_text = {
"text": item,
}
res = self._httpx.post(
f"{self.api_url}/clip/embed_text?api_key={self.api_key}",
json=infer_clip_payload_text,
)
result = res.json()["embeddings"]
embeddings.append(np.array(result[0], dtype=np.float32))
# Cast to the expected Embeddings type
return cast(Embeddings, embeddings)
@staticmethod
def name() -> str:
return "roboflow"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "EmbeddingFunction[Union[Documents, Images]]":
api_key_env_var = config.get("api_key_env_var")
api_url = config.get("api_url")
if api_key_env_var is None or api_url is None:
assert False, "This code should not be reached"
return RoboflowEmbeddingFunction(
api_key_env_var=api_key_env_var, api_url=api_url
)
def get_config(self) -> Dict[str, Any]:
return {"api_key_env_var": self.api_key_env_var, "api_url": self.api_url}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
# API URL can be changed, so no validation needed
pass
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "roboflow")

View File

@@ -0,0 +1,41 @@
# Embedding Function Schemas
This directory contains JSON schemas for all embedding functions in Chroma. The purpose of having this schema is to support cross language compatibility, and to validate that changes in one client library do not accidentally diverge from others.
## Schema Structure
Each schema follows the JSON Schema Draft-07 specification and includes:
- `version`: The version of the schema
- `title`: The title of the schema
- `description`: A description of the schema
- `properties`: The properties that can be configured for the embedding function
- `required`: The properties that are required for the embedding function
- `additionalProperties`: Whether additional properties are allowed (always set to `false` to ensure strict validation)
## Usage
The schemas can be used to validate the configuration of embedding functions using the `validate_config` function:
```python
from chromadb.utils.embedding_functions.schemas import validate_config
# Validate a configuration
config = {
"api_key_env_var": "CHROMA_OPENAI_API_KEY",
"model_name": "text-embedding-ada-002"
}
validate_config(config, "openai")
```
## Adding New Schemas
To add a new schema:
1. Create a new JSON file in this directory with the name of the embedding function (e.g., `new_function.json`)
2. Define the schema following the JSON Schema Draft-07 specification
3. Update the embedding function to use the schema for validation
## Schema Versioning
Each schema includes a version number to support future changes to embedding function configurations. When making changes to a schema, increment the version number to ensure backward compatibility.

View File

@@ -0,0 +1,19 @@
from chromadb.utils.embedding_functions.schemas.schema_utils import (
validate_config_schema,
load_schema,
get_schema_version,
)
from chromadb.utils.embedding_functions.schemas.registry import (
get_available_schemas,
get_schema_info,
get_embedding_function_names,
)
__all__ = [
"validate_config_schema",
"load_schema",
"get_schema_version",
"get_available_schemas",
"get_schema_info",
"get_embedding_function_names",
]

View File

@@ -0,0 +1,285 @@
from __future__ import annotations
import re
from functools import lru_cache
from typing import Iterable, List, Protocol, cast
DEFAULT_ENGLISH_STOPWORDS: List[str] = [
"a",
"about",
"above",
"after",
"again",
"against",
"ain",
"all",
"am",
"an",
"and",
"any",
"are",
"aren",
"aren't",
"as",
"at",
"be",
"because",
"been",
"before",
"being",
"below",
"between",
"both",
"but",
"by",
"can",
"couldn",
"couldn't",
"d",
"did",
"didn",
"didn't",
"do",
"does",
"doesn",
"doesn't",
"doing",
"don",
"don't",
"down",
"during",
"each",
"few",
"for",
"from",
"further",
"had",
"hadn",
"hadn't",
"has",
"hasn",
"hasn't",
"have",
"haven",
"haven't",
"having",
"he",
"her",
"here",
"hers",
"herself",
"him",
"himself",
"his",
"how",
"i",
"if",
"in",
"into",
"is",
"isn",
"isn't",
"it",
"it's",
"its",
"itself",
"just",
"ll",
"m",
"ma",
"me",
"mightn",
"mightn't",
"more",
"most",
"mustn",
"mustn't",
"my",
"myself",
"needn",
"needn't",
"no",
"nor",
"not",
"now",
"o",
"of",
"off",
"on",
"once",
"only",
"or",
"other",
"our",
"ours",
"ourselves",
"out",
"over",
"own",
"re",
"s",
"same",
"shan",
"shan't",
"she",
"she's",
"should",
"should've",
"shouldn",
"shouldn't",
"so",
"some",
"such",
"t",
"than",
"that",
"that'll",
"the",
"their",
"theirs",
"them",
"themselves",
"then",
"there",
"these",
"they",
"this",
"those",
"through",
"to",
"too",
"under",
"until",
"up",
"ve",
"very",
"was",
"wasn",
"wasn't",
"we",
"were",
"weren",
"weren't",
"what",
"when",
"where",
"which",
"while",
"who",
"whom",
"why",
"will",
"with",
"won",
"won't",
"wouldn",
"wouldn't",
"y",
"you",
"you'd",
"you'll",
"you're",
"you've",
"your",
"yours",
"yourself",
"yourselves",
]
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(DEFAULT_ENGLISH_STOPWORDS)
class SnowballStemmer(Protocol):
def stem(self, token: str) -> str: # pragma: no cover - protocol definition
...
class _SnowballStemmerAdapter:
"""Adapter that provides the uniform `stem` API used across languages."""
def __init__(self) -> None:
try:
import snowballstemmer
except ImportError:
raise ValueError(
"The snowballstemmer python package is not installed. Please install it with `pip install snowballstemmer`"
)
self._stemmer = snowballstemmer.stemmer("english")
def stem(self, token: str) -> str:
return cast(str, self._stemmer.stemWord(token))
@lru_cache(maxsize=1)
def get_english_stemmer() -> SnowballStemmer:
"""Return a cached Snowball stemmer for English."""
return _SnowballStemmerAdapter()
class Bm25Tokenizer:
"""Tokenizer with stopword filtering and stemming used by BM25 embeddings."""
def __init__(
self,
stemmer: SnowballStemmer,
stopwords: Iterable[str],
token_max_length: int,
) -> None:
self._stemmer = stemmer
self._stopwords = {word.lower() for word in stopwords}
self._token_max_length = token_max_length
self._non_alphanumeric_pattern = re.compile(r"[^\w\s]+", flags=re.UNICODE)
def _remove_non_alphanumeric(self, text: str) -> str:
return self._non_alphanumeric_pattern.sub(" ", text)
@staticmethod
def _simple_tokenize(text: str) -> List[str]:
return [token for token in text.lower().split() if token]
def tokenize(self, text: str) -> List[str]:
cleaned = self._remove_non_alphanumeric(text)
raw_tokens = self._simple_tokenize(cleaned)
tokens: List[str] = []
for token in raw_tokens:
if token in self._stopwords:
continue
if len(token) > self._token_max_length:
continue
stemmed = self._stemmer.stem(token).strip()
if stemmed:
tokens.append(stemmed)
return tokens
class Murmur3AbsHasher:
def __init__(self, seed: int = 0) -> None:
try:
import mmh3
except ImportError:
raise ValueError(
"The murmurhash3 python package is not installed. Please install it with `pip install murmurhash3`"
)
self.hasher = mmh3.hash
self.seed = seed
def hash(self, token: str) -> int:
return cast(int, abs(self.hasher(token, seed=self.seed)))
__all__ = [
"Bm25Tokenizer",
"DEFAULT_CHROMA_BM25_STOPWORDS",
"DEFAULT_ENGLISH_STOPWORDS",
"SnowballStemmer",
"get_english_stemmer",
"Murmur3AbsHasher",
]

View File

@@ -0,0 +1,55 @@
"""
Schema Registry for Embedding Functions
This module provides a registry of all available schemas for embedding functions.
It can be used to get information about available schemas and their versions.
"""
from typing import Dict, List, Set
import os
import json
from .schema_utils import SCHEMAS_DIR
def get_available_schemas() -> List[str]:
"""
Get a list of all available schemas.
Returns:
A list of schema names (without .json extension)
"""
schemas = []
for filename in os.listdir(SCHEMAS_DIR):
if filename.endswith(".json") and filename != "base_schema.json":
schemas.append(filename[:-5]) # Remove .json extension
return schemas
def get_schema_info() -> Dict[str, Dict[str, str]]:
"""
Get information about all available schemas.
Returns:
A dictionary mapping schema names to information about the schema
"""
schema_info = {}
for schema_name in get_available_schemas():
schema_path = os.path.join(SCHEMAS_DIR, f"{schema_name}.json")
with open(schema_path, "r") as f:
schema = json.load(f)
schema_info[schema_name] = {
"version": schema.get("version", "1.0.0"),
"title": schema.get("title", ""),
"description": schema.get("description", ""),
}
return schema_info
def get_embedding_function_names() -> Set[str]:
"""
Get a set of all embedding function names that have schemas.
Returns:
A set of embedding function names
"""
return set(get_available_schemas())

View File

@@ -0,0 +1,85 @@
import json
import os
from typing import Dict, Any, cast
import jsonschema
from jsonschema import ValidationError
# Path to the schemas directory
SCHEMAS_DIR = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
),
"schemas",
"embedding_functions",
)
cached_schemas: Dict[str, Dict[str, Any]] = {}
def load_schema(schema_name: str) -> Dict[str, Any]:
"""
Load a JSON schema from the schemas directory.
Args:
schema_name: Name of the schema file (without .json extension)
Returns:
The loaded schema as a dictionary
Raises:
FileNotFoundError: If the schema file does not exist
json.JSONDecodeError: If the schema file is not valid JSON
"""
if schema_name in cached_schemas:
return cached_schemas[schema_name]
schema_path = os.path.join(SCHEMAS_DIR, f"{schema_name}.json")
with open(schema_path, "r") as f:
schema = cast(Dict[str, Any], json.load(f))
cached_schemas[schema_name] = schema
return schema
def validate_config_schema(config: Dict[str, Any], schema_name: str) -> None:
"""
Validate a configuration against a schema.
Args:
config: Configuration to validate
schema_name: Name of the schema file (without .json extension)
Raises:
ValidationError: If the configuration does not match the schema
FileNotFoundError: If the schema file does not exist
json.JSONDecodeError: If the schema file is not valid JSON
"""
schema = load_schema(schema_name)
try:
jsonschema.validate(instance=config, schema=schema)
except ValidationError as e:
# Enhance the error message with more context
error_path = "/".join(str(path) for path in e.path)
error_message = (
f"Config validation failed for schema '{schema_name}': {e.message}"
)
if error_path:
error_message += f" at path '{error_path}'"
raise ValidationError(error_message) from e
def get_schema_version(schema_name: str) -> str:
"""
Get the version of a schema.
Args:
schema_name: Name of the schema file (without .json extension)
Returns:
The schema version as a string
Raises:
FileNotFoundError: If the schema file does not exist
json.JSONDecodeError: If the schema file is not valid JSON
KeyError: If the schema does not have a version
"""
schema = load_schema(schema_name)
return cast(str, schema.get("version", "1.0.0"))

View File

@@ -0,0 +1,121 @@
from chromadb.api.types import EmbeddingFunction, Space, Embeddings, Documents
from typing import List, Dict, Any
import numpy as np
from chromadb.utils.embedding_functions.schemas import validate_config_schema
class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]):
# Since we do dynamic imports we have to type this as Any
models: Dict[str, Any] = {}
# If you have a beefier machine, try "gtr-t5-large".
# for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
device: str = "cpu",
normalize_embeddings: bool = False,
**kwargs: Any,
):
"""Initialize SentenceTransformerEmbeddingFunction.
Args:
model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2"
device (str, optional): Device used for computation, defaults to "cpu"
normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False
**kwargs: Additional arguments to pass to the SentenceTransformer model.
"""
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise ValueError(
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
)
self.model_name = model_name
self.device = device
self.normalize_embeddings = normalize_embeddings
for key, value in kwargs.items():
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
raise ValueError(f"Keyword argument {key} is not a primitive type")
self.kwargs = kwargs
if model_name not in self.models:
self.models[model_name] = SentenceTransformer(
model_name_or_path=model_name, device=device, **kwargs
)
self._model = self.models[model_name]
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
embeddings = self._model.encode(
list(input),
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
)
return [np.array(embedding, dtype=np.float32) for embedding in embeddings]
@staticmethod
def name() -> str:
return "sentence_transformer"
def default_space(self) -> Space:
# If normalize_embeddings is True, cosine is equivalent to dot product
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
model_name = config.get("model_name")
device = config.get("device")
normalize_embeddings = config.get("normalize_embeddings")
kwargs = config.get("kwargs", {})
if model_name is None or device is None or normalize_embeddings is None:
assert False, "This code should not be reached"
return SentenceTransformerEmbeddingFunction(
model_name=model_name,
device=device,
normalize_embeddings=normalize_embeddings,
**kwargs,
)
def get_config(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"device": self.device,
"normalize_embeddings": self.normalize_embeddings,
"kwargs": self.kwargs,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
# model_name is also used as the identifier for model path if stored locally.
# Users should be able to change the path if needed, so we should not validate that.
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
return
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "sentence_transformer")

View File

@@ -0,0 +1,90 @@
from chromadb.api.types import EmbeddingFunction, Space, Embeddings, Documents
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import List, Dict, Any
import numpy as np
class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Text2Vec model.
"""
def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"):
"""
Initialize the Text2VecEmbeddingFunction.
Args:
model_name (str, optional): The name of the model to use for text embeddings.
Defaults to "shibing624/text2vec-base-chinese".
"""
try:
from text2vec import SentenceModel
except ImportError:
raise ValueError(
"The text2vec python package is not installed. Please install it with `pip install text2vec`"
)
self.model_name = model_name
self._model = SentenceModel(model_name_or_path=model_name)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents.
"""
# Text2Vec only works with text documents
if not all(isinstance(item, str) for item in input):
raise ValueError("Text2Vec only supports text documents, not images")
embeddings = self._model.encode(list(input), convert_to_numpy=True)
# Convert to numpy arrays
return [np.array(embedding, dtype=np.float32) for embedding in embeddings]
@staticmethod
def name() -> str:
return "text2vec"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
model_name = config.get("model_name")
if model_name is None:
assert False, "This code should not be reached"
return Text2VecEmbeddingFunction(model_name=model_name)
def get_config(self) -> Dict[str, Any]:
return {"model_name": self.model_name}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
# model_name is also used as the identifier for model path if stored locally.
# Users should be able to change the path if needed, so we should not validate that.
# e.g. moving file path from /v1/my-model.bin to /v2/my-model.bin
return
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "text2vec")

View File

@@ -0,0 +1,145 @@
from chromadb.api.types import (
Embeddings,
Documents,
EmbeddingFunction,
Space,
)
from typing import List, Dict, Any, Optional
import os
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import cast
import warnings
ENDPOINT = "https://api.together.xyz/v1/embeddings"
class TogetherAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the Together AI API.
"""
def __init__(
self,
model_name: str,
api_key: Optional[str] = None,
api_key_env_var: str = "CHROMA_TOGETHER_AI_API_KEY",
):
"""
Initialize the TogetherAIEmbeddingFunction. See the docs for supported models here:
https://docs.together.ai/docs/serverless-models#embedding-models
Args:
model_name: The name of the model to use for text embeddings.
api_key: The API key to use for the Together AI API.
api_key_env_var: The environment variable to use for the Together AI API key.
"""
try:
import httpx
except ImportError:
raise ValueError(
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.model_name = model_name
self.api_key = api_key
self.api_key_env_var = api_key_env_var
if not self.api_key:
self.api_key = os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(
f"API key not found in environment variable {self.api_key_env_var}"
)
self._session = httpx.Client()
self._session.headers.update(
{
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"accept": "application/json",
}
)
def __call__(self, input: Documents) -> Embeddings:
"""
Embed a list of texts using the Together AI API.
Args:
input: A list of texts to embed.
"""
if not input:
raise ValueError("Input is required")
if not isinstance(input, list):
raise ValueError("Input must be a list")
if not all(isinstance(item, str) for item in input):
raise ValueError("All items in input must be strings")
response = self._session.post(
ENDPOINT,
json={"model": self.model_name, "input": input},
)
response.raise_for_status()
data = response.json()
embeddings = [item["embedding"] for item in data["data"]]
return cast(Embeddings, embeddings)
@staticmethod
def name() -> str:
return "together_ai"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
if api_key_env_var is None or model_name is None:
raise ValueError("api_key_env_var and model_name must be provided")
return TogetherAIEmbeddingFunction(
model_name=model_name, api_key_env_var=api_key_env_var
)
def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
"""
validate_config_schema(config, "together_ai")

View File

@@ -0,0 +1,136 @@
from chromadb.api.types import EmbeddingFunction, Space, Embeddings, Documents
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import List, Dict, Any, Optional
import os
import numpy as np
import warnings
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the VoyageAI API.
"""
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "voyage-large-2",
api_key_env_var: str = "CHROMA_VOYAGE_API_KEY",
input_type: Optional[str] = None,
truncation: bool = True,
):
"""
Initialize the VoyageAIEmbeddingFunction.
Args:
api_key_env_var (str, optional): Environment variable name that contains your API key for the VoyageAI API.
Defaults to "CHROMA_VOYAGE_API_KEY".
model_name (str, optional): The name of the model to use for text embeddings.
Defaults to "voyage-large-2".
api_key (str, optional): API key for the VoyageAI API. If not provided, will look for it in the environment variable.
input_type (str, optional): The type of input to use for the VoyageAI API.
Defaults to None.
truncation (bool): Whether to truncate the input text.
Defaults to True.
"""
try:
import voyageai
except ImportError:
raise ValueError(
"The voyageai python package is not installed. Please install it with `pip install voyageai`"
)
if api_key is not None:
warnings.warn(
"Direct api_key configuration will not be persisted. "
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
self.model_name = model_name
self.input_type = input_type
self.truncation = truncation
self._client = voyageai.Client(api_key=self.api_key)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents to generate embeddings for.
Returns:
Embeddings for the documents.
"""
embeddings = self._client.embed(
texts=input,
model=self.model_name,
input_type=self.input_type,
truncation=self.truncation,
)
# Convert to numpy arrays
return [
np.array(embedding, dtype=np.float32) for embedding in embeddings.embeddings
]
@staticmethod
def name() -> str:
return "voyageai"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
input_type = config.get("input_type")
truncation = config.get("truncation")
if api_key_env_var is None or model_name is None:
assert False, "This code should not be reached"
return VoyageAIEmbeddingFunction(
api_key_env_var=api_key_env_var,
model_name=model_name,
input_type=input_type,
truncation=truncation if truncation is not None else True,
)
def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
"input_type": self.input_type,
"truncation": self.truncation,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "voyageai")

View File

@@ -0,0 +1,18 @@
from uuid import UUID
from starlette.responses import JSONResponse
from chromadb.errors import ChromaError, InvalidUUIDError
def fastapi_json_response(error: ChromaError) -> JSONResponse:
return JSONResponse(
content={"error": error.name(), "message": error.message()},
status_code=error.code(),
)
def string_to_uuid(uuid_str: str) -> UUID:
try:
return UUID(uuid_str)
except ValueError:
raise InvalidUUIDError(f"Could not parse {uuid_str} as a UUID")

View File

@@ -0,0 +1,32 @@
from collections import OrderedDict
from typing import Any, Callable, Generic, Optional, TypeVar
K = TypeVar("K")
V = TypeVar("V")
class LRUCache(Generic[K, V]):
"""A simple LRU cache implementation, based on the OrderedDict class, which allows
for a callback to be invoked when an item is evicted from the cache."""
def __init__(self, capacity: int, callback: Optional[Callable[[K, V], Any]] = None):
self.capacity = capacity
self.cache: OrderedDict[K, V] = OrderedDict()
self.callback = callback
def get(self, key: K) -> Optional[V]:
if key not in self.cache:
return None
value = self.cache.pop(key)
self.cache[key] = value
return value
def set(self, key: K, value: V) -> None:
if key in self.cache:
self.cache.pop(key)
elif len(self.cache) == self.capacity:
evicted_key, evicted_value = self.cache.popitem(last=False)
if self.callback:
self.callback(evicted_key, evicted_value)
self.cache[key] = value

View File

@@ -0,0 +1,8 @@
def int_to_bytes(int: int) -> bytes:
"""Convert int to a 24 byte big endian byte string"""
return int.to_bytes(24, "big")
def bytes_to_int(bytes: bytes) -> int:
"""Convert a 24 byte big endian byte string to an int"""
return int.from_bytes(bytes, "big")

View File

@@ -0,0 +1,74 @@
import threading
from types import TracebackType
from typing import Optional, Type
class ReadWriteLock:
"""A lock object that allows many simultaneous "read locks", but
only one "write lock." """
def __init__(self) -> None:
self._read_ready = threading.Condition(threading.RLock())
self._readers = 0
def acquire_read(self) -> None:
"""Acquire a read lock. Blocks only if a thread has
acquired the write lock."""
self._read_ready.acquire()
try:
self._readers += 1
finally:
self._read_ready.release()
def release_read(self) -> None:
"""Release a read lock."""
self._read_ready.acquire()
try:
self._readers -= 1
if not self._readers:
self._read_ready.notify_all()
finally:
self._read_ready.release()
def acquire_write(self) -> None:
"""Acquire a write lock. Blocks until there are no
acquired read or write locks."""
self._read_ready.acquire()
while self._readers > 0:
self._read_ready.wait()
def release_write(self) -> None:
"""Release a write lock."""
self._read_ready.release()
class ReadRWLock:
def __init__(self, rwLock: ReadWriteLock):
self.rwLock = rwLock
def __enter__(self) -> None:
self.rwLock.acquire_read()
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.rwLock.release_read()
class WriteRWLock:
def __init__(self, rwLock: ReadWriteLock):
self.rwLock = rwLock
def __enter__(self) -> None:
self.rwLock.acquire_write()
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.rwLock.release_write()

View File

@@ -0,0 +1,68 @@
# An implementation of https://en.wikipedia.org/wiki/Rendezvous_hashing
from typing import Callable, List, Tuple
import mmh3
import heapq
Hasher = Callable[[str, str], int]
Member = str
Members = List[str]
Key = str
def assign(
key: Key, members: Members, hasher: Hasher, replication: int
) -> List[Member]:
"""Assigns a key to a member using the rendezvous hashing algorithm
Args:
key: The key to assign
members: The list of members to assign the key to
hasher: The hashing function to use
replication: The number of members to assign the key to
Returns:
A list of members that the key has been assigned to
"""
if replication > len(members):
raise ValueError(
"Replication factor cannot be greater than the number of members"
)
if len(members) == 0:
raise ValueError("Cannot assign key to empty memberlist")
if len(members) == 1:
# Don't copy the input list for some safety
return [members[0]]
if key == "":
raise ValueError("Cannot assign empty key")
member_score_heap: List[Tuple[int, Member]] = []
for member in members:
score = -hasher(member, key)
# Invert the score since heapq is a min heap
heapq.heappush(member_score_heap, (score, member))
output_members: List[Member] = []
for _ in range(replication):
member_and_score = heapq.heappop(member_score_heap)
output_members.append(member_and_score[1])
return output_members
def merge_hashes(x: int, y: int) -> int:
"""murmurhash3 mix 64-bit"""
acc = x ^ y
acc ^= acc >> 33
acc = (
acc * 0xFF51AFD7ED558CCD
) % 2**64 # We need to mod here to prevent python from using arbitrary size int
acc ^= acc >> 33
acc = (acc * 0xC4CEB9FE1A85EC53) % 2**64
acc ^= acc >> 33
return acc
def murmur3hasher(member: Member, key: Key) -> int:
"""Hashes the key and member using the murmur3 hashing algorithm"""
member_hash = mmh3.hash64(member, signed=False)[0]
key_hash = mmh3.hash64(key, signed=False)[0]
return merge_hashes(member_hash, key_hash)

View File

@@ -0,0 +1,125 @@
from typing import List, Dict, Any, Optional, Union
import numpy as np
import pandas as pd
from chromadb.api.types import QueryResult, GetResult
def _transform_embeddings(
embeddings: Optional[List[np.ndarray]], # type: ignore
) -> Optional[Union[List[List[float]], List[np.ndarray]]]: # type: ignore
"""
Transform embeddings from numpy arrays to lists of floats.
This is a shared helper function to avoid duplicating the transformation logic.
"""
if embeddings is None:
return None
return (
[emb.tolist() for emb in embeddings]
if isinstance(embeddings[0], np.ndarray)
else embeddings
)
def _add_query_fields(
data_dict: Dict[str, Any],
query_result: QueryResult,
query_idx: int,
) -> None:
"""
Helper function to add fields from a query result to a dictionary.
Handles the nested array structure specific to query results.
Args:
data_dict: Dictionary to add the fields to
query_result: QueryResult containing the data
query_idx: Index of the current query being processed
"""
for field in query_result["included"]:
value = query_result.get(field)
if value is not None:
key = field.rstrip("s") # DF naming convention is not plural
if field == "embeddings":
value = _transform_embeddings(value) # type: ignore
if isinstance(value, list) and len(value) > 0:
value = value[query_idx] # type: ignore
data_dict[key] = value
def _add_get_fields(
data_dict: Dict[str, Any],
get_result: GetResult,
) -> None:
"""
Helper function to add fields from a get result to a dictionary.
Handles the flat array structure specific to get results.
Args:
data_dict: Dictionary to add the fields to
get_result: GetResult containing the data
"""
for field in get_result["included"]:
value = get_result.get(field)
if value is not None:
key = field.rstrip("s") # DF naming convention is not plural
if field == "embeddings":
value = _transform_embeddings(value) # type: ignore
data_dict[key] = value
def query_result_to_dfs(query_result: QueryResult) -> List["pd.DataFrame"]:
"""
Function to convert QueryResult to list of DataFrames.
Handles the nested array structure specific to query results.
Column order is defined by the order of the fields in the QueryResult.
Args:
query_result: QueryResult to convert to DataFrames.
Returns:
List of DataFrames.
"""
try:
import pandas as pd
except ImportError:
raise ImportError("pandas is required to convert query results to DataFrames.")
dfs = []
num_queries = len(query_result["ids"])
for i in range(num_queries):
data_for_df: Dict[str, Any] = {}
data_for_df["id"] = query_result["ids"][i]
_add_query_fields(data_for_df, query_result, i)
df = pd.DataFrame(data_for_df)
df.set_index("id", inplace=True)
dfs.append(df)
return dfs
def get_result_to_df(get_result: GetResult) -> "pd.DataFrame":
"""
Function to convert GetResult to a DataFrame.
Handles the flat array structure specific to get results.
Column order is defined by the order of the fields in the GetResult.
Args:
get_result: GetResult to convert to a DataFrame.
Returns:
DataFrame.
"""
try:
import pandas as pd
except ImportError:
raise ImportError("pandas is required to convert get results to a DataFrame.")
data_for_df: Dict[str, Any] = {}
data_for_df["id"] = get_result["ids"]
_add_get_fields(data_for_df, get_result)
df = pd.DataFrame(data_for_df)
df.set_index("id", inplace=True)
return df

View File

@@ -0,0 +1,36 @@
from typing import List, Tuple
from chromadb.base_types import SparseVector
def normalize_sparse_vector(indices: List[int], values: List[float]) -> SparseVector:
"""Normalize and create a SparseVector by sorting indices and values together.
This function takes raw indices and values (which may be unsorted or have duplicates)
and returns a properly constructed SparseVector with sorted indices.
Args:
indices: List of dimension indices (may be unsorted)
values: List of values corresponding to each index
Returns:
SparseVector with indices sorted in ascending order
Raises:
ValueError: If indices and values have different lengths
ValueError: If there are duplicate indices (after sorting)
ValueError: If indices are negative
ValueError: If values are not numeric
"""
if not indices:
return SparseVector(indices=[], values=[])
# Sort indices and values together by index
sorted_pairs = sorted(zip(indices, values), key=lambda x: x[0])
sorted_indices, sorted_values = zip(*sorted_pairs)
# Create SparseVector which will validate:
# - indices are sorted
# - no duplicate indices
# - indices are non-negative
# - values are numeric
return SparseVector(indices=list(sorted_indices), values=list(sorted_values))