增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -11,6 +11,7 @@ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings
from chromadb.api import AdminAPI, AsyncClientAPI, ClientAPI
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Cmek,
CollectionMetadata,
UpdateMetadata,
Documents,
@@ -58,6 +59,7 @@ import os
# Re-export types from chromadb.types
__all__ = [
"Cmek",
"Collection",
"Metadata",
"Metadatas",
@@ -105,7 +107,7 @@ logger = logging.getLogger(__name__)
__settings = Settings()
__version__ = "1.3.4"
__version__ = "1.4.0"
# Workaround to deal with Colab's old sqlite3 version

View File

@@ -33,10 +33,14 @@ from chromadb.execution.expression import ( # noqa: F401, F403
Sub,
Sum,
Val,
Aggregate,
MinK,
MaxK,
GroupBy,
)
from abc import ABC, abstractmethod
from typing import Sequence, Optional, List, Dict, Any
from typing import Sequence, Optional, List, Dict, Any, Tuple
from uuid import UUID
from overrides import override
@@ -824,7 +828,7 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
) -> Tuple["AttachedFunction", bool]:
"""Attach a function to a collection.
Args:
@@ -837,14 +841,40 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
database: The database name
Returns:
AttachedFunction: Object representing the attached function
Tuple of (AttachedFunction, created) where created is True if newly created,
False if already existed (idempotent request)
"""
pass
@abstractmethod
def get_attached_function(
self,
name: str,
input_collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Get an attached function by name for a specific collection.
Args:
name: Name of the attached function
input_collection_id: The collection ID
tenant: The tenant name
database: The database name
Returns:
AttachedFunction: The attached function object
Raises:
NotFoundError: If the attached function doesn't exist
"""
pass
@abstractmethod
def detach_function(
self,
attached_function_id: UUID,
name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
@@ -852,7 +882,8 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
"""Detach a function and prevent any further runs.
Args:
attached_function_id: ID of the attached function to remove
name: Name of the attached function to remove
input_collection_id: ID of the input collection
delete_output: Whether to also delete the output collection
tenant: The tenant name
database: The database name

View File

@@ -2,7 +2,7 @@ import asyncio
from uuid import UUID
import urllib.parse
import orjson
from typing import Any, Optional, cast, Tuple, Sequence, Dict, List
from typing import Any, Mapping, Optional, cast, Tuple, Sequence, Dict, List
import logging
import httpx
from overrides import override
@@ -130,16 +130,23 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
+ " (https://github.com/chroma-core/chroma)"
)
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
self._clients[loop_hash] = httpx.AsyncClient(
timeout=None,
headers=headers,
verify=self._settings.chroma_server_ssl_verify or False,
limits=limits,
limits=self.http_limits,
)
return self._clients[loop_hash]
@override
def get_request_headers(self) -> Mapping[str, str]:
return dict(self._get_client().headers)
@override
def get_api_url(self) -> str:
return self._api_url
async def _make_request(
self, method: str, path: str, **kwargs: Dict[str, Any]
) -> Any:
@@ -527,7 +534,7 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
return GetResult(
ids=resp_json["ids"],
embeddings=resp_json.get("embeddings", None),
metadatas=metadatas, # type: ignore
metadatas=metadatas,
documents=resp_json.get("documents", None),
data=None,
uris=resp_json.get("uris", None),
@@ -723,7 +730,7 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
ids=resp_json["ids"],
distances=resp_json.get("distances", None),
embeddings=resp_json.get("embeddings", None),
metadatas=metadata_batches, # type: ignore
metadatas=metadata_batches,
documents=resp_json.get("documents", None),
uris=resp_json.get("uris", None),
data=None,

View File

@@ -1,19 +1,51 @@
from typing import Any, Dict, Optional, TypeVar
from typing import Any, Dict, Mapping, Optional, TypeVar
from urllib.parse import quote, urlparse, urlunparse
import logging
import orjson as json
import httpx
import chromadb.errors as errors
from chromadb.config import Settings
from chromadb.config import Component, Settings, System
logger = logging.getLogger(__name__)
class BaseHTTPClient:
# inherits from Component so that it can create an init function to use system
# this way it can build limits from the settings in System
class BaseHTTPClient(Component):
_settings: Settings
pre_flight_checks: Any = None
keepalive_secs: int = 40
DEFAULT_KEEPALIVE_SECS: float = 40.0
def __init__(self, system: System):
super().__init__(system)
self._settings = system.settings
keepalive_setting = self._settings.chroma_http_keepalive_secs
self.keepalive_secs: Optional[float] = (
keepalive_setting
if keepalive_setting is not None
else BaseHTTPClient.DEFAULT_KEEPALIVE_SECS
)
self._http_limits = self._build_limits()
def _build_limits(self) -> httpx.Limits:
limit_kwargs: Dict[str, Any] = {}
if self.keepalive_secs is not None:
limit_kwargs["keepalive_expiry"] = self.keepalive_secs
max_connections = self._settings.chroma_http_max_connections
if max_connections is not None:
limit_kwargs["max_connections"] = max_connections
max_keepalive_connections = self._settings.chroma_http_max_keepalive_connections
if max_keepalive_connections is not None:
limit_kwargs["max_keepalive_connections"] = max_keepalive_connections
return httpx.Limits(**limit_kwargs)
@property
def http_limits(self) -> httpx.Limits:
return self._http_limits
@staticmethod
def _validate_host(host: str) -> None:
@@ -103,3 +135,11 @@ class BaseHTTPClient:
if trace_id:
raise Exception(f"{resp.text} (trace ID: {trace_id})")
raise (Exception(resp.text))
def get_request_headers(self) -> Mapping[str, str]:
"""Return headers used for HTTP requests."""
return {}
def get_api_url(self) -> str:
"""Return the API URL for this client."""
return ""

View File

@@ -1,6 +1,6 @@
import orjson
import logging
from typing import Any, Dict, Optional, cast, Tuple, List
from typing import Any, Dict, Mapping, Optional, cast, Tuple, List
from typing import Sequence
from uuid import UUID
import httpx
@@ -79,8 +79,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
default_api_path=system.settings.chroma_server_api_default_path,
)
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
self._session = httpx.Client(timeout=None, limits=limits)
if self._settings.chroma_server_ssl_verify is not None:
self._session = httpx.Client(
timeout=None,
limits=self.http_limits,
verify=self._settings.chroma_server_ssl_verify,
)
else:
self._session = httpx.Client(timeout=None, limits=self.http_limits)
self._header = system.settings.chroma_server_headers or {}
self._header["Content-Type"] = "application/json"
@@ -90,8 +96,6 @@ class FastAPI(BaseHTTPClient, ServerAPI):
+ " (https://github.com/chroma-core/chroma)"
)
if self._settings.chroma_server_ssl_verify is not None:
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
if self._header is not None:
self._session.headers.update(self._header)
@@ -101,6 +105,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
for header, value in _headers.items():
self._session.headers[header] = value.get_secret_value()
@override
def get_request_headers(self) -> Mapping[str, str]:
return dict(self._session.headers)
@override
def get_api_url(self) -> str:
return self._api_url
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
# If the request has json in kwargs, use orjson to serialize it,
# remove it from kwargs, and add it to the content parameter
@@ -492,7 +504,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
return GetResult(
ids=resp_json["ids"],
embeddings=resp_json.get("embeddings", None),
metadatas=metadatas, # type: ignore
metadatas=metadatas,
documents=resp_json.get("documents", None),
data=None,
uris=resp_json.get("uris", None),
@@ -700,7 +712,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
ids=resp_json["ids"],
distances=resp_json.get("distances", None),
embeddings=resp_json.get("embeddings", None),
metadatas=metadata_batches, # type: ignore
metadatas=metadata_batches,
documents=resp_json.get("documents", None),
uris=resp_json.get("uris", None),
data=None,
@@ -761,7 +773,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
) -> Tuple["AttachedFunction", bool]:
"""Attach a function to a collection."""
resp_json = self._make_request(
"post",
@@ -774,23 +786,56 @@ class FastAPI(BaseHTTPClient, ServerAPI):
},
)
return AttachedFunction(
attached_function = AttachedFunction(
client=self,
id=UUID(resp_json["attached_function"]["id"]),
name=resp_json["attached_function"]["name"],
function_id=resp_json["attached_function"]["function_id"],
function_name=resp_json["attached_function"]["function_name"],
input_collection_id=input_collection_id,
output_collection=output_collection,
params=params,
tenant=tenant,
database=database,
)
created = resp_json.get(
"created", True
) # Default to True for backwards compatibility
return (attached_function, created)
@trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.ALL)
@override
def get_attached_function(
self,
name: str,
input_collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Get an attached function by name for a specific collection."""
resp_json = self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/{name}",
)
af = resp_json["attached_function"]
return AttachedFunction(
client=self,
id=UUID(af["id"]),
name=af["name"],
function_name=af["function_name"],
input_collection_id=input_collection_id,
output_collection=af["output_collection"],
params=af.get("params"),
tenant=tenant,
database=database,
)
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
@override
def detach_function(
self,
attached_function_id: UUID,
name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
@@ -798,7 +843,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
"""Detach a function and prevent any further runs."""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/attached_functions/{name}/detach",
json={
"delete_output": delete_output,
},

View File

@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Optional, Dict, Any
from uuid import UUID
import json
if TYPE_CHECKING:
from chromadb.api import ServerAPI # noqa: F401
@@ -13,7 +14,7 @@ class AttachedFunction:
client: "ServerAPI",
id: UUID,
name: str,
function_id: str,
function_name: str,
input_collection_id: UUID,
output_collection: str,
params: Optional[Dict[str, Any]],
@@ -26,7 +27,7 @@ class AttachedFunction:
client: The API client
id: Unique identifier for this attached function
name: Name of this attached function instance
function_id: The function identifier (e.g., "record_counter")
function_name: The function name (e.g., "record_counter", "statistics")
input_collection_id: ID of the input collection
output_collection: Name of the output collection
params: Function-specific parameters
@@ -36,7 +37,7 @@ class AttachedFunction:
self._client = client
self._id = id
self._name = name
self._function_id = function_id
self._function_name = function_name
self._input_collection_id = input_collection_id
self._output_collection = output_collection
self._params = params
@@ -54,9 +55,9 @@ class AttachedFunction:
return self._name
@property
def function_id(self) -> str:
"""The function identifier."""
return self._function_id
def function_name(self) -> str:
"""The function name."""
return self._function_name
@property
def input_collection_id(self) -> UUID:
@@ -73,29 +74,69 @@ class AttachedFunction:
"""The function parameters."""
return self._params
def detach(self, delete_output_collection: bool = False) -> bool:
"""Detach this function and prevent any further runs.
@staticmethod
def _normalize_params(params: Optional[Any]) -> Dict[str, Any]:
"""Normalize params to a consistent dict format.
Args:
delete_output_collection: Whether to also delete the output collection. Defaults to False.
Returns:
bool: True if successful
Example:
>>> success = attached_fn.detach(delete_output_collection=True)
Handles None, empty strings, JSON strings, and dicts.
"""
return self._client.detach_function(
attached_function_id=self._id,
delete_output=delete_output_collection,
tenant=self._tenant,
database=self._database,
)
if params is None:
return {}
if isinstance(params, str):
try:
result = json.loads(params) if params else {}
return result if isinstance(result, dict) else {}
except json.JSONDecodeError:
return {}
if isinstance(params, dict):
return params
return {}
def __repr__(self) -> str:
return (
f"AttachedFunction(id={self._id}, name='{self._name}', "
f"function_id='{self._function_id}', "
f"function_name='{self._function_name}', "
f"input_collection_id={self._input_collection_id}, "
f"output_collection='{self._output_collection}')"
)
def __eq__(self, other: object) -> bool:
"""Compare two AttachedFunction objects for equality."""
if not isinstance(other, AttachedFunction):
return False
# Normalize params: handle None, {}, and JSON strings
self_params = self._normalize_params(self._params)
other_params = self._normalize_params(other._params)
return (
self._id == other._id
and self._name == other._name
and self._function_name == other._function_name
and self._input_collection_id == other._input_collection_id
and self._output_collection == other._output_collection
and self_params == other_params
and self._tenant == other._tenant
and self._database == other._database
)
def __hash__(self) -> int:
"""Return hash of the AttachedFunction."""
# Normalize params using the same logic as __eq__
normalized_params = self._normalize_params(self._params)
params_tuple = (
tuple(sorted(normalized_params.items())) if normalized_params else ()
)
return hash(
(
self._id,
self._name,
self._function_name,
self._input_collection_id,
self._output_collection,
params_tuple,
self._tenant,
self._database,
)
)

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any, Tuple
from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.api.types import (
@@ -25,6 +25,8 @@ from chromadb.execution.expression.plan import Search
import logging
from chromadb.api.functions import Function
if TYPE_CHECKING:
from chromadb.api.models.AttachedFunction import AttachedFunction
@@ -500,30 +502,36 @@ class Collection(CollectionCommon["ServerAPI"]):
def attach_function(
self,
function_id: str,
function: Function,
name: str,
output_collection: str,
params: Optional[Dict[str, Any]] = None,
) -> "AttachedFunction":
) -> Tuple["AttachedFunction", bool]:
"""Attach a function to this collection.
Args:
function_id: Built-in function identifier (e.g., "record_counter")
function: A Function enum value (e.g., STATISTICS_FUNCTION, RECORD_COUNTER_FUNCTION)
name: Unique name for this attached function
output_collection: Name of the collection where function output will be stored
params: Optional dictionary with function-specific parameters
Returns:
AttachedFunction: Object representing the attached function
Tuple of (AttachedFunction, created) where created is True if newly created,
False if already existed (idempotent request)
Example:
>>> from chromadb.api.functions import STATISTICS_FUNCTION
>>> attached_fn = collection.attach_function(
... function_id="record_counter",
... function=STATISTICS_FUNCTION,
... name="mycoll_stats_fn",
... output_collection="mycoll_stats",
... params={"threshold": 100}
... )
>>> if created:
... print("New function attached")
... else:
... print("Function already existed")
"""
function_id = function.value if isinstance(function, Function) else function
return self._client.attach_function(
function_id=function_id,
name=name,
@@ -533,3 +541,47 @@ class Collection(CollectionCommon["ServerAPI"]):
tenant=self.tenant,
database=self.database,
)
def get_attached_function(self, name: str) -> "AttachedFunction":
"""Get an attached function by name for this collection.
Args:
name: Name of the attached function
Returns:
AttachedFunction: The attached function object
Raises:
NotFoundError: If the attached function doesn't exist
"""
return self._client.get_attached_function(
name=name,
input_collection_id=self.id,
tenant=self.tenant,
database=self.database,
)
def detach_function(
self,
name: str,
delete_output_collection: bool = False,
) -> bool:
"""Detach a function from this collection.
Args:
name: The name of the attached function
delete_output_collection: Whether to also delete the output collection. Defaults to False.
Returns:
bool: True if successful
Example:
>>> success = collection.detach_function("my_function", delete_output_collection=True)
"""
return self._client.detach_function(
name=name,
input_collection_id=self.id,
delete_output=delete_output_collection,
tenant=self.tenant,
database=self.database,
)

View File

@@ -1021,6 +1021,7 @@ class CollectionCommon(Generic[ClientT]):
return Search(
where=search._where,
rank=embedded_rank,
group_by=search._group_by,
limit=search._limit,
select=search._select,
)

View File

@@ -49,7 +49,7 @@ from chromadb.execution.expression.plan import Search
import chromadb_rust_bindings
from typing import Optional, Sequence, List, Dict, Any
from typing import Optional, Sequence, List, Dict, Any, Tuple
from overrides import override
from uuid import UUID
import json
@@ -613,6 +613,20 @@ class RustBindingsAPI(ServerAPI):
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Tuple["AttachedFunction", bool]:
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
raise NotImplementedError(
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
"The Rust bindings (embedded mode) do not support attached function operations."
)
@override
def get_attached_function(
self,
name: str,
input_collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
raise NotImplementedError(
@@ -623,7 +637,8 @@ class RustBindingsAPI(ServerAPI):
@override
def detach_function(
self,
attached_function_id: UUID,
name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,

View File

@@ -75,6 +75,7 @@ from typing import (
Dict,
Callable,
TypeVar,
Tuple,
)
from overrides import override
from uuid import UUID, uuid4
@@ -922,6 +923,20 @@ class SegmentAPI(ServerAPI):
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Tuple["AttachedFunction", bool]:
"""Attached functions are not supported in the Segment API (local embedded mode)."""
raise NotImplementedError(
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
"The Segment API (embedded mode) does not support attached function operations."
)
@override
def get_attached_function(
self,
name: str,
input_collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Attached functions are not supported in the Segment API (local embedded mode)."""
raise NotImplementedError(
@@ -932,7 +947,8 @@ class SegmentAPI(ServerAPI):
@override
def detach_function(
self,
attached_function_id: UUID,
name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,

View File

@@ -1,11 +1,14 @@
from typing import ClassVar, Dict
from typing import ClassVar, Dict, Optional
import logging
import uuid
from chromadb.api import ServerAPI
from chromadb.api.base_http_client import BaseHTTPClient
from chromadb.config import Settings, System
from chromadb.telemetry.product import ProductTelemetryClient
from chromadb.telemetry.product.events import ClientStartEvent
logger = logging.getLogger(__name__)
class SharedSystemClient:
_identifier_to_system: ClassVar[Dict[str, System]] = {}
@@ -94,3 +97,59 @@ class SharedSystemClient:
def _submit_client_start_event(self) -> None:
telemetry_client = self._system.instance(ProductTelemetryClient)
telemetry_client.capture(ClientStartEvent())
@staticmethod
def get_chroma_cloud_api_key_from_clients() -> Optional[str]:
"""
Try to extract api key from existing client instances by checking httpx session headers.
Requirements to pull api key:
- must be a BaseHTTPClient instance (ignore RustBindingsAPI and SegmentAPI)
- must have "api.trychroma.com" or "gcp.trychroma.com" in the _api_url (ignore local/self-hosted instances)
- must have "x-chroma-token" or "X-Chroma-Token" in the headers
Returns:
The first api key found, or None if no client instances have api keys set.
"""
api_keys: list[str] = []
systems_snapshot = list(SharedSystemClient._identifier_to_system.values())
for system in systems_snapshot:
try:
server_api = system.instance(ServerAPI)
if not isinstance(server_api, BaseHTTPClient):
# RustBindingsAPI and SegmentAPI don't have HTTP headers
continue
# Only pull api key if the url contains the chroma cloud url
api_url = server_api.get_api_url()
if (
"api.trychroma.com" not in api_url
and "gcp.trychroma.com" not in api_url
):
continue
headers = server_api.get_request_headers()
api_key = None
for key, value in headers.items():
if key.lower() == "x-chroma-token":
api_key = value
break
if api_key:
api_keys.append(api_key)
except Exception:
# If we can't access the ServerAPI instance, continue to the next
continue
if not api_keys:
return None
# log if multiple viable api keys found
if len(api_keys) > 1:
logger.info(
f"Multiple Chroma Cloud clients found, using API key starting with {api_keys[0][:8]}..."
)
return api_keys[0]

View File

@@ -12,6 +12,7 @@ from typing import (
get_args,
TYPE_CHECKING,
Final,
Type,
)
from copy import deepcopy
from typing_extensions import TypeAlias
@@ -20,7 +21,8 @@ from numpy.typing import NDArray
import numpy as np
import warnings
from typing_extensions import TypedDict, Protocol, runtime_checkable
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, model_validator
from pydantic_core import PydanticCustomError
import chromadb.errors as errors
from chromadb.base_types import (
@@ -52,6 +54,8 @@ import pybase64
from functools import lru_cache
import struct
import math
import re
from enum import Enum
# Re-export types from chromadb.types
__all__ = [
@@ -112,6 +116,9 @@ __all__ = [
"EMBEDDING_KEY",
"TYPE_KEY",
"SPARSE_VECTOR_TYPE_VALUE",
# CMEK Support
"Cmek",
"CmekProvider",
# Space type
"Space",
# Embedding Functions
@@ -213,7 +220,7 @@ def optional_base64_strings_to_embeddings(
def normalize_embeddings(
target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]]
target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]],
) -> Optional[Embeddings]:
if target is None:
return None
@@ -448,7 +455,9 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None:
)
zero_lengths = [
key for key, lst in record_set.items() if lst is not None and len(lst) == 0 # type: ignore[arg-type]
key
for key, lst in record_set.items()
if lst is not None and len(lst) == 0 # type: ignore[arg-type]
]
if zero_lengths:
@@ -456,7 +465,9 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None:
if len(set(lengths)) > 1:
error_str = ", ".join(
f"{key}: {len(lst)}" for key, lst in record_set.items() if lst is not None # type: ignore[arg-type]
f"{key}: {len(lst)}"
for key, lst in record_set.items()
if lst is not None # type: ignore[arg-type]
)
raise ValueError(f"Unequal lengths for fields: {error_str}")
@@ -756,8 +767,7 @@ class EmbeddingFunction(Protocol[D]):
"""
@abstractmethod
def __call__(self, input: D) -> Embeddings:
...
def __call__(self, input: D) -> Embeddings: ...
def embed_query(self, input: D) -> Embeddings:
"""
@@ -941,8 +951,7 @@ def validate_embedding_function(
class DataLoader(Protocol[L]):
def __call__(self, uris: URIs) -> L:
...
def __call__(self, uris: URIs) -> L: ...
def validate_ids(ids: IDs) -> IDs:
@@ -1063,7 +1072,7 @@ def serialize_metadata(metadata: Optional[Metadata]) -> Optional[Dict[str, Any]]
def deserialize_metadata(
metadata: Optional[Dict[str, Any]]
metadata: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
"""Deserialize metadata from transport, converting dicts with #type=sparse_vector to dataclass instances.
@@ -1399,8 +1408,7 @@ class SparseEmbeddingFunction(Protocol[D]):
"""
@abstractmethod
def __call__(self, input: D) -> SparseVectors:
...
def __call__(self, input: D) -> SparseVectors: ...
def embed_query(self, input: D) -> SparseVectors:
"""
@@ -1493,15 +1501,57 @@ def validate_sparse_embedding_function(
# Index Configuration Types for Collection Schema
def _create_extra_fields_validator(valid_fields: list[str]) -> Any:
"""Create a model validator that provides helpful error messages for invalid fields."""
@model_validator(mode="before")
def validate_extra_fields(cls: Type[BaseModel], data: Any) -> Any:
if isinstance(data, dict):
invalid_fields = [k for k in data.keys() if k not in valid_fields]
if invalid_fields:
invalid_fields_str = ", ".join(f"'{f}'" for f in invalid_fields)
class_name = cls.__name__
# Create a clear, actionable error message
if len(invalid_fields) == 1:
msg = (
f"'{invalid_fields[0]}' is not a valid field for {class_name}. "
)
else:
msg = f"Invalid fields for {class_name}: {invalid_fields_str}. "
raise PydanticCustomError(
"invalid_field",
msg,
{"invalid_fields": invalid_fields},
)
return data
return validate_extra_fields
class FtsIndexConfig(BaseModel):
"""Configuration for Full-Text Search index. No parameters required."""
model_config = {"extra": "forbid"}
pass
class HnswIndexConfig(BaseModel):
"""Configuration for HNSW vector index."""
_validate_extra_fields = _create_extra_fields_validator(
[
"ef_construction",
"max_neighbors",
"ef_search",
"num_threads",
"batch_size",
"sync_threshold",
"resize_factor",
]
)
ef_construction: Optional[int] = None
max_neighbors: Optional[int] = None
ef_search: Optional[int] = None
@@ -1514,6 +1564,27 @@ class HnswIndexConfig(BaseModel):
class SpannIndexConfig(BaseModel):
"""Configuration for SPANN vector index."""
_validate_extra_fields = _create_extra_fields_validator(
[
"search_nprobe",
"search_rng_factor",
"search_rng_epsilon",
"nreplica_count",
"write_nprobe",
"write_rng_factor",
"write_rng_epsilon",
"split_threshold",
"num_samples_kmeans",
"initial_lambda",
"reassign_neighbor_count",
"merge_threshold",
"num_centers_to_merge_to",
"ef_construction",
"ef_search",
"max_neighbors",
]
)
search_nprobe: Optional[int] = None
write_nprobe: Optional[int] = None
ef_construction: Optional[int] = None
@@ -1527,10 +1598,13 @@ class SpannIndexConfig(BaseModel):
class VectorIndexConfig(BaseModel):
"""Configuration for vector index with space, embedding function, and algorithm config."""
model_config = {"arbitrary_types_allowed": True}
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
space: Optional[Space] = None
embedding_function: Optional[Any] = DefaultEmbeddingFunction()
source_key: Optional[str] = None # key to source the vector from (accepts str or Key)
source_key: Optional[str] = (
None # key to source the vector from (accepts str or Key)
)
hnsw: Optional[HnswIndexConfig] = None
spann: Optional[SpannIndexConfig] = None
@@ -1542,6 +1616,7 @@ class VectorIndexConfig(BaseModel):
return None
# Import Key at runtime to avoid circular import
from chromadb.execution.expression.operator import Key as KeyType
if isinstance(v, KeyType):
v = v.name # Extract string from Key
elif isinstance(v, str):
@@ -1574,10 +1649,13 @@ class VectorIndexConfig(BaseModel):
class SparseVectorIndexConfig(BaseModel):
"""Configuration for sparse vector index."""
model_config = {"arbitrary_types_allowed": True}
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
# TODO(Sanket): Change this to the appropriate sparse ef and use a default here.
embedding_function: Optional[Any] = None
source_key: Optional[str] = None # key to source the sparse vector from (accepts str or Key)
source_key: Optional[str] = (
None # key to source the sparse vector from (accepts str or Key)
)
bm25: Optional[bool] = None
@field_validator("source_key", mode="before")
@@ -1588,6 +1666,7 @@ class SparseVectorIndexConfig(BaseModel):
return None
# Import Key at runtime to avoid circular import
from chromadb.execution.expression.operator import Key as KeyType
if isinstance(v, KeyType):
v = v.name # Extract string from Key
elif isinstance(v, str):
@@ -1622,24 +1701,32 @@ class SparseVectorIndexConfig(BaseModel):
class StringInvertedIndexConfig(BaseModel):
"""Configuration for string inverted index."""
model_config = {"extra": "forbid"}
pass
class IntInvertedIndexConfig(BaseModel):
"""Configuration for integer inverted index."""
model_config = {"extra": "forbid"}
pass
class FloatInvertedIndexConfig(BaseModel):
"""Configuration for float inverted index."""
model_config = {"extra": "forbid"}
pass
class BoolInvertedIndexConfig(BaseModel):
"""Configuration for boolean inverted index."""
model_config = {"extra": "forbid"}
pass
@@ -1683,6 +1770,148 @@ TYPE_KEY: Final[str] = "#type"
SPARSE_VECTOR_TYPE_VALUE: Final[str] = "sparse_vector"
# ============================================================================
# CMEK (Customer-Managed Encryption Key) Support
# ============================================================================
class CmekProvider(str, Enum):
"""Supported cloud providers for customer-managed encryption keys.
Currently only Google Cloud Platform (GCP) is supported.
"""
GCP = "gcp"
# Regex pattern for validating GCP KMS resource names
_CMEK_GCP_PATTERN = re.compile(r"^projects/.+/locations/.+/keyRings/.+/cryptoKeys/.+$")
@dataclass
class Cmek:
"""Customer-managed encryption key (CMEK) for collection data encryption.
CMEK allows you to use your own encryption keys managed by cloud providers'
key management services (KMS) instead of default provider-managed keys. This
gives you greater control over key lifecycle, access policies, and audit logging.
Attributes:
provider: The cloud provider (currently only 'gcp' supported)
resource: The provider-specific resource identifier for the encryption key
Example:
>>> # Create a CMEK for GCP
>>> cmek = Cmek.gcp(
... "projects/my-project/locations/us-central1/"
... "keyRings/my-ring/cryptoKeys/my-key"
... )
>>>
>>> # Validate the resource name format
>>> if cmek.validate_pattern():
... print("Valid CMEK format")
Note:
Pattern validation only checks format correctness. It does not verify
that the key exists or is accessible. Key permissions and access control
must be configured separately in your cloud provider's console.
"""
provider: CmekProvider
resource: str
@classmethod
def gcp(cls, resource: str) -> "Cmek":
"""Create a CMEK instance for Google Cloud Platform.
Args:
resource: GCP Cloud KMS resource name in the format:
projects/{project-id}/locations/{location}/keyRings/{key-ring}/cryptoKeys/{key-name}
Example: "projects/my-project/locations/us-central1/keyRings/my-ring/cryptoKeys/my-key"
Returns:
Cmek: A new CMEK instance configured for GCP
Raises:
ValueError: If the resource format is invalid (when validate_pattern() is called)
Example:
>>> cmek = Cmek.gcp(
... "projects/my-project/locations/us-central1/"
... "keyRings/my-ring/cryptoKeys/my-key"
... )
"""
return cls(provider=CmekProvider.GCP, resource=resource)
def validate_pattern(self) -> bool:
"""Validate the CMEK resource name format.
Validates that the resource name matches the expected pattern for the
provider. This is a format check only and does not verify that the key
exists or that you have access to it.
For GCP, the expected format is:
projects/{project}/locations/{location}/keyRings/{keyRing}/cryptoKeys/{cryptoKey}
Returns:
bool: True if the resource name format is valid, False otherwise
Example:
>>> cmek = Cmek.gcp("projects/p/locations/l/keyRings/r/cryptoKeys/k")
>>> cmek.validate_pattern() # Returns True
>>>
>>> bad_cmek = Cmek.gcp("invalid-format")
>>> bad_cmek.validate_pattern() # Returns False
"""
if self.provider == CmekProvider.GCP:
return _CMEK_GCP_PATTERN.match(self.resource) is not None
return False
def to_dict(self) -> Dict[str, Any]:
"""Serialize CMEK to dictionary format for API transport.
Returns:
Dict containing the provider variant and resource identifier.
Example:
>>> cmek = Cmek.gcp("projects/p/locations/l/keyRings/r/cryptoKeys/k")
>>> cmek.to_dict()
{'gcp': 'projects/p/locations/l/keyRings/r/cryptoKeys/k'}
"""
if self.provider == CmekProvider.GCP:
return {"gcp": self.resource}
# Unreachable with current providers, but future-proof
raise ValueError(f"Unknown CMEK provider: {self.provider}")
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Cmek":
"""Deserialize CMEK from dictionary format.
Args:
data: Dictionary containing provider variant and resource.
Format matches Rust serde enum serialization with snake_case.
Returns:
Cmek: Deserialized CMEK instance
Raises:
ValueError: If the provider is unsupported or data is malformed
Example:
>>> data = {'gcp': 'projects/p/locations/l/keyRings/r/cryptoKeys/k'}
>>> cmek = Cmek.from_dict(data)
"""
if "gcp" in data:
resource = data["gcp"]
if not isinstance(resource, str):
raise ValueError(
f"CMEK gcp resource must be a string, got: {type(resource)}"
)
return cls.gcp(resource)
raise ValueError(f"Unsupported or missing CMEK provider in data: {data}")
# Index Type Classes
@@ -1774,24 +2003,40 @@ class ValueTypes:
@dataclass
class Schema:
"""Collection schema for configuring indexes and encryption.
The schema controls how data is indexed and can optionally specify
customer-managed encryption keys (CMEK) for data at rest.
Attributes:
defaults: Default index configurations for each value type
keys: Key-specific index overrides
cmek: Optional customer-managed encryption key for collection data
"""
defaults: ValueTypes
keys: Dict[str, ValueTypes]
cmek: Optional[Cmek] = None
def __init__(self) -> None:
# Initialize the dataclass fields first
self.defaults = ValueTypes()
self.keys: Dict[str, ValueTypes] = {}
self.cmek: Optional[Cmek] = None
# Populate with sensible defaults automatically
self._initialize_defaults()
self._initialize_keys()
def create_index(
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
self,
config: Optional[IndexConfig] = None,
key: Optional[Union[str, "Key"]] = None,
) -> "Schema":
"""Create an index configuration."""
# Convert Key to string if provided
from chromadb.execution.expression.operator import Key as KeyType
if key is not None and isinstance(key, KeyType):
key = key.name
@@ -1869,11 +2114,14 @@ class Schema:
return self
def delete_index(
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
self,
config: Optional[IndexConfig] = None,
key: Optional[Union[str, "Key"]] = None,
) -> "Schema":
"""Disable an index configuration (set enabled=False)."""
# Convert Key to string if provided
from chromadb.execution.expression.operator import Key as KeyType
if key is not None and isinstance(key, KeyType):
key = key.name
@@ -1930,6 +2178,23 @@ class Schema:
return self
def set_cmek(self, cmek: Optional[Cmek]) -> "Schema":
"""Set customer-managed encryption key for the collection (fluent interface).
Args:
cmek: Customer-managed encryption key configuration, or None to remove CMEK
Returns:
Self for method chaining
Example:
>>> schema = Schema().set_cmek(
... Cmek.gcp("projects/my-project/locations/us/keyRings/my-ring/cryptoKeys/my-key")
... )
"""
self.cmek = cmek
return self
def _get_config_class_name(self, config: IndexConfig) -> str:
"""Get the class name for a config."""
return config.__class__.__name__
@@ -1943,11 +2208,15 @@ class Schema:
"""
# Update the config in defaults (preserve enabled=False)
current_enabled = self.defaults.float_list.vector_index.enabled # type: ignore[union-attr]
self.defaults.float_list.vector_index = VectorIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
self.defaults.float_list.vector_index = VectorIndexType(
enabled=current_enabled, config=config
) # type: ignore[union-attr]
# Update the config on #embedding key (preserve enabled=True and source_key="#document")
current_enabled = self.keys[EMBEDDING_KEY].float_list.vector_index.enabled # type: ignore[union-attr]
current_source_key = self.keys[EMBEDDING_KEY].float_list.vector_index.config.source_key # type: ignore[union-attr]
current_source_key = self.keys[
EMBEDDING_KEY
].float_list.vector_index.config.source_key # type: ignore[union-attr]
# Create a new config with user settings but preserve the original source_key
embedding_config = VectorIndexConfig(
@@ -1957,7 +2226,9 @@ class Schema:
spann=config.spann,
source_key=current_source_key, # Preserve original source_key (should be "#document")
)
self.keys[EMBEDDING_KEY].float_list.vector_index = VectorIndexType(enabled=current_enabled, config=embedding_config) # type: ignore[union-attr]
self.keys[EMBEDDING_KEY].float_list.vector_index = VectorIndexType(
enabled=current_enabled, config=embedding_config
) # type: ignore[union-attr]
def _set_fts_index_config(self, config: FtsIndexConfig) -> None:
"""
@@ -1967,11 +2238,15 @@ class Schema:
"""
# Update the config in defaults (preserve enabled=False)
current_enabled = self.defaults.string.fts_index.enabled # type: ignore[union-attr]
self.defaults.string.fts_index = FtsIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
self.defaults.string.fts_index = FtsIndexType(
enabled=current_enabled, config=config
) # type: ignore[union-attr]
# Update the config on #document key (preserve enabled=True)
current_enabled = self.keys[DOCUMENT_KEY].string.fts_index.enabled # type: ignore[union-attr]
self.keys[DOCUMENT_KEY].string.fts_index = FtsIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
self.keys[DOCUMENT_KEY].string.fts_index = FtsIndexType(
enabled=current_enabled, config=config
) # type: ignore[union-attr]
def _set_index_in_defaults(self, config: IndexConfig, enabled: bool) -> None:
"""Set an index configuration in the defaults."""
@@ -2071,37 +2346,51 @@ class Schema:
if config_name == "FtsIndexConfig":
if self.keys[key].string is None:
self.keys[key].string = StringValueType()
self.keys[key].string.fts_index = FtsIndexType(enabled=enabled, config=cast(FtsIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].string.fts_index = FtsIndexType(
enabled=enabled, config=cast(FtsIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "StringInvertedIndexConfig":
if self.keys[key].string is None:
self.keys[key].string = StringValueType()
self.keys[key].string.string_inverted_index = StringInvertedIndexType(enabled=enabled, config=cast(StringInvertedIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].string.string_inverted_index = StringInvertedIndexType(
enabled=enabled, config=cast(StringInvertedIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "VectorIndexConfig":
if self.keys[key].float_list is None:
self.keys[key].float_list = FloatListValueType()
self.keys[key].float_list.vector_index = VectorIndexType(enabled=enabled, config=cast(VectorIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].float_list.vector_index = VectorIndexType(
enabled=enabled, config=cast(VectorIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "SparseVectorIndexConfig":
if self.keys[key].sparse_vector is None:
self.keys[key].sparse_vector = SparseVectorValueType()
self.keys[key].sparse_vector.sparse_vector_index = SparseVectorIndexType(enabled=enabled, config=cast(SparseVectorIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].sparse_vector.sparse_vector_index = SparseVectorIndexType(
enabled=enabled, config=cast(SparseVectorIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "IntInvertedIndexConfig":
if self.keys[key].int_value is None:
self.keys[key].int_value = IntValueType()
self.keys[key].int_value.int_inverted_index = IntInvertedIndexType(enabled=enabled, config=cast(IntInvertedIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].int_value.int_inverted_index = IntInvertedIndexType(
enabled=enabled, config=cast(IntInvertedIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "FloatInvertedIndexConfig":
if self.keys[key].float_value is None:
self.keys[key].float_value = FloatValueType()
self.keys[key].float_value.float_inverted_index = FloatInvertedIndexType(enabled=enabled, config=cast(FloatInvertedIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].float_value.float_inverted_index = FloatInvertedIndexType(
enabled=enabled, config=cast(FloatInvertedIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "BoolInvertedIndexConfig":
if self.keys[key].boolean is None:
self.keys[key].boolean = BoolValueType()
self.keys[key].boolean.bool_inverted_index = BoolInvertedIndexType(enabled=enabled, config=cast(BoolInvertedIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].boolean.bool_inverted_index = BoolInvertedIndexType(
enabled=enabled, config=cast(BoolInvertedIndexConfig, config)
) # type: ignore[union-attr]
def _enable_all_indexes_for_key(self, key: str) -> None:
"""Enable all possible index types for a specific key."""
@@ -2249,7 +2538,13 @@ class Schema:
for key, value_types in self.keys.items():
keys_json[key] = self._serialize_value_types(value_types)
return {"defaults": defaults_json, "keys": keys_json}
result: Dict[str, Any] = {"defaults": defaults_json, "keys": keys_json}
# Add CMEK if present
if self.cmek is not None:
result["cmek"] = self.cmek.to_dict()
return result
@classmethod
def deserialize_from_json(cls, json_data: Dict[str, Any]) -> "Schema":
@@ -2263,6 +2558,11 @@ class Schema:
for key, value_types_json in json_data.get("keys", {}).items():
instance.keys[key] = cls._deserialize_value_types(value_types_json)
# Deserialize CMEK if present
instance.cmek = None
if "cmek" in json_data and json_data["cmek"] is not None:
instance.cmek = Cmek.from_dict(json_data["cmek"])
return instance
def _serialize_value_types(self, value_types: ValueTypes) -> Dict[str, Any]:
@@ -2410,6 +2710,10 @@ class Schema:
if embedding_func.is_legacy():
config_dict["embedding_function"] = {"type": "legacy"}
else:
if hasattr(embedding_func, "validate_config"):
embedding_func.validate_config(
embedding_func.get_config()
)
config_dict["embedding_function"] = {
"name": embedding_func.name(),
"type": "known",
@@ -2439,6 +2743,8 @@ class Schema:
config_dict["embedding_function"] = {"type": "unknown"}
else:
embedding_func = cast(SparseEmbeddingFunction, embedding_func) # type: ignore
if hasattr(embedding_func, "validate_config"):
embedding_func.validate_config(embedding_func.get_config())
config_dict["embedding_function"] = {
"name": embedding_func.name(),
"type": "known",

View File

@@ -16,15 +16,18 @@ class SparseVector:
Attributes:
indices: List of dimension indices (must be non-negative integers, sorted in strictly ascending order)
values: List of values corresponding to each index (floats)
labels: Optional list of string labels corresponding to each index
Note:
- Indices must be sorted in strictly ascending order (no duplicates)
- Indices and values must have the same length
- If labels is provided, it must have the same length as indices and values
- All validations are performed in __post_init__
"""
indices: List[int]
values: List[float]
labels: Optional[List[str]] = None
def __post_init__(self) -> None:
"""Validate the sparse vector structure."""
@@ -44,6 +47,17 @@ class SparseVector:
f"got {len(self.indices)} indices and {len(self.values)} values"
)
if self.labels is not None:
if not isinstance(self.labels, list):
raise ValueError(
f"Expected SparseVector labels to be a list, got {type(self.labels).__name__}"
)
if len(self.labels) != len(self.indices):
raise ValueError(
f"SparseVector labels must have the same length as indices and values, "
f"got {len(self.labels)} labels, {len(self.indices)} indices"
)
for i, idx in enumerate(self.indices):
if not isinstance(idx, int):
raise ValueError(
@@ -70,21 +84,36 @@ class SparseVector:
)
def to_dict(self) -> Dict[str, Any]:
"""Serialize to transport format with type tag."""
return {
"""Serialize to transport format with type tag.
Note: Uses 'tokens' as the wire format key name for compatibility
with the protobuf schema, even though the Python attribute is 'labels'.
"""
result = {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": self.indices,
"values": self.values,
}
if self.labels is not None:
result["tokens"] = self.labels # Wire format uses 'tokens'
return result
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "SparseVector":
"""Deserialize from transport format (strict - requires #type field)."""
"""Deserialize from transport format (strict - requires #type field).
Note: Reads from 'tokens' key in the wire format for compatibility
with the protobuf schema, mapping it to the 'labels' attribute.
"""
if d.get(TYPE_KEY) != SPARSE_VECTOR_TYPE_VALUE:
raise ValueError(
f"Expected {TYPE_KEY}='{SPARSE_VECTOR_TYPE_VALUE}', got {d.get(TYPE_KEY)}"
)
return cls(indices=d["indices"], values=d["values"])
return cls(
indices=d["indices"],
values=d["values"],
labels=d.get("tokens") # Wire format uses 'tokens'
)
Metadata = Mapping[str, Optional[Union[str, int, float, bool, SparseVector]]]

View File

@@ -154,6 +154,10 @@ class Settings(BaseSettings): # type: ignore
# eg ["http://localhost:8000"]
chroma_server_cors_allow_origins: List[str] = []
chroma_http_keepalive_secs: Optional[float] = 40.0
chroma_http_max_connections: Optional[int] = None
chroma_http_max_keepalive_connections: Optional[int] = None
# ==================
# Server config
# ==================

View File

@@ -39,6 +39,11 @@ from chromadb.execution.expression.operator import (
Sub,
Sum,
Val,
# GroupBy and Aggregate expressions
Aggregate,
MinK,
MaxK,
GroupBy,
)
from chromadb.execution.expression.plan import (
@@ -87,4 +92,9 @@ __all__ = [
"Sub",
"Sum",
"Val",
# GroupBy and Aggregate expressions
"Aggregate",
"MinK",
"MaxK",
"GroupBy",
]

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Set, Any, Union
from typing import Optional, List, Dict, Set, Any, Union, cast
import numpy as np
from numpy.typing import NDArray
@@ -7,9 +7,11 @@ from chromadb.api.types import (
Embeddings,
IDs,
Include,
OneOrMany,
SparseVector,
TYPE_KEY,
SPARSE_VECTOR_TYPE_VALUE,
maybe_cast_one_to_many,
normalize_embeddings,
validate_embeddings,
)
@@ -1024,7 +1026,7 @@ class Knn(Rank):
- A dense vector (list or numpy array)
- A sparse vector (SparseVector dict)
key: The embedding key to search against. Can be:
- "#embedding" (default) - searches the main embedding field
- Key.EMBEDDING (default) - searches the main embedding field
- A metadata field name (e.g., "my_custom_field") - searches that metadata field
limit: Maximum number of results to consider (default: 16)
default: Default score for records not in KNN results (default: None)
@@ -1054,7 +1056,7 @@ class Knn(Rank):
"NDArray[np.float64]",
"NDArray[np.int32]",
]
key: str = "#embedding"
key: Union[Key, str] = K.EMBEDDING
limit: int = 16
default: Optional[float] = None
return_rank: bool = False
@@ -1069,8 +1071,12 @@ class Knn(Rank):
# Convert numpy array to list
query_value = query_value.tolist()
key_value = self.key
if isinstance(key_value, Key):
key_value = key_value.name
# Build result dict - only include non-default values to keep JSON clean
result = {"query": query_value, "key": self.key, "limit": self.limit}
result = {"query": query_value, "key": key_value, "limit": self.limit}
# Only include optional fields if they're set to non-default values
if self.default is not None:
@@ -1291,3 +1297,216 @@ class Select:
# Convert to set while preserving the Key instances
return Select(keys=set(key_list))
# GroupBy and Aggregate types for grouping search results
def _keys_to_strings(keys: OneOrMany[Union[Key, str]]) -> List[str]:
"""Convert OneOrMany[Key|str] to List[str] for serialization."""
keys_list = cast(List[Union[Key, str]], maybe_cast_one_to_many(keys))
return [k.name if isinstance(k, Key) else k for k in keys_list]
def _strings_to_keys(keys: Union[List[Any], tuple[Any, ...]]) -> List[Union[Key, str]]:
"""Convert List[str] to List[Key] for deserialization."""
return [Key(k) if isinstance(k, str) else k for k in keys]
def _parse_k_aggregate(
op: str, data: Dict[str, Any]
) -> tuple[List[Union[Key, str]], int]:
"""Parse common fields for MinK/MaxK from dict.
Args:
op: The operator name (e.g., "$min_k" or "$max_k")
data: The dict containing the operator
Returns:
Tuple of (keys, k) where keys is List[Union[Key, str]] and k is int
Raises:
TypeError: If data types are invalid
ValueError: If required fields are missing or invalid
"""
agg_data = data[op]
if not isinstance(agg_data, dict):
raise TypeError(f"{op} requires a dict, got {type(agg_data).__name__}")
if "keys" not in agg_data:
raise ValueError(f"{op} requires 'keys' field")
if "k" not in agg_data:
raise ValueError(f"{op} requires 'k' field")
keys = agg_data["keys"]
if not isinstance(keys, (list, tuple)):
raise TypeError(f"{op} keys must be a list, got {type(keys).__name__}")
if not keys:
raise ValueError(f"{op} keys cannot be empty")
k = agg_data["k"]
if not isinstance(k, int):
raise TypeError(f"{op} k must be an integer, got {type(k).__name__}")
if k <= 0:
raise ValueError(f"{op} k must be positive, got {k}")
return _strings_to_keys(keys), k
@dataclass
class Aggregate:
"""Base class for aggregation expressions within groups.
Aggregations determine which records to keep from each group:
- MinK: Keep k records with minimum values (ascending order)
- MaxK: Keep k records with maximum values (descending order)
Examples:
# Keep top 3 by score per group (single key)
MinK(keys=Key.SCORE, k=3)
# Keep top 5 by priority, then score as tiebreaker (multiple keys)
MinK(keys=[Key("priority"), Key.SCORE], k=5)
# Keep bottom 2 by score per group
MaxK(keys=Key.SCORE, k=2)
"""
def to_dict(self) -> Dict[str, Any]:
"""Convert the Aggregate expression to a dictionary for JSON serialization"""
raise NotImplementedError("Subclasses must implement to_dict()")
@staticmethod
def from_dict(data: Dict[str, Any]) -> "Aggregate":
"""Create Aggregate expression from dictionary.
Supports:
- {"$min_k": {"keys": [...], "k": n}} -> MinK(keys=[...], k=n)
- {"$max_k": {"keys": [...], "k": n}} -> MaxK(keys=[...], k=n)
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for Aggregate, got {type(data).__name__}")
if not data:
raise ValueError("Aggregate dict cannot be empty")
if len(data) != 1:
raise ValueError(
f"Aggregate dict must contain exactly one operator, got {len(data)}"
)
op = next(iter(data.keys()))
if op == "$min_k":
keys, k = _parse_k_aggregate(op, data)
return MinK(keys=keys, k=k)
elif op == "$max_k":
keys, k = _parse_k_aggregate(op, data)
return MaxK(keys=keys, k=k)
else:
raise ValueError(f"Unknown aggregate operator: {op}")
@dataclass
class MinK(Aggregate):
"""Keep k records with minimum aggregate key values per group"""
keys: OneOrMany[Union[Key, str]]
k: int
def to_dict(self) -> Dict[str, Any]:
return {"$min_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
@dataclass
class MaxK(Aggregate):
"""Keep k records with maximum aggregate key values per group"""
keys: OneOrMany[Union[Key, str]]
k: int
def to_dict(self) -> Dict[str, Any]:
return {"$max_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
@dataclass
class GroupBy:
"""Group results by metadata keys and aggregate within each group.
Groups search results by one or more metadata fields, then applies an
aggregation (MinK or MaxK) to select records within each group.
The final output is flattened and sorted by score.
Args:
keys: Metadata key(s) to group by. Can be a single key or a list of keys.
E.g., Key("category") or [Key("category"), Key("author")]
aggregate: Aggregation to apply within each group (MinK or MaxK)
Note: Both keys and aggregate must be specified together.
Examples:
# Top 3 documents per category (single key)
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=Key.SCORE, k=3)
)
# Top 2 per (year, category) combination (multiple keys)
GroupBy(
keys=[Key("year"), Key("category")],
aggregate=MinK(keys=Key.SCORE, k=2)
)
# Top 1 per category by priority, score as tiebreaker
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=1)
)
"""
keys: OneOrMany[Union[Key, str]] = field(default_factory=list)
aggregate: Optional[Aggregate] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert the GroupBy to a dictionary for JSON serialization"""
# Default GroupBy (no keys, no aggregate) serializes to {}
if not self.keys or self.aggregate is None:
return {}
result: Dict[str, Any] = {"keys": _keys_to_strings(self.keys)}
result["aggregate"] = self.aggregate.to_dict()
return result
@staticmethod
def from_dict(data: Dict[str, Any]) -> "GroupBy":
"""Create GroupBy from dictionary.
Examples:
- {} -> GroupBy() (default, no grouping)
- {"keys": ["category"], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for GroupBy, got {type(data).__name__}")
# Empty dict returns default GroupBy (no grouping)
if not data:
return GroupBy()
# Non-empty dict requires keys and aggregate
if "keys" not in data:
raise ValueError("GroupBy requires 'keys' field")
if "aggregate" not in data:
raise ValueError("GroupBy requires 'aggregate' field")
keys = data["keys"]
if not isinstance(keys, (list, tuple)):
raise TypeError(f"GroupBy keys must be a list, got {type(keys).__name__}")
if not keys:
raise ValueError("GroupBy keys cannot be empty")
aggregate_data = data["aggregate"]
if not isinstance(aggregate_data, dict):
raise TypeError(
f"GroupBy aggregate must be a dict, got {type(aggregate_data).__name__}"
)
aggregate = Aggregate.from_dict(aggregate_data)
return GroupBy(keys=_strings_to_keys(keys), aggregate=aggregate)

View File

@@ -4,12 +4,12 @@ from typing import List, Dict, Any, Union, Set, Optional
from chromadb.execution.expression.operator import (
KNN,
Filter,
GroupBy,
Limit,
Projection,
Scan,
Rank,
Select,
Val,
Where,
Key,
)
@@ -77,9 +77,18 @@ class Search:
Combined with metadata filtering:
Search().where((Key.ID.is_in(["id1", "id2"])) & (Key("status") == "active"))
With group_by:
(Search()
.rank(Knn(query=[0.1, 0.2]))
.group_by(GroupBy(
keys=[Key("category")],
aggregate=MinK(keys=[Key.SCORE], k=3)
)))
Empty Search() is valid and will use defaults:
- where: None (no filtering)
- rank: None (no ranking - results ordered by default order)
- group_by: None (no grouping)
- limit: No limit
- select: Empty selection
"""
@@ -88,6 +97,7 @@ class Search:
self,
where: Optional[Union[Where, Dict[str, Any]]] = None,
rank: Optional[Union[Rank, Dict[str, Any]]] = None,
group_by: Optional[Union[GroupBy, Dict[str, Any]]] = None,
limit: Optional[Union[Limit, Dict[str, Any], int]] = None,
select: Optional[Union[Select, Dict[str, Any], List[str], Set[str]]] = None,
):
@@ -99,11 +109,13 @@ class Search:
rank: Rank expression or dict for scoring (defaults to None - no ranking)
Dict will be converted using Rank.from_dict()
Note: Primitive numbers are not accepted - use {"$val": number} for constant ranks
group_by: GroupBy configuration for grouping and aggregating results (defaults to None)
Dict will be converted using GroupBy.from_dict()
limit: Limit configuration for pagination (defaults to no limit)
Can be a Limit object, a dict for Limit.from_dict(), or an int
When passing an int, it creates Limit(limit=value, offset=0)
select: Select configuration for keys (defaults to empty selection)
Can be a Select object, a dict for Select.from_dict(),
Can be a Select object, a dict for Select.from_dict(),
or a list/set of strings (e.g., ["#document", "#score"])
"""
# Handle where parameter
@@ -117,7 +129,7 @@ class Search:
raise TypeError(
f"where must be a Where object, dict, or None, got {type(where).__name__}"
)
# Handle rank parameter
if rank is None:
self._rank = None
@@ -129,7 +141,19 @@ class Search:
raise TypeError(
f"rank must be a Rank object, dict, or None, got {type(rank).__name__}"
)
# Handle group_by parameter
if group_by is None:
self._group_by = GroupBy()
elif isinstance(group_by, GroupBy):
self._group_by = group_by
elif isinstance(group_by, dict):
self._group_by = GroupBy.from_dict(group_by)
else:
raise TypeError(
f"group_by must be a GroupBy object, dict, or None, got {type(group_by).__name__}"
)
# Handle limit parameter
if limit is None:
self._limit = Limit()
@@ -143,7 +167,7 @@ class Search:
raise TypeError(
f"limit must be a Limit object, dict, int, or None, got {type(limit).__name__}"
)
# Handle select parameter
if select is None:
self._select = Select()
@@ -164,6 +188,7 @@ class Search:
return {
"filter": self._where.to_dict() if self._where is not None else None,
"rank": self._rank.to_dict() if self._rank is not None else None,
"group_by": self._group_by.to_dict(),
"limit": self._limit.to_dict(),
"select": self._select.to_dict(),
}
@@ -173,7 +198,11 @@ class Search:
"""Select all predefined keys (document, embedding, metadata, score)"""
new_select = Select(keys={Key.DOCUMENT, Key.EMBEDDING, Key.METADATA, Key.SCORE})
return Search(
where=self._where, rank=self._rank, limit=self._limit, select=new_select
where=self._where,
rank=self._rank,
group_by=self._group_by,
limit=self._limit,
select=new_select,
)
def select(self, *keys: Union[Key, str]) -> "Search":
@@ -187,7 +216,11 @@ class Search:
"""
new_select = Select(keys=set(keys))
return Search(
where=self._where, rank=self._rank, limit=self._limit, select=new_select
where=self._where,
rank=self._rank,
group_by=self._group_by,
limit=self._limit,
select=new_select,
)
def where(self, where: Optional[Union[Where, Dict[str, Any]]]) -> "Search":
@@ -202,20 +235,12 @@ class Search:
search.where({"status": "active"})
search.where({"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]})
"""
# Convert dict to Where if needed
if where is None:
converted_where = None
elif isinstance(where, Where):
converted_where = where
elif isinstance(where, dict):
converted_where = Where.from_dict(where)
else:
raise TypeError(
f"where must be a Where object, dict, or None, got {type(where).__name__}"
)
return Search(
where=converted_where, rank=self._rank, limit=self._limit, select=self._select
where=where,
rank=self._rank,
group_by=self._group_by,
limit=self._limit,
select=self._select,
)
def rank(self, rank_expr: Optional[Union[Rank, Dict[str, Any]]]) -> "Search":
@@ -231,20 +256,37 @@ class Search:
search.rank({"$knn": {"query": [0.1, 0.2]}})
search.rank({"$sum": [{"$knn": {"query": [0.1, 0.2]}}, {"$val": 0.5}]})
"""
# Convert dict to Rank if needed
if rank_expr is None:
converted_rank = None
elif isinstance(rank_expr, Rank):
converted_rank = rank_expr
elif isinstance(rank_expr, dict):
converted_rank = Rank.from_dict(rank_expr)
else:
raise TypeError(
f"rank_expr must be a Rank object, dict, or None, got {type(rank_expr).__name__}"
)
return Search(
where=self._where, rank=converted_rank, limit=self._limit, select=self._select
where=self._where,
rank=rank_expr,
group_by=self._group_by,
limit=self._limit,
select=self._select,
)
def group_by(self, group_by: Optional[Union[GroupBy, Dict[str, Any]]]) -> "Search":
"""Set the group_by configuration for grouping and aggregating results
Args:
group_by: A GroupBy object, dict, or None for grouping
Dicts will be converted using GroupBy.from_dict()
Example:
search.group_by(GroupBy(
keys=[Key("category")],
aggregate=MinK(keys=[Key.SCORE], k=3)
))
search.group_by({
"keys": ["category"],
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}
})
"""
return Search(
where=self._where,
rank=self._rank,
group_by=group_by,
limit=self._limit,
select=self._select,
)
def limit(self, limit: int, offset: int = 0) -> "Search":
@@ -259,5 +301,9 @@ class Search:
"""
new_limit = Limit(offset=offset, limit=limit)
return Search(
where=self._where, rank=self._rank, limit=new_limit, select=self._select
where=self._where,
rank=self._rank,
group_by=self._group_by,
limit=new_limit,
select=self._select,
)

View File

@@ -425,6 +425,20 @@ class FastAPI(Server):
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/functions/attach",
self.attach_function,
methods=["POST"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/functions/{function_name}",
self.get_attached_function,
methods=["GET"],
response_model=None,
)
def shutdown(self) -> None:
self._system.stop()
@@ -938,6 +952,103 @@ class FastAPI(Server):
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.attach_function", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def attach_function(
self,
request: Request,
tenant: str,
database_name: str,
collection_id: str,
) -> Dict[str, Any]:
try:
def process_attach_function(request: Request, raw_body: bytes) -> Dict[str, Any]:
body = orjson.loads(raw_body)
# NOTE: Auth check for attaching functions
self.sync_auth_request(
request.headers,
AuthzAction.UPDATE_COLLECTION, # Using UPDATE_COLLECTION as the auth action
tenant,
database_name,
collection_id,
)
self._set_request_context(request=request)
name = body.get("name")
function_id = body.get("function_id")
output_collection = body.get("output_collection")
params = body.get("params")
attached_fn = self._api.attach_function(
function_id=function_id,
name=name,
input_collection_id=_uuid(collection_id),
output_collection=output_collection,
params=params,
tenant=tenant,
database=database_name,
)
return {
"attached_function": {
"id": str(attached_fn.id),
"name": attached_fn.name,
"function_name": attached_fn.function_name,
"output_collection": attached_fn.output_collection,
"params": attached_fn.params,
}
}
raw_body = await request.body()
return await to_thread.run_sync(
process_attach_function,
request,
raw_body,
limiter=self._capacity_limiter,
)
except Exception:
raise
@trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_attached_function(
self,
request: Request,
tenant: str,
database_name: str,
collection_id: str,
function_name: str,
) -> Dict[str, Any]:
# NOTE: Auth check for getting attached functions
await self.auth_request(
request.headers,
AuthzAction.GET_COLLECTION, # Using GET_COLLECTION as the auth action
tenant,
database_name,
collection_id,
)
add_attributes_to_current_span({"tenant": tenant})
attached_fn = await to_thread.run_sync(
self._api.get_attached_function,
function_name,
_uuid(collection_id),
tenant,
database_name,
limiter=self._capacity_limiter,
)
return {
"attached_function": {
"id": str(attached_fn.id),
"name": attached_fn.name,
"function_name": attached_fn.function_name,
"output_collection": attached_fn.output_collection,
"params": attached_fn.params,
}
}
@trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def add(

View File

@@ -31,7 +31,7 @@ from chromadb.utils.embedding_functions import (
)
from chromadb.api.models.Collection import Collection
from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.errors import InvalidArgumentError, InternalError
from chromadb.errors import InvalidArgumentError
from chromadb.execution.expression import Knn, Search
from chromadb.types import Collection as CollectionModel
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
@@ -509,7 +509,7 @@ def test_sparse_vector_not_allowed_locally(
schema = Schema()
schema.create_index(key="sparse_metadata", config=SparseVectorIndexConfig())
with pytest.raises(
InternalError, match="Sparse vector indexing is not enabled in local"
InvalidArgumentError, match="Sparse vector indexing is not enabled in local"
):
_create_isolated_collection(client_factories, schema=schema)
@@ -1011,12 +1011,26 @@ def test_collection_fork_inherits_and_isolates_schema(
client_factories,
schema=schema,
)
parent_version_before_add = get_collection_version(client, collection.name)
parent_ids = [f"parent-{i}" for i in range(251)]
parent_docs = [f"parent doc {i}" for i in range(251)]
parent_metadatas: List[Mapping[str, Any]] = [
{"shared_key": f"parent_{i}"} for i in range(251)
]
collection.add(
ids=["parent-1"],
documents=["parent doc"],
metadatas=[{"shared_key": "parent"}],
ids=parent_ids,
documents=parent_docs,
metadatas=parent_metadatas,
)
# Wait for parent to compact before forking. Otherwise, the fork inherits
# uncompacted logs, and compaction of those inherited logs could increment
# the fork's version before the fork's own data is compacted.
wait_for_version_increase(client, collection.name, parent_version_before_add)
assert collection.schema is not None
parent_schema_json = collection.schema.serialize_to_json()
@@ -1050,8 +1064,8 @@ def test_collection_fork_inherits_and_isolates_schema(
assert reloaded_parent.schema is not None
assert "child_only" not in reloaded_parent.schema.keys
parent_results = reloaded_parent.get(where={"shared_key": "parent"})
assert set(parent_results["ids"]) == {"parent-1"}
parent_results = reloaded_parent.get(where={"shared_key": "parent_10"})
assert set(parent_results["ids"]) == {"parent-10"}
child_results = forked.get(where={"child_only": "value_10"})
assert set(child_results["ids"]) == {"fork-10"}

View File

@@ -7,11 +7,21 @@ for automatically processing collections.
import pytest
from chromadb.api.client import Client as ClientCreator
from chromadb.api.functions import (
RECORD_COUNTER_FUNCTION,
STATISTICS_FUNCTION,
Function,
)
from chromadb.config import System
from chromadb.errors import ChromaError, NotFoundError
from chromadb.test.utils.wait_for_version_increase import (
get_collection_version,
wait_for_version_increase,
)
from time import sleep
def test_function_attach_and_detach(basic_http_client: System) -> None:
def test_count_function_attach_and_detach(basic_http_client: System) -> None:
"""Test creating and removing a function with the record_counter operator"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
@@ -22,45 +32,39 @@ def test_function_attach_and_detach(basic_http_client: System) -> None:
metadata={"description": "Sample documents for task processing"},
)
# Add initial documents
collection.add(
ids=["doc1", "doc2", "doc3"],
documents=[
"The quick brown fox jumps over the lazy dog",
"Machine learning is a subset of artificial intelligence",
"Python is a popular programming language",
],
metadatas=[{"source": "proverb"}, {"source": "tech"}, {"source": "tech"}],
)
# Verify collection has documents
assert collection.count() == 3
# Create a task that counts records in the collection
attached_fn = collection.attach_function(
attached_fn, created = collection.attach_function(
name="count_my_docs",
function_id="record_counter", # Built-in operator that counts records
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
# Verify task creation succeeded
assert attached_fn is not None
assert created is True
initial_version = get_collection_version(client, collection.name)
# Add more documents
# Add documents
collection.add(
ids=["doc4", "doc5"],
documents=[
"Chroma is a vector database",
"Tasks automate data processing",
],
ids=["doc_{}".format(i) for i in range(0, 300)],
documents=["test document"] * 300,
)
# Verify documents were added
assert collection.count() == 5
assert collection.count() == 300
wait_for_version_increase(client, collection.name, initial_version)
# Give some time to invalidate the frontend query cache
sleep(60)
result = client.get_collection("my_documents_counts").get("function_output")
assert result["metadatas"] is not None
assert result["metadatas"][0]["total_count"] == 300
# Remove the task
success = attached_fn.detach(
success = collection.detach_function(
attached_fn.name,
delete_output_collection=True,
)
@@ -79,13 +83,42 @@ def test_task_with_invalid_function(basic_http_client: System) -> None:
# Attempt to create task with non-existent function should raise ChromaError
with pytest.raises(ChromaError, match="function not found"):
collection.attach_function(
function=Function._NONEXISTENT_TEST_ONLY,
name="invalid_task",
function_id="nonexistent_function",
output_collection="output_collection",
params=None,
)
def test_attach_function_returns_function_name(basic_http_client: System) -> None:
"""Test that attach_function and get_attached_function return function_name field instead of UUID"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
collection = client.create_collection(name="test_function_name")
collection.add(ids=["id1"], documents=["doc1"])
# Attach a function and verify function_name field in response
attached_fn, created = collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="my_counter",
output_collection="output_collection",
params=None,
)
# Verify the attached function has function_name (not function_id UUID)
assert created is True
assert attached_fn.function_name == "record_counter"
assert attached_fn.name == "my_counter"
# Get the attached function and verify function_name field is also present
retrieved_fn = collection.get_attached_function("my_counter")
assert retrieved_fn == attached_fn
# Clean up
collection.detach_function(attached_fn.name, delete_output_collection=True)
def test_function_multiple_collections(basic_http_client: System) -> None:
"""Test attaching functions on multiple collections"""
client = ClientCreator.from_system(basic_http_client)
@@ -95,93 +128,163 @@ def test_function_multiple_collections(basic_http_client: System) -> None:
collection1 = client.create_collection(name="collection_1")
collection1.add(ids=["id1", "id2"], documents=["doc1", "doc2"])
attached_fn1 = collection1.attach_function(
attached_fn1, created1 = collection1.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_1",
function_id="record_counter",
output_collection="output_1",
params=None,
)
assert attached_fn1 is not None
assert created1 is True
# Create second collection and task
collection2 = client.create_collection(name="collection_2")
collection2.add(ids=["id3", "id4"], documents=["doc3", "doc4"])
attached_fn2 = collection2.attach_function(
attached_fn2, created2 = collection2.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_2",
function_id="record_counter",
output_collection="output_2",
params=None,
)
assert attached_fn2 is not None
assert created2 is True
# Task IDs should be different
assert attached_fn1.id != attached_fn2.id
# Clean up
assert attached_fn1.detach(delete_output_collection=True) is True
assert attached_fn2.detach(delete_output_collection=True) is True
assert (
collection1.detach_function(attached_fn1.name, delete_output_collection=True)
is True
)
assert (
collection2.detach_function(attached_fn2.name, delete_output_collection=True)
is True
)
def test_functions_multiple_attached_functions(basic_http_client: System) -> None:
"""Test attaching multiple functions on the same collection"""
def test_functions_one_attached_function_per_collection(
basic_http_client: System,
) -> None:
"""Test that only one attached function is allowed per collection"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create a single collection
collection = client.create_collection(name="multi_task_collection")
collection = client.create_collection(name="single_task_collection")
collection.add(ids=["id1", "id2", "id3"], documents=["doc1", "doc2", "doc3"])
# Create first task on the collection
attached_fn1 = collection.attach_function(
attached_fn1, created = collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_1",
function_id="record_counter",
output_collection="output_1",
params=None,
)
assert attached_fn1 is not None
assert created is True
# Create second task on the SAME collection with a different name
attached_fn2 = collection.attach_function(
# Attempt to create a second task with a different name should fail
# (only one attached function allowed per collection)
with pytest.raises(
ChromaError,
match="collection already has an attached function: name=task_1, function=record_counter, output_collection=output_1",
):
collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_2",
output_collection="output_2",
params=None,
)
# Attempt to create a task with the same name but different function_id should also fail
with pytest.raises(
ChromaError,
match=r"collection already has an attached function: name=task_1, function=record_counter, output_collection=output_1",
):
collection.attach_function(
function=STATISTICS_FUNCTION,
name="task_1",
output_collection="output_different", # Different output collection
params=None,
)
# Detach the first function
assert (
collection.detach_function(attached_fn1.name, delete_output_collection=True)
is True
)
# Now we should be able to attach a new function
attached_fn2, created2 = collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_2",
function_id="record_counter",
output_collection="output_2",
params=None,
)
assert attached_fn2 is not None
assert created2 is True
assert attached_fn2.id != attached_fn1.id
# Task IDs should be different even though they're on the same collection
assert attached_fn1.id != attached_fn2.id
# Create third task on the same collection
attached_fn3 = collection.attach_function(
name="task_3",
function_id="record_counter",
output_collection="output_3",
params=None,
# Clean up
assert (
collection.detach_function(attached_fn2.name, delete_output_collection=True)
is True
)
assert attached_fn3 is not None
assert attached_fn3.id != attached_fn1.id
assert attached_fn3.id != attached_fn2.id
# Attempt to create a task with duplicate name on same collection should fail
with pytest.raises(ChromaError, match="already exists"):
def test_attach_function_with_invalid_params(basic_http_client: System) -> None:
"""Test that attach_function with non-empty params raises an error"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
collection = client.create_collection(name="test_invalid_params")
collection.add(ids=["id1"], documents=["test document"])
# Attempt to create task with non-empty params should fail
# (no functions currently accept parameters)
with pytest.raises(
ChromaError,
match="params must be empty - no functions currently accept parameters",
):
collection.attach_function(
name="task_1", # Duplicate name
function_id="record_counter",
output_collection="output_duplicate",
params=None,
name="invalid_params_task",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params={"some_key": "some_value"},
)
# Clean up - remove each task individually
assert attached_fn1.detach(delete_output_collection=True) is True
assert attached_fn2.detach(delete_output_collection=True) is True
assert attached_fn3.detach(delete_output_collection=True) is True
def test_attach_function_output_collection_already_exists(
basic_http_client: System,
) -> None:
"""Test that attach_function fails when output collection name already exists"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create a collection that will be used as input
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test document"])
# Create another collection with the name we want to use for output
client.create_collection(name="existing_output_collection")
# Attempt to create task with output collection name that already exists
with pytest.raises(
ChromaError,
match=r"Output collection \[existing_output_collection\] already exists",
):
input_collection.attach_function(
name="my_task",
function=RECORD_COUNTER_FUNCTION,
output_collection="existing_output_collection",
params=None,
)
def test_function_remove_nonexistent(basic_http_client: System) -> None:
@@ -191,15 +294,294 @@ def test_function_remove_nonexistent(basic_http_client: System) -> None:
collection = client.create_collection(name="test_collection")
collection.add(ids=["id1"], documents=["test"])
attached_fn = collection.attach_function(
attached_fn, _ = collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="test_function",
function_id="record_counter",
output_collection="output_collection",
params=None,
)
attached_fn.detach(delete_output_collection=True)
collection.detach_function(attached_fn.name, delete_output_collection=True)
# Trying to detach this function again should raise NotFoundError
with pytest.raises(NotFoundError, match="does not exist"):
attached_fn.detach(delete_output_collection=True)
collection.detach_function(attached_fn.name, delete_output_collection=True)
def test_attach_to_output_collection_fails(basic_http_client: System) -> None:
"""Test that attaching a function to an output collection fails"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create input collection
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test"])
_, _ = input_collection.attach_function(
name="test_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
output_collection = client.get_collection(name="output_collection")
with pytest.raises(
ChromaError, match="cannot attach function to an output collection"
):
_ = output_collection.attach_function(
name="test_function_2",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection_2",
params=None,
)
def test_delete_output_collection_detaches_function(basic_http_client: System) -> None:
"""Test that deleting an output collection also detaches the attached function"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create input collection and attach a function
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test"])
attached_fn, created = input_collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn is not None
assert created is True
# Delete the output collection directly
client.delete_collection("output_collection")
# The attached function should now be gone - trying to get it should raise NotFoundError
with pytest.raises(NotFoundError):
input_collection.get_attached_function("my_function")
def test_delete_orphaned_output_collection(basic_http_client: System) -> None:
"""Test that deleting an output collection from a recently detached function works"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create input collection and attach a function
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test"])
attached_fn, created = input_collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn is not None
assert created is True
input_collection.detach_function(attached_fn.name, delete_output_collection=False)
# Delete the output collection directly
client.delete_collection("output_collection")
# The attached function should still exist but be marked as detached
with pytest.raises(NotFoundError):
input_collection.get_attached_function("my_function")
with pytest.raises(NotFoundError):
# Try to use the function - it should fail since it's detached
client.get_collection("output_collection")
def test_partial_attach_function_repair(
basic_http_client: System,
) -> None:
"""Test creating and removing a function with the record_counter operator"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create a collection
collection = client.get_or_create_collection(
name="my_document",
)
# Create a task that counts records in the collection
attached_fn, created = collection.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
assert created is True
# Verify task creation succeeded
assert attached_fn is not None
collection2 = client.get_or_create_collection(
name="my_document2",
)
# Create a task that counts records in the collection
# This should fail
with pytest.raises(
ChromaError, match=r"Output collection \[my_documents_counts\] already exists"
):
attached_fn, _ = collection2.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
# Detach the function
assert (
collection.detach_function(attached_fn.name, delete_output_collection=True)
is True
)
# Create a task that counts records in the collection
attached_fn, created = collection2.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
assert attached_fn is not None
assert created is True
def test_output_collection_created_with_schema(basic_http_client: System) -> None:
"""Test that output collections are created with the source_attached_function_id in the schema"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create input collection and attach a function
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test"])
attached_fn, created = input_collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn is not None
assert created is True
# Get the output collection - it should exist
output_collection = client.get_collection(name="output_collection")
assert output_collection is not None
# The source_attached_function_id is stored in the schema (not metadata)
# We can't directly access the schema from the client, but we verify the collection exists
# and the attached function orchestrator will use this field internally
assert "source_attached_function_id" in output_collection._model.pretty_schema()
# Clean up
input_collection.detach_function(attached_fn.name, delete_output_collection=True)
def test_count_function_attach_and_detach_attach_attach(
basic_http_client: System,
) -> None:
"""Test creating and removing a function with the record_counter operator"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create a collection
collection = client.get_or_create_collection(
name="my_document",
metadata={"description": "Sample documents for task processing"},
)
# Create a task that counts records in the collection
attached_fn, created = collection.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
# Verify task creation succeeded
assert created is True
assert attached_fn is not None
initial_version = get_collection_version(client, collection.name)
# Add documents
collection.add(
ids=["doc_{}".format(i) for i in range(0, 300)],
documents=["test document"] * 300,
)
# Verify documents were added
assert collection.count() == 300
wait_for_version_increase(client, collection.name, initial_version)
# Give some time to invalidate the frontend query cache
sleep(60)
result = client.get_collection("my_documents_counts").get("function_output")
assert result["metadatas"] is not None
assert result["metadatas"][0]["total_count"] == 300
# Remove the task
success = collection.detach_function(
attached_fn.name, delete_output_collection=True
)
# Verify task removal succeeded
assert success is True
# Attach a function that counts records in the collection
attached_fn, created = collection.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
assert attached_fn is not None
assert created is True
# Attach a function that counts records in the collection
attached_fn, created = collection.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
assert created is False
assert attached_fn is not None
def test_attach_function_idempotency(basic_http_client: System) -> None:
"""Test that attach_function is idempotent - calling it twice with same params returns created=False"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
collection = client.create_collection(name="idempotency_test")
collection.add(ids=["id1"], documents=["test document"])
# First attach - should be newly created
attached_fn1, created1 = collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn1 is not None
assert created1 is True
# Second attach with identical params - should be idempotent (created=False)
attached_fn2, created2 = collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn2 is not None
assert created2 is False
# Both should return the same function ID
assert attached_fn1.id == attached_fn2.id
# Clean up
collection.detach_function(attached_fn1.name, delete_output_collection=True)

View File

@@ -1,7 +1,9 @@
import math
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from chromadb import SparseVector
from chromadb.utils.embedding_functions.chroma_bm25_embedding_function import (
DEFAULT_CHROMA_BM25_STOPWORDS,
ChromaBm25EmbeddingFunction,
@@ -137,3 +139,64 @@ def test_validate_config_update_allows_known_keys() -> None:
embedder.validate_config_update(
embedder.get_config(), {"k": 1.1, "stopwords": ["custom"]}
)
def test_multithreaded_usage() -> None:
embedder = ChromaBm25EmbeddingFunction()
base_texts = [
"""The gravitational wave background from massive black hole binaries emit bursts of
gravitational waves at periapse. Such events may be directly resolvable in the Galactic
centre. However, if the star does not spiral in, the emitted GWs are not resolvable for
extra-galactic MBHs, but constitute a source of background noise. We estimate the power
spectrum of this extreme mass ratio burst background.""",
"""Dynamics of planets in exoplanetary systems with multiple stars showing how the
gravitational interactions between the stars and planets affect the orbital stability
and long-term evolution of the planetary system architectures.""",
"""Diurnal Thermal Tides in a Non-rotating atmosphere with realistic heating profiles
and temperature gradients that demonstrate the complex interplay between radiation
and atmospheric dynamics in planetary atmospheres.""",
"""Intermittent turbulence, noise and waves in stellar atmospheres create complex
patterns of energy transport and momentum deposition that influence the structure
and evolution of stellar interiors and surfaces.""",
"""Superconductivity in quantum materials and condensed matter physics systems
exhibiting novel quantum phenomena including topological phases, strongly correlated
electron systems, and exotic superconducting pairing mechanisms.""",
"""Machine learning models require careful tuning of hyperparameters including learning
rates, regularization coefficients, and architectural choices that demonstrate the
complex interplay between optimization algorithms and model capacity.""",
"""Natural language processing enables text understanding through sophisticated
algorithms that analyze semantic relationships, syntactic structures, and contextual
information to extract meaningful representations from unstructured textual data.""",
"""Vector databases store high-dimensional embeddings efficiently using advanced
indexing techniques including approximate nearest neighbor search algorithms that
balance accuracy and computational efficiency for large-scale similarity search.""",
]
texts = base_texts * 30
num_threads = 10
def process_single_text(text: str) -> SparseVector:
return embedder([text])[0]
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(process_single_text, text) for text in texts]
all_results = []
for future in as_completed(futures):
try:
embedding = future.result()
all_results.append(embedding)
except Exception as e:
pytest.fail(
f"Threading error detected: {type(e).__name__}: {e}. "
"This indicates the stemmer is not thread-safe when cached."
)
assert len(all_results) == len(texts)
for embedding in all_results:
assert embedding.indices
assert len(embedding.indices) == len(embedding.values)
assert _is_sorted(embedding.indices)
for value in embedding.values:
assert value > 0
assert math.isfinite(value)

View File

@@ -33,12 +33,14 @@ def test_get_builtins_holds() -> None:
"GoogleGenerativeAiEmbeddingFunction",
"GooglePalmEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"GoogleGenaiEmbeddingFunction",
"HuggingFaceEmbeddingFunction",
"HuggingFaceEmbeddingServer",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"MistralEmbeddingFunction",
"MorphEmbeddingFunction",
"NomicEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OllamaEmbeddingFunction",
"OpenAIEmbeddingFunction",

View File

@@ -1,7 +1,7 @@
import hashlib
import hypothesis
import hypothesis.strategies as st
from typing import Any, Optional, List, Dict, Union, cast
from typing import Any, Optional, List, Dict, Union, cast, Tuple
from typing_extensions import TypedDict
import uuid
import numpy as np
@@ -9,6 +9,10 @@ import numpy.typing as npt
import chromadb.api.types as types
import re
from hypothesis.strategies._internal.strategies import SearchStrategy
from chromadb.test.api.test_schema_e2e import (
SimpleEmbeddingFunction,
DeterministicSparseEmbeddingFunction,
)
from chromadb.test.conftest import NOT_CLUSTER_ONLY
from dataclasses import dataclass
from chromadb.api.types import (
@@ -17,12 +21,27 @@ from chromadb.api.types import (
EmbeddingFunction,
Embeddings,
Metadata,
Schema,
CollectionMetadata,
VectorIndexConfig,
SparseVectorIndexConfig,
StringInvertedIndexConfig,
IntInvertedIndexConfig,
FloatInvertedIndexConfig,
BoolInvertedIndexConfig,
HnswIndexConfig,
SpannIndexConfig,
Space,
)
from chromadb.types import LiteralValue, WhereOperator, LogicalOperator
from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled
from chromadb.test.conftest import is_spann_disabled_mode
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
CreateSpannConfiguration,
CreateHNSWConfiguration,
)
from chromadb.utils.embedding_functions import (
register_embedding_function,
)
# Set the random seed for reproducibility
@@ -266,6 +285,365 @@ class ExternalCollection:
embedding_function: Optional[types.EmbeddingFunction[Embeddable]]
@register_embedding_function
class SimpleIpEmbeddingFunction(SimpleEmbeddingFunction):
"""Simple embedding function with ip space for persistence tests."""
def default_space(self) -> str: # type: ignore[override]
return "ip"
@st.composite
def vector_index_config_strategy(draw: st.DrawFn) -> VectorIndexConfig:
"""Generate VectorIndexConfig with optional space, embedding_function, source_key, hnsw, spann."""
space = None
embedding_function = None
source_key = None
hnsw = None
spann = None
if draw(st.booleans()):
space = draw(st.sampled_from(["cosine", "l2", "ip"]))
if draw(st.booleans()):
embedding_function = SimpleIpEmbeddingFunction(
dim=draw(st.integers(min_value=1, max_value=1000))
)
if draw(st.booleans()):
source_key = draw(st.one_of(st.just("#document"), safe_text))
index_choice = draw(st.sampled_from(["hnsw", "spann", "none"]))
if index_choice == "hnsw":
hnsw = HnswIndexConfig(
ef_construction=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
max_neighbors=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
ef_search=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
sync_threshold=draw(st.integers(min_value=2, max_value=10000))
if draw(st.booleans())
else None,
resize_factor=draw(st.floats(min_value=1.0, max_value=5.0))
if draw(st.booleans())
else None,
)
elif index_choice == "spann":
spann = SpannIndexConfig(
search_nprobe=draw(st.integers(min_value=1, max_value=128))
if draw(st.booleans())
else None,
write_nprobe=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
ef_construction=draw(st.integers(min_value=1, max_value=200))
if draw(st.booleans())
else None,
ef_search=draw(st.integers(min_value=1, max_value=200))
if draw(st.booleans())
else None,
max_neighbors=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
reassign_neighbor_count=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
split_threshold=draw(st.integers(min_value=50, max_value=200))
if draw(st.booleans())
else None,
merge_threshold=draw(st.integers(min_value=25, max_value=100))
if draw(st.booleans())
else None,
)
return VectorIndexConfig(
space=cast(Space, space),
embedding_function=embedding_function,
source_key=source_key,
hnsw=hnsw,
spann=spann,
)
@st.composite
def sparse_vector_index_config_strategy(draw: st.DrawFn) -> SparseVectorIndexConfig:
"""Generate SparseVectorIndexConfig with optional embedding_function, source_key, bm25."""
embedding_function = None
source_key = None
bm25 = None
if draw(st.booleans()):
embedding_function = DeterministicSparseEmbeddingFunction()
source_key = draw(st.one_of(st.just("#document"), safe_text))
if draw(st.booleans()):
bm25 = draw(st.booleans())
return SparseVectorIndexConfig(
embedding_function=embedding_function,
source_key=source_key,
bm25=bm25,
)
@st.composite
def schema_strategy(draw: st.DrawFn) -> Optional[Schema]:
"""Generate a Schema object with various create_index/delete_index operations."""
if draw(st.booleans()):
return None
schema = Schema()
# Decide how many operations to perform
num_operations = draw(st.integers(min_value=0, max_value=5))
sparse_index_created = False
for _ in range(num_operations):
operation = draw(st.sampled_from(["create_index", "delete_index"]))
config_type = draw(
st.sampled_from(
[
"string_inverted",
"int_inverted",
"float_inverted",
"bool_inverted",
"vector",
"sparse_vector",
]
)
)
# Decide if we're setting on a key or globally
use_key = draw(st.booleans())
key = None
if use_key and config_type != "vector":
# Vector indexes can't be set on specific keys, only globally
key = draw(safe_text)
if operation == "create_index":
if config_type == "string_inverted":
schema.create_index(config=StringInvertedIndexConfig(), key=key)
elif config_type == "int_inverted":
schema.create_index(config=IntInvertedIndexConfig(), key=key)
elif config_type == "float_inverted":
schema.create_index(config=FloatInvertedIndexConfig(), key=key)
elif config_type == "bool_inverted":
schema.create_index(config=BoolInvertedIndexConfig(), key=key)
elif config_type == "vector":
vector_config = draw(vector_index_config_strategy())
schema.create_index(config=vector_config, key=None)
elif (
config_type == "sparse_vector"
and not is_spann_disabled_mode
and not sparse_index_created
):
sparse_config = draw(sparse_vector_index_config_strategy())
# Sparse vector MUST have a key
if key is None:
key = draw(safe_text)
schema.create_index(config=sparse_config, key=key)
sparse_index_created = True
elif operation == "delete_index":
if config_type == "string_inverted":
schema.delete_index(config=StringInvertedIndexConfig(), key=key)
elif config_type == "int_inverted":
schema.delete_index(config=IntInvertedIndexConfig(), key=key)
elif config_type == "float_inverted":
schema.delete_index(config=FloatInvertedIndexConfig(), key=key)
elif config_type == "bool_inverted":
schema.delete_index(config=BoolInvertedIndexConfig(), key=key)
# Vector, FTS, and sparse_vector deletion is not currently supported
return schema
@st.composite
def metadata_with_hnsw_strategy(draw: st.DrawFn) -> Optional[CollectionMetadata]:
"""Generate metadata with hnsw parameters."""
metadata: CollectionMetadata = {}
if draw(st.booleans()):
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
if draw(st.booleans()):
metadata["hnsw:construction_ef"] = draw(
st.integers(min_value=1, max_value=1000)
)
if draw(st.booleans()):
metadata["hnsw:search_ef"] = draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans()):
metadata["hnsw:M"] = draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans()):
metadata["hnsw:resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
if draw(st.booleans()):
metadata["hnsw:sync_threshold"] = draw(
st.integers(min_value=2, max_value=10000)
)
return metadata if metadata else None
@st.composite
def create_configuration_strategy(
draw: st.DrawFn,
) -> Optional[CreateCollectionConfiguration]:
"""Generate CreateCollectionConfiguration with mutual exclusivity rules."""
configuration: CreateCollectionConfiguration = {}
# Optionally set embedding_function (independent)
if draw(st.booleans()):
configuration["embedding_function"] = SimpleIpEmbeddingFunction(
dim=draw(st.integers(min_value=1, max_value=1000))
)
# Decide: set space only, OR set hnsw config, OR set spann config
config_choice = draw(
st.sampled_from(
["space_only_hnsw", "space_only_spann", "hnsw", "spann", "none"]
)
)
if config_choice == "space_only_hnsw":
configuration["hnsw"] = CreateHNSWConfiguration(
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
)
elif config_choice == "space_only_spann":
configuration["spann"] = CreateSpannConfiguration(
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
)
elif config_choice == "hnsw":
# Set hnsw config (optionally with space)
hnsw_config: CreateHNSWConfiguration = {}
if draw(st.booleans()):
hnsw_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
hnsw_config["ef_construction"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["ef_search"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["sync_threshold"] = draw(st.integers(min_value=2, max_value=10000))
hnsw_config["resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
configuration["hnsw"] = hnsw_config
elif config_choice == "spann":
# Set spann config (optionally with space)
spann_config: CreateSpannConfiguration = {}
if draw(st.booleans()):
spann_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
spann_config["search_nprobe"] = draw(st.integers(min_value=1, max_value=128))
spann_config["write_nprobe"] = draw(st.integers(min_value=1, max_value=64))
spann_config["ef_construction"] = draw(st.integers(min_value=1, max_value=200))
spann_config["ef_search"] = draw(st.integers(min_value=1, max_value=200))
spann_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=64))
spann_config["reassign_neighbor_count"] = draw(
st.integers(min_value=1, max_value=64)
)
spann_config["split_threshold"] = draw(st.integers(min_value=50, max_value=200))
spann_config["merge_threshold"] = draw(st.integers(min_value=25, max_value=100))
configuration["spann"] = spann_config
return configuration if configuration else None
@dataclass
class CollectionInputCombination:
"""
Input tuple for collection creation tests.
"""
metadata: Optional[CollectionMetadata]
configuration: Optional[CreateCollectionConfiguration]
schema: Optional[Schema]
schema_vector_info: Optional[Dict[str, Any]]
kind: str
def non_none_items(items: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in items.items() if v is not None}
def vector_index_to_dict(config: VectorIndexConfig) -> Dict[str, Any]:
embedding_default_space: Optional[str] = None
if config.embedding_function is not None and hasattr(
config.embedding_function, "default_space"
):
embedding_default_space = cast(str, config.embedding_function.default_space())
return {
"space": config.space,
"hnsw": config.hnsw.model_dump(exclude_none=True) if config.hnsw else None,
"spann": config.spann.model_dump(exclude_none=True) if config.spann else None,
"embedding_function_default_space": embedding_default_space,
}
@st.composite
def _schema_input_strategy(
draw: st.DrawFn,
) -> Tuple[Schema, Dict[str, Any]]:
schema = Schema()
vector_config = draw(vector_index_config_strategy())
schema.create_index(config=vector_config, key=None)
return schema, vector_index_to_dict(vector_config)
@st.composite
def metadata_configuration_schema_strategy(
draw: st.DrawFn,
) -> CollectionInputCombination:
"""
Generate compatible combinations of metadata, configuration, and schema inputs.
"""
choice = draw(
st.sampled_from(
[
"none",
"metadata",
"configuration",
"metadata_configuration",
"schema",
]
)
)
metadata: Optional[CollectionMetadata] = None
configuration: Optional[CreateCollectionConfiguration] = None
schema: Optional[Schema] = None
schema_info: Optional[Dict[str, Any]] = None
if choice in ("metadata", "metadata_configuration"):
metadata = draw(
metadata_with_hnsw_strategy().filter(
lambda value: value is not None and len(value) > 0
)
)
if choice in ("configuration", "metadata_configuration"):
configuration = draw(
create_configuration_strategy().filter(
lambda value: value is not None
and (
(value.get("hnsw") is not None and len(value["hnsw"]) > 0)
or (value.get("spann") is not None and len(value["spann"]) > 0)
)
)
)
if choice == "schema":
schema, schema_info = draw(_schema_input_strategy())
return CollectionInputCombination(
metadata=metadata,
configuration=configuration,
schema=schema,
schema_vector_info=schema_info,
kind=choice,
)
@dataclass
class Collection(ExternalCollection):
"""
@@ -344,7 +722,7 @@ def collections(
spann_config: CreateSpannConfiguration = {
"space": spann_space,
"write_nprobe": 4,
"reassign_neighbor_count": 4
"reassign_neighbor_count": 4,
}
collection_config = {
"spann": spann_config,
@@ -395,7 +773,7 @@ def collections(
known_document_keywords=known_document_keywords,
has_embeddings=has_embeddings,
embedding_function=embedding_function,
collection_config=collection_config
collection_config=collection_config,
)
@@ -421,7 +799,9 @@ def metadata(
del metadata[key] # type: ignore
# Finally, add in some of the known keys for the collection
sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = {
k: st.just(v) for k, v in collection.known_metadata_keys.items()
k: st.just(v)
for k, v in collection.known_metadata_keys.items()
if isinstance(v, (str, int, float))
}
metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore
# We don't allow submitting empty metadata

View File

@@ -31,7 +31,7 @@ collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="co
normal=hypothesis.settings(max_examples=500),
fast=hypothesis.settings(max_examples=200),
),
max_examples=2
max_examples=2,
)
def test_add_miniscule(
client: ClientAPI,
@@ -332,7 +332,8 @@ def test_out_of_order_ids(client: ClientAPI) -> None:
]
coll = client.create_collection(
"test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore
"test",
embedding_function=lambda input: [[1, 2, 3] for _ in input], # type: ignore
)
embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids]
coll.add(ids=ooo_ids, embeddings=embeddings)
@@ -369,3 +370,155 @@ def test_add_partial(client: ClientAPI) -> None:
assert results["ids"] == ["1", "2", "3"]
assert results["metadatas"] == [{"a": 1}, None, {"a": 3}]
assert results["documents"] == ["a", "b", None]
@pytest.mark.skipif(
NOT_CLUSTER_ONLY,
reason="GroupBy is only supported in distributed mode",
)
def test_search_group_by(client: ClientAPI) -> None:
"""Test GroupBy with single key, multiple keys, and multiple ranking keys."""
from chromadb.execution.expression.operator import GroupBy, MinK, Key
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression import Knn
create_isolated_database(client)
coll = client.create_collection(name="test_group_by")
# Test data: 12 records across 3 categories and 2 years
# Embeddings are designed so science docs are closest to query [1,0,0,0]
ids = [
"sci_2023_1",
"sci_2023_2",
"sci_2024_1",
"sci_2024_2",
"tech_2023_1",
"tech_2023_2",
"tech_2024_1",
"tech_2024_2",
"arts_2023_1",
"arts_2023_2",
"arts_2024_1",
"arts_2024_2",
]
embeddings = cast(
Embeddings,
[
# Science - closest to [1,0,0,0]
[1.0, 0.0, 0.0, 0.0], # sci_2023_1: score ~0.0
[0.9, 0.1, 0.0, 0.0], # sci_2023_2: score ~0.141
[0.8, 0.2, 0.0, 0.0], # sci_2024_1: score ~0.283
[0.7, 0.3, 0.0, 0.0], # sci_2024_2: score ~0.424
# Tech - farther from [1,0,0,0]
[0.0, 1.0, 0.0, 0.0], # tech_2023_1: score ~1.414
[0.0, 0.9, 0.1, 0.0], # tech_2023_2: score ~1.345
[0.0, 0.8, 0.2, 0.0], # tech_2024_1: score ~1.281
[0.0, 0.7, 0.3, 0.0], # tech_2024_2: score ~1.221
# Arts - farther from [1,0,0,0]
[0.0, 0.0, 1.0, 0.0], # arts_2023_1: score ~1.414
[0.0, 0.0, 0.9, 0.1], # arts_2023_2: score ~1.345
[0.0, 0.0, 0.8, 0.2], # arts_2024_1: score ~1.281
[0.0, 0.0, 0.7, 0.3], # arts_2024_2: score ~1.221
],
)
metadatas: Metadatas = [
{"category": "science", "year": 2023, "priority": 1},
{"category": "science", "year": 2023, "priority": 2},
{"category": "science", "year": 2024, "priority": 1},
{"category": "science", "year": 2024, "priority": 3},
{"category": "tech", "year": 2023, "priority": 2},
{"category": "tech", "year": 2023, "priority": 1},
{"category": "tech", "year": 2024, "priority": 1},
{"category": "tech", "year": 2024, "priority": 2},
{"category": "arts", "year": 2023, "priority": 3},
{"category": "arts", "year": 2023, "priority": 1},
{"category": "arts", "year": 2024, "priority": 2},
{"category": "arts", "year": 2024, "priority": 1},
]
documents = [f"doc_{id}" for id in ids]
coll.add(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
)
query = [1.0, 0.0, 0.0, 0.0]
# Test 1: Single key grouping - top 2 per category by score
# Expected: 2 best from each category (science, tech, arts)
# - science: sci_2023_1 (0.0), sci_2023_2 (0.141)
# - tech: tech_2024_2 (1.221), tech_2024_1 (1.281)
# - arts: arts_2024_2 (1.221), arts_2024_1 (1.281)
results1 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=2)))
.limit(12)
)
assert results1["ids"] is not None
result1_ids = results1["ids"][0]
assert len(result1_ids) == 6
expected1 = {
"sci_2023_1",
"sci_2023_2",
"tech_2024_2",
"tech_2024_1",
"arts_2024_2",
"arts_2024_1",
}
assert set(result1_ids) == expected1
# Test 2: Multiple key grouping - top 1 per (category, year) combination
# 6 groups: (science,2023), (science,2024), (tech,2023), (tech,2024), (arts,2023), (arts,2024)
results2 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(
GroupBy(
keys=[Key("category"), Key("year")],
aggregate=MinK(keys=Key.SCORE, k=1),
)
)
.limit(12)
)
assert results2["ids"] is not None
result2_ids = results2["ids"][0]
assert len(result2_ids) == 6
expected2 = {
"sci_2023_1",
"sci_2024_1",
"tech_2023_2",
"tech_2024_2",
"arts_2023_2",
"arts_2024_2",
}
assert set(result2_ids) == expected2
# Test 3: Multiple ranking keys - priority first, then score as tiebreaker
# Top 2 per category, sorted by priority (ascending), then score (ascending)
results3 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=2),
)
)
.limit(12)
)
assert results3["ids"] is not None
result3_ids = results3["ids"][0]
assert len(result3_ids) == 6
expected3 = {
"sci_2023_1",
"sci_2024_1",
"tech_2024_1",
"tech_2023_2",
"arts_2024_2",
"arts_2023_2",
}
assert set(result3_ids) == expected3

View File

@@ -1983,7 +1983,9 @@ def test_sparse_vector_in_metadata_validation():
with pytest.raises(ValueError, match="SparseVector values must be numbers"):
invalid_metadata_4 = {
"text": "non-numeric value",
"sparse_embedding": SparseVector(indices=[0, 1], values=[0.1, "not_a_number"]), # type: ignore
"sparse_embedding": SparseVector(
indices=[0, 1], values=[0.1, "not_a_number"]
), # type: ignore
}
# Test 7: Multiple sparse vectors in metadata
@@ -2683,6 +2685,59 @@ def test_rrf_to_dict() -> None:
print("All RRF tests passed!")
def test_group_by_serialization() -> None:
"""Test GroupBy, MinK, and MaxK serialization and deserialization."""
import pytest
from chromadb.execution.expression.operator import (
GroupBy,
MinK,
MaxK,
Key,
Aggregate,
)
# to_dict with OneOrMany keys
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
assert group_by.to_dict() == {
"keys": ["category"],
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
}
# to_dict with multiple keys and MaxK
group_by = GroupBy(
keys=[Key("year"), Key("category")],
aggregate=MaxK(keys=[Key.SCORE, Key("priority")], k=5),
)
assert group_by.to_dict() == {
"keys": ["year", "category"],
"aggregate": {"$max_k": {"keys": ["#score", "priority"], "k": 5}},
}
# Round-trip
original = GroupBy(keys=[Key("category")], aggregate=MinK(keys=[Key.SCORE], k=3))
assert GroupBy.from_dict(original.to_dict()).to_dict() == original.to_dict()
# Empty GroupBy serializes to {} and from_dict({}) returns default GroupBy
empty_group_by = GroupBy()
assert empty_group_by.to_dict() == {}
assert GroupBy.from_dict({}).to_dict() == {}
# Error cases
with pytest.raises(ValueError, match="requires 'keys' field"):
GroupBy.from_dict({"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}})
with pytest.raises(ValueError, match="requires 'aggregate' field"):
GroupBy.from_dict({"keys": ["category"]})
with pytest.raises(ValueError, match="keys cannot be empty"):
GroupBy.from_dict(
{"keys": [], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
)
with pytest.raises(ValueError, match="Unknown aggregate operator"):
Aggregate.from_dict({"$unknown": {"keys": ["#score"], "k": 3}})
# Expression API Tests - Testing dict support and from_dict methods
class TestSearchDictSupport:
"""Test Search class dict input support."""
@@ -2786,6 +2841,49 @@ class TestSearchDictSupport:
with pytest.raises(TypeError, match="select must be"):
Search(select=123)
def test_search_with_group_by(self):
"""Test Search accepts group_by as dict, object, and builder method."""
import pytest
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import GroupBy, MinK, Key
# Dict input
search = Search(
group_by={
"keys": ["category"],
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
}
)
assert isinstance(search._group_by, GroupBy)
# Object input and builder method
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
assert Search(group_by=group_by)._group_by is group_by
assert Search().group_by(group_by)._group_by.aggregate is not None
# Invalid inputs
with pytest.raises(TypeError, match="group_by must be"):
Search(group_by="invalid")
with pytest.raises(ValueError, match="requires 'aggregate' field"):
Search(group_by={"keys": ["category"]})
def test_search_group_by_serialization(self):
"""Test Search serializes group_by correctly."""
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import GroupBy, MinK, Key, Knn
# Without group_by - empty dict
search = Search().rank(Knn(query=[0.1, 0.2])).limit(10)
assert search.to_dict()["group_by"] == {}
# With group_by - has keys and aggregate
search = Search().group_by(
GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
)
result = search.to_dict()["group_by"]
assert result["keys"] == ["category"]
assert result["aggregate"] == {"$min_k": {"keys": ["#score"], "k": 3}}
class TestWhereFromDict:
"""Test Where.from_dict() conversion."""
@@ -3310,3 +3408,27 @@ class TestRoundTripConversion:
return d1 == d2
assert compare_search_dicts(new_dict, search_dict)
def test_search_round_trip_with_group_by(self):
"""Test Search round-trip with group_by."""
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import Key, GroupBy, MinK
original = Search(
where=Key("status") == "active",
group_by=GroupBy(
keys=[Key("category")],
aggregate=MinK(keys=[Key.SCORE], k=3),
),
)
# Verify group_by round-trip
search_dict = original.to_dict()
assert search_dict["group_by"]["keys"] == ["category"]
assert search_dict["group_by"]["aggregate"] == {
"$min_k": {"keys": ["#score"], "k": 3}
}
# Reconstruct and compare group_by
restored = Search(group_by=GroupBy.from_dict(search_dict["group_by"]))
assert restored.to_dict()["group_by"] == search_dict["group_by"]

View File

@@ -1,10 +1,11 @@
import asyncio
from typing import Any, Callable, Generator, cast
from unittest.mock import patch
from typing import Any, Callable, Generator, cast, Dict, Tuple
from unittest.mock import MagicMock, patch
import chromadb
from chromadb.config import Settings
from chromadb.config import Settings, System
from chromadb.api import ClientAPI
import chromadb.server.fastapi
from chromadb.api.fastapi import FastAPI
import pytest
import tempfile
import os
@@ -110,3 +111,43 @@ def test_http_client_with_inconsistent_port_settings(
str(e)
== "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]"
)
def make_sync_client_factory() -> Tuple[Callable[..., Any], Dict[str, Any]]:
captured: Dict[str, Any] = {}
# takes any positional args to match httpx.Client
def factory(*_: Any, **kwargs: Any) -> Any:
captured.update(kwargs)
session = MagicMock()
session.headers = {}
return session
return factory, captured
def test_fastapi_uses_http_limits_from_settings() -> None:
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host="localhost",
chroma_server_http_port=9000,
chroma_server_ssl_verify=True,
chroma_http_keepalive_secs=12.5,
chroma_http_max_connections=64,
chroma_http_max_keepalive_connections=16,
)
system = System(settings)
factory, captured = make_sync_client_factory()
with patch.object(FastAPI, "require", side_effect=[MagicMock(), MagicMock()]):
with patch("chromadb.api.fastapi.httpx.Client", side_effect=factory):
api = FastAPI(system)
api.stop()
limits = captured["limits"]
assert limits.keepalive_expiry == 12.5
assert limits.max_connections == 64
assert limits.max_keepalive_connections == 16
assert captured["timeout"] is None
assert captured["verify"] is True

View File

@@ -189,3 +189,21 @@ def test_runtime_dependencies() -> None:
assert data.starts == ["D", "C"]
system.stop()
assert data.stops == ["C", "D"]
def test_http_client_setting_defaults() -> None:
settings = Settings()
assert settings.chroma_http_keepalive_secs == 40.0
assert settings.chroma_http_max_connections is None
assert settings.chroma_http_max_keepalive_connections is None
def test_http_client_setting_overrides() -> None:
settings = Settings(
chroma_http_keepalive_secs=5.5,
chroma_http_max_connections=123,
chroma_http_max_keepalive_connections=17,
)
assert settings.chroma_http_keepalive_secs == 5.5
assert settings.chroma_http_max_connections == 123
assert settings.chroma_http_max_keepalive_connections == 17

View File

@@ -1,5 +1,5 @@
import pytest
from typing import List, Any, Callable
from typing import List, Any, Callable, Dict
from jsonschema import ValidationError
from unittest.mock import MagicMock, create_autospec
from chromadb.utils.embedding_functions.schemas import (
@@ -7,7 +7,10 @@ from chromadb.utils.embedding_functions.schemas import (
load_schema,
get_available_schemas,
)
from chromadb.utils.embedding_functions import known_embedding_functions
from chromadb.utils.embedding_functions import (
known_embedding_functions,
sparse_known_embedding_functions,
)
from chromadb.api.types import Documents, Embeddings
from pytest import MonkeyPatch
@@ -143,3 +146,306 @@ class TestEmbeddingFunctionSchemas:
if schema.get("additionalProperties", True) is False:
with pytest.raises(ValidationError):
validate_config_schema(test_config, schema_name)
def _create_valid_config_from_schema(
self, schema: Dict[str, Any]
) -> Dict[str, Any]:
"""Create a valid config from a schema by filling in required fields"""
config: Dict[str, Any] = {}
if "required" in schema and "properties" in schema:
for field in schema["required"]:
if field in schema["properties"]:
field_schema = schema["properties"][field]
config[field] = self._get_value_from_field_schema(field_schema)
return config
def _get_value_from_field_schema(self, field_schema: Dict[str, Any]) -> Any:
"""Get a valid value from a field schema"""
# Handle enums - use first enum value
if "enum" in field_schema:
return field_schema["enum"][0]
# Handle type (could be a list or single value)
field_type = field_schema.get("type")
if field_type is None:
return "dummy" # Fallback if no type specified
if isinstance(field_type, list):
# If null is in the type list, prefer non-null type
non_null_types = [t for t in field_type if t != "null"]
field_type = non_null_types[0] if non_null_types else field_type[0]
if field_type == "object":
# Handle nested objects
nested_config = {}
if "properties" in field_schema:
nested_required = field_schema.get("required", [])
for prop in nested_required:
if prop in field_schema["properties"]:
nested_config[prop] = self._get_value_from_field_schema(
field_schema["properties"][prop]
)
return nested_config if nested_config else {}
if field_type == "array":
# Return empty array for arrays
return []
# Use the existing dummy value method for primitive types
return self._get_dummy_value(field_type)
def _has_custom_validation(self, ef_class: Any) -> bool:
"""Check if validate_config actually validates (not just base implementation)"""
try:
# Try with an obviously invalid config - if it doesn't raise, it's base implementation
invalid_config = {"__invalid_test_config__": True}
try:
ef_class.validate_config(invalid_config)
# If we get here without exception, it's using base implementation
return False
except (ValidationError, ValueError, FileNotFoundError):
# If it raises any validation-related error, it's actually validating
return True
except Exception:
# Any other exception means it's trying to validate (e.g., schema not found)
return True
def _setup_env_vars_for_ef(
self, ef_name: str, mock_common_deps: MonkeyPatch
) -> None:
"""Set up environment variables needed for embedding function instantiation"""
# Map of embedding function names to their default API key environment variable names
api_key_env_vars = {
"cohere": "CHROMA_COHERE_API_KEY",
"openai": "CHROMA_OPENAI_API_KEY",
"huggingface": "CHROMA_HUGGINGFACE_API_KEY",
"huggingface_server": "CHROMA_HUGGINGFACE_API_KEY",
"google_palm": "CHROMA_GOOGLE_PALM_API_KEY",
"google_generative_ai": "CHROMA_GOOGLE_GENAI_API_KEY",
"google_vertex": "CHROMA_GOOGLE_VERTEX_API_KEY",
"jina": "CHROMA_JINA_API_KEY",
"mistral": "MISTRAL_API_KEY",
"morph": "MORPH_API_KEY",
"voyageai": "CHROMA_VOYAGE_API_KEY",
"cloudflare_workers_ai": "CHROMA_CLOUDFLARE_API_KEY",
"together_ai": "CHROMA_TOGETHER_AI_API_KEY",
"baseten": "CHROMA_BASETEN_API_KEY",
"roboflow": "CHROMA_ROBOFLOW_API_KEY",
"amazon_bedrock": "AWS_ACCESS_KEY_ID", # AWS uses different env vars
"chroma-cloud-qwen": "CHROMA_API_KEY",
# Sparse embedding functions
"chroma-cloud-splade": "CHROMA_API_KEY",
}
# Set API key environment variable if needed
if ef_name in api_key_env_vars:
mock_common_deps.setenv(api_key_env_vars[ef_name], "test-api-key")
# Special cases that need additional environment variables
if ef_name == "amazon_bedrock":
mock_common_deps.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
mock_common_deps.setenv("AWS_REGION", "us-east-1")
def _create_ef_instance(
self, ef_name: str, ef_class: Any, mock_common_deps: MonkeyPatch
) -> Any:
"""Create an embedding function instance, handling special cases"""
# Set up environment variables first
self._setup_env_vars_for_ef(ef_name, mock_common_deps)
# Mock missing modules that are imported inside __init__ methods
import sys
# Create mock modules
mock_pil = MagicMock()
mock_pil_image = MagicMock()
mock_google_genai = MagicMock()
mock_vertexai = MagicMock()
mock_vertexai_lm = MagicMock()
mock_boto3 = MagicMock()
mock_jina = MagicMock()
mock_mistralai = MagicMock()
# Mock boto3.Session for amazon_bedrock
mock_boto3_session = MagicMock()
mock_session_instance = MagicMock()
mock_session_instance.region_name = "us-east-1"
mock_session_instance.profile_name = None
mock_session_instance.client.return_value = MagicMock()
mock_boto3_session.return_value = mock_session_instance
mock_boto3.Session = mock_boto3_session
# Mock vertexai.init and TextEmbeddingModel
mock_text_embedding_model = MagicMock()
mock_text_embedding_model.from_pretrained.return_value = MagicMock()
mock_vertexai_lm.TextEmbeddingModel = mock_text_embedding_model
mock_vertexai.language_models = mock_vertexai_lm
mock_vertexai.init = MagicMock()
# Mock google.generativeai - need to set up google module first
mock_google = MagicMock()
mock_google_genai.configure = MagicMock() # For palm.configure()
mock_google_genai.GenerativeModel = MagicMock(return_value=MagicMock())
mock_google.generativeai = mock_google_genai
# Mock jina Client
mock_jina.Client = MagicMock()
# Mock mistralai
mock_mistral_client = MagicMock()
mock_mistral_client.return_value.embeddings.create.return_value.data = [
MagicMock(embedding=[0.1, 0.2, 0.3])
]
mock_mistralai.Mistral = mock_mistral_client
# Add missing modules to sys.modules using monkeypatch
modules_to_mock = {
"PIL": mock_pil,
"PIL.Image": mock_pil_image,
"google": mock_google,
"google.generativeai": mock_google_genai,
"vertexai": mock_vertexai,
"vertexai.language_models": mock_vertexai_lm,
"boto3": mock_boto3,
"jina": mock_jina,
"mistralai": mock_mistralai,
}
for module_name, mock_module in modules_to_mock.items():
mock_common_deps.setitem(sys.modules, module_name, mock_module)
# Special cases that need additional arguments
if ef_name == "cloudflare_workers_ai":
return ef_class(
model_name="test-model",
account_id="test-account-id",
)
elif ef_name == "baseten":
# Baseten needs api_key explicitly passed even with env var
return ef_class(
api_key="test-api-key",
api_base="https://test.api.baseten.co",
)
elif ef_name == "amazon_bedrock":
# Amazon Bedrock needs a boto3 session - create a mock session
# boto3 is already mocked in sys.modules above
mock_session = mock_boto3.Session(region_name="us-east-1")
return ef_class(
session=mock_session,
model_name="amazon.titan-embed-text-v1",
)
elif ef_name == "huggingface_server":
return ef_class(url="http://localhost:8080")
elif ef_name == "google_vertex":
return ef_class(project_id="test-project", region="us-central1")
elif ef_name == "mistral":
return ef_class(model="mistral-embed")
elif ef_name == "roboflow":
return ef_class() # No model_name needed
elif ef_name == "chroma-cloud-qwen":
from chromadb.utils.embedding_functions.chroma_cloud_qwen_embedding_function import (
ChromaCloudQwenEmbeddingModel,
)
return ef_class(
model=ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
task="nl_to_code",
)
else:
# Try with no args first
try:
return ef_class()
except Exception:
# If that fails, try with common minimal args
return ef_class(model_name="test-model")
@pytest.mark.parametrize("ef_name", get_embedding_function_names())
def test_validate_config_with_schema(
self,
ef_name: str,
mock_embeddings: Callable[[Documents], Embeddings],
mock_common_deps: MonkeyPatch,
) -> None:
"""Test that validate_config works correctly with actual configs from embedding functions"""
ef_class = known_embedding_functions[ef_name]
# Skip if the embedding function doesn't have a validate_config method
if not hasattr(ef_class, "validate_config"):
pytest.skip(f"{ef_name} does not have validate_config method")
# Check if it's callable (static methods are callable on the class)
if not callable(getattr(ef_class, "validate_config", None)):
pytest.skip(f"{ef_name} validate_config is not callable")
# Skip if using base implementation (doesn't actually validate)
if not self._has_custom_validation(ef_class):
pytest.skip(
f"{ef_name} uses base validate_config implementation (no validation)"
)
# Create a real instance to get the actual config
# We'll mock __call__ to avoid needing to actually generate embeddings
try:
ef_instance = self._create_ef_instance(ef_name, ef_class, mock_common_deps)
except Exception as e:
pytest.skip(
f"{ef_name} requires arguments that we cannot provide without external deps: {e}"
)
# Mock only __call__ to avoid needing to actually generate embeddings
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
# Get the actual config from the embedding function (this uses the real get_config method)
config = ef_instance.get_config()
# Filter out None values - optional fields with None shouldn't be included in validation
# This matches common JSON schema practice where optional fields are omitted rather than null
config = {k: v for k, v in config.items() if v is not None}
# Validate the actual config using the embedding function's validate_config method
ef_class.validate_config(config)
def test_validate_config_sparse_embedding_functions(
self,
mock_embeddings: Callable[[Documents], Embeddings],
mock_common_deps: MonkeyPatch,
) -> None:
"""Test validate_config for sparse embedding functions with actual configs"""
for ef_name, ef_class in sparse_known_embedding_functions.items():
# Skip if the embedding function doesn't have a validate_config method
if not hasattr(ef_class, "validate_config"):
continue
# Check if it's callable (static methods are callable on the class)
if not callable(getattr(ef_class, "validate_config", None)):
continue
# Skip if using base implementation (doesn't actually validate)
if not self._has_custom_validation(ef_class):
continue
# Create a real instance to get the actual config
# We'll mock __call__ to avoid needing to actually generate embeddings
try:
ef_instance = self._create_ef_instance(
ef_name, ef_class, mock_common_deps
)
except Exception:
continue # Skip if we can't create instance
# Mock only __call__ to avoid needing to actually generate embeddings
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
# Get the actual config from the embedding function (this uses the real get_config method)
config = ef_instance.get_config()
# Filter out None values - optional fields with None shouldn't be included in validation
# This matches common JSON schema practice where optional fields are omitted rather than null
config = {k: v for k, v in config.items() if v is not None}
# Validate the actual config using the embedding function's validate_config method
ef_class.validate_config(config)

View File

@@ -23,6 +23,7 @@ from chromadb.utils.embedding_functions.google_embedding_function import (
GooglePalmEmbeddingFunction,
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
GoogleGenaiEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
@@ -68,6 +69,10 @@ from chromadb.utils.embedding_functions.mistral_embedding_function import (
from chromadb.utils.embedding_functions.morph_embedding_function import (
MorphEmbeddingFunction,
)
from chromadb.utils.embedding_functions.nomic_embedding_function import (
NomicEmbeddingFunction,
NomicQueryConfig,
)
from chromadb.utils.embedding_functions.huggingface_sparse_embedding_function import (
HuggingFaceSparseEmbeddingFunction,
)
@@ -98,11 +103,13 @@ _all_classes: Set[str] = {
"GooglePalmEmbeddingFunction",
"GoogleGenerativeAiEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"GoogleGenaiEmbeddingFunction",
"OllamaEmbeddingFunction",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"MistralEmbeddingFunction",
"MorphEmbeddingFunction",
"NomicEmbeddingFunction",
"VoyageAIEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OpenCLIPEmbeddingFunction",
@@ -137,11 +144,13 @@ known_embedding_functions: Dict[str, Type[EmbeddingFunction]] = { # type: ignor
"google_palm": GooglePalmEmbeddingFunction,
"google_generative_ai": GoogleGenerativeAiEmbeddingFunction,
"google_vertex": GoogleVertexEmbeddingFunction,
"google_genai": GoogleGenaiEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"mistral": MistralEmbeddingFunction,
"morph": MorphEmbeddingFunction,
"nomic": NomicEmbeddingFunction,
"voyageai": VoyageAIEmbeddingFunction,
"onnx_mini_lm_l6_v2": ONNXMiniLM_L6_V2,
"open_clip": OpenCLIPEmbeddingFunction,
@@ -259,12 +268,15 @@ __all__ = [
"GooglePalmEmbeddingFunction",
"GoogleGenerativeAiEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"GoogleGenaiEmbeddingFunction",
"OllamaEmbeddingFunction",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"JinaQueryConfig",
"MistralEmbeddingFunction",
"MorphEmbeddingFunction",
"NomicEmbeddingFunction",
"NomicQueryConfig",
"VoyageAIEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OpenCLIPEmbeddingFunction",

View File

@@ -2,6 +2,7 @@ import os
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import Dict, Any, Optional, List
from chromadb.api.types import Space
import warnings
@@ -35,12 +36,16 @@ class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
if os.getenv("BASETEN_API_KEY") is not None:
self.api_key_env_var = "BASETEN_API_KEY"
else:
self.api_key_env_var = api_key_env_var
# Prioritize api_key argument, then environment variable
resolved_api_key = api_key or os.getenv(api_key_env_var)
resolved_api_key = api_key or os.getenv(self.api_key_env_var)
if not resolved_api_key:
raise ValueError(
f"API key not provided and {api_key_env_var} environment variable is not set."
f"API key not provided and {self.api_key_env_var} environment variable is not set."
)
self.api_key = resolved_api_key
if not api_base:
@@ -96,3 +101,16 @@ class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
api_base=api_base,
api_key_env_var=api_key_env_var,
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "baseten")

View File

@@ -231,4 +231,4 @@ class Bm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "bm25")
validate_config_schema(config, "bm25")

View File

@@ -23,12 +23,32 @@ DEFAULT_TOKEN_MAX_LENGTH = 40
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)
class _HashedToken:
__slots__ = ("hash", "label")
def __init__(self, hash: int, label: Optional[str]):
self.hash = hash
self.label = label
def __hash__(self) -> int:
return self.hash
def __eq__(self, other: object) -> bool:
if not isinstance(other, _HashedToken):
return NotImplemented
return self.hash == other.hash
def __lt__(self, other: "_HashedToken") -> bool:
return self.hash < other.hash
class ChromaBm25Config(TypedDict, total=False):
k: float
b: float
avg_doc_length: float
token_max_length: int
stopwords: List[str]
include_tokens: bool
class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
@@ -39,6 +59,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
stopwords: Optional[Iterable[str]] = None,
include_tokens: bool = False,
) -> None:
"""Initialize the BM25 sparse embedding function."""
@@ -46,38 +67,51 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
self.b = float(b)
self.avg_doc_length = float(avg_doc_length)
self.token_max_length = int(token_max_length)
self.include_tokens = bool(include_tokens)
if stopwords is not None:
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
stopword_list: Iterable[str] = self.stopwords
self._stopword_list: Iterable[str] = self.stopwords
else:
self.stopwords = None
stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
self._stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
stemmer = get_english_stemmer()
self._tokenizer = Bm25Tokenizer(stemmer, stopword_list, self.token_max_length)
self._hasher = Murmur3AbsHasher()
def _encode(self, text: str) -> SparseVector:
tokens = self._tokenizer.tokenize(text)
stemmer = get_english_stemmer()
tokenizer = Bm25Tokenizer(stemmer, self._stopword_list, self.token_max_length)
tokens = tokenizer.tokenize(text)
if not tokens:
return SparseVector(indices=[], values=[])
doc_len = float(len(tokens))
counts = Counter(self._hasher.hash(token) for token in tokens)
counts = Counter(
_HashedToken(
self._hasher.hash(token), token if self.include_tokens else None
)
for token in tokens
)
indices = sorted(counts.keys())
sorted_keys = sorted(counts.keys())
indices: List[int] = []
values: List[float] = []
for idx in indices:
tf = float(counts[idx])
labels: Optional[List[str]] = [] if self.include_tokens else None
for key in sorted_keys:
tf = float(counts[key])
denominator = tf + self.k * (
1 - self.b + (self.b * doc_len) / self.avg_doc_length
)
score = tf * (self.k + 1) / denominator
values.append(score)
return SparseVector(indices=indices, values=values)
indices.append(key.hash)
values.append(score)
if labels is not None and key.label is not None:
labels.append(key.label)
return SparseVector(indices=indices, values=values, labels=labels)
def __call__(self, input: Documents) -> SparseVectors:
sparse_vectors: SparseVectors = []
@@ -99,7 +133,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
@staticmethod
def build_from_config(
config: Dict[str, Any]
config: Dict[str, Any],
) -> "SparseEmbeddingFunction[Documents]":
return ChromaBm25EmbeddingFunction(
k=config.get("k", DEFAULT_K),
@@ -107,6 +141,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
stopwords=config.get("stopwords"),
include_tokens=config.get("include_tokens", False),
)
def get_config(self) -> Dict[str, Any]:
@@ -115,6 +150,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
"b": self.b,
"avg_doc_length": self.avg_doc_length,
"token_max_length": self.token_max_length,
"include_tokens": self.include_tokens,
}
if self.stopwords is not None:
@@ -125,7 +161,14 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords"}
mutable_keys = {
"k",
"b",
"avg_doc_length",
"token_max_length",
"stopwords",
"include_tokens",
}
for key in new_config:
if key not in mutable_keys:
raise ValueError(f"Updating '{key}' is not supported for {NAME}")

View File

@@ -1,8 +1,9 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from typing import List, Dict, Any, Union
from typing import List, Dict, Any, Union, Optional
import os
import numpy as np
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.embedding_functions.utils import _get_shared_system_client
from enum import Enum
@@ -32,7 +33,7 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
model: ChromaCloudQwenEmbeddingModel,
task: str,
task: Optional[str],
instructions: ChromaCloudQwenEmbeddingInstructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
api_key_env_var: str = "CHROMA_API_KEY",
):
@@ -41,7 +42,8 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
Args:
model (ChromaCloudQwenEmbeddingModel): The specific Qwen model to use for embeddings.
task (str): The task for which embeddings are being generated.
task (str, optional): The task for which embeddings are being generated. If None or empty,
empty instructions will be used for both documents and queries.
instructions (ChromaCloudQwenEmbeddingInstructions, optional): A dictionary containing
custom instructions to use for the specified Qwen model. Defaults to CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS.
api_key_env_var (str, optional): Environment variable name that contains your API key.
@@ -55,9 +57,18 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
)
self.api_key_env_var = api_key_env_var
# First, try to get API key from environment variable
self.api_key = os.getenv(api_key_env_var)
# If not found in env var, try to get it from existing client instances
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
SharedSystemClient = _get_shared_system_client()
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
# Raise error if still no API key found
if not self.api_key:
raise ValueError(
f"API key not found in environment variable {api_key_env_var} "
f"or in any existing client instances"
)
self.model = model
self.task = task
@@ -102,10 +113,14 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
if not input:
return []
payload: Dict[str, Union[str, Documents]] = {
"instructions": self.instructions[self.task][
instruction = ""
if self.task and self.task in self.instructions:
instruction = self.instructions[self.task][
ChromaCloudQwenEmbeddingTarget.DOCUMENTS
],
]
payload: Dict[str, Union[str, Documents]] = {
"instructions": instruction,
"texts": input,
}
@@ -120,10 +135,14 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
if not input:
return []
payload: Dict[str, Union[str, Documents]] = {
"instructions": self.instructions[self.task][
instruction = ""
if self.task and self.task in self.instructions:
instruction = self.instructions[self.task][
ChromaCloudQwenEmbeddingTarget.QUERY
],
]
payload: Dict[str, Union[str, Documents]] = {
"instructions": instruction,
"texts": input,
}

View File

@@ -1,15 +1,16 @@
from chromadb.api.types import (
SparseEmbeddingFunction,
SparseVector,
SparseVectors,
Documents,
)
from typing import Dict, Any
from typing import Dict, Any, List, Optional
from enum import Enum
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
from chromadb.base_types import SparseVector
import os
from typing import Union
from chromadb.utils.embedding_functions.utils import _get_shared_system_client
class ChromaCloudSpladeEmbeddingModel(Enum):
@@ -21,6 +22,7 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
self,
api_key_env_var: str = "CHROMA_API_KEY",
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
include_tokens: bool = False,
):
"""
Initialize the ChromaCloudSpladeEmbeddingFunction.
@@ -36,12 +38,20 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
self.api_key_env_var = api_key_env_var
# First, try to get API key from environment variable
self.api_key = os.getenv(self.api_key_env_var)
# If not found in env var, try to get it from existing client instances
if not self.api_key:
SharedSystemClient = _get_shared_system_client()
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
# Raise error if still no API key found
if not self.api_key:
raise ValueError(
f"API key not found in environment variable {self.api_key_env_var}"
f"API key not found in environment variable {self.api_key_env_var} "
f"or in any existing client instances"
)
self.model = model
self.include_tokens = bool(include_tokens)
self._api_url = "https://embed.trychroma.com/embed_sparse"
self._session = httpx.Client()
self._session.headers.update(
@@ -80,6 +90,7 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
"texts": list(input),
"task": "",
"target": "",
"fetch_tokens": "true" if self.include_tokens is True else "false",
}
try:
@@ -113,13 +124,17 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
if isinstance(emb, dict):
indices = emb.get("indices", [])
values = emb.get("values", [])
raw_labels = emb.get("labels") if self.include_tokens else None
labels: Optional[List[str]] = raw_labels if raw_labels else None
else:
# Already a SparseVector, extract its data
assert isinstance(emb, SparseVector)
indices = emb.indices
values = emb.values
labels = emb.labels if self.include_tokens else None
normalized_vectors.append(
normalize_sparse_vector(indices=indices, values=values)
normalize_sparse_vector(indices=indices, values=values, labels=labels)
)
return normalized_vectors
@@ -141,18 +156,25 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
return ChromaCloudSpladeEmbeddingFunction(
api_key_env_var=api_key_env_var,
model=ChromaCloudSpladeEmbeddingModel(model),
include_tokens=config.get("include_tokens", False),
)
def get_config(self) -> Dict[str, Any]:
return {"api_key_env_var": self.api_key_env_var, "model": self.model.value}
return {
"api_key_env_var": self.api_key_env_var,
"model": self.model.value,
"include_tokens": self.include_tokens,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model" in new_config:
raise ValueError(
"model cannot be changed after the embedding function has been initialized"
)
immutable_keys = {"include_tokens", "model"}
for key in immutable_keys:
if key in new_config and new_config[key] != old_config.get(key):
raise ValueError(
f"Updating '{key}' is not supported for chroma-cloud-splade"
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:

View File

@@ -53,12 +53,19 @@ class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
)
self.model_name = model_name
self.account_id = account_id
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("CLOUDFLARE_API_KEY") is not None:
self.api_key_env_var = "CLOUDFLARE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
self.gateway_id = gateway_id
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
if self.gateway_id:
self._api_url = f"{GATEWAY_BASE_URL}/{self.account_id}/{self.gateway_id}/workers-ai/{self.model_name}"

View File

@@ -43,10 +43,16 @@ class CohereEmbeddingFunction(EmbeddingFunction[Embeddable]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("COHERE_API_KEY") is not None:
self.api_key_env_var = "COHERE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name

View File

@@ -7,6 +7,150 @@ from chromadb.utils.embedding_functions.schemas import validate_config_schema
import warnings
class GoogleGenaiEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
model_name: str,
vertexai: Optional[bool] = None,
project: Optional[str] = None,
location: Optional[str] = None,
api_key_env_var: str = "GOOGLE_API_KEY",
):
"""
Initialize the GoogleGenaiEmbeddingFunction.
Args:
model_name (str): The name of the model to use for text embeddings.
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google GenAI API.
Defaults to "GOOGLE_API_KEY".
"""
try:
import google.genai as genai
except ImportError:
raise ValueError(
"The google-genai python package is not installed. Please install it with `pip install google-genai`"
)
self.model_name = model_name
self.api_key_env_var = api_key_env_var
self.vertexai = vertexai
self.project = project
self.location = location
self.api_key = os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.client = genai.Client(
api_key=self.api_key, vertexai=vertexai, project=project, location=location
)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents.
"""
if not input:
raise ValueError("Input documents cannot be empty")
if not isinstance(input, (list, tuple)):
raise ValueError("Input must be a list or tuple of documents")
if not all(isinstance(doc, str) for doc in input):
raise ValueError("All input documents must be strings")
try:
response = self.client.models.embed_content(
model=self.model_name, contents=input
)
except Exception as e:
raise ValueError(f"Failed to generate embeddings: {str(e)}") from e
# Validate response structure
if not hasattr(response, "embeddings") or not response.embeddings:
raise ValueError("No embeddings returned from the API")
embeddings_list = []
for ce in response.embeddings:
if not hasattr(ce, "values"):
raise ValueError("Malformed embedding response: missing 'values'")
embeddings_list.append(np.array(ce.values, dtype=np.float32))
return cast(Embeddings, embeddings_list)
@staticmethod
def name() -> str:
return "google_genai"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
model_name = config.get("model_name")
vertexai = config.get("vertexai")
project = config.get("project")
location = config.get("location")
if model_name is None:
raise ValueError("The model name is required.")
return GoogleGenaiEmbeddingFunction(
model_name=model_name,
vertexai=vertexai,
project=project,
location=location,
)
def get_config(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"vertexai": self.vertexai,
"project": self.project,
"location": self.location,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
if "vertexai" in new_config:
raise ValueError(
"The vertexai cannot be changed after the embedding function has been initialized."
)
if "project" in new_config:
raise ValueError(
"The project cannot be changed after the embedding function has been initialized."
)
if "location" in new_config:
raise ValueError(
"The location cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "google_genai")
class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""
@@ -38,10 +182,16 @@ class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("GOOGLE_API_KEY") is not None:
self.api_key_env_var = "GOOGLE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
@@ -154,10 +304,16 @@ class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("GOOGLE_API_KEY") is not None:
self.api_key_env_var = "GOOGLE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
self.task_type = task_type
@@ -289,10 +445,16 @@ class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("GOOGLE_API_KEY") is not None:
self.api_key_env_var = "GOOGLE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
self.project_id = project_id

View File

@@ -40,10 +40,16 @@ class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("HUGGINGFACE_API_KEY") is not None:
self.api_key_env_var = "HUGGINGFACE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
@@ -160,6 +166,9 @@ class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]):
self.url = url
self.api_key_env_var = api_key_env_var
if os.getenv("HUGGINGFACE_API_KEY") is not None:
self.api_key_env_var = "HUGGINGFACE_API_KEY"
if self.api_key_env_var is not None:
self.api_key = api_key or os.getenv(self.api_key_env_var)
else:

View File

@@ -81,10 +81,16 @@ class JinaEmbeddingFunction(EmbeddingFunction[Embeddable]):
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("JINA_API_KEY") is not None:
self.api_key_env_var = "JINA_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name

Some files were not shown because too many files have changed in this diff Show More