增加环绕侦察场景适配
This commit is contained in:
@@ -33,10 +33,14 @@ from chromadb.execution.expression import ( # noqa: F401, F403
|
||||
Sub,
|
||||
Sum,
|
||||
Val,
|
||||
Aggregate,
|
||||
MinK,
|
||||
MaxK,
|
||||
GroupBy,
|
||||
)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence, Optional, List, Dict, Any
|
||||
from typing import Sequence, Optional, List, Dict, Any, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from overrides import override
|
||||
@@ -824,7 +828,7 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attach a function to a collection.
|
||||
|
||||
Args:
|
||||
@@ -837,14 +841,40 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
|
||||
database: The database name
|
||||
|
||||
Returns:
|
||||
AttachedFunction: Object representing the attached function
|
||||
Tuple of (AttachedFunction, created) where created is True if newly created,
|
||||
False if already existed (idempotent request)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_attached_function(
|
||||
self,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Get an attached function by name for a specific collection.
|
||||
|
||||
Args:
|
||||
name: Name of the attached function
|
||||
input_collection_id: The collection ID
|
||||
tenant: The tenant name
|
||||
database: The database name
|
||||
|
||||
Returns:
|
||||
AttachedFunction: The attached function object
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the attached function doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
@@ -852,7 +882,8 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
|
||||
"""Detach a function and prevent any further runs.
|
||||
|
||||
Args:
|
||||
attached_function_id: ID of the attached function to remove
|
||||
name: Name of the attached function to remove
|
||||
input_collection_id: ID of the input collection
|
||||
delete_output: Whether to also delete the output collection
|
||||
tenant: The tenant name
|
||||
database: The database name
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,7 +2,7 @@ import asyncio
|
||||
from uuid import UUID
|
||||
import urllib.parse
|
||||
import orjson
|
||||
from typing import Any, Optional, cast, Tuple, Sequence, Dict, List
|
||||
from typing import Any, Mapping, Optional, cast, Tuple, Sequence, Dict, List
|
||||
import logging
|
||||
import httpx
|
||||
from overrides import override
|
||||
@@ -130,16 +130,23 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
|
||||
+ " (https://github.com/chroma-core/chroma)"
|
||||
)
|
||||
|
||||
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
|
||||
self._clients[loop_hash] = httpx.AsyncClient(
|
||||
timeout=None,
|
||||
headers=headers,
|
||||
verify=self._settings.chroma_server_ssl_verify or False,
|
||||
limits=limits,
|
||||
limits=self.http_limits,
|
||||
)
|
||||
|
||||
return self._clients[loop_hash]
|
||||
|
||||
@override
|
||||
def get_request_headers(self) -> Mapping[str, str]:
|
||||
return dict(self._get_client().headers)
|
||||
|
||||
@override
|
||||
def get_api_url(self) -> str:
|
||||
return self._api_url
|
||||
|
||||
async def _make_request(
|
||||
self, method: str, path: str, **kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
@@ -527,7 +534,7 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
|
||||
return GetResult(
|
||||
ids=resp_json["ids"],
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadatas, # type: ignore
|
||||
metadatas=metadatas,
|
||||
documents=resp_json.get("documents", None),
|
||||
data=None,
|
||||
uris=resp_json.get("uris", None),
|
||||
@@ -723,7 +730,7 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
|
||||
ids=resp_json["ids"],
|
||||
distances=resp_json.get("distances", None),
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadata_batches, # type: ignore
|
||||
metadatas=metadata_batches,
|
||||
documents=resp_json.get("documents", None),
|
||||
uris=resp_json.get("uris", None),
|
||||
data=None,
|
||||
|
||||
@@ -1,19 +1,51 @@
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
from typing import Any, Dict, Mapping, Optional, TypeVar
|
||||
from urllib.parse import quote, urlparse, urlunparse
|
||||
import logging
|
||||
import orjson as json
|
||||
import httpx
|
||||
|
||||
import chromadb.errors as errors
|
||||
from chromadb.config import Settings
|
||||
from chromadb.config import Component, Settings, System
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseHTTPClient:
|
||||
# inherits from Component so that it can create an init function to use system
|
||||
# this way it can build limits from the settings in System
|
||||
class BaseHTTPClient(Component):
|
||||
_settings: Settings
|
||||
pre_flight_checks: Any = None
|
||||
keepalive_secs: int = 40
|
||||
DEFAULT_KEEPALIVE_SECS: float = 40.0
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
keepalive_setting = self._settings.chroma_http_keepalive_secs
|
||||
self.keepalive_secs: Optional[float] = (
|
||||
keepalive_setting
|
||||
if keepalive_setting is not None
|
||||
else BaseHTTPClient.DEFAULT_KEEPALIVE_SECS
|
||||
)
|
||||
self._http_limits = self._build_limits()
|
||||
|
||||
def _build_limits(self) -> httpx.Limits:
|
||||
limit_kwargs: Dict[str, Any] = {}
|
||||
if self.keepalive_secs is not None:
|
||||
limit_kwargs["keepalive_expiry"] = self.keepalive_secs
|
||||
|
||||
max_connections = self._settings.chroma_http_max_connections
|
||||
if max_connections is not None:
|
||||
limit_kwargs["max_connections"] = max_connections
|
||||
|
||||
max_keepalive_connections = self._settings.chroma_http_max_keepalive_connections
|
||||
if max_keepalive_connections is not None:
|
||||
limit_kwargs["max_keepalive_connections"] = max_keepalive_connections
|
||||
|
||||
return httpx.Limits(**limit_kwargs)
|
||||
|
||||
@property
|
||||
def http_limits(self) -> httpx.Limits:
|
||||
return self._http_limits
|
||||
|
||||
@staticmethod
|
||||
def _validate_host(host: str) -> None:
|
||||
@@ -103,3 +135,11 @@ class BaseHTTPClient:
|
||||
if trace_id:
|
||||
raise Exception(f"{resp.text} (trace ID: {trace_id})")
|
||||
raise (Exception(resp.text))
|
||||
|
||||
def get_request_headers(self) -> Mapping[str, str]:
|
||||
"""Return headers used for HTTP requests."""
|
||||
return {}
|
||||
|
||||
def get_api_url(self) -> str:
|
||||
"""Return the API URL for this client."""
|
||||
return ""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import orjson
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, cast, Tuple, List
|
||||
from typing import Any, Dict, Mapping, Optional, cast, Tuple, List
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
import httpx
|
||||
@@ -79,8 +79,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
default_api_path=system.settings.chroma_server_api_default_path,
|
||||
)
|
||||
|
||||
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
|
||||
self._session = httpx.Client(timeout=None, limits=limits)
|
||||
if self._settings.chroma_server_ssl_verify is not None:
|
||||
self._session = httpx.Client(
|
||||
timeout=None,
|
||||
limits=self.http_limits,
|
||||
verify=self._settings.chroma_server_ssl_verify,
|
||||
)
|
||||
else:
|
||||
self._session = httpx.Client(timeout=None, limits=self.http_limits)
|
||||
|
||||
self._header = system.settings.chroma_server_headers or {}
|
||||
self._header["Content-Type"] = "application/json"
|
||||
@@ -90,8 +96,6 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
+ " (https://github.com/chroma-core/chroma)"
|
||||
)
|
||||
|
||||
if self._settings.chroma_server_ssl_verify is not None:
|
||||
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
|
||||
if self._header is not None:
|
||||
self._session.headers.update(self._header)
|
||||
|
||||
@@ -101,6 +105,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
for header, value in _headers.items():
|
||||
self._session.headers[header] = value.get_secret_value()
|
||||
|
||||
@override
|
||||
def get_request_headers(self) -> Mapping[str, str]:
|
||||
return dict(self._session.headers)
|
||||
|
||||
@override
|
||||
def get_api_url(self) -> str:
|
||||
return self._api_url
|
||||
|
||||
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
|
||||
# If the request has json in kwargs, use orjson to serialize it,
|
||||
# remove it from kwargs, and add it to the content parameter
|
||||
@@ -492,7 +504,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
return GetResult(
|
||||
ids=resp_json["ids"],
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadatas, # type: ignore
|
||||
metadatas=metadatas,
|
||||
documents=resp_json.get("documents", None),
|
||||
data=None,
|
||||
uris=resp_json.get("uris", None),
|
||||
@@ -700,7 +712,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
ids=resp_json["ids"],
|
||||
distances=resp_json.get("distances", None),
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadata_batches, # type: ignore
|
||||
metadatas=metadata_batches,
|
||||
documents=resp_json.get("documents", None),
|
||||
uris=resp_json.get("uris", None),
|
||||
data=None,
|
||||
@@ -761,7 +773,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attach a function to a collection."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
@@ -774,23 +786,56 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
},
|
||||
)
|
||||
|
||||
return AttachedFunction(
|
||||
attached_function = AttachedFunction(
|
||||
client=self,
|
||||
id=UUID(resp_json["attached_function"]["id"]),
|
||||
name=resp_json["attached_function"]["name"],
|
||||
function_id=resp_json["attached_function"]["function_id"],
|
||||
function_name=resp_json["attached_function"]["function_name"],
|
||||
input_collection_id=input_collection_id,
|
||||
output_collection=output_collection,
|
||||
params=params,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
created = resp_json.get(
|
||||
"created", True
|
||||
) # Default to True for backwards compatibility
|
||||
return (attached_function, created)
|
||||
|
||||
@trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def get_attached_function(
|
||||
self,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Get an attached function by name for a specific collection."""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/{name}",
|
||||
)
|
||||
|
||||
af = resp_json["attached_function"]
|
||||
return AttachedFunction(
|
||||
client=self,
|
||||
id=UUID(af["id"]),
|
||||
name=af["name"],
|
||||
function_name=af["function_name"],
|
||||
input_collection_id=input_collection_id,
|
||||
output_collection=af["output_collection"],
|
||||
params=af.get("params"),
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
@@ -798,7 +843,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
"""Detach a function and prevent any further runs."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/attached_functions/{name}/detach",
|
||||
json={
|
||||
"delete_output": delete_output,
|
||||
},
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api import ServerAPI # noqa: F401
|
||||
@@ -13,7 +14,7 @@ class AttachedFunction:
|
||||
client: "ServerAPI",
|
||||
id: UUID,
|
||||
name: str,
|
||||
function_id: str,
|
||||
function_name: str,
|
||||
input_collection_id: UUID,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]],
|
||||
@@ -26,7 +27,7 @@ class AttachedFunction:
|
||||
client: The API client
|
||||
id: Unique identifier for this attached function
|
||||
name: Name of this attached function instance
|
||||
function_id: The function identifier (e.g., "record_counter")
|
||||
function_name: The function name (e.g., "record_counter", "statistics")
|
||||
input_collection_id: ID of the input collection
|
||||
output_collection: Name of the output collection
|
||||
params: Function-specific parameters
|
||||
@@ -36,7 +37,7 @@ class AttachedFunction:
|
||||
self._client = client
|
||||
self._id = id
|
||||
self._name = name
|
||||
self._function_id = function_id
|
||||
self._function_name = function_name
|
||||
self._input_collection_id = input_collection_id
|
||||
self._output_collection = output_collection
|
||||
self._params = params
|
||||
@@ -54,9 +55,9 @@ class AttachedFunction:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def function_id(self) -> str:
|
||||
"""The function identifier."""
|
||||
return self._function_id
|
||||
def function_name(self) -> str:
|
||||
"""The function name."""
|
||||
return self._function_name
|
||||
|
||||
@property
|
||||
def input_collection_id(self) -> UUID:
|
||||
@@ -73,29 +74,69 @@ class AttachedFunction:
|
||||
"""The function parameters."""
|
||||
return self._params
|
||||
|
||||
def detach(self, delete_output_collection: bool = False) -> bool:
|
||||
"""Detach this function and prevent any further runs.
|
||||
@staticmethod
|
||||
def _normalize_params(params: Optional[Any]) -> Dict[str, Any]:
|
||||
"""Normalize params to a consistent dict format.
|
||||
|
||||
Args:
|
||||
delete_output_collection: Whether to also delete the output collection. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
|
||||
Example:
|
||||
>>> success = attached_fn.detach(delete_output_collection=True)
|
||||
Handles None, empty strings, JSON strings, and dicts.
|
||||
"""
|
||||
return self._client.detach_function(
|
||||
attached_function_id=self._id,
|
||||
delete_output=delete_output_collection,
|
||||
tenant=self._tenant,
|
||||
database=self._database,
|
||||
)
|
||||
if params is None:
|
||||
return {}
|
||||
if isinstance(params, str):
|
||||
try:
|
||||
result = json.loads(params) if params else {}
|
||||
return result if isinstance(result, dict) else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
if isinstance(params, dict):
|
||||
return params
|
||||
return {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AttachedFunction(id={self._id}, name='{self._name}', "
|
||||
f"function_id='{self._function_id}', "
|
||||
f"function_name='{self._function_name}', "
|
||||
f"input_collection_id={self._input_collection_id}, "
|
||||
f"output_collection='{self._output_collection}')"
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare two AttachedFunction objects for equality."""
|
||||
if not isinstance(other, AttachedFunction):
|
||||
return False
|
||||
|
||||
# Normalize params: handle None, {}, and JSON strings
|
||||
self_params = self._normalize_params(self._params)
|
||||
other_params = self._normalize_params(other._params)
|
||||
|
||||
return (
|
||||
self._id == other._id
|
||||
and self._name == other._name
|
||||
and self._function_name == other._function_name
|
||||
and self._input_collection_id == other._input_collection_id
|
||||
and self._output_collection == other._output_collection
|
||||
and self_params == other_params
|
||||
and self._tenant == other._tenant
|
||||
and self._database == other._database
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the AttachedFunction."""
|
||||
# Normalize params using the same logic as __eq__
|
||||
normalized_params = self._normalize_params(self._params)
|
||||
params_tuple = (
|
||||
tuple(sorted(normalized_params.items())) if normalized_params else ()
|
||||
)
|
||||
|
||||
return hash(
|
||||
(
|
||||
self._id,
|
||||
self._name,
|
||||
self._function_name,
|
||||
self._input_collection_id,
|
||||
self._output_collection,
|
||||
params_tuple,
|
||||
self._tenant,
|
||||
self._database,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any
|
||||
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any, Tuple
|
||||
|
||||
from chromadb.api.models.CollectionCommon import CollectionCommon
|
||||
from chromadb.api.types import (
|
||||
@@ -25,6 +25,8 @@ from chromadb.execution.expression.plan import Search
|
||||
|
||||
import logging
|
||||
|
||||
from chromadb.api.functions import Function
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api.models.AttachedFunction import AttachedFunction
|
||||
|
||||
@@ -500,30 +502,36 @@ class Collection(CollectionCommon["ServerAPI"]):
|
||||
|
||||
def attach_function(
|
||||
self,
|
||||
function_id: str,
|
||||
function: Function,
|
||||
name: str,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> "AttachedFunction":
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attach a function to this collection.
|
||||
|
||||
Args:
|
||||
function_id: Built-in function identifier (e.g., "record_counter")
|
||||
function: A Function enum value (e.g., STATISTICS_FUNCTION, RECORD_COUNTER_FUNCTION)
|
||||
name: Unique name for this attached function
|
||||
output_collection: Name of the collection where function output will be stored
|
||||
params: Optional dictionary with function-specific parameters
|
||||
|
||||
Returns:
|
||||
AttachedFunction: Object representing the attached function
|
||||
Tuple of (AttachedFunction, created) where created is True if newly created,
|
||||
False if already existed (idempotent request)
|
||||
|
||||
Example:
|
||||
>>> from chromadb.api.functions import STATISTICS_FUNCTION
|
||||
>>> attached_fn = collection.attach_function(
|
||||
... function_id="record_counter",
|
||||
... function=STATISTICS_FUNCTION,
|
||||
... name="mycoll_stats_fn",
|
||||
... output_collection="mycoll_stats",
|
||||
... params={"threshold": 100}
|
||||
... )
|
||||
>>> if created:
|
||||
... print("New function attached")
|
||||
... else:
|
||||
... print("Function already existed")
|
||||
"""
|
||||
function_id = function.value if isinstance(function, Function) else function
|
||||
return self._client.attach_function(
|
||||
function_id=function_id,
|
||||
name=name,
|
||||
@@ -533,3 +541,47 @@ class Collection(CollectionCommon["ServerAPI"]):
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def get_attached_function(self, name: str) -> "AttachedFunction":
|
||||
"""Get an attached function by name for this collection.
|
||||
|
||||
Args:
|
||||
name: Name of the attached function
|
||||
|
||||
Returns:
|
||||
AttachedFunction: The attached function object
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the attached function doesn't exist
|
||||
"""
|
||||
return self._client.get_attached_function(
|
||||
name=name,
|
||||
input_collection_id=self.id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def detach_function(
|
||||
self,
|
||||
name: str,
|
||||
delete_output_collection: bool = False,
|
||||
) -> bool:
|
||||
"""Detach a function from this collection.
|
||||
|
||||
Args:
|
||||
name: The name of the attached function
|
||||
delete_output_collection: Whether to also delete the output collection. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
|
||||
Example:
|
||||
>>> success = collection.detach_function("my_function", delete_output_collection=True)
|
||||
"""
|
||||
return self._client.detach_function(
|
||||
name=name,
|
||||
input_collection_id=self.id,
|
||||
delete_output=delete_output_collection,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@@ -1021,6 +1021,7 @@ class CollectionCommon(Generic[ClientT]):
|
||||
return Search(
|
||||
where=search._where,
|
||||
rank=embedded_rank,
|
||||
group_by=search._group_by,
|
||||
limit=search._limit,
|
||||
select=search._select,
|
||||
)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -49,7 +49,7 @@ from chromadb.execution.expression.plan import Search
|
||||
import chromadb_rust_bindings
|
||||
|
||||
|
||||
from typing import Optional, Sequence, List, Dict, Any
|
||||
from typing import Optional, Sequence, List, Dict, Any, Tuple
|
||||
from overrides import override
|
||||
from uuid import UUID
|
||||
import json
|
||||
@@ -613,6 +613,20 @@ class RustBindingsAPI(ServerAPI):
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
|
||||
"The Rust bindings (embedded mode) do not support attached function operations."
|
||||
)
|
||||
|
||||
@override
|
||||
def get_attached_function(
|
||||
self,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
@@ -623,7 +637,8 @@ class RustBindingsAPI(ServerAPI):
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
|
||||
@@ -75,6 +75,7 @@ from typing import (
|
||||
Dict,
|
||||
Callable,
|
||||
TypeVar,
|
||||
Tuple,
|
||||
)
|
||||
from overrides import override
|
||||
from uuid import UUID, uuid4
|
||||
@@ -922,6 +923,20 @@ class SegmentAPI(ServerAPI):
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attached functions are not supported in the Segment API (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
|
||||
"The Segment API (embedded mode) does not support attached function operations."
|
||||
)
|
||||
|
||||
@override
|
||||
def get_attached_function(
|
||||
self,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Attached functions are not supported in the Segment API (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
@@ -932,7 +947,8 @@ class SegmentAPI(ServerAPI):
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from typing import ClassVar, Dict
|
||||
from typing import ClassVar, Dict, Optional
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from chromadb.api import ServerAPI
|
||||
from chromadb.api.base_http_client import BaseHTTPClient
|
||||
from chromadb.config import Settings, System
|
||||
from chromadb.telemetry.product import ProductTelemetryClient
|
||||
from chromadb.telemetry.product.events import ClientStartEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharedSystemClient:
|
||||
_identifier_to_system: ClassVar[Dict[str, System]] = {}
|
||||
@@ -94,3 +97,59 @@ class SharedSystemClient:
|
||||
def _submit_client_start_event(self) -> None:
|
||||
telemetry_client = self._system.instance(ProductTelemetryClient)
|
||||
telemetry_client.capture(ClientStartEvent())
|
||||
|
||||
@staticmethod
|
||||
def get_chroma_cloud_api_key_from_clients() -> Optional[str]:
|
||||
"""
|
||||
Try to extract api key from existing client instances by checking httpx session headers.
|
||||
|
||||
Requirements to pull api key:
|
||||
- must be a BaseHTTPClient instance (ignore RustBindingsAPI and SegmentAPI)
|
||||
- must have "api.trychroma.com" or "gcp.trychroma.com" in the _api_url (ignore local/self-hosted instances)
|
||||
- must have "x-chroma-token" or "X-Chroma-Token" in the headers
|
||||
|
||||
Returns:
|
||||
The first api key found, or None if no client instances have api keys set.
|
||||
"""
|
||||
|
||||
api_keys: list[str] = []
|
||||
systems_snapshot = list(SharedSystemClient._identifier_to_system.values())
|
||||
for system in systems_snapshot:
|
||||
try:
|
||||
server_api = system.instance(ServerAPI)
|
||||
|
||||
if not isinstance(server_api, BaseHTTPClient):
|
||||
# RustBindingsAPI and SegmentAPI don't have HTTP headers
|
||||
continue
|
||||
|
||||
# Only pull api key if the url contains the chroma cloud url
|
||||
api_url = server_api.get_api_url()
|
||||
if (
|
||||
"api.trychroma.com" not in api_url
|
||||
and "gcp.trychroma.com" not in api_url
|
||||
):
|
||||
continue
|
||||
|
||||
headers = server_api.get_request_headers()
|
||||
api_key = None
|
||||
for key, value in headers.items():
|
||||
if key.lower() == "x-chroma-token":
|
||||
api_key = value
|
||||
break
|
||||
|
||||
if api_key:
|
||||
api_keys.append(api_key)
|
||||
except Exception:
|
||||
# If we can't access the ServerAPI instance, continue to the next
|
||||
continue
|
||||
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
# log if multiple viable api keys found
|
||||
if len(api_keys) > 1:
|
||||
logger.info(
|
||||
f"Multiple Chroma Cloud clients found, using API key starting with {api_keys[0][:8]}..."
|
||||
)
|
||||
|
||||
return api_keys[0]
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import (
|
||||
get_args,
|
||||
TYPE_CHECKING,
|
||||
Final,
|
||||
Type,
|
||||
)
|
||||
from copy import deepcopy
|
||||
from typing_extensions import TypeAlias
|
||||
@@ -20,7 +21,8 @@ from numpy.typing import NDArray
|
||||
import numpy as np
|
||||
import warnings
|
||||
from typing_extensions import TypedDict, Protocol, runtime_checkable
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
import chromadb.errors as errors
|
||||
from chromadb.base_types import (
|
||||
@@ -52,6 +54,8 @@ import pybase64
|
||||
from functools import lru_cache
|
||||
import struct
|
||||
import math
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
# Re-export types from chromadb.types
|
||||
__all__ = [
|
||||
@@ -112,6 +116,9 @@ __all__ = [
|
||||
"EMBEDDING_KEY",
|
||||
"TYPE_KEY",
|
||||
"SPARSE_VECTOR_TYPE_VALUE",
|
||||
# CMEK Support
|
||||
"Cmek",
|
||||
"CmekProvider",
|
||||
# Space type
|
||||
"Space",
|
||||
# Embedding Functions
|
||||
@@ -213,7 +220,7 @@ def optional_base64_strings_to_embeddings(
|
||||
|
||||
|
||||
def normalize_embeddings(
|
||||
target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]]
|
||||
target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]],
|
||||
) -> Optional[Embeddings]:
|
||||
if target is None:
|
||||
return None
|
||||
@@ -448,7 +455,9 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None:
|
||||
)
|
||||
|
||||
zero_lengths = [
|
||||
key for key, lst in record_set.items() if lst is not None and len(lst) == 0 # type: ignore[arg-type]
|
||||
key
|
||||
for key, lst in record_set.items()
|
||||
if lst is not None and len(lst) == 0 # type: ignore[arg-type]
|
||||
]
|
||||
|
||||
if zero_lengths:
|
||||
@@ -456,7 +465,9 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None:
|
||||
|
||||
if len(set(lengths)) > 1:
|
||||
error_str = ", ".join(
|
||||
f"{key}: {len(lst)}" for key, lst in record_set.items() if lst is not None # type: ignore[arg-type]
|
||||
f"{key}: {len(lst)}"
|
||||
for key, lst in record_set.items()
|
||||
if lst is not None # type: ignore[arg-type]
|
||||
)
|
||||
raise ValueError(f"Unequal lengths for fields: {error_str}")
|
||||
|
||||
@@ -756,8 +767,7 @@ class EmbeddingFunction(Protocol[D]):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, input: D) -> Embeddings:
|
||||
...
|
||||
def __call__(self, input: D) -> Embeddings: ...
|
||||
|
||||
def embed_query(self, input: D) -> Embeddings:
|
||||
"""
|
||||
@@ -941,8 +951,7 @@ def validate_embedding_function(
|
||||
|
||||
|
||||
class DataLoader(Protocol[L]):
|
||||
def __call__(self, uris: URIs) -> L:
|
||||
...
|
||||
def __call__(self, uris: URIs) -> L: ...
|
||||
|
||||
|
||||
def validate_ids(ids: IDs) -> IDs:
|
||||
@@ -1063,7 +1072,7 @@ def serialize_metadata(metadata: Optional[Metadata]) -> Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
def deserialize_metadata(
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Deserialize metadata from transport, converting dicts with #type=sparse_vector to dataclass instances.
|
||||
|
||||
@@ -1399,8 +1408,7 @@ class SparseEmbeddingFunction(Protocol[D]):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, input: D) -> SparseVectors:
|
||||
...
|
||||
def __call__(self, input: D) -> SparseVectors: ...
|
||||
|
||||
def embed_query(self, input: D) -> SparseVectors:
|
||||
"""
|
||||
@@ -1493,15 +1501,57 @@ def validate_sparse_embedding_function(
|
||||
|
||||
|
||||
# Index Configuration Types for Collection Schema
|
||||
def _create_extra_fields_validator(valid_fields: list[str]) -> Any:
|
||||
"""Create a model validator that provides helpful error messages for invalid fields."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_extra_fields(cls: Type[BaseModel], data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
invalid_fields = [k for k in data.keys() if k not in valid_fields]
|
||||
if invalid_fields:
|
||||
invalid_fields_str = ", ".join(f"'{f}'" for f in invalid_fields)
|
||||
class_name = cls.__name__
|
||||
# Create a clear, actionable error message
|
||||
if len(invalid_fields) == 1:
|
||||
msg = (
|
||||
f"'{invalid_fields[0]}' is not a valid field for {class_name}. "
|
||||
)
|
||||
else:
|
||||
msg = f"Invalid fields for {class_name}: {invalid_fields_str}. "
|
||||
|
||||
raise PydanticCustomError(
|
||||
"invalid_field",
|
||||
msg,
|
||||
{"invalid_fields": invalid_fields},
|
||||
)
|
||||
return data
|
||||
|
||||
return validate_extra_fields
|
||||
|
||||
|
||||
class FtsIndexConfig(BaseModel):
|
||||
"""Configuration for Full-Text Search index. No parameters required."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HnswIndexConfig(BaseModel):
|
||||
"""Configuration for HNSW vector index."""
|
||||
|
||||
_validate_extra_fields = _create_extra_fields_validator(
|
||||
[
|
||||
"ef_construction",
|
||||
"max_neighbors",
|
||||
"ef_search",
|
||||
"num_threads",
|
||||
"batch_size",
|
||||
"sync_threshold",
|
||||
"resize_factor",
|
||||
]
|
||||
)
|
||||
|
||||
ef_construction: Optional[int] = None
|
||||
max_neighbors: Optional[int] = None
|
||||
ef_search: Optional[int] = None
|
||||
@@ -1514,6 +1564,27 @@ class HnswIndexConfig(BaseModel):
|
||||
class SpannIndexConfig(BaseModel):
|
||||
"""Configuration for SPANN vector index."""
|
||||
|
||||
_validate_extra_fields = _create_extra_fields_validator(
|
||||
[
|
||||
"search_nprobe",
|
||||
"search_rng_factor",
|
||||
"search_rng_epsilon",
|
||||
"nreplica_count",
|
||||
"write_nprobe",
|
||||
"write_rng_factor",
|
||||
"write_rng_epsilon",
|
||||
"split_threshold",
|
||||
"num_samples_kmeans",
|
||||
"initial_lambda",
|
||||
"reassign_neighbor_count",
|
||||
"merge_threshold",
|
||||
"num_centers_to_merge_to",
|
||||
"ef_construction",
|
||||
"ef_search",
|
||||
"max_neighbors",
|
||||
]
|
||||
)
|
||||
|
||||
search_nprobe: Optional[int] = None
|
||||
write_nprobe: Optional[int] = None
|
||||
ef_construction: Optional[int] = None
|
||||
@@ -1527,10 +1598,13 @@ class SpannIndexConfig(BaseModel):
|
||||
class VectorIndexConfig(BaseModel):
|
||||
"""Configuration for vector index with space, embedding function, and algorithm config."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
|
||||
|
||||
space: Optional[Space] = None
|
||||
embedding_function: Optional[Any] = DefaultEmbeddingFunction()
|
||||
source_key: Optional[str] = None # key to source the vector from (accepts str or Key)
|
||||
source_key: Optional[str] = (
|
||||
None # key to source the vector from (accepts str or Key)
|
||||
)
|
||||
hnsw: Optional[HnswIndexConfig] = None
|
||||
spann: Optional[SpannIndexConfig] = None
|
||||
|
||||
@@ -1542,6 +1616,7 @@ class VectorIndexConfig(BaseModel):
|
||||
return None
|
||||
# Import Key at runtime to avoid circular import
|
||||
from chromadb.execution.expression.operator import Key as KeyType
|
||||
|
||||
if isinstance(v, KeyType):
|
||||
v = v.name # Extract string from Key
|
||||
elif isinstance(v, str):
|
||||
@@ -1574,10 +1649,13 @@ class VectorIndexConfig(BaseModel):
|
||||
class SparseVectorIndexConfig(BaseModel):
|
||||
"""Configuration for sparse vector index."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
|
||||
|
||||
# TODO(Sanket): Change this to the appropriate sparse ef and use a default here.
|
||||
embedding_function: Optional[Any] = None
|
||||
source_key: Optional[str] = None # key to source the sparse vector from (accepts str or Key)
|
||||
source_key: Optional[str] = (
|
||||
None # key to source the sparse vector from (accepts str or Key)
|
||||
)
|
||||
bm25: Optional[bool] = None
|
||||
|
||||
@field_validator("source_key", mode="before")
|
||||
@@ -1588,6 +1666,7 @@ class SparseVectorIndexConfig(BaseModel):
|
||||
return None
|
||||
# Import Key at runtime to avoid circular import
|
||||
from chromadb.execution.expression.operator import Key as KeyType
|
||||
|
||||
if isinstance(v, KeyType):
|
||||
v = v.name # Extract string from Key
|
||||
elif isinstance(v, str):
|
||||
@@ -1622,24 +1701,32 @@ class SparseVectorIndexConfig(BaseModel):
|
||||
class StringInvertedIndexConfig(BaseModel):
|
||||
"""Configuration for string inverted index."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IntInvertedIndexConfig(BaseModel):
|
||||
"""Configuration for integer inverted index."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FloatInvertedIndexConfig(BaseModel):
|
||||
"""Configuration for float inverted index."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BoolInvertedIndexConfig(BaseModel):
|
||||
"""Configuration for boolean inverted index."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -1683,6 +1770,148 @@ TYPE_KEY: Final[str] = "#type"
|
||||
SPARSE_VECTOR_TYPE_VALUE: Final[str] = "sparse_vector"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CMEK (Customer-Managed Encryption Key) Support
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CmekProvider(str, Enum):
|
||||
"""Supported cloud providers for customer-managed encryption keys.
|
||||
|
||||
Currently only Google Cloud Platform (GCP) is supported.
|
||||
"""
|
||||
|
||||
GCP = "gcp"
|
||||
|
||||
|
||||
# Regex pattern for validating GCP KMS resource names
|
||||
_CMEK_GCP_PATTERN = re.compile(r"^projects/.+/locations/.+/keyRings/.+/cryptoKeys/.+$")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Cmek:
|
||||
"""Customer-managed encryption key (CMEK) for collection data encryption.
|
||||
|
||||
CMEK allows you to use your own encryption keys managed by cloud providers'
|
||||
key management services (KMS) instead of default provider-managed keys. This
|
||||
gives you greater control over key lifecycle, access policies, and audit logging.
|
||||
|
||||
Attributes:
|
||||
provider: The cloud provider (currently only 'gcp' supported)
|
||||
resource: The provider-specific resource identifier for the encryption key
|
||||
|
||||
Example:
|
||||
>>> # Create a CMEK for GCP
|
||||
>>> cmek = Cmek.gcp(
|
||||
... "projects/my-project/locations/us-central1/"
|
||||
... "keyRings/my-ring/cryptoKeys/my-key"
|
||||
... )
|
||||
>>>
|
||||
>>> # Validate the resource name format
|
||||
>>> if cmek.validate_pattern():
|
||||
... print("Valid CMEK format")
|
||||
|
||||
Note:
|
||||
Pattern validation only checks format correctness. It does not verify
|
||||
that the key exists or is accessible. Key permissions and access control
|
||||
must be configured separately in your cloud provider's console.
|
||||
"""
|
||||
|
||||
provider: CmekProvider
|
||||
resource: str
|
||||
|
||||
@classmethod
|
||||
def gcp(cls, resource: str) -> "Cmek":
|
||||
"""Create a CMEK instance for Google Cloud Platform.
|
||||
|
||||
Args:
|
||||
resource: GCP Cloud KMS resource name in the format:
|
||||
projects/{project-id}/locations/{location}/keyRings/{key-ring}/cryptoKeys/{key-name}
|
||||
|
||||
Example: "projects/my-project/locations/us-central1/keyRings/my-ring/cryptoKeys/my-key"
|
||||
|
||||
Returns:
|
||||
Cmek: A new CMEK instance configured for GCP
|
||||
|
||||
Raises:
|
||||
ValueError: If the resource format is invalid (when validate_pattern() is called)
|
||||
|
||||
Example:
|
||||
>>> cmek = Cmek.gcp(
|
||||
... "projects/my-project/locations/us-central1/"
|
||||
... "keyRings/my-ring/cryptoKeys/my-key"
|
||||
... )
|
||||
"""
|
||||
return cls(provider=CmekProvider.GCP, resource=resource)
|
||||
|
||||
def validate_pattern(self) -> bool:
|
||||
"""Validate the CMEK resource name format.
|
||||
|
||||
Validates that the resource name matches the expected pattern for the
|
||||
provider. This is a format check only and does not verify that the key
|
||||
exists or that you have access to it.
|
||||
|
||||
For GCP, the expected format is:
|
||||
projects/{project}/locations/{location}/keyRings/{keyRing}/cryptoKeys/{cryptoKey}
|
||||
|
||||
Returns:
|
||||
bool: True if the resource name format is valid, False otherwise
|
||||
|
||||
Example:
|
||||
>>> cmek = Cmek.gcp("projects/p/locations/l/keyRings/r/cryptoKeys/k")
|
||||
>>> cmek.validate_pattern() # Returns True
|
||||
>>>
|
||||
>>> bad_cmek = Cmek.gcp("invalid-format")
|
||||
>>> bad_cmek.validate_pattern() # Returns False
|
||||
"""
|
||||
if self.provider == CmekProvider.GCP:
|
||||
return _CMEK_GCP_PATTERN.match(self.resource) is not None
|
||||
return False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize CMEK to dictionary format for API transport.
|
||||
|
||||
Returns:
|
||||
Dict containing the provider variant and resource identifier.
|
||||
|
||||
Example:
|
||||
>>> cmek = Cmek.gcp("projects/p/locations/l/keyRings/r/cryptoKeys/k")
|
||||
>>> cmek.to_dict()
|
||||
{'gcp': 'projects/p/locations/l/keyRings/r/cryptoKeys/k'}
|
||||
"""
|
||||
if self.provider == CmekProvider.GCP:
|
||||
return {"gcp": self.resource}
|
||||
# Unreachable with current providers, but future-proof
|
||||
raise ValueError(f"Unknown CMEK provider: {self.provider}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Cmek":
|
||||
"""Deserialize CMEK from dictionary format.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing provider variant and resource.
|
||||
Format matches Rust serde enum serialization with snake_case.
|
||||
|
||||
Returns:
|
||||
Cmek: Deserialized CMEK instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is unsupported or data is malformed
|
||||
|
||||
Example:
|
||||
>>> data = {'gcp': 'projects/p/locations/l/keyRings/r/cryptoKeys/k'}
|
||||
>>> cmek = Cmek.from_dict(data)
|
||||
"""
|
||||
if "gcp" in data:
|
||||
resource = data["gcp"]
|
||||
if not isinstance(resource, str):
|
||||
raise ValueError(
|
||||
f"CMEK gcp resource must be a string, got: {type(resource)}"
|
||||
)
|
||||
return cls.gcp(resource)
|
||||
raise ValueError(f"Unsupported or missing CMEK provider in data: {data}")
|
||||
|
||||
|
||||
# Index Type Classes
|
||||
|
||||
|
||||
@@ -1774,24 +2003,40 @@ class ValueTypes:
|
||||
|
||||
@dataclass
|
||||
class Schema:
|
||||
"""Collection schema for configuring indexes and encryption.
|
||||
|
||||
The schema controls how data is indexed and can optionally specify
|
||||
customer-managed encryption keys (CMEK) for data at rest.
|
||||
|
||||
Attributes:
|
||||
defaults: Default index configurations for each value type
|
||||
keys: Key-specific index overrides
|
||||
cmek: Optional customer-managed encryption key for collection data
|
||||
"""
|
||||
|
||||
defaults: ValueTypes
|
||||
keys: Dict[str, ValueTypes]
|
||||
cmek: Optional[Cmek] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Initialize the dataclass fields first
|
||||
self.defaults = ValueTypes()
|
||||
self.keys: Dict[str, ValueTypes] = {}
|
||||
self.cmek: Optional[Cmek] = None
|
||||
|
||||
# Populate with sensible defaults automatically
|
||||
self._initialize_defaults()
|
||||
self._initialize_keys()
|
||||
|
||||
def create_index(
|
||||
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
|
||||
self,
|
||||
config: Optional[IndexConfig] = None,
|
||||
key: Optional[Union[str, "Key"]] = None,
|
||||
) -> "Schema":
|
||||
"""Create an index configuration."""
|
||||
# Convert Key to string if provided
|
||||
from chromadb.execution.expression.operator import Key as KeyType
|
||||
|
||||
if key is not None and isinstance(key, KeyType):
|
||||
key = key.name
|
||||
|
||||
@@ -1869,11 +2114,14 @@ class Schema:
|
||||
return self
|
||||
|
||||
def delete_index(
|
||||
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
|
||||
self,
|
||||
config: Optional[IndexConfig] = None,
|
||||
key: Optional[Union[str, "Key"]] = None,
|
||||
) -> "Schema":
|
||||
"""Disable an index configuration (set enabled=False)."""
|
||||
# Convert Key to string if provided
|
||||
from chromadb.execution.expression.operator import Key as KeyType
|
||||
|
||||
if key is not None and isinstance(key, KeyType):
|
||||
key = key.name
|
||||
|
||||
@@ -1930,6 +2178,23 @@ class Schema:
|
||||
|
||||
return self
|
||||
|
||||
def set_cmek(self, cmek: Optional[Cmek]) -> "Schema":
|
||||
"""Set customer-managed encryption key for the collection (fluent interface).
|
||||
|
||||
Args:
|
||||
cmek: Customer-managed encryption key configuration, or None to remove CMEK
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
|
||||
Example:
|
||||
>>> schema = Schema().set_cmek(
|
||||
... Cmek.gcp("projects/my-project/locations/us/keyRings/my-ring/cryptoKeys/my-key")
|
||||
... )
|
||||
"""
|
||||
self.cmek = cmek
|
||||
return self
|
||||
|
||||
def _get_config_class_name(self, config: IndexConfig) -> str:
|
||||
"""Get the class name for a config."""
|
||||
return config.__class__.__name__
|
||||
@@ -1943,11 +2208,15 @@ class Schema:
|
||||
"""
|
||||
# Update the config in defaults (preserve enabled=False)
|
||||
current_enabled = self.defaults.float_list.vector_index.enabled # type: ignore[union-attr]
|
||||
self.defaults.float_list.vector_index = VectorIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
|
||||
self.defaults.float_list.vector_index = VectorIndexType(
|
||||
enabled=current_enabled, config=config
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
# Update the config on #embedding key (preserve enabled=True and source_key="#document")
|
||||
current_enabled = self.keys[EMBEDDING_KEY].float_list.vector_index.enabled # type: ignore[union-attr]
|
||||
current_source_key = self.keys[EMBEDDING_KEY].float_list.vector_index.config.source_key # type: ignore[union-attr]
|
||||
current_source_key = self.keys[
|
||||
EMBEDDING_KEY
|
||||
].float_list.vector_index.config.source_key # type: ignore[union-attr]
|
||||
|
||||
# Create a new config with user settings but preserve the original source_key
|
||||
embedding_config = VectorIndexConfig(
|
||||
@@ -1957,7 +2226,9 @@ class Schema:
|
||||
spann=config.spann,
|
||||
source_key=current_source_key, # Preserve original source_key (should be "#document")
|
||||
)
|
||||
self.keys[EMBEDDING_KEY].float_list.vector_index = VectorIndexType(enabled=current_enabled, config=embedding_config) # type: ignore[union-attr]
|
||||
self.keys[EMBEDDING_KEY].float_list.vector_index = VectorIndexType(
|
||||
enabled=current_enabled, config=embedding_config
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
def _set_fts_index_config(self, config: FtsIndexConfig) -> None:
|
||||
"""
|
||||
@@ -1967,11 +2238,15 @@ class Schema:
|
||||
"""
|
||||
# Update the config in defaults (preserve enabled=False)
|
||||
current_enabled = self.defaults.string.fts_index.enabled # type: ignore[union-attr]
|
||||
self.defaults.string.fts_index = FtsIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
|
||||
self.defaults.string.fts_index = FtsIndexType(
|
||||
enabled=current_enabled, config=config
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
# Update the config on #document key (preserve enabled=True)
|
||||
current_enabled = self.keys[DOCUMENT_KEY].string.fts_index.enabled # type: ignore[union-attr]
|
||||
self.keys[DOCUMENT_KEY].string.fts_index = FtsIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
|
||||
self.keys[DOCUMENT_KEY].string.fts_index = FtsIndexType(
|
||||
enabled=current_enabled, config=config
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
def _set_index_in_defaults(self, config: IndexConfig, enabled: bool) -> None:
|
||||
"""Set an index configuration in the defaults."""
|
||||
@@ -2071,37 +2346,51 @@ class Schema:
|
||||
if config_name == "FtsIndexConfig":
|
||||
if self.keys[key].string is None:
|
||||
self.keys[key].string = StringValueType()
|
||||
self.keys[key].string.fts_index = FtsIndexType(enabled=enabled, config=cast(FtsIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].string.fts_index = FtsIndexType(
|
||||
enabled=enabled, config=cast(FtsIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "StringInvertedIndexConfig":
|
||||
if self.keys[key].string is None:
|
||||
self.keys[key].string = StringValueType()
|
||||
self.keys[key].string.string_inverted_index = StringInvertedIndexType(enabled=enabled, config=cast(StringInvertedIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].string.string_inverted_index = StringInvertedIndexType(
|
||||
enabled=enabled, config=cast(StringInvertedIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "VectorIndexConfig":
|
||||
if self.keys[key].float_list is None:
|
||||
self.keys[key].float_list = FloatListValueType()
|
||||
self.keys[key].float_list.vector_index = VectorIndexType(enabled=enabled, config=cast(VectorIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].float_list.vector_index = VectorIndexType(
|
||||
enabled=enabled, config=cast(VectorIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "SparseVectorIndexConfig":
|
||||
if self.keys[key].sparse_vector is None:
|
||||
self.keys[key].sparse_vector = SparseVectorValueType()
|
||||
self.keys[key].sparse_vector.sparse_vector_index = SparseVectorIndexType(enabled=enabled, config=cast(SparseVectorIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].sparse_vector.sparse_vector_index = SparseVectorIndexType(
|
||||
enabled=enabled, config=cast(SparseVectorIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "IntInvertedIndexConfig":
|
||||
if self.keys[key].int_value is None:
|
||||
self.keys[key].int_value = IntValueType()
|
||||
self.keys[key].int_value.int_inverted_index = IntInvertedIndexType(enabled=enabled, config=cast(IntInvertedIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].int_value.int_inverted_index = IntInvertedIndexType(
|
||||
enabled=enabled, config=cast(IntInvertedIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "FloatInvertedIndexConfig":
|
||||
if self.keys[key].float_value is None:
|
||||
self.keys[key].float_value = FloatValueType()
|
||||
self.keys[key].float_value.float_inverted_index = FloatInvertedIndexType(enabled=enabled, config=cast(FloatInvertedIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].float_value.float_inverted_index = FloatInvertedIndexType(
|
||||
enabled=enabled, config=cast(FloatInvertedIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "BoolInvertedIndexConfig":
|
||||
if self.keys[key].boolean is None:
|
||||
self.keys[key].boolean = BoolValueType()
|
||||
self.keys[key].boolean.bool_inverted_index = BoolInvertedIndexType(enabled=enabled, config=cast(BoolInvertedIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].boolean.bool_inverted_index = BoolInvertedIndexType(
|
||||
enabled=enabled, config=cast(BoolInvertedIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
def _enable_all_indexes_for_key(self, key: str) -> None:
|
||||
"""Enable all possible index types for a specific key."""
|
||||
@@ -2249,7 +2538,13 @@ class Schema:
|
||||
for key, value_types in self.keys.items():
|
||||
keys_json[key] = self._serialize_value_types(value_types)
|
||||
|
||||
return {"defaults": defaults_json, "keys": keys_json}
|
||||
result: Dict[str, Any] = {"defaults": defaults_json, "keys": keys_json}
|
||||
|
||||
# Add CMEK if present
|
||||
if self.cmek is not None:
|
||||
result["cmek"] = self.cmek.to_dict()
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def deserialize_from_json(cls, json_data: Dict[str, Any]) -> "Schema":
|
||||
@@ -2263,6 +2558,11 @@ class Schema:
|
||||
for key, value_types_json in json_data.get("keys", {}).items():
|
||||
instance.keys[key] = cls._deserialize_value_types(value_types_json)
|
||||
|
||||
# Deserialize CMEK if present
|
||||
instance.cmek = None
|
||||
if "cmek" in json_data and json_data["cmek"] is not None:
|
||||
instance.cmek = Cmek.from_dict(json_data["cmek"])
|
||||
|
||||
return instance
|
||||
|
||||
def _serialize_value_types(self, value_types: ValueTypes) -> Dict[str, Any]:
|
||||
@@ -2410,6 +2710,10 @@ class Schema:
|
||||
if embedding_func.is_legacy():
|
||||
config_dict["embedding_function"] = {"type": "legacy"}
|
||||
else:
|
||||
if hasattr(embedding_func, "validate_config"):
|
||||
embedding_func.validate_config(
|
||||
embedding_func.get_config()
|
||||
)
|
||||
config_dict["embedding_function"] = {
|
||||
"name": embedding_func.name(),
|
||||
"type": "known",
|
||||
@@ -2439,6 +2743,8 @@ class Schema:
|
||||
config_dict["embedding_function"] = {"type": "unknown"}
|
||||
else:
|
||||
embedding_func = cast(SparseEmbeddingFunction, embedding_func) # type: ignore
|
||||
if hasattr(embedding_func, "validate_config"):
|
||||
embedding_func.validate_config(embedding_func.get_config())
|
||||
config_dict["embedding_function"] = {
|
||||
"name": embedding_func.name(),
|
||||
"type": "known",
|
||||
|
||||
Reference in New Issue
Block a user