增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -33,10 +33,14 @@ from chromadb.execution.expression import ( # noqa: F401, F403
Sub,
Sum,
Val,
Aggregate,
MinK,
MaxK,
GroupBy,
)
from abc import ABC, abstractmethod
from typing import Sequence, Optional, List, Dict, Any
from typing import Sequence, Optional, List, Dict, Any, Tuple
from uuid import UUID
from overrides import override
@@ -824,7 +828,7 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
) -> Tuple["AttachedFunction", bool]:
"""Attach a function to a collection.
Args:
@@ -837,14 +841,40 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
database: The database name
Returns:
AttachedFunction: Object representing the attached function
Tuple of (AttachedFunction, created) where created is True if newly created,
False if already existed (idempotent request)
"""
pass
@abstractmethod
def get_attached_function(
self,
name: str,
input_collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Get an attached function by name for a specific collection.
Args:
name: Name of the attached function
input_collection_id: The collection ID
tenant: The tenant name
database: The database name
Returns:
AttachedFunction: The attached function object
Raises:
NotFoundError: If the attached function doesn't exist
"""
pass
@abstractmethod
def detach_function(
self,
attached_function_id: UUID,
name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
@@ -852,7 +882,8 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
"""Detach a function and prevent any further runs.
Args:
attached_function_id: ID of the attached function to remove
name: Name of the attached function to remove
input_collection_id: ID of the input collection
delete_output: Whether to also delete the output collection
tenant: The tenant name
database: The database name

View File

@@ -2,7 +2,7 @@ import asyncio
from uuid import UUID
import urllib.parse
import orjson
from typing import Any, Optional, cast, Tuple, Sequence, Dict, List
from typing import Any, Mapping, Optional, cast, Tuple, Sequence, Dict, List
import logging
import httpx
from overrides import override
@@ -130,16 +130,23 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
+ " (https://github.com/chroma-core/chroma)"
)
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
self._clients[loop_hash] = httpx.AsyncClient(
timeout=None,
headers=headers,
verify=self._settings.chroma_server_ssl_verify or False,
limits=limits,
limits=self.http_limits,
)
return self._clients[loop_hash]
@override
def get_request_headers(self) -> Mapping[str, str]:
return dict(self._get_client().headers)
@override
def get_api_url(self) -> str:
return self._api_url
async def _make_request(
self, method: str, path: str, **kwargs: Dict[str, Any]
) -> Any:
@@ -527,7 +534,7 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
return GetResult(
ids=resp_json["ids"],
embeddings=resp_json.get("embeddings", None),
metadatas=metadatas, # type: ignore
metadatas=metadatas,
documents=resp_json.get("documents", None),
data=None,
uris=resp_json.get("uris", None),
@@ -723,7 +730,7 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
ids=resp_json["ids"],
distances=resp_json.get("distances", None),
embeddings=resp_json.get("embeddings", None),
metadatas=metadata_batches, # type: ignore
metadatas=metadata_batches,
documents=resp_json.get("documents", None),
uris=resp_json.get("uris", None),
data=None,

View File

@@ -1,19 +1,51 @@
from typing import Any, Dict, Optional, TypeVar
from typing import Any, Dict, Mapping, Optional, TypeVar
from urllib.parse import quote, urlparse, urlunparse
import logging
import orjson as json
import httpx
import chromadb.errors as errors
from chromadb.config import Settings
from chromadb.config import Component, Settings, System
logger = logging.getLogger(__name__)
class BaseHTTPClient:
# inherits from Component so that it can create an init function to use system
# this way it can build limits from the settings in System
class BaseHTTPClient(Component):
_settings: Settings
pre_flight_checks: Any = None
keepalive_secs: int = 40
DEFAULT_KEEPALIVE_SECS: float = 40.0
def __init__(self, system: System):
super().__init__(system)
self._settings = system.settings
keepalive_setting = self._settings.chroma_http_keepalive_secs
self.keepalive_secs: Optional[float] = (
keepalive_setting
if keepalive_setting is not None
else BaseHTTPClient.DEFAULT_KEEPALIVE_SECS
)
self._http_limits = self._build_limits()
def _build_limits(self) -> httpx.Limits:
limit_kwargs: Dict[str, Any] = {}
if self.keepalive_secs is not None:
limit_kwargs["keepalive_expiry"] = self.keepalive_secs
max_connections = self._settings.chroma_http_max_connections
if max_connections is not None:
limit_kwargs["max_connections"] = max_connections
max_keepalive_connections = self._settings.chroma_http_max_keepalive_connections
if max_keepalive_connections is not None:
limit_kwargs["max_keepalive_connections"] = max_keepalive_connections
return httpx.Limits(**limit_kwargs)
@property
def http_limits(self) -> httpx.Limits:
return self._http_limits
@staticmethod
def _validate_host(host: str) -> None:
@@ -103,3 +135,11 @@ class BaseHTTPClient:
if trace_id:
raise Exception(f"{resp.text} (trace ID: {trace_id})")
raise (Exception(resp.text))
def get_request_headers(self) -> Mapping[str, str]:
"""Return headers used for HTTP requests."""
return {}
def get_api_url(self) -> str:
"""Return the API URL for this client."""
return ""

View File

@@ -1,6 +1,6 @@
import orjson
import logging
from typing import Any, Dict, Optional, cast, Tuple, List
from typing import Any, Dict, Mapping, Optional, cast, Tuple, List
from typing import Sequence
from uuid import UUID
import httpx
@@ -79,8 +79,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
default_api_path=system.settings.chroma_server_api_default_path,
)
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
self._session = httpx.Client(timeout=None, limits=limits)
if self._settings.chroma_server_ssl_verify is not None:
self._session = httpx.Client(
timeout=None,
limits=self.http_limits,
verify=self._settings.chroma_server_ssl_verify,
)
else:
self._session = httpx.Client(timeout=None, limits=self.http_limits)
self._header = system.settings.chroma_server_headers or {}
self._header["Content-Type"] = "application/json"
@@ -90,8 +96,6 @@ class FastAPI(BaseHTTPClient, ServerAPI):
+ " (https://github.com/chroma-core/chroma)"
)
if self._settings.chroma_server_ssl_verify is not None:
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
if self._header is not None:
self._session.headers.update(self._header)
@@ -101,6 +105,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
for header, value in _headers.items():
self._session.headers[header] = value.get_secret_value()
@override
def get_request_headers(self) -> Mapping[str, str]:
return dict(self._session.headers)
@override
def get_api_url(self) -> str:
return self._api_url
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
# If the request has json in kwargs, use orjson to serialize it,
# remove it from kwargs, and add it to the content parameter
@@ -492,7 +504,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
return GetResult(
ids=resp_json["ids"],
embeddings=resp_json.get("embeddings", None),
metadatas=metadatas, # type: ignore
metadatas=metadatas,
documents=resp_json.get("documents", None),
data=None,
uris=resp_json.get("uris", None),
@@ -700,7 +712,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
ids=resp_json["ids"],
distances=resp_json.get("distances", None),
embeddings=resp_json.get("embeddings", None),
metadatas=metadata_batches, # type: ignore
metadatas=metadata_batches,
documents=resp_json.get("documents", None),
uris=resp_json.get("uris", None),
data=None,
@@ -761,7 +773,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
) -> Tuple["AttachedFunction", bool]:
"""Attach a function to a collection."""
resp_json = self._make_request(
"post",
@@ -774,23 +786,56 @@ class FastAPI(BaseHTTPClient, ServerAPI):
},
)
return AttachedFunction(
attached_function = AttachedFunction(
client=self,
id=UUID(resp_json["attached_function"]["id"]),
name=resp_json["attached_function"]["name"],
function_id=resp_json["attached_function"]["function_id"],
function_name=resp_json["attached_function"]["function_name"],
input_collection_id=input_collection_id,
output_collection=output_collection,
params=params,
tenant=tenant,
database=database,
)
created = resp_json.get(
"created", True
) # Default to True for backwards compatibility
return (attached_function, created)
@trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.ALL)
@override
def get_attached_function(
self,
name: str,
input_collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Get an attached function by name for a specific collection."""
resp_json = self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/{name}",
)
af = resp_json["attached_function"]
return AttachedFunction(
client=self,
id=UUID(af["id"]),
name=af["name"],
function_name=af["function_name"],
input_collection_id=input_collection_id,
output_collection=af["output_collection"],
params=af.get("params"),
tenant=tenant,
database=database,
)
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
@override
def detach_function(
self,
attached_function_id: UUID,
name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
@@ -798,7 +843,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
"""Detach a function and prevent any further runs."""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/attached_functions/{name}/detach",
json={
"delete_output": delete_output,
},

View File

@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Optional, Dict, Any
from uuid import UUID
import json
if TYPE_CHECKING:
from chromadb.api import ServerAPI # noqa: F401
@@ -13,7 +14,7 @@ class AttachedFunction:
client: "ServerAPI",
id: UUID,
name: str,
function_id: str,
function_name: str,
input_collection_id: UUID,
output_collection: str,
params: Optional[Dict[str, Any]],
@@ -26,7 +27,7 @@ class AttachedFunction:
client: The API client
id: Unique identifier for this attached function
name: Name of this attached function instance
function_id: The function identifier (e.g., "record_counter")
function_name: The function name (e.g., "record_counter", "statistics")
input_collection_id: ID of the input collection
output_collection: Name of the output collection
params: Function-specific parameters
@@ -36,7 +37,7 @@ class AttachedFunction:
self._client = client
self._id = id
self._name = name
self._function_id = function_id
self._function_name = function_name
self._input_collection_id = input_collection_id
self._output_collection = output_collection
self._params = params
@@ -54,9 +55,9 @@ class AttachedFunction:
return self._name
@property
def function_id(self) -> str:
"""The function identifier."""
return self._function_id
def function_name(self) -> str:
"""The function name."""
return self._function_name
@property
def input_collection_id(self) -> UUID:
@@ -73,29 +74,69 @@ class AttachedFunction:
"""The function parameters."""
return self._params
def detach(self, delete_output_collection: bool = False) -> bool:
"""Detach this function and prevent any further runs.
@staticmethod
def _normalize_params(params: Optional[Any]) -> Dict[str, Any]:
"""Normalize params to a consistent dict format.
Args:
delete_output_collection: Whether to also delete the output collection. Defaults to False.
Returns:
bool: True if successful
Example:
>>> success = attached_fn.detach(delete_output_collection=True)
Handles None, empty strings, JSON strings, and dicts.
"""
return self._client.detach_function(
attached_function_id=self._id,
delete_output=delete_output_collection,
tenant=self._tenant,
database=self._database,
)
if params is None:
return {}
if isinstance(params, str):
try:
result = json.loads(params) if params else {}
return result if isinstance(result, dict) else {}
except json.JSONDecodeError:
return {}
if isinstance(params, dict):
return params
return {}
def __repr__(self) -> str:
return (
f"AttachedFunction(id={self._id}, name='{self._name}', "
f"function_id='{self._function_id}', "
f"function_name='{self._function_name}', "
f"input_collection_id={self._input_collection_id}, "
f"output_collection='{self._output_collection}')"
)
def __eq__(self, other: object) -> bool:
"""Compare two AttachedFunction objects for equality."""
if not isinstance(other, AttachedFunction):
return False
# Normalize params: handle None, {}, and JSON strings
self_params = self._normalize_params(self._params)
other_params = self._normalize_params(other._params)
return (
self._id == other._id
and self._name == other._name
and self._function_name == other._function_name
and self._input_collection_id == other._input_collection_id
and self._output_collection == other._output_collection
and self_params == other_params
and self._tenant == other._tenant
and self._database == other._database
)
def __hash__(self) -> int:
"""Return hash of the AttachedFunction."""
# Normalize params using the same logic as __eq__
normalized_params = self._normalize_params(self._params)
params_tuple = (
tuple(sorted(normalized_params.items())) if normalized_params else ()
)
return hash(
(
self._id,
self._name,
self._function_name,
self._input_collection_id,
self._output_collection,
params_tuple,
self._tenant,
self._database,
)
)

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any, Tuple
from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.api.types import (
@@ -25,6 +25,8 @@ from chromadb.execution.expression.plan import Search
import logging
from chromadb.api.functions import Function
if TYPE_CHECKING:
from chromadb.api.models.AttachedFunction import AttachedFunction
@@ -500,30 +502,36 @@ class Collection(CollectionCommon["ServerAPI"]):
def attach_function(
self,
function_id: str,
function: Function,
name: str,
output_collection: str,
params: Optional[Dict[str, Any]] = None,
) -> "AttachedFunction":
) -> Tuple["AttachedFunction", bool]:
"""Attach a function to this collection.
Args:
function_id: Built-in function identifier (e.g., "record_counter")
function: A Function enum value (e.g., STATISTICS_FUNCTION, RECORD_COUNTER_FUNCTION)
name: Unique name for this attached function
output_collection: Name of the collection where function output will be stored
params: Optional dictionary with function-specific parameters
Returns:
AttachedFunction: Object representing the attached function
Tuple of (AttachedFunction, created) where created is True if newly created,
False if already existed (idempotent request)
Example:
>>> from chromadb.api.functions import STATISTICS_FUNCTION
>>> attached_fn = collection.attach_function(
... function_id="record_counter",
... function=STATISTICS_FUNCTION,
... name="mycoll_stats_fn",
... output_collection="mycoll_stats",
... params={"threshold": 100}
... )
>>> if created:
... print("New function attached")
... else:
... print("Function already existed")
"""
function_id = function.value if isinstance(function, Function) else function
return self._client.attach_function(
function_id=function_id,
name=name,
@@ -533,3 +541,47 @@ class Collection(CollectionCommon["ServerAPI"]):
tenant=self.tenant,
database=self.database,
)
def get_attached_function(self, name: str) -> "AttachedFunction":
"""Get an attached function by name for this collection.
Args:
name: Name of the attached function
Returns:
AttachedFunction: The attached function object
Raises:
NotFoundError: If the attached function doesn't exist
"""
return self._client.get_attached_function(
name=name,
input_collection_id=self.id,
tenant=self.tenant,
database=self.database,
)
def detach_function(
self,
name: str,
delete_output_collection: bool = False,
) -> bool:
"""Detach a function from this collection.
Args:
name: The name of the attached function
delete_output_collection: Whether to also delete the output collection. Defaults to False.
Returns:
bool: True if successful
Example:
>>> success = collection.detach_function("my_function", delete_output_collection=True)
"""
return self._client.detach_function(
name=name,
input_collection_id=self.id,
delete_output=delete_output_collection,
tenant=self.tenant,
database=self.database,
)

View File

@@ -1021,6 +1021,7 @@ class CollectionCommon(Generic[ClientT]):
return Search(
where=search._where,
rank=embedded_rank,
group_by=search._group_by,
limit=search._limit,
select=search._select,
)

View File

@@ -49,7 +49,7 @@ from chromadb.execution.expression.plan import Search
import chromadb_rust_bindings
from typing import Optional, Sequence, List, Dict, Any
from typing import Optional, Sequence, List, Dict, Any, Tuple
from overrides import override
from uuid import UUID
import json
@@ -613,6 +613,20 @@ class RustBindingsAPI(ServerAPI):
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Tuple["AttachedFunction", bool]:
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
raise NotImplementedError(
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
"The Rust bindings (embedded mode) do not support attached function operations."
)
@override
def get_attached_function(
self,
name: str,
input_collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
raise NotImplementedError(
@@ -623,7 +637,8 @@ class RustBindingsAPI(ServerAPI):
@override
def detach_function(
self,
attached_function_id: UUID,
name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,

View File

@@ -75,6 +75,7 @@ from typing import (
Dict,
Callable,
TypeVar,
Tuple,
)
from overrides import override
from uuid import UUID, uuid4
@@ -922,6 +923,20 @@ class SegmentAPI(ServerAPI):
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Tuple["AttachedFunction", bool]:
"""Attached functions are not supported in the Segment API (local embedded mode)."""
raise NotImplementedError(
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
"The Segment API (embedded mode) does not support attached function operations."
)
@override
def get_attached_function(
self,
name: str,
input_collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Attached functions are not supported in the Segment API (local embedded mode)."""
raise NotImplementedError(
@@ -932,7 +947,8 @@ class SegmentAPI(ServerAPI):
@override
def detach_function(
self,
attached_function_id: UUID,
name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,

View File

@@ -1,11 +1,14 @@
from typing import ClassVar, Dict
from typing import ClassVar, Dict, Optional
import logging
import uuid
from chromadb.api import ServerAPI
from chromadb.api.base_http_client import BaseHTTPClient
from chromadb.config import Settings, System
from chromadb.telemetry.product import ProductTelemetryClient
from chromadb.telemetry.product.events import ClientStartEvent
logger = logging.getLogger(__name__)
class SharedSystemClient:
_identifier_to_system: ClassVar[Dict[str, System]] = {}
@@ -94,3 +97,59 @@ class SharedSystemClient:
def _submit_client_start_event(self) -> None:
telemetry_client = self._system.instance(ProductTelemetryClient)
telemetry_client.capture(ClientStartEvent())
@staticmethod
def get_chroma_cloud_api_key_from_clients() -> Optional[str]:
"""
Try to extract api key from existing client instances by checking httpx session headers.
Requirements to pull api key:
- must be a BaseHTTPClient instance (ignore RustBindingsAPI and SegmentAPI)
- must have "api.trychroma.com" or "gcp.trychroma.com" in the _api_url (ignore local/self-hosted instances)
- must have "x-chroma-token" or "X-Chroma-Token" in the headers
Returns:
The first api key found, or None if no client instances have api keys set.
"""
api_keys: list[str] = []
systems_snapshot = list(SharedSystemClient._identifier_to_system.values())
for system in systems_snapshot:
try:
server_api = system.instance(ServerAPI)
if not isinstance(server_api, BaseHTTPClient):
# RustBindingsAPI and SegmentAPI don't have HTTP headers
continue
# Only pull api key if the url contains the chroma cloud url
api_url = server_api.get_api_url()
if (
"api.trychroma.com" not in api_url
and "gcp.trychroma.com" not in api_url
):
continue
headers = server_api.get_request_headers()
api_key = None
for key, value in headers.items():
if key.lower() == "x-chroma-token":
api_key = value
break
if api_key:
api_keys.append(api_key)
except Exception:
# If we can't access the ServerAPI instance, continue to the next
continue
if not api_keys:
return None
# log if multiple viable api keys found
if len(api_keys) > 1:
logger.info(
f"Multiple Chroma Cloud clients found, using API key starting with {api_keys[0][:8]}..."
)
return api_keys[0]

View File

@@ -12,6 +12,7 @@ from typing import (
get_args,
TYPE_CHECKING,
Final,
Type,
)
from copy import deepcopy
from typing_extensions import TypeAlias
@@ -20,7 +21,8 @@ from numpy.typing import NDArray
import numpy as np
import warnings
from typing_extensions import TypedDict, Protocol, runtime_checkable
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, model_validator
from pydantic_core import PydanticCustomError
import chromadb.errors as errors
from chromadb.base_types import (
@@ -52,6 +54,8 @@ import pybase64
from functools import lru_cache
import struct
import math
import re
from enum import Enum
# Re-export types from chromadb.types
__all__ = [
@@ -112,6 +116,9 @@ __all__ = [
"EMBEDDING_KEY",
"TYPE_KEY",
"SPARSE_VECTOR_TYPE_VALUE",
# CMEK Support
"Cmek",
"CmekProvider",
# Space type
"Space",
# Embedding Functions
@@ -213,7 +220,7 @@ def optional_base64_strings_to_embeddings(
def normalize_embeddings(
target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]]
target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]],
) -> Optional[Embeddings]:
if target is None:
return None
@@ -448,7 +455,9 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None:
)
zero_lengths = [
key for key, lst in record_set.items() if lst is not None and len(lst) == 0 # type: ignore[arg-type]
key
for key, lst in record_set.items()
if lst is not None and len(lst) == 0 # type: ignore[arg-type]
]
if zero_lengths:
@@ -456,7 +465,9 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None:
if len(set(lengths)) > 1:
error_str = ", ".join(
f"{key}: {len(lst)}" for key, lst in record_set.items() if lst is not None # type: ignore[arg-type]
f"{key}: {len(lst)}"
for key, lst in record_set.items()
if lst is not None # type: ignore[arg-type]
)
raise ValueError(f"Unequal lengths for fields: {error_str}")
@@ -756,8 +767,7 @@ class EmbeddingFunction(Protocol[D]):
"""
@abstractmethod
def __call__(self, input: D) -> Embeddings:
...
def __call__(self, input: D) -> Embeddings: ...
def embed_query(self, input: D) -> Embeddings:
"""
@@ -941,8 +951,7 @@ def validate_embedding_function(
class DataLoader(Protocol[L]):
def __call__(self, uris: URIs) -> L:
...
def __call__(self, uris: URIs) -> L: ...
def validate_ids(ids: IDs) -> IDs:
@@ -1063,7 +1072,7 @@ def serialize_metadata(metadata: Optional[Metadata]) -> Optional[Dict[str, Any]]
def deserialize_metadata(
metadata: Optional[Dict[str, Any]]
metadata: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
"""Deserialize metadata from transport, converting dicts with #type=sparse_vector to dataclass instances.
@@ -1399,8 +1408,7 @@ class SparseEmbeddingFunction(Protocol[D]):
"""
@abstractmethod
def __call__(self, input: D) -> SparseVectors:
...
def __call__(self, input: D) -> SparseVectors: ...
def embed_query(self, input: D) -> SparseVectors:
"""
@@ -1493,15 +1501,57 @@ def validate_sparse_embedding_function(
# Index Configuration Types for Collection Schema
def _create_extra_fields_validator(valid_fields: list[str]) -> Any:
"""Create a model validator that provides helpful error messages for invalid fields."""
@model_validator(mode="before")
def validate_extra_fields(cls: Type[BaseModel], data: Any) -> Any:
if isinstance(data, dict):
invalid_fields = [k for k in data.keys() if k not in valid_fields]
if invalid_fields:
invalid_fields_str = ", ".join(f"'{f}'" for f in invalid_fields)
class_name = cls.__name__
# Create a clear, actionable error message
if len(invalid_fields) == 1:
msg = (
f"'{invalid_fields[0]}' is not a valid field for {class_name}. "
)
else:
msg = f"Invalid fields for {class_name}: {invalid_fields_str}. "
raise PydanticCustomError(
"invalid_field",
msg,
{"invalid_fields": invalid_fields},
)
return data
return validate_extra_fields
class FtsIndexConfig(BaseModel):
"""Configuration for Full-Text Search index. No parameters required."""
model_config = {"extra": "forbid"}
pass
class HnswIndexConfig(BaseModel):
"""Configuration for HNSW vector index."""
_validate_extra_fields = _create_extra_fields_validator(
[
"ef_construction",
"max_neighbors",
"ef_search",
"num_threads",
"batch_size",
"sync_threshold",
"resize_factor",
]
)
ef_construction: Optional[int] = None
max_neighbors: Optional[int] = None
ef_search: Optional[int] = None
@@ -1514,6 +1564,27 @@ class HnswIndexConfig(BaseModel):
class SpannIndexConfig(BaseModel):
"""Configuration for SPANN vector index."""
_validate_extra_fields = _create_extra_fields_validator(
[
"search_nprobe",
"search_rng_factor",
"search_rng_epsilon",
"nreplica_count",
"write_nprobe",
"write_rng_factor",
"write_rng_epsilon",
"split_threshold",
"num_samples_kmeans",
"initial_lambda",
"reassign_neighbor_count",
"merge_threshold",
"num_centers_to_merge_to",
"ef_construction",
"ef_search",
"max_neighbors",
]
)
search_nprobe: Optional[int] = None
write_nprobe: Optional[int] = None
ef_construction: Optional[int] = None
@@ -1527,10 +1598,13 @@ class SpannIndexConfig(BaseModel):
class VectorIndexConfig(BaseModel):
"""Configuration for vector index with space, embedding function, and algorithm config."""
model_config = {"arbitrary_types_allowed": True}
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
space: Optional[Space] = None
embedding_function: Optional[Any] = DefaultEmbeddingFunction()
source_key: Optional[str] = None # key to source the vector from (accepts str or Key)
source_key: Optional[str] = (
None # key to source the vector from (accepts str or Key)
)
hnsw: Optional[HnswIndexConfig] = None
spann: Optional[SpannIndexConfig] = None
@@ -1542,6 +1616,7 @@ class VectorIndexConfig(BaseModel):
return None
# Import Key at runtime to avoid circular import
from chromadb.execution.expression.operator import Key as KeyType
if isinstance(v, KeyType):
v = v.name # Extract string from Key
elif isinstance(v, str):
@@ -1574,10 +1649,13 @@ class VectorIndexConfig(BaseModel):
class SparseVectorIndexConfig(BaseModel):
"""Configuration for sparse vector index."""
model_config = {"arbitrary_types_allowed": True}
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
# TODO(Sanket): Change this to the appropriate sparse ef and use a default here.
embedding_function: Optional[Any] = None
source_key: Optional[str] = None # key to source the sparse vector from (accepts str or Key)
source_key: Optional[str] = (
None # key to source the sparse vector from (accepts str or Key)
)
bm25: Optional[bool] = None
@field_validator("source_key", mode="before")
@@ -1588,6 +1666,7 @@ class SparseVectorIndexConfig(BaseModel):
return None
# Import Key at runtime to avoid circular import
from chromadb.execution.expression.operator import Key as KeyType
if isinstance(v, KeyType):
v = v.name # Extract string from Key
elif isinstance(v, str):
@@ -1622,24 +1701,32 @@ class SparseVectorIndexConfig(BaseModel):
class StringInvertedIndexConfig(BaseModel):
"""Configuration for string inverted index."""
model_config = {"extra": "forbid"}
pass
class IntInvertedIndexConfig(BaseModel):
"""Configuration for integer inverted index."""
model_config = {"extra": "forbid"}
pass
class FloatInvertedIndexConfig(BaseModel):
"""Configuration for float inverted index."""
model_config = {"extra": "forbid"}
pass
class BoolInvertedIndexConfig(BaseModel):
"""Configuration for boolean inverted index."""
model_config = {"extra": "forbid"}
pass
@@ -1683,6 +1770,148 @@ TYPE_KEY: Final[str] = "#type"
SPARSE_VECTOR_TYPE_VALUE: Final[str] = "sparse_vector"
# ============================================================================
# CMEK (Customer-Managed Encryption Key) Support
# ============================================================================
class CmekProvider(str, Enum):
"""Supported cloud providers for customer-managed encryption keys.
Currently only Google Cloud Platform (GCP) is supported.
"""
GCP = "gcp"
# Regex pattern for validating GCP KMS resource names
_CMEK_GCP_PATTERN = re.compile(r"^projects/.+/locations/.+/keyRings/.+/cryptoKeys/.+$")
@dataclass
class Cmek:
"""Customer-managed encryption key (CMEK) for collection data encryption.
CMEK allows you to use your own encryption keys managed by cloud providers'
key management services (KMS) instead of default provider-managed keys. This
gives you greater control over key lifecycle, access policies, and audit logging.
Attributes:
provider: The cloud provider (currently only 'gcp' supported)
resource: The provider-specific resource identifier for the encryption key
Example:
>>> # Create a CMEK for GCP
>>> cmek = Cmek.gcp(
... "projects/my-project/locations/us-central1/"
... "keyRings/my-ring/cryptoKeys/my-key"
... )
>>>
>>> # Validate the resource name format
>>> if cmek.validate_pattern():
... print("Valid CMEK format")
Note:
Pattern validation only checks format correctness. It does not verify
that the key exists or is accessible. Key permissions and access control
must be configured separately in your cloud provider's console.
"""
provider: CmekProvider
resource: str
@classmethod
def gcp(cls, resource: str) -> "Cmek":
"""Create a CMEK instance for Google Cloud Platform.
Args:
resource: GCP Cloud KMS resource name in the format:
projects/{project-id}/locations/{location}/keyRings/{key-ring}/cryptoKeys/{key-name}
Example: "projects/my-project/locations/us-central1/keyRings/my-ring/cryptoKeys/my-key"
Returns:
Cmek: A new CMEK instance configured for GCP
Raises:
ValueError: If the resource format is invalid (when validate_pattern() is called)
Example:
>>> cmek = Cmek.gcp(
... "projects/my-project/locations/us-central1/"
... "keyRings/my-ring/cryptoKeys/my-key"
... )
"""
return cls(provider=CmekProvider.GCP, resource=resource)
def validate_pattern(self) -> bool:
"""Validate the CMEK resource name format.
Validates that the resource name matches the expected pattern for the
provider. This is a format check only and does not verify that the key
exists or that you have access to it.
For GCP, the expected format is:
projects/{project}/locations/{location}/keyRings/{keyRing}/cryptoKeys/{cryptoKey}
Returns:
bool: True if the resource name format is valid, False otherwise
Example:
>>> cmek = Cmek.gcp("projects/p/locations/l/keyRings/r/cryptoKeys/k")
>>> cmek.validate_pattern() # Returns True
>>>
>>> bad_cmek = Cmek.gcp("invalid-format")
>>> bad_cmek.validate_pattern() # Returns False
"""
if self.provider == CmekProvider.GCP:
return _CMEK_GCP_PATTERN.match(self.resource) is not None
return False
def to_dict(self) -> Dict[str, Any]:
"""Serialize CMEK to dictionary format for API transport.
Returns:
Dict containing the provider variant and resource identifier.
Example:
>>> cmek = Cmek.gcp("projects/p/locations/l/keyRings/r/cryptoKeys/k")
>>> cmek.to_dict()
{'gcp': 'projects/p/locations/l/keyRings/r/cryptoKeys/k'}
"""
if self.provider == CmekProvider.GCP:
return {"gcp": self.resource}
# Unreachable with current providers, but future-proof
raise ValueError(f"Unknown CMEK provider: {self.provider}")
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Cmek":
"""Deserialize CMEK from dictionary format.
Args:
data: Dictionary containing provider variant and resource.
Format matches Rust serde enum serialization with snake_case.
Returns:
Cmek: Deserialized CMEK instance
Raises:
ValueError: If the provider is unsupported or data is malformed
Example:
>>> data = {'gcp': 'projects/p/locations/l/keyRings/r/cryptoKeys/k'}
>>> cmek = Cmek.from_dict(data)
"""
if "gcp" in data:
resource = data["gcp"]
if not isinstance(resource, str):
raise ValueError(
f"CMEK gcp resource must be a string, got: {type(resource)}"
)
return cls.gcp(resource)
raise ValueError(f"Unsupported or missing CMEK provider in data: {data}")
# Index Type Classes
@@ -1774,24 +2003,40 @@ class ValueTypes:
@dataclass
class Schema:
"""Collection schema for configuring indexes and encryption.
The schema controls how data is indexed and can optionally specify
customer-managed encryption keys (CMEK) for data at rest.
Attributes:
defaults: Default index configurations for each value type
keys: Key-specific index overrides
cmek: Optional customer-managed encryption key for collection data
"""
defaults: ValueTypes
keys: Dict[str, ValueTypes]
cmek: Optional[Cmek] = None
def __init__(self) -> None:
# Initialize the dataclass fields first
self.defaults = ValueTypes()
self.keys: Dict[str, ValueTypes] = {}
self.cmek: Optional[Cmek] = None
# Populate with sensible defaults automatically
self._initialize_defaults()
self._initialize_keys()
def create_index(
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
self,
config: Optional[IndexConfig] = None,
key: Optional[Union[str, "Key"]] = None,
) -> "Schema":
"""Create an index configuration."""
# Convert Key to string if provided
from chromadb.execution.expression.operator import Key as KeyType
if key is not None and isinstance(key, KeyType):
key = key.name
@@ -1869,11 +2114,14 @@ class Schema:
return self
def delete_index(
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
self,
config: Optional[IndexConfig] = None,
key: Optional[Union[str, "Key"]] = None,
) -> "Schema":
"""Disable an index configuration (set enabled=False)."""
# Convert Key to string if provided
from chromadb.execution.expression.operator import Key as KeyType
if key is not None and isinstance(key, KeyType):
key = key.name
@@ -1930,6 +2178,23 @@ class Schema:
return self
def set_cmek(self, cmek: Optional[Cmek]) -> "Schema":
"""Set customer-managed encryption key for the collection (fluent interface).
Args:
cmek: Customer-managed encryption key configuration, or None to remove CMEK
Returns:
Self for method chaining
Example:
>>> schema = Schema().set_cmek(
... Cmek.gcp("projects/my-project/locations/us/keyRings/my-ring/cryptoKeys/my-key")
... )
"""
self.cmek = cmek
return self
def _get_config_class_name(self, config: IndexConfig) -> str:
"""Get the class name for a config."""
return config.__class__.__name__
@@ -1943,11 +2208,15 @@ class Schema:
"""
# Update the config in defaults (preserve enabled=False)
current_enabled = self.defaults.float_list.vector_index.enabled # type: ignore[union-attr]
self.defaults.float_list.vector_index = VectorIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
self.defaults.float_list.vector_index = VectorIndexType(
enabled=current_enabled, config=config
) # type: ignore[union-attr]
# Update the config on #embedding key (preserve enabled=True and source_key="#document")
current_enabled = self.keys[EMBEDDING_KEY].float_list.vector_index.enabled # type: ignore[union-attr]
current_source_key = self.keys[EMBEDDING_KEY].float_list.vector_index.config.source_key # type: ignore[union-attr]
current_source_key = self.keys[
EMBEDDING_KEY
].float_list.vector_index.config.source_key # type: ignore[union-attr]
# Create a new config with user settings but preserve the original source_key
embedding_config = VectorIndexConfig(
@@ -1957,7 +2226,9 @@ class Schema:
spann=config.spann,
source_key=current_source_key, # Preserve original source_key (should be "#document")
)
self.keys[EMBEDDING_KEY].float_list.vector_index = VectorIndexType(enabled=current_enabled, config=embedding_config) # type: ignore[union-attr]
self.keys[EMBEDDING_KEY].float_list.vector_index = VectorIndexType(
enabled=current_enabled, config=embedding_config
) # type: ignore[union-attr]
def _set_fts_index_config(self, config: FtsIndexConfig) -> None:
"""
@@ -1967,11 +2238,15 @@ class Schema:
"""
# Update the config in defaults (preserve enabled=False)
current_enabled = self.defaults.string.fts_index.enabled # type: ignore[union-attr]
self.defaults.string.fts_index = FtsIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
self.defaults.string.fts_index = FtsIndexType(
enabled=current_enabled, config=config
) # type: ignore[union-attr]
# Update the config on #document key (preserve enabled=True)
current_enabled = self.keys[DOCUMENT_KEY].string.fts_index.enabled # type: ignore[union-attr]
self.keys[DOCUMENT_KEY].string.fts_index = FtsIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
self.keys[DOCUMENT_KEY].string.fts_index = FtsIndexType(
enabled=current_enabled, config=config
) # type: ignore[union-attr]
def _set_index_in_defaults(self, config: IndexConfig, enabled: bool) -> None:
"""Set an index configuration in the defaults."""
@@ -2071,37 +2346,51 @@ class Schema:
if config_name == "FtsIndexConfig":
if self.keys[key].string is None:
self.keys[key].string = StringValueType()
self.keys[key].string.fts_index = FtsIndexType(enabled=enabled, config=cast(FtsIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].string.fts_index = FtsIndexType(
enabled=enabled, config=cast(FtsIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "StringInvertedIndexConfig":
if self.keys[key].string is None:
self.keys[key].string = StringValueType()
self.keys[key].string.string_inverted_index = StringInvertedIndexType(enabled=enabled, config=cast(StringInvertedIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].string.string_inverted_index = StringInvertedIndexType(
enabled=enabled, config=cast(StringInvertedIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "VectorIndexConfig":
if self.keys[key].float_list is None:
self.keys[key].float_list = FloatListValueType()
self.keys[key].float_list.vector_index = VectorIndexType(enabled=enabled, config=cast(VectorIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].float_list.vector_index = VectorIndexType(
enabled=enabled, config=cast(VectorIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "SparseVectorIndexConfig":
if self.keys[key].sparse_vector is None:
self.keys[key].sparse_vector = SparseVectorValueType()
self.keys[key].sparse_vector.sparse_vector_index = SparseVectorIndexType(enabled=enabled, config=cast(SparseVectorIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].sparse_vector.sparse_vector_index = SparseVectorIndexType(
enabled=enabled, config=cast(SparseVectorIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "IntInvertedIndexConfig":
if self.keys[key].int_value is None:
self.keys[key].int_value = IntValueType()
self.keys[key].int_value.int_inverted_index = IntInvertedIndexType(enabled=enabled, config=cast(IntInvertedIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].int_value.int_inverted_index = IntInvertedIndexType(
enabled=enabled, config=cast(IntInvertedIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "FloatInvertedIndexConfig":
if self.keys[key].float_value is None:
self.keys[key].float_value = FloatValueType()
self.keys[key].float_value.float_inverted_index = FloatInvertedIndexType(enabled=enabled, config=cast(FloatInvertedIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].float_value.float_inverted_index = FloatInvertedIndexType(
enabled=enabled, config=cast(FloatInvertedIndexConfig, config)
) # type: ignore[union-attr]
elif config_name == "BoolInvertedIndexConfig":
if self.keys[key].boolean is None:
self.keys[key].boolean = BoolValueType()
self.keys[key].boolean.bool_inverted_index = BoolInvertedIndexType(enabled=enabled, config=cast(BoolInvertedIndexConfig, config)) # type: ignore[union-attr]
self.keys[key].boolean.bool_inverted_index = BoolInvertedIndexType(
enabled=enabled, config=cast(BoolInvertedIndexConfig, config)
) # type: ignore[union-attr]
def _enable_all_indexes_for_key(self, key: str) -> None:
"""Enable all possible index types for a specific key."""
@@ -2249,7 +2538,13 @@ class Schema:
for key, value_types in self.keys.items():
keys_json[key] = self._serialize_value_types(value_types)
return {"defaults": defaults_json, "keys": keys_json}
result: Dict[str, Any] = {"defaults": defaults_json, "keys": keys_json}
# Add CMEK if present
if self.cmek is not None:
result["cmek"] = self.cmek.to_dict()
return result
@classmethod
def deserialize_from_json(cls, json_data: Dict[str, Any]) -> "Schema":
@@ -2263,6 +2558,11 @@ class Schema:
for key, value_types_json in json_data.get("keys", {}).items():
instance.keys[key] = cls._deserialize_value_types(value_types_json)
# Deserialize CMEK if present
instance.cmek = None
if "cmek" in json_data and json_data["cmek"] is not None:
instance.cmek = Cmek.from_dict(json_data["cmek"])
return instance
def _serialize_value_types(self, value_types: ValueTypes) -> Dict[str, Any]:
@@ -2410,6 +2710,10 @@ class Schema:
if embedding_func.is_legacy():
config_dict["embedding_function"] = {"type": "legacy"}
else:
if hasattr(embedding_func, "validate_config"):
embedding_func.validate_config(
embedding_func.get_config()
)
config_dict["embedding_function"] = {
"name": embedding_func.name(),
"type": "known",
@@ -2439,6 +2743,8 @@ class Schema:
config_dict["embedding_function"] = {"type": "unknown"}
else:
embedding_func = cast(SparseEmbeddingFunction, embedding_func) # type: ignore
if hasattr(embedding_func, "validate_config"):
embedding_func.validate_config(embedding_func.get_config())
config_dict["embedding_function"] = {
"name": embedding_func.name(),
"type": "known",