增加环绕侦察场景适配
This commit is contained in:
@@ -11,6 +11,7 @@ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings
|
||||
from chromadb.api import AdminAPI, AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import (
|
||||
Cmek,
|
||||
CollectionMetadata,
|
||||
UpdateMetadata,
|
||||
Documents,
|
||||
@@ -58,6 +59,7 @@ import os
|
||||
|
||||
# Re-export types from chromadb.types
|
||||
__all__ = [
|
||||
"Cmek",
|
||||
"Collection",
|
||||
"Metadata",
|
||||
"Metadatas",
|
||||
@@ -105,7 +107,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
__settings = Settings()
|
||||
|
||||
__version__ = "1.3.4"
|
||||
__version__ = "1.4.0"
|
||||
|
||||
|
||||
# Workaround to deal with Colab's old sqlite3 version
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -33,10 +33,14 @@ from chromadb.execution.expression import ( # noqa: F401, F403
|
||||
Sub,
|
||||
Sum,
|
||||
Val,
|
||||
Aggregate,
|
||||
MinK,
|
||||
MaxK,
|
||||
GroupBy,
|
||||
)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence, Optional, List, Dict, Any
|
||||
from typing import Sequence, Optional, List, Dict, Any, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from overrides import override
|
||||
@@ -824,7 +828,7 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attach a function to a collection.
|
||||
|
||||
Args:
|
||||
@@ -837,14 +841,40 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
|
||||
database: The database name
|
||||
|
||||
Returns:
|
||||
AttachedFunction: Object representing the attached function
|
||||
Tuple of (AttachedFunction, created) where created is True if newly created,
|
||||
False if already existed (idempotent request)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_attached_function(
|
||||
self,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Get an attached function by name for a specific collection.
|
||||
|
||||
Args:
|
||||
name: Name of the attached function
|
||||
input_collection_id: The collection ID
|
||||
tenant: The tenant name
|
||||
database: The database name
|
||||
|
||||
Returns:
|
||||
AttachedFunction: The attached function object
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the attached function doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
@@ -852,7 +882,8 @@ class ServerAPI(BaseAPI, AdminAPI, Component):
|
||||
"""Detach a function and prevent any further runs.
|
||||
|
||||
Args:
|
||||
attached_function_id: ID of the attached function to remove
|
||||
name: Name of the attached function to remove
|
||||
input_collection_id: ID of the input collection
|
||||
delete_output: Whether to also delete the output collection
|
||||
tenant: The tenant name
|
||||
database: The database name
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,7 +2,7 @@ import asyncio
|
||||
from uuid import UUID
|
||||
import urllib.parse
|
||||
import orjson
|
||||
from typing import Any, Optional, cast, Tuple, Sequence, Dict, List
|
||||
from typing import Any, Mapping, Optional, cast, Tuple, Sequence, Dict, List
|
||||
import logging
|
||||
import httpx
|
||||
from overrides import override
|
||||
@@ -130,16 +130,23 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
|
||||
+ " (https://github.com/chroma-core/chroma)"
|
||||
)
|
||||
|
||||
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
|
||||
self._clients[loop_hash] = httpx.AsyncClient(
|
||||
timeout=None,
|
||||
headers=headers,
|
||||
verify=self._settings.chroma_server_ssl_verify or False,
|
||||
limits=limits,
|
||||
limits=self.http_limits,
|
||||
)
|
||||
|
||||
return self._clients[loop_hash]
|
||||
|
||||
@override
|
||||
def get_request_headers(self) -> Mapping[str, str]:
|
||||
return dict(self._get_client().headers)
|
||||
|
||||
@override
|
||||
def get_api_url(self) -> str:
|
||||
return self._api_url
|
||||
|
||||
async def _make_request(
|
||||
self, method: str, path: str, **kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
@@ -527,7 +534,7 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
|
||||
return GetResult(
|
||||
ids=resp_json["ids"],
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadatas, # type: ignore
|
||||
metadatas=metadatas,
|
||||
documents=resp_json.get("documents", None),
|
||||
data=None,
|
||||
uris=resp_json.get("uris", None),
|
||||
@@ -723,7 +730,7 @@ class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
|
||||
ids=resp_json["ids"],
|
||||
distances=resp_json.get("distances", None),
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadata_batches, # type: ignore
|
||||
metadatas=metadata_batches,
|
||||
documents=resp_json.get("documents", None),
|
||||
uris=resp_json.get("uris", None),
|
||||
data=None,
|
||||
|
||||
@@ -1,19 +1,51 @@
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
from typing import Any, Dict, Mapping, Optional, TypeVar
|
||||
from urllib.parse import quote, urlparse, urlunparse
|
||||
import logging
|
||||
import orjson as json
|
||||
import httpx
|
||||
|
||||
import chromadb.errors as errors
|
||||
from chromadb.config import Settings
|
||||
from chromadb.config import Component, Settings, System
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseHTTPClient:
|
||||
# inherits from Component so that it can create an init function to use system
|
||||
# this way it can build limits from the settings in System
|
||||
class BaseHTTPClient(Component):
|
||||
_settings: Settings
|
||||
pre_flight_checks: Any = None
|
||||
keepalive_secs: int = 40
|
||||
DEFAULT_KEEPALIVE_SECS: float = 40.0
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
keepalive_setting = self._settings.chroma_http_keepalive_secs
|
||||
self.keepalive_secs: Optional[float] = (
|
||||
keepalive_setting
|
||||
if keepalive_setting is not None
|
||||
else BaseHTTPClient.DEFAULT_KEEPALIVE_SECS
|
||||
)
|
||||
self._http_limits = self._build_limits()
|
||||
|
||||
def _build_limits(self) -> httpx.Limits:
|
||||
limit_kwargs: Dict[str, Any] = {}
|
||||
if self.keepalive_secs is not None:
|
||||
limit_kwargs["keepalive_expiry"] = self.keepalive_secs
|
||||
|
||||
max_connections = self._settings.chroma_http_max_connections
|
||||
if max_connections is not None:
|
||||
limit_kwargs["max_connections"] = max_connections
|
||||
|
||||
max_keepalive_connections = self._settings.chroma_http_max_keepalive_connections
|
||||
if max_keepalive_connections is not None:
|
||||
limit_kwargs["max_keepalive_connections"] = max_keepalive_connections
|
||||
|
||||
return httpx.Limits(**limit_kwargs)
|
||||
|
||||
@property
|
||||
def http_limits(self) -> httpx.Limits:
|
||||
return self._http_limits
|
||||
|
||||
@staticmethod
|
||||
def _validate_host(host: str) -> None:
|
||||
@@ -103,3 +135,11 @@ class BaseHTTPClient:
|
||||
if trace_id:
|
||||
raise Exception(f"{resp.text} (trace ID: {trace_id})")
|
||||
raise (Exception(resp.text))
|
||||
|
||||
def get_request_headers(self) -> Mapping[str, str]:
|
||||
"""Return headers used for HTTP requests."""
|
||||
return {}
|
||||
|
||||
def get_api_url(self) -> str:
|
||||
"""Return the API URL for this client."""
|
||||
return ""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import orjson
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, cast, Tuple, List
|
||||
from typing import Any, Dict, Mapping, Optional, cast, Tuple, List
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
import httpx
|
||||
@@ -79,8 +79,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
default_api_path=system.settings.chroma_server_api_default_path,
|
||||
)
|
||||
|
||||
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
|
||||
self._session = httpx.Client(timeout=None, limits=limits)
|
||||
if self._settings.chroma_server_ssl_verify is not None:
|
||||
self._session = httpx.Client(
|
||||
timeout=None,
|
||||
limits=self.http_limits,
|
||||
verify=self._settings.chroma_server_ssl_verify,
|
||||
)
|
||||
else:
|
||||
self._session = httpx.Client(timeout=None, limits=self.http_limits)
|
||||
|
||||
self._header = system.settings.chroma_server_headers or {}
|
||||
self._header["Content-Type"] = "application/json"
|
||||
@@ -90,8 +96,6 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
+ " (https://github.com/chroma-core/chroma)"
|
||||
)
|
||||
|
||||
if self._settings.chroma_server_ssl_verify is not None:
|
||||
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
|
||||
if self._header is not None:
|
||||
self._session.headers.update(self._header)
|
||||
|
||||
@@ -101,6 +105,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
for header, value in _headers.items():
|
||||
self._session.headers[header] = value.get_secret_value()
|
||||
|
||||
@override
|
||||
def get_request_headers(self) -> Mapping[str, str]:
|
||||
return dict(self._session.headers)
|
||||
|
||||
@override
|
||||
def get_api_url(self) -> str:
|
||||
return self._api_url
|
||||
|
||||
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
|
||||
# If the request has json in kwargs, use orjson to serialize it,
|
||||
# remove it from kwargs, and add it to the content parameter
|
||||
@@ -492,7 +504,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
return GetResult(
|
||||
ids=resp_json["ids"],
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadatas, # type: ignore
|
||||
metadatas=metadatas,
|
||||
documents=resp_json.get("documents", None),
|
||||
data=None,
|
||||
uris=resp_json.get("uris", None),
|
||||
@@ -700,7 +712,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
ids=resp_json["ids"],
|
||||
distances=resp_json.get("distances", None),
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadata_batches, # type: ignore
|
||||
metadatas=metadata_batches,
|
||||
documents=resp_json.get("documents", None),
|
||||
uris=resp_json.get("uris", None),
|
||||
data=None,
|
||||
@@ -761,7 +773,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attach a function to a collection."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
@@ -774,23 +786,56 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
},
|
||||
)
|
||||
|
||||
return AttachedFunction(
|
||||
attached_function = AttachedFunction(
|
||||
client=self,
|
||||
id=UUID(resp_json["attached_function"]["id"]),
|
||||
name=resp_json["attached_function"]["name"],
|
||||
function_id=resp_json["attached_function"]["function_id"],
|
||||
function_name=resp_json["attached_function"]["function_name"],
|
||||
input_collection_id=input_collection_id,
|
||||
output_collection=output_collection,
|
||||
params=params,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
created = resp_json.get(
|
||||
"created", True
|
||||
) # Default to True for backwards compatibility
|
||||
return (attached_function, created)
|
||||
|
||||
@trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def get_attached_function(
|
||||
self,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Get an attached function by name for a specific collection."""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/{name}",
|
||||
)
|
||||
|
||||
af = resp_json["attached_function"]
|
||||
return AttachedFunction(
|
||||
client=self,
|
||||
id=UUID(af["id"]),
|
||||
name=af["name"],
|
||||
function_name=af["function_name"],
|
||||
input_collection_id=input_collection_id,
|
||||
output_collection=af["output_collection"],
|
||||
params=af.get("params"),
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
@@ -798,7 +843,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
"""Detach a function and prevent any further runs."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/attached_functions/{name}/detach",
|
||||
json={
|
||||
"delete_output": delete_output,
|
||||
},
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api import ServerAPI # noqa: F401
|
||||
@@ -13,7 +14,7 @@ class AttachedFunction:
|
||||
client: "ServerAPI",
|
||||
id: UUID,
|
||||
name: str,
|
||||
function_id: str,
|
||||
function_name: str,
|
||||
input_collection_id: UUID,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]],
|
||||
@@ -26,7 +27,7 @@ class AttachedFunction:
|
||||
client: The API client
|
||||
id: Unique identifier for this attached function
|
||||
name: Name of this attached function instance
|
||||
function_id: The function identifier (e.g., "record_counter")
|
||||
function_name: The function name (e.g., "record_counter", "statistics")
|
||||
input_collection_id: ID of the input collection
|
||||
output_collection: Name of the output collection
|
||||
params: Function-specific parameters
|
||||
@@ -36,7 +37,7 @@ class AttachedFunction:
|
||||
self._client = client
|
||||
self._id = id
|
||||
self._name = name
|
||||
self._function_id = function_id
|
||||
self._function_name = function_name
|
||||
self._input_collection_id = input_collection_id
|
||||
self._output_collection = output_collection
|
||||
self._params = params
|
||||
@@ -54,9 +55,9 @@ class AttachedFunction:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def function_id(self) -> str:
|
||||
"""The function identifier."""
|
||||
return self._function_id
|
||||
def function_name(self) -> str:
|
||||
"""The function name."""
|
||||
return self._function_name
|
||||
|
||||
@property
|
||||
def input_collection_id(self) -> UUID:
|
||||
@@ -73,29 +74,69 @@ class AttachedFunction:
|
||||
"""The function parameters."""
|
||||
return self._params
|
||||
|
||||
def detach(self, delete_output_collection: bool = False) -> bool:
|
||||
"""Detach this function and prevent any further runs.
|
||||
@staticmethod
|
||||
def _normalize_params(params: Optional[Any]) -> Dict[str, Any]:
|
||||
"""Normalize params to a consistent dict format.
|
||||
|
||||
Args:
|
||||
delete_output_collection: Whether to also delete the output collection. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
|
||||
Example:
|
||||
>>> success = attached_fn.detach(delete_output_collection=True)
|
||||
Handles None, empty strings, JSON strings, and dicts.
|
||||
"""
|
||||
return self._client.detach_function(
|
||||
attached_function_id=self._id,
|
||||
delete_output=delete_output_collection,
|
||||
tenant=self._tenant,
|
||||
database=self._database,
|
||||
)
|
||||
if params is None:
|
||||
return {}
|
||||
if isinstance(params, str):
|
||||
try:
|
||||
result = json.loads(params) if params else {}
|
||||
return result if isinstance(result, dict) else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
if isinstance(params, dict):
|
||||
return params
|
||||
return {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AttachedFunction(id={self._id}, name='{self._name}', "
|
||||
f"function_id='{self._function_id}', "
|
||||
f"function_name='{self._function_name}', "
|
||||
f"input_collection_id={self._input_collection_id}, "
|
||||
f"output_collection='{self._output_collection}')"
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare two AttachedFunction objects for equality."""
|
||||
if not isinstance(other, AttachedFunction):
|
||||
return False
|
||||
|
||||
# Normalize params: handle None, {}, and JSON strings
|
||||
self_params = self._normalize_params(self._params)
|
||||
other_params = self._normalize_params(other._params)
|
||||
|
||||
return (
|
||||
self._id == other._id
|
||||
and self._name == other._name
|
||||
and self._function_name == other._function_name
|
||||
and self._input_collection_id == other._input_collection_id
|
||||
and self._output_collection == other._output_collection
|
||||
and self_params == other_params
|
||||
and self._tenant == other._tenant
|
||||
and self._database == other._database
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the AttachedFunction."""
|
||||
# Normalize params using the same logic as __eq__
|
||||
normalized_params = self._normalize_params(self._params)
|
||||
params_tuple = (
|
||||
tuple(sorted(normalized_params.items())) if normalized_params else ()
|
||||
)
|
||||
|
||||
return hash(
|
||||
(
|
||||
self._id,
|
||||
self._name,
|
||||
self._function_name,
|
||||
self._input_collection_id,
|
||||
self._output_collection,
|
||||
params_tuple,
|
||||
self._tenant,
|
||||
self._database,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any
|
||||
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any, Tuple
|
||||
|
||||
from chromadb.api.models.CollectionCommon import CollectionCommon
|
||||
from chromadb.api.types import (
|
||||
@@ -25,6 +25,8 @@ from chromadb.execution.expression.plan import Search
|
||||
|
||||
import logging
|
||||
|
||||
from chromadb.api.functions import Function
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api.models.AttachedFunction import AttachedFunction
|
||||
|
||||
@@ -500,30 +502,36 @@ class Collection(CollectionCommon["ServerAPI"]):
|
||||
|
||||
def attach_function(
|
||||
self,
|
||||
function_id: str,
|
||||
function: Function,
|
||||
name: str,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> "AttachedFunction":
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attach a function to this collection.
|
||||
|
||||
Args:
|
||||
function_id: Built-in function identifier (e.g., "record_counter")
|
||||
function: A Function enum value (e.g., STATISTICS_FUNCTION, RECORD_COUNTER_FUNCTION)
|
||||
name: Unique name for this attached function
|
||||
output_collection: Name of the collection where function output will be stored
|
||||
params: Optional dictionary with function-specific parameters
|
||||
|
||||
Returns:
|
||||
AttachedFunction: Object representing the attached function
|
||||
Tuple of (AttachedFunction, created) where created is True if newly created,
|
||||
False if already existed (idempotent request)
|
||||
|
||||
Example:
|
||||
>>> from chromadb.api.functions import STATISTICS_FUNCTION
|
||||
>>> attached_fn = collection.attach_function(
|
||||
... function_id="record_counter",
|
||||
... function=STATISTICS_FUNCTION,
|
||||
... name="mycoll_stats_fn",
|
||||
... output_collection="mycoll_stats",
|
||||
... params={"threshold": 100}
|
||||
... )
|
||||
>>> if created:
|
||||
... print("New function attached")
|
||||
... else:
|
||||
... print("Function already existed")
|
||||
"""
|
||||
function_id = function.value if isinstance(function, Function) else function
|
||||
return self._client.attach_function(
|
||||
function_id=function_id,
|
||||
name=name,
|
||||
@@ -533,3 +541,47 @@ class Collection(CollectionCommon["ServerAPI"]):
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def get_attached_function(self, name: str) -> "AttachedFunction":
|
||||
"""Get an attached function by name for this collection.
|
||||
|
||||
Args:
|
||||
name: Name of the attached function
|
||||
|
||||
Returns:
|
||||
AttachedFunction: The attached function object
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the attached function doesn't exist
|
||||
"""
|
||||
return self._client.get_attached_function(
|
||||
name=name,
|
||||
input_collection_id=self.id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def detach_function(
|
||||
self,
|
||||
name: str,
|
||||
delete_output_collection: bool = False,
|
||||
) -> bool:
|
||||
"""Detach a function from this collection.
|
||||
|
||||
Args:
|
||||
name: The name of the attached function
|
||||
delete_output_collection: Whether to also delete the output collection. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
|
||||
Example:
|
||||
>>> success = collection.detach_function("my_function", delete_output_collection=True)
|
||||
"""
|
||||
return self._client.detach_function(
|
||||
name=name,
|
||||
input_collection_id=self.id,
|
||||
delete_output=delete_output_collection,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@@ -1021,6 +1021,7 @@ class CollectionCommon(Generic[ClientT]):
|
||||
return Search(
|
||||
where=search._where,
|
||||
rank=embedded_rank,
|
||||
group_by=search._group_by,
|
||||
limit=search._limit,
|
||||
select=search._select,
|
||||
)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -49,7 +49,7 @@ from chromadb.execution.expression.plan import Search
|
||||
import chromadb_rust_bindings
|
||||
|
||||
|
||||
from typing import Optional, Sequence, List, Dict, Any
|
||||
from typing import Optional, Sequence, List, Dict, Any, Tuple
|
||||
from overrides import override
|
||||
from uuid import UUID
|
||||
import json
|
||||
@@ -613,6 +613,20 @@ class RustBindingsAPI(ServerAPI):
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
|
||||
"The Rust bindings (embedded mode) do not support attached function operations."
|
||||
)
|
||||
|
||||
@override
|
||||
def get_attached_function(
|
||||
self,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
@@ -623,7 +637,8 @@ class RustBindingsAPI(ServerAPI):
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
|
||||
@@ -75,6 +75,7 @@ from typing import (
|
||||
Dict,
|
||||
Callable,
|
||||
TypeVar,
|
||||
Tuple,
|
||||
)
|
||||
from overrides import override
|
||||
from uuid import UUID, uuid4
|
||||
@@ -922,6 +923,20 @@ class SegmentAPI(ServerAPI):
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attached functions are not supported in the Segment API (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
|
||||
"The Segment API (embedded mode) does not support attached function operations."
|
||||
)
|
||||
|
||||
@override
|
||||
def get_attached_function(
|
||||
self,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Attached functions are not supported in the Segment API (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
@@ -932,7 +947,8 @@ class SegmentAPI(ServerAPI):
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from typing import ClassVar, Dict
|
||||
from typing import ClassVar, Dict, Optional
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from chromadb.api import ServerAPI
|
||||
from chromadb.api.base_http_client import BaseHTTPClient
|
||||
from chromadb.config import Settings, System
|
||||
from chromadb.telemetry.product import ProductTelemetryClient
|
||||
from chromadb.telemetry.product.events import ClientStartEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharedSystemClient:
|
||||
_identifier_to_system: ClassVar[Dict[str, System]] = {}
|
||||
@@ -94,3 +97,59 @@ class SharedSystemClient:
|
||||
def _submit_client_start_event(self) -> None:
|
||||
telemetry_client = self._system.instance(ProductTelemetryClient)
|
||||
telemetry_client.capture(ClientStartEvent())
|
||||
|
||||
@staticmethod
|
||||
def get_chroma_cloud_api_key_from_clients() -> Optional[str]:
|
||||
"""
|
||||
Try to extract api key from existing client instances by checking httpx session headers.
|
||||
|
||||
Requirements to pull api key:
|
||||
- must be a BaseHTTPClient instance (ignore RustBindingsAPI and SegmentAPI)
|
||||
- must have "api.trychroma.com" or "gcp.trychroma.com" in the _api_url (ignore local/self-hosted instances)
|
||||
- must have "x-chroma-token" or "X-Chroma-Token" in the headers
|
||||
|
||||
Returns:
|
||||
The first api key found, or None if no client instances have api keys set.
|
||||
"""
|
||||
|
||||
api_keys: list[str] = []
|
||||
systems_snapshot = list(SharedSystemClient._identifier_to_system.values())
|
||||
for system in systems_snapshot:
|
||||
try:
|
||||
server_api = system.instance(ServerAPI)
|
||||
|
||||
if not isinstance(server_api, BaseHTTPClient):
|
||||
# RustBindingsAPI and SegmentAPI don't have HTTP headers
|
||||
continue
|
||||
|
||||
# Only pull api key if the url contains the chroma cloud url
|
||||
api_url = server_api.get_api_url()
|
||||
if (
|
||||
"api.trychroma.com" not in api_url
|
||||
and "gcp.trychroma.com" not in api_url
|
||||
):
|
||||
continue
|
||||
|
||||
headers = server_api.get_request_headers()
|
||||
api_key = None
|
||||
for key, value in headers.items():
|
||||
if key.lower() == "x-chroma-token":
|
||||
api_key = value
|
||||
break
|
||||
|
||||
if api_key:
|
||||
api_keys.append(api_key)
|
||||
except Exception:
|
||||
# If we can't access the ServerAPI instance, continue to the next
|
||||
continue
|
||||
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
# log if multiple viable api keys found
|
||||
if len(api_keys) > 1:
|
||||
logger.info(
|
||||
f"Multiple Chroma Cloud clients found, using API key starting with {api_keys[0][:8]}..."
|
||||
)
|
||||
|
||||
return api_keys[0]
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import (
|
||||
get_args,
|
||||
TYPE_CHECKING,
|
||||
Final,
|
||||
Type,
|
||||
)
|
||||
from copy import deepcopy
|
||||
from typing_extensions import TypeAlias
|
||||
@@ -20,7 +21,8 @@ from numpy.typing import NDArray
|
||||
import numpy as np
|
||||
import warnings
|
||||
from typing_extensions import TypedDict, Protocol, runtime_checkable
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
import chromadb.errors as errors
|
||||
from chromadb.base_types import (
|
||||
@@ -52,6 +54,8 @@ import pybase64
|
||||
from functools import lru_cache
|
||||
import struct
|
||||
import math
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
# Re-export types from chromadb.types
|
||||
__all__ = [
|
||||
@@ -112,6 +116,9 @@ __all__ = [
|
||||
"EMBEDDING_KEY",
|
||||
"TYPE_KEY",
|
||||
"SPARSE_VECTOR_TYPE_VALUE",
|
||||
# CMEK Support
|
||||
"Cmek",
|
||||
"CmekProvider",
|
||||
# Space type
|
||||
"Space",
|
||||
# Embedding Functions
|
||||
@@ -213,7 +220,7 @@ def optional_base64_strings_to_embeddings(
|
||||
|
||||
|
||||
def normalize_embeddings(
|
||||
target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]]
|
||||
target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]],
|
||||
) -> Optional[Embeddings]:
|
||||
if target is None:
|
||||
return None
|
||||
@@ -448,7 +455,9 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None:
|
||||
)
|
||||
|
||||
zero_lengths = [
|
||||
key for key, lst in record_set.items() if lst is not None and len(lst) == 0 # type: ignore[arg-type]
|
||||
key
|
||||
for key, lst in record_set.items()
|
||||
if lst is not None and len(lst) == 0 # type: ignore[arg-type]
|
||||
]
|
||||
|
||||
if zero_lengths:
|
||||
@@ -456,7 +465,9 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None:
|
||||
|
||||
if len(set(lengths)) > 1:
|
||||
error_str = ", ".join(
|
||||
f"{key}: {len(lst)}" for key, lst in record_set.items() if lst is not None # type: ignore[arg-type]
|
||||
f"{key}: {len(lst)}"
|
||||
for key, lst in record_set.items()
|
||||
if lst is not None # type: ignore[arg-type]
|
||||
)
|
||||
raise ValueError(f"Unequal lengths for fields: {error_str}")
|
||||
|
||||
@@ -756,8 +767,7 @@ class EmbeddingFunction(Protocol[D]):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, input: D) -> Embeddings:
|
||||
...
|
||||
def __call__(self, input: D) -> Embeddings: ...
|
||||
|
||||
def embed_query(self, input: D) -> Embeddings:
|
||||
"""
|
||||
@@ -941,8 +951,7 @@ def validate_embedding_function(
|
||||
|
||||
|
||||
class DataLoader(Protocol[L]):
|
||||
def __call__(self, uris: URIs) -> L:
|
||||
...
|
||||
def __call__(self, uris: URIs) -> L: ...
|
||||
|
||||
|
||||
def validate_ids(ids: IDs) -> IDs:
|
||||
@@ -1063,7 +1072,7 @@ def serialize_metadata(metadata: Optional[Metadata]) -> Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
def deserialize_metadata(
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Deserialize metadata from transport, converting dicts with #type=sparse_vector to dataclass instances.
|
||||
|
||||
@@ -1399,8 +1408,7 @@ class SparseEmbeddingFunction(Protocol[D]):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, input: D) -> SparseVectors:
|
||||
...
|
||||
def __call__(self, input: D) -> SparseVectors: ...
|
||||
|
||||
def embed_query(self, input: D) -> SparseVectors:
|
||||
"""
|
||||
@@ -1493,15 +1501,57 @@ def validate_sparse_embedding_function(
|
||||
|
||||
|
||||
# Index Configuration Types for Collection Schema
|
||||
def _create_extra_fields_validator(valid_fields: list[str]) -> Any:
|
||||
"""Create a model validator that provides helpful error messages for invalid fields."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_extra_fields(cls: Type[BaseModel], data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
invalid_fields = [k for k in data.keys() if k not in valid_fields]
|
||||
if invalid_fields:
|
||||
invalid_fields_str = ", ".join(f"'{f}'" for f in invalid_fields)
|
||||
class_name = cls.__name__
|
||||
# Create a clear, actionable error message
|
||||
if len(invalid_fields) == 1:
|
||||
msg = (
|
||||
f"'{invalid_fields[0]}' is not a valid field for {class_name}. "
|
||||
)
|
||||
else:
|
||||
msg = f"Invalid fields for {class_name}: {invalid_fields_str}. "
|
||||
|
||||
raise PydanticCustomError(
|
||||
"invalid_field",
|
||||
msg,
|
||||
{"invalid_fields": invalid_fields},
|
||||
)
|
||||
return data
|
||||
|
||||
return validate_extra_fields
|
||||
|
||||
|
||||
class FtsIndexConfig(BaseModel):
|
||||
"""Configuration for Full-Text Search index. No parameters required."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HnswIndexConfig(BaseModel):
|
||||
"""Configuration for HNSW vector index."""
|
||||
|
||||
_validate_extra_fields = _create_extra_fields_validator(
|
||||
[
|
||||
"ef_construction",
|
||||
"max_neighbors",
|
||||
"ef_search",
|
||||
"num_threads",
|
||||
"batch_size",
|
||||
"sync_threshold",
|
||||
"resize_factor",
|
||||
]
|
||||
)
|
||||
|
||||
ef_construction: Optional[int] = None
|
||||
max_neighbors: Optional[int] = None
|
||||
ef_search: Optional[int] = None
|
||||
@@ -1514,6 +1564,27 @@ class HnswIndexConfig(BaseModel):
|
||||
class SpannIndexConfig(BaseModel):
|
||||
"""Configuration for SPANN vector index."""
|
||||
|
||||
_validate_extra_fields = _create_extra_fields_validator(
|
||||
[
|
||||
"search_nprobe",
|
||||
"search_rng_factor",
|
||||
"search_rng_epsilon",
|
||||
"nreplica_count",
|
||||
"write_nprobe",
|
||||
"write_rng_factor",
|
||||
"write_rng_epsilon",
|
||||
"split_threshold",
|
||||
"num_samples_kmeans",
|
||||
"initial_lambda",
|
||||
"reassign_neighbor_count",
|
||||
"merge_threshold",
|
||||
"num_centers_to_merge_to",
|
||||
"ef_construction",
|
||||
"ef_search",
|
||||
"max_neighbors",
|
||||
]
|
||||
)
|
||||
|
||||
search_nprobe: Optional[int] = None
|
||||
write_nprobe: Optional[int] = None
|
||||
ef_construction: Optional[int] = None
|
||||
@@ -1527,10 +1598,13 @@ class SpannIndexConfig(BaseModel):
|
||||
class VectorIndexConfig(BaseModel):
|
||||
"""Configuration for vector index with space, embedding function, and algorithm config."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
|
||||
|
||||
space: Optional[Space] = None
|
||||
embedding_function: Optional[Any] = DefaultEmbeddingFunction()
|
||||
source_key: Optional[str] = None # key to source the vector from (accepts str or Key)
|
||||
source_key: Optional[str] = (
|
||||
None # key to source the vector from (accepts str or Key)
|
||||
)
|
||||
hnsw: Optional[HnswIndexConfig] = None
|
||||
spann: Optional[SpannIndexConfig] = None
|
||||
|
||||
@@ -1542,6 +1616,7 @@ class VectorIndexConfig(BaseModel):
|
||||
return None
|
||||
# Import Key at runtime to avoid circular import
|
||||
from chromadb.execution.expression.operator import Key as KeyType
|
||||
|
||||
if isinstance(v, KeyType):
|
||||
v = v.name # Extract string from Key
|
||||
elif isinstance(v, str):
|
||||
@@ -1574,10 +1649,13 @@ class VectorIndexConfig(BaseModel):
|
||||
class SparseVectorIndexConfig(BaseModel):
|
||||
"""Configuration for sparse vector index."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
|
||||
|
||||
# TODO(Sanket): Change this to the appropriate sparse ef and use a default here.
|
||||
embedding_function: Optional[Any] = None
|
||||
source_key: Optional[str] = None # key to source the sparse vector from (accepts str or Key)
|
||||
source_key: Optional[str] = (
|
||||
None # key to source the sparse vector from (accepts str or Key)
|
||||
)
|
||||
bm25: Optional[bool] = None
|
||||
|
||||
@field_validator("source_key", mode="before")
|
||||
@@ -1588,6 +1666,7 @@ class SparseVectorIndexConfig(BaseModel):
|
||||
return None
|
||||
# Import Key at runtime to avoid circular import
|
||||
from chromadb.execution.expression.operator import Key as KeyType
|
||||
|
||||
if isinstance(v, KeyType):
|
||||
v = v.name # Extract string from Key
|
||||
elif isinstance(v, str):
|
||||
@@ -1622,24 +1701,32 @@ class SparseVectorIndexConfig(BaseModel):
|
||||
class StringInvertedIndexConfig(BaseModel):
|
||||
"""Configuration for string inverted index."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IntInvertedIndexConfig(BaseModel):
|
||||
"""Configuration for integer inverted index."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FloatInvertedIndexConfig(BaseModel):
|
||||
"""Configuration for float inverted index."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BoolInvertedIndexConfig(BaseModel):
|
||||
"""Configuration for boolean inverted index."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -1683,6 +1770,148 @@ TYPE_KEY: Final[str] = "#type"
|
||||
SPARSE_VECTOR_TYPE_VALUE: Final[str] = "sparse_vector"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CMEK (Customer-Managed Encryption Key) Support
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CmekProvider(str, Enum):
|
||||
"""Supported cloud providers for customer-managed encryption keys.
|
||||
|
||||
Currently only Google Cloud Platform (GCP) is supported.
|
||||
"""
|
||||
|
||||
GCP = "gcp"
|
||||
|
||||
|
||||
# Regex pattern for validating GCP KMS resource names
|
||||
_CMEK_GCP_PATTERN = re.compile(r"^projects/.+/locations/.+/keyRings/.+/cryptoKeys/.+$")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Cmek:
|
||||
"""Customer-managed encryption key (CMEK) for collection data encryption.
|
||||
|
||||
CMEK allows you to use your own encryption keys managed by cloud providers'
|
||||
key management services (KMS) instead of default provider-managed keys. This
|
||||
gives you greater control over key lifecycle, access policies, and audit logging.
|
||||
|
||||
Attributes:
|
||||
provider: The cloud provider (currently only 'gcp' supported)
|
||||
resource: The provider-specific resource identifier for the encryption key
|
||||
|
||||
Example:
|
||||
>>> # Create a CMEK for GCP
|
||||
>>> cmek = Cmek.gcp(
|
||||
... "projects/my-project/locations/us-central1/"
|
||||
... "keyRings/my-ring/cryptoKeys/my-key"
|
||||
... )
|
||||
>>>
|
||||
>>> # Validate the resource name format
|
||||
>>> if cmek.validate_pattern():
|
||||
... print("Valid CMEK format")
|
||||
|
||||
Note:
|
||||
Pattern validation only checks format correctness. It does not verify
|
||||
that the key exists or is accessible. Key permissions and access control
|
||||
must be configured separately in your cloud provider's console.
|
||||
"""
|
||||
|
||||
provider: CmekProvider
|
||||
resource: str
|
||||
|
||||
@classmethod
|
||||
def gcp(cls, resource: str) -> "Cmek":
|
||||
"""Create a CMEK instance for Google Cloud Platform.
|
||||
|
||||
Args:
|
||||
resource: GCP Cloud KMS resource name in the format:
|
||||
projects/{project-id}/locations/{location}/keyRings/{key-ring}/cryptoKeys/{key-name}
|
||||
|
||||
Example: "projects/my-project/locations/us-central1/keyRings/my-ring/cryptoKeys/my-key"
|
||||
|
||||
Returns:
|
||||
Cmek: A new CMEK instance configured for GCP
|
||||
|
||||
Raises:
|
||||
ValueError: If the resource format is invalid (when validate_pattern() is called)
|
||||
|
||||
Example:
|
||||
>>> cmek = Cmek.gcp(
|
||||
... "projects/my-project/locations/us-central1/"
|
||||
... "keyRings/my-ring/cryptoKeys/my-key"
|
||||
... )
|
||||
"""
|
||||
return cls(provider=CmekProvider.GCP, resource=resource)
|
||||
|
||||
def validate_pattern(self) -> bool:
|
||||
"""Validate the CMEK resource name format.
|
||||
|
||||
Validates that the resource name matches the expected pattern for the
|
||||
provider. This is a format check only and does not verify that the key
|
||||
exists or that you have access to it.
|
||||
|
||||
For GCP, the expected format is:
|
||||
projects/{project}/locations/{location}/keyRings/{keyRing}/cryptoKeys/{cryptoKey}
|
||||
|
||||
Returns:
|
||||
bool: True if the resource name format is valid, False otherwise
|
||||
|
||||
Example:
|
||||
>>> cmek = Cmek.gcp("projects/p/locations/l/keyRings/r/cryptoKeys/k")
|
||||
>>> cmek.validate_pattern() # Returns True
|
||||
>>>
|
||||
>>> bad_cmek = Cmek.gcp("invalid-format")
|
||||
>>> bad_cmek.validate_pattern() # Returns False
|
||||
"""
|
||||
if self.provider == CmekProvider.GCP:
|
||||
return _CMEK_GCP_PATTERN.match(self.resource) is not None
|
||||
return False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize CMEK to dictionary format for API transport.
|
||||
|
||||
Returns:
|
||||
Dict containing the provider variant and resource identifier.
|
||||
|
||||
Example:
|
||||
>>> cmek = Cmek.gcp("projects/p/locations/l/keyRings/r/cryptoKeys/k")
|
||||
>>> cmek.to_dict()
|
||||
{'gcp': 'projects/p/locations/l/keyRings/r/cryptoKeys/k'}
|
||||
"""
|
||||
if self.provider == CmekProvider.GCP:
|
||||
return {"gcp": self.resource}
|
||||
# Unreachable with current providers, but future-proof
|
||||
raise ValueError(f"Unknown CMEK provider: {self.provider}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Cmek":
|
||||
"""Deserialize CMEK from dictionary format.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing provider variant and resource.
|
||||
Format matches Rust serde enum serialization with snake_case.
|
||||
|
||||
Returns:
|
||||
Cmek: Deserialized CMEK instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is unsupported or data is malformed
|
||||
|
||||
Example:
|
||||
>>> data = {'gcp': 'projects/p/locations/l/keyRings/r/cryptoKeys/k'}
|
||||
>>> cmek = Cmek.from_dict(data)
|
||||
"""
|
||||
if "gcp" in data:
|
||||
resource = data["gcp"]
|
||||
if not isinstance(resource, str):
|
||||
raise ValueError(
|
||||
f"CMEK gcp resource must be a string, got: {type(resource)}"
|
||||
)
|
||||
return cls.gcp(resource)
|
||||
raise ValueError(f"Unsupported or missing CMEK provider in data: {data}")
|
||||
|
||||
|
||||
# Index Type Classes
|
||||
|
||||
|
||||
@@ -1774,24 +2003,40 @@ class ValueTypes:
|
||||
|
||||
@dataclass
|
||||
class Schema:
|
||||
"""Collection schema for configuring indexes and encryption.
|
||||
|
||||
The schema controls how data is indexed and can optionally specify
|
||||
customer-managed encryption keys (CMEK) for data at rest.
|
||||
|
||||
Attributes:
|
||||
defaults: Default index configurations for each value type
|
||||
keys: Key-specific index overrides
|
||||
cmek: Optional customer-managed encryption key for collection data
|
||||
"""
|
||||
|
||||
defaults: ValueTypes
|
||||
keys: Dict[str, ValueTypes]
|
||||
cmek: Optional[Cmek] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Initialize the dataclass fields first
|
||||
self.defaults = ValueTypes()
|
||||
self.keys: Dict[str, ValueTypes] = {}
|
||||
self.cmek: Optional[Cmek] = None
|
||||
|
||||
# Populate with sensible defaults automatically
|
||||
self._initialize_defaults()
|
||||
self._initialize_keys()
|
||||
|
||||
def create_index(
|
||||
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
|
||||
self,
|
||||
config: Optional[IndexConfig] = None,
|
||||
key: Optional[Union[str, "Key"]] = None,
|
||||
) -> "Schema":
|
||||
"""Create an index configuration."""
|
||||
# Convert Key to string if provided
|
||||
from chromadb.execution.expression.operator import Key as KeyType
|
||||
|
||||
if key is not None and isinstance(key, KeyType):
|
||||
key = key.name
|
||||
|
||||
@@ -1869,11 +2114,14 @@ class Schema:
|
||||
return self
|
||||
|
||||
def delete_index(
|
||||
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
|
||||
self,
|
||||
config: Optional[IndexConfig] = None,
|
||||
key: Optional[Union[str, "Key"]] = None,
|
||||
) -> "Schema":
|
||||
"""Disable an index configuration (set enabled=False)."""
|
||||
# Convert Key to string if provided
|
||||
from chromadb.execution.expression.operator import Key as KeyType
|
||||
|
||||
if key is not None and isinstance(key, KeyType):
|
||||
key = key.name
|
||||
|
||||
@@ -1930,6 +2178,23 @@ class Schema:
|
||||
|
||||
return self
|
||||
|
||||
def set_cmek(self, cmek: Optional[Cmek]) -> "Schema":
|
||||
"""Set customer-managed encryption key for the collection (fluent interface).
|
||||
|
||||
Args:
|
||||
cmek: Customer-managed encryption key configuration, or None to remove CMEK
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
|
||||
Example:
|
||||
>>> schema = Schema().set_cmek(
|
||||
... Cmek.gcp("projects/my-project/locations/us/keyRings/my-ring/cryptoKeys/my-key")
|
||||
... )
|
||||
"""
|
||||
self.cmek = cmek
|
||||
return self
|
||||
|
||||
def _get_config_class_name(self, config: IndexConfig) -> str:
|
||||
"""Get the class name for a config."""
|
||||
return config.__class__.__name__
|
||||
@@ -1943,11 +2208,15 @@ class Schema:
|
||||
"""
|
||||
# Update the config in defaults (preserve enabled=False)
|
||||
current_enabled = self.defaults.float_list.vector_index.enabled # type: ignore[union-attr]
|
||||
self.defaults.float_list.vector_index = VectorIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
|
||||
self.defaults.float_list.vector_index = VectorIndexType(
|
||||
enabled=current_enabled, config=config
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
# Update the config on #embedding key (preserve enabled=True and source_key="#document")
|
||||
current_enabled = self.keys[EMBEDDING_KEY].float_list.vector_index.enabled # type: ignore[union-attr]
|
||||
current_source_key = self.keys[EMBEDDING_KEY].float_list.vector_index.config.source_key # type: ignore[union-attr]
|
||||
current_source_key = self.keys[
|
||||
EMBEDDING_KEY
|
||||
].float_list.vector_index.config.source_key # type: ignore[union-attr]
|
||||
|
||||
# Create a new config with user settings but preserve the original source_key
|
||||
embedding_config = VectorIndexConfig(
|
||||
@@ -1957,7 +2226,9 @@ class Schema:
|
||||
spann=config.spann,
|
||||
source_key=current_source_key, # Preserve original source_key (should be "#document")
|
||||
)
|
||||
self.keys[EMBEDDING_KEY].float_list.vector_index = VectorIndexType(enabled=current_enabled, config=embedding_config) # type: ignore[union-attr]
|
||||
self.keys[EMBEDDING_KEY].float_list.vector_index = VectorIndexType(
|
||||
enabled=current_enabled, config=embedding_config
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
def _set_fts_index_config(self, config: FtsIndexConfig) -> None:
|
||||
"""
|
||||
@@ -1967,11 +2238,15 @@ class Schema:
|
||||
"""
|
||||
# Update the config in defaults (preserve enabled=False)
|
||||
current_enabled = self.defaults.string.fts_index.enabled # type: ignore[union-attr]
|
||||
self.defaults.string.fts_index = FtsIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
|
||||
self.defaults.string.fts_index = FtsIndexType(
|
||||
enabled=current_enabled, config=config
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
# Update the config on #document key (preserve enabled=True)
|
||||
current_enabled = self.keys[DOCUMENT_KEY].string.fts_index.enabled # type: ignore[union-attr]
|
||||
self.keys[DOCUMENT_KEY].string.fts_index = FtsIndexType(enabled=current_enabled, config=config) # type: ignore[union-attr]
|
||||
self.keys[DOCUMENT_KEY].string.fts_index = FtsIndexType(
|
||||
enabled=current_enabled, config=config
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
def _set_index_in_defaults(self, config: IndexConfig, enabled: bool) -> None:
|
||||
"""Set an index configuration in the defaults."""
|
||||
@@ -2071,37 +2346,51 @@ class Schema:
|
||||
if config_name == "FtsIndexConfig":
|
||||
if self.keys[key].string is None:
|
||||
self.keys[key].string = StringValueType()
|
||||
self.keys[key].string.fts_index = FtsIndexType(enabled=enabled, config=cast(FtsIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].string.fts_index = FtsIndexType(
|
||||
enabled=enabled, config=cast(FtsIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "StringInvertedIndexConfig":
|
||||
if self.keys[key].string is None:
|
||||
self.keys[key].string = StringValueType()
|
||||
self.keys[key].string.string_inverted_index = StringInvertedIndexType(enabled=enabled, config=cast(StringInvertedIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].string.string_inverted_index = StringInvertedIndexType(
|
||||
enabled=enabled, config=cast(StringInvertedIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "VectorIndexConfig":
|
||||
if self.keys[key].float_list is None:
|
||||
self.keys[key].float_list = FloatListValueType()
|
||||
self.keys[key].float_list.vector_index = VectorIndexType(enabled=enabled, config=cast(VectorIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].float_list.vector_index = VectorIndexType(
|
||||
enabled=enabled, config=cast(VectorIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "SparseVectorIndexConfig":
|
||||
if self.keys[key].sparse_vector is None:
|
||||
self.keys[key].sparse_vector = SparseVectorValueType()
|
||||
self.keys[key].sparse_vector.sparse_vector_index = SparseVectorIndexType(enabled=enabled, config=cast(SparseVectorIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].sparse_vector.sparse_vector_index = SparseVectorIndexType(
|
||||
enabled=enabled, config=cast(SparseVectorIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "IntInvertedIndexConfig":
|
||||
if self.keys[key].int_value is None:
|
||||
self.keys[key].int_value = IntValueType()
|
||||
self.keys[key].int_value.int_inverted_index = IntInvertedIndexType(enabled=enabled, config=cast(IntInvertedIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].int_value.int_inverted_index = IntInvertedIndexType(
|
||||
enabled=enabled, config=cast(IntInvertedIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "FloatInvertedIndexConfig":
|
||||
if self.keys[key].float_value is None:
|
||||
self.keys[key].float_value = FloatValueType()
|
||||
self.keys[key].float_value.float_inverted_index = FloatInvertedIndexType(enabled=enabled, config=cast(FloatInvertedIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].float_value.float_inverted_index = FloatInvertedIndexType(
|
||||
enabled=enabled, config=cast(FloatInvertedIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
elif config_name == "BoolInvertedIndexConfig":
|
||||
if self.keys[key].boolean is None:
|
||||
self.keys[key].boolean = BoolValueType()
|
||||
self.keys[key].boolean.bool_inverted_index = BoolInvertedIndexType(enabled=enabled, config=cast(BoolInvertedIndexConfig, config)) # type: ignore[union-attr]
|
||||
self.keys[key].boolean.bool_inverted_index = BoolInvertedIndexType(
|
||||
enabled=enabled, config=cast(BoolInvertedIndexConfig, config)
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
def _enable_all_indexes_for_key(self, key: str) -> None:
|
||||
"""Enable all possible index types for a specific key."""
|
||||
@@ -2249,7 +2538,13 @@ class Schema:
|
||||
for key, value_types in self.keys.items():
|
||||
keys_json[key] = self._serialize_value_types(value_types)
|
||||
|
||||
return {"defaults": defaults_json, "keys": keys_json}
|
||||
result: Dict[str, Any] = {"defaults": defaults_json, "keys": keys_json}
|
||||
|
||||
# Add CMEK if present
|
||||
if self.cmek is not None:
|
||||
result["cmek"] = self.cmek.to_dict()
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def deserialize_from_json(cls, json_data: Dict[str, Any]) -> "Schema":
|
||||
@@ -2263,6 +2558,11 @@ class Schema:
|
||||
for key, value_types_json in json_data.get("keys", {}).items():
|
||||
instance.keys[key] = cls._deserialize_value_types(value_types_json)
|
||||
|
||||
# Deserialize CMEK if present
|
||||
instance.cmek = None
|
||||
if "cmek" in json_data and json_data["cmek"] is not None:
|
||||
instance.cmek = Cmek.from_dict(json_data["cmek"])
|
||||
|
||||
return instance
|
||||
|
||||
def _serialize_value_types(self, value_types: ValueTypes) -> Dict[str, Any]:
|
||||
@@ -2410,6 +2710,10 @@ class Schema:
|
||||
if embedding_func.is_legacy():
|
||||
config_dict["embedding_function"] = {"type": "legacy"}
|
||||
else:
|
||||
if hasattr(embedding_func, "validate_config"):
|
||||
embedding_func.validate_config(
|
||||
embedding_func.get_config()
|
||||
)
|
||||
config_dict["embedding_function"] = {
|
||||
"name": embedding_func.name(),
|
||||
"type": "known",
|
||||
@@ -2439,6 +2743,8 @@ class Schema:
|
||||
config_dict["embedding_function"] = {"type": "unknown"}
|
||||
else:
|
||||
embedding_func = cast(SparseEmbeddingFunction, embedding_func) # type: ignore
|
||||
if hasattr(embedding_func, "validate_config"):
|
||||
embedding_func.validate_config(embedding_func.get_config())
|
||||
config_dict["embedding_function"] = {
|
||||
"name": embedding_func.name(),
|
||||
"type": "known",
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -16,15 +16,18 @@ class SparseVector:
|
||||
Attributes:
|
||||
indices: List of dimension indices (must be non-negative integers, sorted in strictly ascending order)
|
||||
values: List of values corresponding to each index (floats)
|
||||
labels: Optional list of string labels corresponding to each index
|
||||
|
||||
Note:
|
||||
- Indices must be sorted in strictly ascending order (no duplicates)
|
||||
- Indices and values must have the same length
|
||||
- If labels is provided, it must have the same length as indices and values
|
||||
- All validations are performed in __post_init__
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
values: List[float]
|
||||
labels: Optional[List[str]] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate the sparse vector structure."""
|
||||
@@ -44,6 +47,17 @@ class SparseVector:
|
||||
f"got {len(self.indices)} indices and {len(self.values)} values"
|
||||
)
|
||||
|
||||
if self.labels is not None:
|
||||
if not isinstance(self.labels, list):
|
||||
raise ValueError(
|
||||
f"Expected SparseVector labels to be a list, got {type(self.labels).__name__}"
|
||||
)
|
||||
if len(self.labels) != len(self.indices):
|
||||
raise ValueError(
|
||||
f"SparseVector labels must have the same length as indices and values, "
|
||||
f"got {len(self.labels)} labels, {len(self.indices)} indices"
|
||||
)
|
||||
|
||||
for i, idx in enumerate(self.indices):
|
||||
if not isinstance(idx, int):
|
||||
raise ValueError(
|
||||
@@ -70,21 +84,36 @@ class SparseVector:
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to transport format with type tag."""
|
||||
return {
|
||||
"""Serialize to transport format with type tag.
|
||||
|
||||
Note: Uses 'tokens' as the wire format key name for compatibility
|
||||
with the protobuf schema, even though the Python attribute is 'labels'.
|
||||
"""
|
||||
result = {
|
||||
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
|
||||
"indices": self.indices,
|
||||
"values": self.values,
|
||||
}
|
||||
if self.labels is not None:
|
||||
result["tokens"] = self.labels # Wire format uses 'tokens'
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, Any]) -> "SparseVector":
|
||||
"""Deserialize from transport format (strict - requires #type field)."""
|
||||
"""Deserialize from transport format (strict - requires #type field).
|
||||
|
||||
Note: Reads from 'tokens' key in the wire format for compatibility
|
||||
with the protobuf schema, mapping it to the 'labels' attribute.
|
||||
"""
|
||||
if d.get(TYPE_KEY) != SPARSE_VECTOR_TYPE_VALUE:
|
||||
raise ValueError(
|
||||
f"Expected {TYPE_KEY}='{SPARSE_VECTOR_TYPE_VALUE}', got {d.get(TYPE_KEY)}"
|
||||
)
|
||||
return cls(indices=d["indices"], values=d["values"])
|
||||
return cls(
|
||||
indices=d["indices"],
|
||||
values=d["values"],
|
||||
labels=d.get("tokens") # Wire format uses 'tokens'
|
||||
)
|
||||
|
||||
|
||||
Metadata = Mapping[str, Optional[Union[str, int, float, bool, SparseVector]]]
|
||||
|
||||
@@ -154,6 +154,10 @@ class Settings(BaseSettings): # type: ignore
|
||||
# eg ["http://localhost:8000"]
|
||||
chroma_server_cors_allow_origins: List[str] = []
|
||||
|
||||
chroma_http_keepalive_secs: Optional[float] = 40.0
|
||||
chroma_http_max_connections: Optional[int] = None
|
||||
chroma_http_max_keepalive_connections: Optional[int] = None
|
||||
|
||||
# ==================
|
||||
# Server config
|
||||
# ==================
|
||||
|
||||
Binary file not shown.
@@ -39,6 +39,11 @@ from chromadb.execution.expression.operator import (
|
||||
Sub,
|
||||
Sum,
|
||||
Val,
|
||||
# GroupBy and Aggregate expressions
|
||||
Aggregate,
|
||||
MinK,
|
||||
MaxK,
|
||||
GroupBy,
|
||||
)
|
||||
|
||||
from chromadb.execution.expression.plan import (
|
||||
@@ -87,4 +92,9 @@ __all__ = [
|
||||
"Sub",
|
||||
"Sum",
|
||||
"Val",
|
||||
# GroupBy and Aggregate expressions
|
||||
"Aggregate",
|
||||
"MinK",
|
||||
"MaxK",
|
||||
"GroupBy",
|
||||
]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Set, Any, Union
|
||||
from typing import Optional, List, Dict, Set, Any, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
@@ -7,9 +7,11 @@ from chromadb.api.types import (
|
||||
Embeddings,
|
||||
IDs,
|
||||
Include,
|
||||
OneOrMany,
|
||||
SparseVector,
|
||||
TYPE_KEY,
|
||||
SPARSE_VECTOR_TYPE_VALUE,
|
||||
maybe_cast_one_to_many,
|
||||
normalize_embeddings,
|
||||
validate_embeddings,
|
||||
)
|
||||
@@ -1024,7 +1026,7 @@ class Knn(Rank):
|
||||
- A dense vector (list or numpy array)
|
||||
- A sparse vector (SparseVector dict)
|
||||
key: The embedding key to search against. Can be:
|
||||
- "#embedding" (default) - searches the main embedding field
|
||||
- Key.EMBEDDING (default) - searches the main embedding field
|
||||
- A metadata field name (e.g., "my_custom_field") - searches that metadata field
|
||||
limit: Maximum number of results to consider (default: 16)
|
||||
default: Default score for records not in KNN results (default: None)
|
||||
@@ -1054,7 +1056,7 @@ class Knn(Rank):
|
||||
"NDArray[np.float64]",
|
||||
"NDArray[np.int32]",
|
||||
]
|
||||
key: str = "#embedding"
|
||||
key: Union[Key, str] = K.EMBEDDING
|
||||
limit: int = 16
|
||||
default: Optional[float] = None
|
||||
return_rank: bool = False
|
||||
@@ -1069,8 +1071,12 @@ class Knn(Rank):
|
||||
# Convert numpy array to list
|
||||
query_value = query_value.tolist()
|
||||
|
||||
key_value = self.key
|
||||
if isinstance(key_value, Key):
|
||||
key_value = key_value.name
|
||||
|
||||
# Build result dict - only include non-default values to keep JSON clean
|
||||
result = {"query": query_value, "key": self.key, "limit": self.limit}
|
||||
result = {"query": query_value, "key": key_value, "limit": self.limit}
|
||||
|
||||
# Only include optional fields if they're set to non-default values
|
||||
if self.default is not None:
|
||||
@@ -1291,3 +1297,216 @@ class Select:
|
||||
|
||||
# Convert to set while preserving the Key instances
|
||||
return Select(keys=set(key_list))
|
||||
|
||||
|
||||
# GroupBy and Aggregate types for grouping search results
|
||||
|
||||
|
||||
def _keys_to_strings(keys: OneOrMany[Union[Key, str]]) -> List[str]:
|
||||
"""Convert OneOrMany[Key|str] to List[str] for serialization."""
|
||||
keys_list = cast(List[Union[Key, str]], maybe_cast_one_to_many(keys))
|
||||
return [k.name if isinstance(k, Key) else k for k in keys_list]
|
||||
|
||||
|
||||
def _strings_to_keys(keys: Union[List[Any], tuple[Any, ...]]) -> List[Union[Key, str]]:
|
||||
"""Convert List[str] to List[Key] for deserialization."""
|
||||
return [Key(k) if isinstance(k, str) else k for k in keys]
|
||||
|
||||
|
||||
def _parse_k_aggregate(
|
||||
op: str, data: Dict[str, Any]
|
||||
) -> tuple[List[Union[Key, str]], int]:
|
||||
"""Parse common fields for MinK/MaxK from dict.
|
||||
|
||||
Args:
|
||||
op: The operator name (e.g., "$min_k" or "$max_k")
|
||||
data: The dict containing the operator
|
||||
|
||||
Returns:
|
||||
Tuple of (keys, k) where keys is List[Union[Key, str]] and k is int
|
||||
|
||||
Raises:
|
||||
TypeError: If data types are invalid
|
||||
ValueError: If required fields are missing or invalid
|
||||
"""
|
||||
agg_data = data[op]
|
||||
if not isinstance(agg_data, dict):
|
||||
raise TypeError(f"{op} requires a dict, got {type(agg_data).__name__}")
|
||||
if "keys" not in agg_data:
|
||||
raise ValueError(f"{op} requires 'keys' field")
|
||||
if "k" not in agg_data:
|
||||
raise ValueError(f"{op} requires 'k' field")
|
||||
|
||||
keys = agg_data["keys"]
|
||||
if not isinstance(keys, (list, tuple)):
|
||||
raise TypeError(f"{op} keys must be a list, got {type(keys).__name__}")
|
||||
if not keys:
|
||||
raise ValueError(f"{op} keys cannot be empty")
|
||||
|
||||
k = agg_data["k"]
|
||||
if not isinstance(k, int):
|
||||
raise TypeError(f"{op} k must be an integer, got {type(k).__name__}")
|
||||
if k <= 0:
|
||||
raise ValueError(f"{op} k must be positive, got {k}")
|
||||
|
||||
return _strings_to_keys(keys), k
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aggregate:
|
||||
"""Base class for aggregation expressions within groups.
|
||||
|
||||
Aggregations determine which records to keep from each group:
|
||||
- MinK: Keep k records with minimum values (ascending order)
|
||||
- MaxK: Keep k records with maximum values (descending order)
|
||||
|
||||
Examples:
|
||||
# Keep top 3 by score per group (single key)
|
||||
MinK(keys=Key.SCORE, k=3)
|
||||
|
||||
# Keep top 5 by priority, then score as tiebreaker (multiple keys)
|
||||
MinK(keys=[Key("priority"), Key.SCORE], k=5)
|
||||
|
||||
# Keep bottom 2 by score per group
|
||||
MaxK(keys=Key.SCORE, k=2)
|
||||
"""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the Aggregate expression to a dictionary for JSON serialization"""
|
||||
raise NotImplementedError("Subclasses must implement to_dict()")
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> "Aggregate":
|
||||
"""Create Aggregate expression from dictionary.
|
||||
|
||||
Supports:
|
||||
- {"$min_k": {"keys": [...], "k": n}} -> MinK(keys=[...], k=n)
|
||||
- {"$max_k": {"keys": [...], "k": n}} -> MaxK(keys=[...], k=n)
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected dict for Aggregate, got {type(data).__name__}")
|
||||
|
||||
if not data:
|
||||
raise ValueError("Aggregate dict cannot be empty")
|
||||
|
||||
if len(data) != 1:
|
||||
raise ValueError(
|
||||
f"Aggregate dict must contain exactly one operator, got {len(data)}"
|
||||
)
|
||||
|
||||
op = next(iter(data.keys()))
|
||||
|
||||
if op == "$min_k":
|
||||
keys, k = _parse_k_aggregate(op, data)
|
||||
return MinK(keys=keys, k=k)
|
||||
elif op == "$max_k":
|
||||
keys, k = _parse_k_aggregate(op, data)
|
||||
return MaxK(keys=keys, k=k)
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregate operator: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinK(Aggregate):
|
||||
"""Keep k records with minimum aggregate key values per group"""
|
||||
|
||||
keys: OneOrMany[Union[Key, str]]
|
||||
k: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"$min_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaxK(Aggregate):
|
||||
"""Keep k records with maximum aggregate key values per group"""
|
||||
|
||||
keys: OneOrMany[Union[Key, str]]
|
||||
k: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"$max_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupBy:
|
||||
"""Group results by metadata keys and aggregate within each group.
|
||||
|
||||
Groups search results by one or more metadata fields, then applies an
|
||||
aggregation (MinK or MaxK) to select records within each group.
|
||||
The final output is flattened and sorted by score.
|
||||
|
||||
Args:
|
||||
keys: Metadata key(s) to group by. Can be a single key or a list of keys.
|
||||
E.g., Key("category") or [Key("category"), Key("author")]
|
||||
aggregate: Aggregation to apply within each group (MinK or MaxK)
|
||||
|
||||
Note: Both keys and aggregate must be specified together.
|
||||
|
||||
Examples:
|
||||
# Top 3 documents per category (single key)
|
||||
GroupBy(
|
||||
keys=Key("category"),
|
||||
aggregate=MinK(keys=Key.SCORE, k=3)
|
||||
)
|
||||
|
||||
# Top 2 per (year, category) combination (multiple keys)
|
||||
GroupBy(
|
||||
keys=[Key("year"), Key("category")],
|
||||
aggregate=MinK(keys=Key.SCORE, k=2)
|
||||
)
|
||||
|
||||
# Top 1 per category by priority, score as tiebreaker
|
||||
GroupBy(
|
||||
keys=Key("category"),
|
||||
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=1)
|
||||
)
|
||||
"""
|
||||
|
||||
keys: OneOrMany[Union[Key, str]] = field(default_factory=list)
|
||||
aggregate: Optional[Aggregate] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the GroupBy to a dictionary for JSON serialization"""
|
||||
# Default GroupBy (no keys, no aggregate) serializes to {}
|
||||
if not self.keys or self.aggregate is None:
|
||||
return {}
|
||||
result: Dict[str, Any] = {"keys": _keys_to_strings(self.keys)}
|
||||
result["aggregate"] = self.aggregate.to_dict()
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> "GroupBy":
|
||||
"""Create GroupBy from dictionary.
|
||||
|
||||
Examples:
|
||||
- {} -> GroupBy() (default, no grouping)
|
||||
- {"keys": ["category"], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected dict for GroupBy, got {type(data).__name__}")
|
||||
|
||||
# Empty dict returns default GroupBy (no grouping)
|
||||
if not data:
|
||||
return GroupBy()
|
||||
|
||||
# Non-empty dict requires keys and aggregate
|
||||
if "keys" not in data:
|
||||
raise ValueError("GroupBy requires 'keys' field")
|
||||
if "aggregate" not in data:
|
||||
raise ValueError("GroupBy requires 'aggregate' field")
|
||||
|
||||
keys = data["keys"]
|
||||
if not isinstance(keys, (list, tuple)):
|
||||
raise TypeError(f"GroupBy keys must be a list, got {type(keys).__name__}")
|
||||
if not keys:
|
||||
raise ValueError("GroupBy keys cannot be empty")
|
||||
|
||||
aggregate_data = data["aggregate"]
|
||||
if not isinstance(aggregate_data, dict):
|
||||
raise TypeError(
|
||||
f"GroupBy aggregate must be a dict, got {type(aggregate_data).__name__}"
|
||||
)
|
||||
aggregate = Aggregate.from_dict(aggregate_data)
|
||||
|
||||
return GroupBy(keys=_strings_to_keys(keys), aggregate=aggregate)
|
||||
|
||||
@@ -4,12 +4,12 @@ from typing import List, Dict, Any, Union, Set, Optional
|
||||
from chromadb.execution.expression.operator import (
|
||||
KNN,
|
||||
Filter,
|
||||
GroupBy,
|
||||
Limit,
|
||||
Projection,
|
||||
Scan,
|
||||
Rank,
|
||||
Select,
|
||||
Val,
|
||||
Where,
|
||||
Key,
|
||||
)
|
||||
@@ -77,9 +77,18 @@ class Search:
|
||||
Combined with metadata filtering:
|
||||
Search().where((Key.ID.is_in(["id1", "id2"])) & (Key("status") == "active"))
|
||||
|
||||
With group_by:
|
||||
(Search()
|
||||
.rank(Knn(query=[0.1, 0.2]))
|
||||
.group_by(GroupBy(
|
||||
keys=[Key("category")],
|
||||
aggregate=MinK(keys=[Key.SCORE], k=3)
|
||||
)))
|
||||
|
||||
Empty Search() is valid and will use defaults:
|
||||
- where: None (no filtering)
|
||||
- rank: None (no ranking - results ordered by default order)
|
||||
- group_by: None (no grouping)
|
||||
- limit: No limit
|
||||
- select: Empty selection
|
||||
"""
|
||||
@@ -88,6 +97,7 @@ class Search:
|
||||
self,
|
||||
where: Optional[Union[Where, Dict[str, Any]]] = None,
|
||||
rank: Optional[Union[Rank, Dict[str, Any]]] = None,
|
||||
group_by: Optional[Union[GroupBy, Dict[str, Any]]] = None,
|
||||
limit: Optional[Union[Limit, Dict[str, Any], int]] = None,
|
||||
select: Optional[Union[Select, Dict[str, Any], List[str], Set[str]]] = None,
|
||||
):
|
||||
@@ -99,11 +109,13 @@ class Search:
|
||||
rank: Rank expression or dict for scoring (defaults to None - no ranking)
|
||||
Dict will be converted using Rank.from_dict()
|
||||
Note: Primitive numbers are not accepted - use {"$val": number} for constant ranks
|
||||
group_by: GroupBy configuration for grouping and aggregating results (defaults to None)
|
||||
Dict will be converted using GroupBy.from_dict()
|
||||
limit: Limit configuration for pagination (defaults to no limit)
|
||||
Can be a Limit object, a dict for Limit.from_dict(), or an int
|
||||
When passing an int, it creates Limit(limit=value, offset=0)
|
||||
select: Select configuration for keys (defaults to empty selection)
|
||||
Can be a Select object, a dict for Select.from_dict(),
|
||||
Can be a Select object, a dict for Select.from_dict(),
|
||||
or a list/set of strings (e.g., ["#document", "#score"])
|
||||
"""
|
||||
# Handle where parameter
|
||||
@@ -117,7 +129,7 @@ class Search:
|
||||
raise TypeError(
|
||||
f"where must be a Where object, dict, or None, got {type(where).__name__}"
|
||||
)
|
||||
|
||||
|
||||
# Handle rank parameter
|
||||
if rank is None:
|
||||
self._rank = None
|
||||
@@ -129,7 +141,19 @@ class Search:
|
||||
raise TypeError(
|
||||
f"rank must be a Rank object, dict, or None, got {type(rank).__name__}"
|
||||
)
|
||||
|
||||
|
||||
# Handle group_by parameter
|
||||
if group_by is None:
|
||||
self._group_by = GroupBy()
|
||||
elif isinstance(group_by, GroupBy):
|
||||
self._group_by = group_by
|
||||
elif isinstance(group_by, dict):
|
||||
self._group_by = GroupBy.from_dict(group_by)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"group_by must be a GroupBy object, dict, or None, got {type(group_by).__name__}"
|
||||
)
|
||||
|
||||
# Handle limit parameter
|
||||
if limit is None:
|
||||
self._limit = Limit()
|
||||
@@ -143,7 +167,7 @@ class Search:
|
||||
raise TypeError(
|
||||
f"limit must be a Limit object, dict, int, or None, got {type(limit).__name__}"
|
||||
)
|
||||
|
||||
|
||||
# Handle select parameter
|
||||
if select is None:
|
||||
self._select = Select()
|
||||
@@ -164,6 +188,7 @@ class Search:
|
||||
return {
|
||||
"filter": self._where.to_dict() if self._where is not None else None,
|
||||
"rank": self._rank.to_dict() if self._rank is not None else None,
|
||||
"group_by": self._group_by.to_dict(),
|
||||
"limit": self._limit.to_dict(),
|
||||
"select": self._select.to_dict(),
|
||||
}
|
||||
@@ -173,7 +198,11 @@ class Search:
|
||||
"""Select all predefined keys (document, embedding, metadata, score)"""
|
||||
new_select = Select(keys={Key.DOCUMENT, Key.EMBEDDING, Key.METADATA, Key.SCORE})
|
||||
return Search(
|
||||
where=self._where, rank=self._rank, limit=self._limit, select=new_select
|
||||
where=self._where,
|
||||
rank=self._rank,
|
||||
group_by=self._group_by,
|
||||
limit=self._limit,
|
||||
select=new_select,
|
||||
)
|
||||
|
||||
def select(self, *keys: Union[Key, str]) -> "Search":
|
||||
@@ -187,7 +216,11 @@ class Search:
|
||||
"""
|
||||
new_select = Select(keys=set(keys))
|
||||
return Search(
|
||||
where=self._where, rank=self._rank, limit=self._limit, select=new_select
|
||||
where=self._where,
|
||||
rank=self._rank,
|
||||
group_by=self._group_by,
|
||||
limit=self._limit,
|
||||
select=new_select,
|
||||
)
|
||||
|
||||
def where(self, where: Optional[Union[Where, Dict[str, Any]]]) -> "Search":
|
||||
@@ -202,20 +235,12 @@ class Search:
|
||||
search.where({"status": "active"})
|
||||
search.where({"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]})
|
||||
"""
|
||||
# Convert dict to Where if needed
|
||||
if where is None:
|
||||
converted_where = None
|
||||
elif isinstance(where, Where):
|
||||
converted_where = where
|
||||
elif isinstance(where, dict):
|
||||
converted_where = Where.from_dict(where)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"where must be a Where object, dict, or None, got {type(where).__name__}"
|
||||
)
|
||||
|
||||
return Search(
|
||||
where=converted_where, rank=self._rank, limit=self._limit, select=self._select
|
||||
where=where,
|
||||
rank=self._rank,
|
||||
group_by=self._group_by,
|
||||
limit=self._limit,
|
||||
select=self._select,
|
||||
)
|
||||
|
||||
def rank(self, rank_expr: Optional[Union[Rank, Dict[str, Any]]]) -> "Search":
|
||||
@@ -231,20 +256,37 @@ class Search:
|
||||
search.rank({"$knn": {"query": [0.1, 0.2]}})
|
||||
search.rank({"$sum": [{"$knn": {"query": [0.1, 0.2]}}, {"$val": 0.5}]})
|
||||
"""
|
||||
# Convert dict to Rank if needed
|
||||
if rank_expr is None:
|
||||
converted_rank = None
|
||||
elif isinstance(rank_expr, Rank):
|
||||
converted_rank = rank_expr
|
||||
elif isinstance(rank_expr, dict):
|
||||
converted_rank = Rank.from_dict(rank_expr)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"rank_expr must be a Rank object, dict, or None, got {type(rank_expr).__name__}"
|
||||
)
|
||||
|
||||
return Search(
|
||||
where=self._where, rank=converted_rank, limit=self._limit, select=self._select
|
||||
where=self._where,
|
||||
rank=rank_expr,
|
||||
group_by=self._group_by,
|
||||
limit=self._limit,
|
||||
select=self._select,
|
||||
)
|
||||
|
||||
def group_by(self, group_by: Optional[Union[GroupBy, Dict[str, Any]]]) -> "Search":
|
||||
"""Set the group_by configuration for grouping and aggregating results
|
||||
|
||||
Args:
|
||||
group_by: A GroupBy object, dict, or None for grouping
|
||||
Dicts will be converted using GroupBy.from_dict()
|
||||
|
||||
Example:
|
||||
search.group_by(GroupBy(
|
||||
keys=[Key("category")],
|
||||
aggregate=MinK(keys=[Key.SCORE], k=3)
|
||||
))
|
||||
search.group_by({
|
||||
"keys": ["category"],
|
||||
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}
|
||||
})
|
||||
"""
|
||||
return Search(
|
||||
where=self._where,
|
||||
rank=self._rank,
|
||||
group_by=group_by,
|
||||
limit=self._limit,
|
||||
select=self._select,
|
||||
)
|
||||
|
||||
def limit(self, limit: int, offset: int = 0) -> "Search":
|
||||
@@ -259,5 +301,9 @@ class Search:
|
||||
"""
|
||||
new_limit = Limit(offset=offset, limit=limit)
|
||||
return Search(
|
||||
where=self._where, rank=self._rank, limit=new_limit, select=self._select
|
||||
where=self._where,
|
||||
rank=self._rank,
|
||||
group_by=self._group_by,
|
||||
limit=new_limit,
|
||||
select=self._select,
|
||||
)
|
||||
|
||||
@@ -425,6 +425,20 @@ class FastAPI(Server):
|
||||
response_model=None,
|
||||
)
|
||||
|
||||
self.router.add_api_route(
|
||||
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/functions/attach",
|
||||
self.attach_function,
|
||||
methods=["POST"],
|
||||
response_model=None,
|
||||
)
|
||||
|
||||
self.router.add_api_route(
|
||||
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/functions/{function_name}",
|
||||
self.get_attached_function,
|
||||
methods=["GET"],
|
||||
response_model=None,
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._system.stop()
|
||||
|
||||
@@ -938,6 +952,103 @@ class FastAPI(Server):
|
||||
limiter=self._capacity_limiter,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.attach_function", OpenTelemetryGranularity.OPERATION)
|
||||
@rate_limit
|
||||
async def attach_function(
|
||||
self,
|
||||
request: Request,
|
||||
tenant: str,
|
||||
database_name: str,
|
||||
collection_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
|
||||
def process_attach_function(request: Request, raw_body: bytes) -> Dict[str, Any]:
|
||||
body = orjson.loads(raw_body)
|
||||
# NOTE: Auth check for attaching functions
|
||||
self.sync_auth_request(
|
||||
request.headers,
|
||||
AuthzAction.UPDATE_COLLECTION, # Using UPDATE_COLLECTION as the auth action
|
||||
tenant,
|
||||
database_name,
|
||||
collection_id,
|
||||
)
|
||||
self._set_request_context(request=request)
|
||||
|
||||
name = body.get("name")
|
||||
function_id = body.get("function_id")
|
||||
output_collection = body.get("output_collection")
|
||||
params = body.get("params")
|
||||
|
||||
attached_fn = self._api.attach_function(
|
||||
function_id=function_id,
|
||||
name=name,
|
||||
input_collection_id=_uuid(collection_id),
|
||||
output_collection=output_collection,
|
||||
params=params,
|
||||
tenant=tenant,
|
||||
database=database_name,
|
||||
)
|
||||
|
||||
return {
|
||||
"attached_function": {
|
||||
"id": str(attached_fn.id),
|
||||
"name": attached_fn.name,
|
||||
"function_name": attached_fn.function_name,
|
||||
"output_collection": attached_fn.output_collection,
|
||||
"params": attached_fn.params,
|
||||
}
|
||||
}
|
||||
|
||||
raw_body = await request.body()
|
||||
return await to_thread.run_sync(
|
||||
process_attach_function,
|
||||
request,
|
||||
raw_body,
|
||||
limiter=self._capacity_limiter,
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.OPERATION)
|
||||
@rate_limit
|
||||
async def get_attached_function(
|
||||
self,
|
||||
request: Request,
|
||||
tenant: str,
|
||||
database_name: str,
|
||||
collection_id: str,
|
||||
function_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
# NOTE: Auth check for getting attached functions
|
||||
await self.auth_request(
|
||||
request.headers,
|
||||
AuthzAction.GET_COLLECTION, # Using GET_COLLECTION as the auth action
|
||||
tenant,
|
||||
database_name,
|
||||
collection_id,
|
||||
)
|
||||
add_attributes_to_current_span({"tenant": tenant})
|
||||
|
||||
attached_fn = await to_thread.run_sync(
|
||||
self._api.get_attached_function,
|
||||
function_name,
|
||||
_uuid(collection_id),
|
||||
tenant,
|
||||
database_name,
|
||||
limiter=self._capacity_limiter,
|
||||
)
|
||||
|
||||
return {
|
||||
"attached_function": {
|
||||
"id": str(attached_fn.id),
|
||||
"name": attached_fn.name,
|
||||
"function_name": attached_fn.function_name,
|
||||
"output_collection": attached_fn.output_collection,
|
||||
"params": attached_fn.params,
|
||||
}
|
||||
}
|
||||
|
||||
@trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION)
|
||||
@rate_limit
|
||||
async def add(
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -31,7 +31,7 @@ from chromadb.utils.embedding_functions import (
|
||||
)
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.models.CollectionCommon import CollectionCommon
|
||||
from chromadb.errors import InvalidArgumentError, InternalError
|
||||
from chromadb.errors import InvalidArgumentError
|
||||
from chromadb.execution.expression import Knn, Search
|
||||
from chromadb.types import Collection as CollectionModel
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
|
||||
@@ -509,7 +509,7 @@ def test_sparse_vector_not_allowed_locally(
|
||||
schema = Schema()
|
||||
schema.create_index(key="sparse_metadata", config=SparseVectorIndexConfig())
|
||||
with pytest.raises(
|
||||
InternalError, match="Sparse vector indexing is not enabled in local"
|
||||
InvalidArgumentError, match="Sparse vector indexing is not enabled in local"
|
||||
):
|
||||
_create_isolated_collection(client_factories, schema=schema)
|
||||
|
||||
@@ -1011,12 +1011,26 @@ def test_collection_fork_inherits_and_isolates_schema(
|
||||
client_factories,
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
parent_version_before_add = get_collection_version(client, collection.name)
|
||||
|
||||
parent_ids = [f"parent-{i}" for i in range(251)]
|
||||
parent_docs = [f"parent doc {i}" for i in range(251)]
|
||||
parent_metadatas: List[Mapping[str, Any]] = [
|
||||
{"shared_key": f"parent_{i}"} for i in range(251)
|
||||
]
|
||||
|
||||
collection.add(
|
||||
ids=["parent-1"],
|
||||
documents=["parent doc"],
|
||||
metadatas=[{"shared_key": "parent"}],
|
||||
ids=parent_ids,
|
||||
documents=parent_docs,
|
||||
metadatas=parent_metadatas,
|
||||
)
|
||||
|
||||
# Wait for parent to compact before forking. Otherwise, the fork inherits
|
||||
# uncompacted logs, and compaction of those inherited logs could increment
|
||||
# the fork's version before the fork's own data is compacted.
|
||||
wait_for_version_increase(client, collection.name, parent_version_before_add)
|
||||
|
||||
assert collection.schema is not None
|
||||
parent_schema_json = collection.schema.serialize_to_json()
|
||||
|
||||
@@ -1050,8 +1064,8 @@ def test_collection_fork_inherits_and_isolates_schema(
|
||||
assert reloaded_parent.schema is not None
|
||||
assert "child_only" not in reloaded_parent.schema.keys
|
||||
|
||||
parent_results = reloaded_parent.get(where={"shared_key": "parent"})
|
||||
assert set(parent_results["ids"]) == {"parent-1"}
|
||||
parent_results = reloaded_parent.get(where={"shared_key": "parent_10"})
|
||||
assert set(parent_results["ids"]) == {"parent-10"}
|
||||
|
||||
child_results = forked.get(where={"child_only": "value_10"})
|
||||
assert set(child_results["ids"]) == {"fork-10"}
|
||||
|
||||
@@ -7,11 +7,21 @@ for automatically processing collections.
|
||||
|
||||
import pytest
|
||||
from chromadb.api.client import Client as ClientCreator
|
||||
from chromadb.api.functions import (
|
||||
RECORD_COUNTER_FUNCTION,
|
||||
STATISTICS_FUNCTION,
|
||||
Function,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from chromadb.errors import ChromaError, NotFoundError
|
||||
from chromadb.test.utils.wait_for_version_increase import (
|
||||
get_collection_version,
|
||||
wait_for_version_increase,
|
||||
)
|
||||
from time import sleep
|
||||
|
||||
|
||||
def test_function_attach_and_detach(basic_http_client: System) -> None:
|
||||
def test_count_function_attach_and_detach(basic_http_client: System) -> None:
|
||||
"""Test creating and removing a function with the record_counter operator"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
@@ -22,45 +32,39 @@ def test_function_attach_and_detach(basic_http_client: System) -> None:
|
||||
metadata={"description": "Sample documents for task processing"},
|
||||
)
|
||||
|
||||
# Add initial documents
|
||||
collection.add(
|
||||
ids=["doc1", "doc2", "doc3"],
|
||||
documents=[
|
||||
"The quick brown fox jumps over the lazy dog",
|
||||
"Machine learning is a subset of artificial intelligence",
|
||||
"Python is a popular programming language",
|
||||
],
|
||||
metadatas=[{"source": "proverb"}, {"source": "tech"}, {"source": "tech"}],
|
||||
)
|
||||
|
||||
# Verify collection has documents
|
||||
assert collection.count() == 3
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
attached_fn = collection.attach_function(
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function_id="record_counter", # Built-in operator that counts records
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Verify task creation succeeded
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
initial_version = get_collection_version(client, collection.name)
|
||||
|
||||
# Add more documents
|
||||
# Add documents
|
||||
collection.add(
|
||||
ids=["doc4", "doc5"],
|
||||
documents=[
|
||||
"Chroma is a vector database",
|
||||
"Tasks automate data processing",
|
||||
],
|
||||
ids=["doc_{}".format(i) for i in range(0, 300)],
|
||||
documents=["test document"] * 300,
|
||||
)
|
||||
|
||||
# Verify documents were added
|
||||
assert collection.count() == 5
|
||||
assert collection.count() == 300
|
||||
|
||||
wait_for_version_increase(client, collection.name, initial_version)
|
||||
# Give some time to invalidate the frontend query cache
|
||||
sleep(60)
|
||||
|
||||
result = client.get_collection("my_documents_counts").get("function_output")
|
||||
assert result["metadatas"] is not None
|
||||
assert result["metadatas"][0]["total_count"] == 300
|
||||
|
||||
# Remove the task
|
||||
success = attached_fn.detach(
|
||||
success = collection.detach_function(
|
||||
attached_fn.name,
|
||||
delete_output_collection=True,
|
||||
)
|
||||
|
||||
@@ -79,13 +83,42 @@ def test_task_with_invalid_function(basic_http_client: System) -> None:
|
||||
# Attempt to create task with non-existent function should raise ChromaError
|
||||
with pytest.raises(ChromaError, match="function not found"):
|
||||
collection.attach_function(
|
||||
function=Function._NONEXISTENT_TEST_ONLY,
|
||||
name="invalid_task",
|
||||
function_id="nonexistent_function",
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
|
||||
|
||||
def test_attach_function_returns_function_name(basic_http_client: System) -> None:
|
||||
"""Test that attach_function and get_attached_function return function_name field instead of UUID"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
collection = client.create_collection(name="test_function_name")
|
||||
collection.add(ids=["id1"], documents=["doc1"])
|
||||
|
||||
# Attach a function and verify function_name field in response
|
||||
attached_fn, created = collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="my_counter",
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Verify the attached function has function_name (not function_id UUID)
|
||||
assert created is True
|
||||
assert attached_fn.function_name == "record_counter"
|
||||
assert attached_fn.name == "my_counter"
|
||||
|
||||
# Get the attached function and verify function_name field is also present
|
||||
retrieved_fn = collection.get_attached_function("my_counter")
|
||||
assert retrieved_fn == attached_fn
|
||||
|
||||
# Clean up
|
||||
collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
|
||||
|
||||
def test_function_multiple_collections(basic_http_client: System) -> None:
|
||||
"""Test attaching functions on multiple collections"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
@@ -95,93 +128,163 @@ def test_function_multiple_collections(basic_http_client: System) -> None:
|
||||
collection1 = client.create_collection(name="collection_1")
|
||||
collection1.add(ids=["id1", "id2"], documents=["doc1", "doc2"])
|
||||
|
||||
attached_fn1 = collection1.attach_function(
|
||||
attached_fn1, created1 = collection1.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_1",
|
||||
function_id="record_counter",
|
||||
output_collection="output_1",
|
||||
params=None,
|
||||
)
|
||||
|
||||
assert attached_fn1 is not None
|
||||
assert created1 is True
|
||||
|
||||
# Create second collection and task
|
||||
collection2 = client.create_collection(name="collection_2")
|
||||
collection2.add(ids=["id3", "id4"], documents=["doc3", "doc4"])
|
||||
|
||||
attached_fn2 = collection2.attach_function(
|
||||
attached_fn2, created2 = collection2.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_2",
|
||||
function_id="record_counter",
|
||||
output_collection="output_2",
|
||||
params=None,
|
||||
)
|
||||
|
||||
assert attached_fn2 is not None
|
||||
assert created2 is True
|
||||
|
||||
# Task IDs should be different
|
||||
assert attached_fn1.id != attached_fn2.id
|
||||
|
||||
# Clean up
|
||||
assert attached_fn1.detach(delete_output_collection=True) is True
|
||||
assert attached_fn2.detach(delete_output_collection=True) is True
|
||||
assert (
|
||||
collection1.detach_function(attached_fn1.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
collection2.detach_function(attached_fn2.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_functions_multiple_attached_functions(basic_http_client: System) -> None:
|
||||
"""Test attaching multiple functions on the same collection"""
|
||||
def test_functions_one_attached_function_per_collection(
|
||||
basic_http_client: System,
|
||||
) -> None:
|
||||
"""Test that only one attached function is allowed per collection"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create a single collection
|
||||
collection = client.create_collection(name="multi_task_collection")
|
||||
collection = client.create_collection(name="single_task_collection")
|
||||
collection.add(ids=["id1", "id2", "id3"], documents=["doc1", "doc2", "doc3"])
|
||||
|
||||
# Create first task on the collection
|
||||
attached_fn1 = collection.attach_function(
|
||||
attached_fn1, created = collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_1",
|
||||
function_id="record_counter",
|
||||
output_collection="output_1",
|
||||
params=None,
|
||||
)
|
||||
|
||||
assert attached_fn1 is not None
|
||||
assert created is True
|
||||
|
||||
# Create second task on the SAME collection with a different name
|
||||
attached_fn2 = collection.attach_function(
|
||||
# Attempt to create a second task with a different name should fail
|
||||
# (only one attached function allowed per collection)
|
||||
with pytest.raises(
|
||||
ChromaError,
|
||||
match="collection already has an attached function: name=task_1, function=record_counter, output_collection=output_1",
|
||||
):
|
||||
collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_2",
|
||||
output_collection="output_2",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Attempt to create a task with the same name but different function_id should also fail
|
||||
with pytest.raises(
|
||||
ChromaError,
|
||||
match=r"collection already has an attached function: name=task_1, function=record_counter, output_collection=output_1",
|
||||
):
|
||||
collection.attach_function(
|
||||
function=STATISTICS_FUNCTION,
|
||||
name="task_1",
|
||||
output_collection="output_different", # Different output collection
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Detach the first function
|
||||
assert (
|
||||
collection.detach_function(attached_fn1.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
|
||||
# Now we should be able to attach a new function
|
||||
attached_fn2, created2 = collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_2",
|
||||
function_id="record_counter",
|
||||
output_collection="output_2",
|
||||
params=None,
|
||||
)
|
||||
|
||||
assert attached_fn2 is not None
|
||||
assert created2 is True
|
||||
assert attached_fn2.id != attached_fn1.id
|
||||
|
||||
# Task IDs should be different even though they're on the same collection
|
||||
assert attached_fn1.id != attached_fn2.id
|
||||
|
||||
# Create third task on the same collection
|
||||
attached_fn3 = collection.attach_function(
|
||||
name="task_3",
|
||||
function_id="record_counter",
|
||||
output_collection="output_3",
|
||||
params=None,
|
||||
# Clean up
|
||||
assert (
|
||||
collection.detach_function(attached_fn2.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
|
||||
assert attached_fn3 is not None
|
||||
assert attached_fn3.id != attached_fn1.id
|
||||
assert attached_fn3.id != attached_fn2.id
|
||||
|
||||
# Attempt to create a task with duplicate name on same collection should fail
|
||||
with pytest.raises(ChromaError, match="already exists"):
|
||||
def test_attach_function_with_invalid_params(basic_http_client: System) -> None:
|
||||
"""Test that attach_function with non-empty params raises an error"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
collection = client.create_collection(name="test_invalid_params")
|
||||
collection.add(ids=["id1"], documents=["test document"])
|
||||
|
||||
# Attempt to create task with non-empty params should fail
|
||||
# (no functions currently accept parameters)
|
||||
with pytest.raises(
|
||||
ChromaError,
|
||||
match="params must be empty - no functions currently accept parameters",
|
||||
):
|
||||
collection.attach_function(
|
||||
name="task_1", # Duplicate name
|
||||
function_id="record_counter",
|
||||
output_collection="output_duplicate",
|
||||
params=None,
|
||||
name="invalid_params_task",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params={"some_key": "some_value"},
|
||||
)
|
||||
|
||||
# Clean up - remove each task individually
|
||||
assert attached_fn1.detach(delete_output_collection=True) is True
|
||||
assert attached_fn2.detach(delete_output_collection=True) is True
|
||||
assert attached_fn3.detach(delete_output_collection=True) is True
|
||||
|
||||
def test_attach_function_output_collection_already_exists(
|
||||
basic_http_client: System,
|
||||
) -> None:
|
||||
"""Test that attach_function fails when output collection name already exists"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create a collection that will be used as input
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test document"])
|
||||
|
||||
# Create another collection with the name we want to use for output
|
||||
client.create_collection(name="existing_output_collection")
|
||||
|
||||
# Attempt to create task with output collection name that already exists
|
||||
with pytest.raises(
|
||||
ChromaError,
|
||||
match=r"Output collection \[existing_output_collection\] already exists",
|
||||
):
|
||||
input_collection.attach_function(
|
||||
name="my_task",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="existing_output_collection",
|
||||
params=None,
|
||||
)
|
||||
|
||||
|
||||
def test_function_remove_nonexistent(basic_http_client: System) -> None:
|
||||
@@ -191,15 +294,294 @@ def test_function_remove_nonexistent(basic_http_client: System) -> None:
|
||||
|
||||
collection = client.create_collection(name="test_collection")
|
||||
collection.add(ids=["id1"], documents=["test"])
|
||||
attached_fn = collection.attach_function(
|
||||
attached_fn, _ = collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="test_function",
|
||||
function_id="record_counter",
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
|
||||
attached_fn.detach(delete_output_collection=True)
|
||||
collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
|
||||
# Trying to detach this function again should raise NotFoundError
|
||||
with pytest.raises(NotFoundError, match="does not exist"):
|
||||
attached_fn.detach(delete_output_collection=True)
|
||||
collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
|
||||
|
||||
def test_attach_to_output_collection_fails(basic_http_client: System) -> None:
|
||||
"""Test that attaching a function to an output collection fails"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create input collection
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test"])
|
||||
|
||||
_, _ = input_collection.attach_function(
|
||||
name="test_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
output_collection = client.get_collection(name="output_collection")
|
||||
|
||||
with pytest.raises(
|
||||
ChromaError, match="cannot attach function to an output collection"
|
||||
):
|
||||
_ = output_collection.attach_function(
|
||||
name="test_function_2",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection_2",
|
||||
params=None,
|
||||
)
|
||||
|
||||
|
||||
def test_delete_output_collection_detaches_function(basic_http_client: System) -> None:
|
||||
"""Test that deleting an output collection also detaches the attached function"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create input collection and attach a function
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test"])
|
||||
|
||||
attached_fn, created = input_collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
# Delete the output collection directly
|
||||
client.delete_collection("output_collection")
|
||||
|
||||
# The attached function should now be gone - trying to get it should raise NotFoundError
|
||||
with pytest.raises(NotFoundError):
|
||||
input_collection.get_attached_function("my_function")
|
||||
|
||||
|
||||
def test_delete_orphaned_output_collection(basic_http_client: System) -> None:
|
||||
"""Test that deleting an output collection from a recently detached function works"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create input collection and attach a function
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test"])
|
||||
|
||||
attached_fn, created = input_collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
input_collection.detach_function(attached_fn.name, delete_output_collection=False)
|
||||
|
||||
# Delete the output collection directly
|
||||
client.delete_collection("output_collection")
|
||||
|
||||
# The attached function should still exist but be marked as detached
|
||||
with pytest.raises(NotFoundError):
|
||||
input_collection.get_attached_function("my_function")
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
# Try to use the function - it should fail since it's detached
|
||||
client.get_collection("output_collection")
|
||||
|
||||
def test_partial_attach_function_repair(
|
||||
basic_http_client: System,
|
||||
) -> None:
|
||||
"""Test creating and removing a function with the record_counter operator"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create a collection
|
||||
collection = client.get_or_create_collection(
|
||||
name="my_document",
|
||||
)
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
assert created is True
|
||||
|
||||
# Verify task creation succeeded
|
||||
assert attached_fn is not None
|
||||
|
||||
collection2 = client.get_or_create_collection(
|
||||
name="my_document2",
|
||||
)
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
# This should fail
|
||||
with pytest.raises(
|
||||
ChromaError, match=r"Output collection \[my_documents_counts\] already exists"
|
||||
):
|
||||
attached_fn, _ = collection2.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Detach the function
|
||||
assert (
|
||||
collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
attached_fn, created = collection2.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
|
||||
def test_output_collection_created_with_schema(basic_http_client: System) -> None:
|
||||
"""Test that output collections are created with the source_attached_function_id in the schema"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create input collection and attach a function
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test"])
|
||||
|
||||
attached_fn, created = input_collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
# Get the output collection - it should exist
|
||||
output_collection = client.get_collection(name="output_collection")
|
||||
assert output_collection is not None
|
||||
|
||||
# The source_attached_function_id is stored in the schema (not metadata)
|
||||
# We can't directly access the schema from the client, but we verify the collection exists
|
||||
# and the attached function orchestrator will use this field internally
|
||||
assert "source_attached_function_id" in output_collection._model.pretty_schema()
|
||||
|
||||
# Clean up
|
||||
input_collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
|
||||
|
||||
def test_count_function_attach_and_detach_attach_attach(
|
||||
basic_http_client: System,
|
||||
) -> None:
|
||||
"""Test creating and removing a function with the record_counter operator"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create a collection
|
||||
collection = client.get_or_create_collection(
|
||||
name="my_document",
|
||||
metadata={"description": "Sample documents for task processing"},
|
||||
)
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Verify task creation succeeded
|
||||
assert created is True
|
||||
assert attached_fn is not None
|
||||
initial_version = get_collection_version(client, collection.name)
|
||||
|
||||
# Add documents
|
||||
collection.add(
|
||||
ids=["doc_{}".format(i) for i in range(0, 300)],
|
||||
documents=["test document"] * 300,
|
||||
)
|
||||
|
||||
# Verify documents were added
|
||||
assert collection.count() == 300
|
||||
|
||||
wait_for_version_increase(client, collection.name, initial_version)
|
||||
# Give some time to invalidate the frontend query cache
|
||||
sleep(60)
|
||||
|
||||
result = client.get_collection("my_documents_counts").get("function_output")
|
||||
assert result["metadatas"] is not None
|
||||
assert result["metadatas"][0]["total_count"] == 300
|
||||
|
||||
# Remove the task
|
||||
success = collection.detach_function(
|
||||
attached_fn.name, delete_output_collection=True
|
||||
)
|
||||
|
||||
# Verify task removal succeeded
|
||||
assert success is True
|
||||
|
||||
# Attach a function that counts records in the collection
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
# Attach a function that counts records in the collection
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
assert created is False
|
||||
assert attached_fn is not None
|
||||
|
||||
def test_attach_function_idempotency(basic_http_client: System) -> None:
|
||||
"""Test that attach_function is idempotent - calling it twice with same params returns created=False"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
collection = client.create_collection(name="idempotency_test")
|
||||
collection.add(ids=["id1"], documents=["test document"])
|
||||
|
||||
# First attach - should be newly created
|
||||
attached_fn1, created1 = collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn1 is not None
|
||||
assert created1 is True
|
||||
|
||||
# Second attach with identical params - should be idempotent (created=False)
|
||||
attached_fn2, created2 = collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn2 is not None
|
||||
assert created2 is False
|
||||
|
||||
# Both should return the same function ID
|
||||
assert attached_fn1.id == attached_fn2.id
|
||||
|
||||
# Clean up
|
||||
collection.detach_function(attached_fn1.name, delete_output_collection=True)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import math
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
|
||||
from chromadb import SparseVector
|
||||
from chromadb.utils.embedding_functions.chroma_bm25_embedding_function import (
|
||||
DEFAULT_CHROMA_BM25_STOPWORDS,
|
||||
ChromaBm25EmbeddingFunction,
|
||||
@@ -137,3 +139,64 @@ def test_validate_config_update_allows_known_keys() -> None:
|
||||
embedder.validate_config_update(
|
||||
embedder.get_config(), {"k": 1.1, "stopwords": ["custom"]}
|
||||
)
|
||||
|
||||
|
||||
def test_multithreaded_usage() -> None:
|
||||
embedder = ChromaBm25EmbeddingFunction()
|
||||
base_texts = [
|
||||
"""The gravitational wave background from massive black hole binaries emit bursts of
|
||||
gravitational waves at periapse. Such events may be directly resolvable in the Galactic
|
||||
centre. However, if the star does not spiral in, the emitted GWs are not resolvable for
|
||||
extra-galactic MBHs, but constitute a source of background noise. We estimate the power
|
||||
spectrum of this extreme mass ratio burst background.""",
|
||||
"""Dynamics of planets in exoplanetary systems with multiple stars showing how the
|
||||
gravitational interactions between the stars and planets affect the orbital stability
|
||||
and long-term evolution of the planetary system architectures.""",
|
||||
"""Diurnal Thermal Tides in a Non-rotating atmosphere with realistic heating profiles
|
||||
and temperature gradients that demonstrate the complex interplay between radiation
|
||||
and atmospheric dynamics in planetary atmospheres.""",
|
||||
"""Intermittent turbulence, noise and waves in stellar atmospheres create complex
|
||||
patterns of energy transport and momentum deposition that influence the structure
|
||||
and evolution of stellar interiors and surfaces.""",
|
||||
"""Superconductivity in quantum materials and condensed matter physics systems
|
||||
exhibiting novel quantum phenomena including topological phases, strongly correlated
|
||||
electron systems, and exotic superconducting pairing mechanisms.""",
|
||||
"""Machine learning models require careful tuning of hyperparameters including learning
|
||||
rates, regularization coefficients, and architectural choices that demonstrate the
|
||||
complex interplay between optimization algorithms and model capacity.""",
|
||||
"""Natural language processing enables text understanding through sophisticated
|
||||
algorithms that analyze semantic relationships, syntactic structures, and contextual
|
||||
information to extract meaningful representations from unstructured textual data.""",
|
||||
"""Vector databases store high-dimensional embeddings efficiently using advanced
|
||||
indexing techniques including approximate nearest neighbor search algorithms that
|
||||
balance accuracy and computational efficiency for large-scale similarity search.""",
|
||||
]
|
||||
texts = base_texts * 30
|
||||
|
||||
num_threads = 10
|
||||
|
||||
def process_single_text(text: str) -> SparseVector:
|
||||
return embedder([text])[0]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(process_single_text, text) for text in texts]
|
||||
all_results = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
embedding = future.result()
|
||||
all_results.append(embedding)
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"Threading error detected: {type(e).__name__}: {e}. "
|
||||
"This indicates the stemmer is not thread-safe when cached."
|
||||
)
|
||||
|
||||
assert len(all_results) == len(texts)
|
||||
|
||||
for embedding in all_results:
|
||||
assert embedding.indices
|
||||
assert len(embedding.indices) == len(embedding.values)
|
||||
assert _is_sorted(embedding.indices)
|
||||
for value in embedding.values:
|
||||
assert value > 0
|
||||
assert math.isfinite(value)
|
||||
|
||||
@@ -33,12 +33,14 @@ def test_get_builtins_holds() -> None:
|
||||
"GoogleGenerativeAiEmbeddingFunction",
|
||||
"GooglePalmEmbeddingFunction",
|
||||
"GoogleVertexEmbeddingFunction",
|
||||
"GoogleGenaiEmbeddingFunction",
|
||||
"HuggingFaceEmbeddingFunction",
|
||||
"HuggingFaceEmbeddingServer",
|
||||
"InstructorEmbeddingFunction",
|
||||
"JinaEmbeddingFunction",
|
||||
"MistralEmbeddingFunction",
|
||||
"MorphEmbeddingFunction",
|
||||
"NomicEmbeddingFunction",
|
||||
"ONNXMiniLM_L6_V2",
|
||||
"OllamaEmbeddingFunction",
|
||||
"OpenAIEmbeddingFunction",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import hypothesis
|
||||
import hypothesis.strategies as st
|
||||
from typing import Any, Optional, List, Dict, Union, cast
|
||||
from typing import Any, Optional, List, Dict, Union, cast, Tuple
|
||||
from typing_extensions import TypedDict
|
||||
import uuid
|
||||
import numpy as np
|
||||
@@ -9,6 +9,10 @@ import numpy.typing as npt
|
||||
import chromadb.api.types as types
|
||||
import re
|
||||
from hypothesis.strategies._internal.strategies import SearchStrategy
|
||||
from chromadb.test.api.test_schema_e2e import (
|
||||
SimpleEmbeddingFunction,
|
||||
DeterministicSparseEmbeddingFunction,
|
||||
)
|
||||
from chromadb.test.conftest import NOT_CLUSTER_ONLY
|
||||
from dataclasses import dataclass
|
||||
from chromadb.api.types import (
|
||||
@@ -17,12 +21,27 @@ from chromadb.api.types import (
|
||||
EmbeddingFunction,
|
||||
Embeddings,
|
||||
Metadata,
|
||||
Schema,
|
||||
CollectionMetadata,
|
||||
VectorIndexConfig,
|
||||
SparseVectorIndexConfig,
|
||||
StringInvertedIndexConfig,
|
||||
IntInvertedIndexConfig,
|
||||
FloatInvertedIndexConfig,
|
||||
BoolInvertedIndexConfig,
|
||||
HnswIndexConfig,
|
||||
SpannIndexConfig,
|
||||
Space,
|
||||
)
|
||||
from chromadb.types import LiteralValue, WhereOperator, LogicalOperator
|
||||
from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled
|
||||
from chromadb.test.conftest import is_spann_disabled_mode
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
CreateSpannConfiguration,
|
||||
CreateHNSWConfiguration,
|
||||
)
|
||||
from chromadb.utils.embedding_functions import (
|
||||
register_embedding_function,
|
||||
)
|
||||
|
||||
# Set the random seed for reproducibility
|
||||
@@ -266,6 +285,365 @@ class ExternalCollection:
|
||||
embedding_function: Optional[types.EmbeddingFunction[Embeddable]]
|
||||
|
||||
|
||||
@register_embedding_function
|
||||
class SimpleIpEmbeddingFunction(SimpleEmbeddingFunction):
|
||||
"""Simple embedding function with ip space for persistence tests."""
|
||||
|
||||
def default_space(self) -> str: # type: ignore[override]
|
||||
return "ip"
|
||||
|
||||
|
||||
@st.composite
|
||||
def vector_index_config_strategy(draw: st.DrawFn) -> VectorIndexConfig:
|
||||
"""Generate VectorIndexConfig with optional space, embedding_function, source_key, hnsw, spann."""
|
||||
space = None
|
||||
embedding_function = None
|
||||
source_key = None
|
||||
hnsw = None
|
||||
spann = None
|
||||
|
||||
if draw(st.booleans()):
|
||||
space = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
|
||||
if draw(st.booleans()):
|
||||
embedding_function = SimpleIpEmbeddingFunction(
|
||||
dim=draw(st.integers(min_value=1, max_value=1000))
|
||||
)
|
||||
|
||||
if draw(st.booleans()):
|
||||
source_key = draw(st.one_of(st.just("#document"), safe_text))
|
||||
|
||||
index_choice = draw(st.sampled_from(["hnsw", "spann", "none"]))
|
||||
|
||||
if index_choice == "hnsw":
|
||||
hnsw = HnswIndexConfig(
|
||||
ef_construction=draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
max_neighbors=draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
ef_search=draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
sync_threshold=draw(st.integers(min_value=2, max_value=10000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
resize_factor=draw(st.floats(min_value=1.0, max_value=5.0))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
)
|
||||
elif index_choice == "spann":
|
||||
spann = SpannIndexConfig(
|
||||
search_nprobe=draw(st.integers(min_value=1, max_value=128))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
write_nprobe=draw(st.integers(min_value=1, max_value=64))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
ef_construction=draw(st.integers(min_value=1, max_value=200))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
ef_search=draw(st.integers(min_value=1, max_value=200))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
max_neighbors=draw(st.integers(min_value=1, max_value=64))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
reassign_neighbor_count=draw(st.integers(min_value=1, max_value=64))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
split_threshold=draw(st.integers(min_value=50, max_value=200))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
merge_threshold=draw(st.integers(min_value=25, max_value=100))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
)
|
||||
|
||||
return VectorIndexConfig(
|
||||
space=cast(Space, space),
|
||||
embedding_function=embedding_function,
|
||||
source_key=source_key,
|
||||
hnsw=hnsw,
|
||||
spann=spann,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def sparse_vector_index_config_strategy(draw: st.DrawFn) -> SparseVectorIndexConfig:
|
||||
"""Generate SparseVectorIndexConfig with optional embedding_function, source_key, bm25."""
|
||||
embedding_function = None
|
||||
source_key = None
|
||||
bm25 = None
|
||||
|
||||
if draw(st.booleans()):
|
||||
embedding_function = DeterministicSparseEmbeddingFunction()
|
||||
source_key = draw(st.one_of(st.just("#document"), safe_text))
|
||||
|
||||
if draw(st.booleans()):
|
||||
bm25 = draw(st.booleans())
|
||||
|
||||
return SparseVectorIndexConfig(
|
||||
embedding_function=embedding_function,
|
||||
source_key=source_key,
|
||||
bm25=bm25,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def schema_strategy(draw: st.DrawFn) -> Optional[Schema]:
|
||||
"""Generate a Schema object with various create_index/delete_index operations."""
|
||||
if draw(st.booleans()):
|
||||
return None
|
||||
|
||||
schema = Schema()
|
||||
|
||||
# Decide how many operations to perform
|
||||
num_operations = draw(st.integers(min_value=0, max_value=5))
|
||||
sparse_index_created = False
|
||||
|
||||
for _ in range(num_operations):
|
||||
operation = draw(st.sampled_from(["create_index", "delete_index"]))
|
||||
config_type = draw(
|
||||
st.sampled_from(
|
||||
[
|
||||
"string_inverted",
|
||||
"int_inverted",
|
||||
"float_inverted",
|
||||
"bool_inverted",
|
||||
"vector",
|
||||
"sparse_vector",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Decide if we're setting on a key or globally
|
||||
use_key = draw(st.booleans())
|
||||
key = None
|
||||
if use_key and config_type != "vector":
|
||||
# Vector indexes can't be set on specific keys, only globally
|
||||
key = draw(safe_text)
|
||||
|
||||
if operation == "create_index":
|
||||
if config_type == "string_inverted":
|
||||
schema.create_index(config=StringInvertedIndexConfig(), key=key)
|
||||
elif config_type == "int_inverted":
|
||||
schema.create_index(config=IntInvertedIndexConfig(), key=key)
|
||||
elif config_type == "float_inverted":
|
||||
schema.create_index(config=FloatInvertedIndexConfig(), key=key)
|
||||
elif config_type == "bool_inverted":
|
||||
schema.create_index(config=BoolInvertedIndexConfig(), key=key)
|
||||
elif config_type == "vector":
|
||||
vector_config = draw(vector_index_config_strategy())
|
||||
schema.create_index(config=vector_config, key=None)
|
||||
elif (
|
||||
config_type == "sparse_vector"
|
||||
and not is_spann_disabled_mode
|
||||
and not sparse_index_created
|
||||
):
|
||||
sparse_config = draw(sparse_vector_index_config_strategy())
|
||||
# Sparse vector MUST have a key
|
||||
if key is None:
|
||||
key = draw(safe_text)
|
||||
schema.create_index(config=sparse_config, key=key)
|
||||
sparse_index_created = True
|
||||
|
||||
elif operation == "delete_index":
|
||||
if config_type == "string_inverted":
|
||||
schema.delete_index(config=StringInvertedIndexConfig(), key=key)
|
||||
elif config_type == "int_inverted":
|
||||
schema.delete_index(config=IntInvertedIndexConfig(), key=key)
|
||||
elif config_type == "float_inverted":
|
||||
schema.delete_index(config=FloatInvertedIndexConfig(), key=key)
|
||||
elif config_type == "bool_inverted":
|
||||
schema.delete_index(config=BoolInvertedIndexConfig(), key=key)
|
||||
# Vector, FTS, and sparse_vector deletion is not currently supported
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@st.composite
|
||||
def metadata_with_hnsw_strategy(draw: st.DrawFn) -> Optional[CollectionMetadata]:
|
||||
"""Generate metadata with hnsw parameters."""
|
||||
metadata: CollectionMetadata = {}
|
||||
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:construction_ef"] = draw(
|
||||
st.integers(min_value=1, max_value=1000)
|
||||
)
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:search_ef"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:M"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:sync_threshold"] = draw(
|
||||
st.integers(min_value=2, max_value=10000)
|
||||
)
|
||||
|
||||
return metadata if metadata else None
|
||||
|
||||
|
||||
@st.composite
|
||||
def create_configuration_strategy(
|
||||
draw: st.DrawFn,
|
||||
) -> Optional[CreateCollectionConfiguration]:
|
||||
"""Generate CreateCollectionConfiguration with mutual exclusivity rules."""
|
||||
configuration: CreateCollectionConfiguration = {}
|
||||
|
||||
# Optionally set embedding_function (independent)
|
||||
if draw(st.booleans()):
|
||||
configuration["embedding_function"] = SimpleIpEmbeddingFunction(
|
||||
dim=draw(st.integers(min_value=1, max_value=1000))
|
||||
)
|
||||
|
||||
# Decide: set space only, OR set hnsw config, OR set spann config
|
||||
config_choice = draw(
|
||||
st.sampled_from(
|
||||
["space_only_hnsw", "space_only_spann", "hnsw", "spann", "none"]
|
||||
)
|
||||
)
|
||||
|
||||
if config_choice == "space_only_hnsw":
|
||||
configuration["hnsw"] = CreateHNSWConfiguration(
|
||||
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
)
|
||||
elif config_choice == "space_only_spann":
|
||||
configuration["spann"] = CreateSpannConfiguration(
|
||||
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
)
|
||||
elif config_choice == "hnsw":
|
||||
# Set hnsw config (optionally with space)
|
||||
hnsw_config: CreateHNSWConfiguration = {}
|
||||
if draw(st.booleans()):
|
||||
hnsw_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
hnsw_config["ef_construction"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
hnsw_config["ef_search"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
hnsw_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
hnsw_config["sync_threshold"] = draw(st.integers(min_value=2, max_value=10000))
|
||||
hnsw_config["resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
|
||||
configuration["hnsw"] = hnsw_config
|
||||
elif config_choice == "spann":
|
||||
# Set spann config (optionally with space)
|
||||
spann_config: CreateSpannConfiguration = {}
|
||||
if draw(st.booleans()):
|
||||
spann_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
spann_config["search_nprobe"] = draw(st.integers(min_value=1, max_value=128))
|
||||
spann_config["write_nprobe"] = draw(st.integers(min_value=1, max_value=64))
|
||||
spann_config["ef_construction"] = draw(st.integers(min_value=1, max_value=200))
|
||||
spann_config["ef_search"] = draw(st.integers(min_value=1, max_value=200))
|
||||
spann_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=64))
|
||||
spann_config["reassign_neighbor_count"] = draw(
|
||||
st.integers(min_value=1, max_value=64)
|
||||
)
|
||||
spann_config["split_threshold"] = draw(st.integers(min_value=50, max_value=200))
|
||||
spann_config["merge_threshold"] = draw(st.integers(min_value=25, max_value=100))
|
||||
configuration["spann"] = spann_config
|
||||
|
||||
return configuration if configuration else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectionInputCombination:
|
||||
"""
|
||||
Input tuple for collection creation tests.
|
||||
"""
|
||||
|
||||
metadata: Optional[CollectionMetadata]
|
||||
configuration: Optional[CreateCollectionConfiguration]
|
||||
schema: Optional[Schema]
|
||||
schema_vector_info: Optional[Dict[str, Any]]
|
||||
kind: str
|
||||
|
||||
|
||||
def non_none_items(items: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: v for k, v in items.items() if v is not None}
|
||||
|
||||
|
||||
def vector_index_to_dict(config: VectorIndexConfig) -> Dict[str, Any]:
|
||||
embedding_default_space: Optional[str] = None
|
||||
if config.embedding_function is not None and hasattr(
|
||||
config.embedding_function, "default_space"
|
||||
):
|
||||
embedding_default_space = cast(str, config.embedding_function.default_space())
|
||||
|
||||
return {
|
||||
"space": config.space,
|
||||
"hnsw": config.hnsw.model_dump(exclude_none=True) if config.hnsw else None,
|
||||
"spann": config.spann.model_dump(exclude_none=True) if config.spann else None,
|
||||
"embedding_function_default_space": embedding_default_space,
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def _schema_input_strategy(
|
||||
draw: st.DrawFn,
|
||||
) -> Tuple[Schema, Dict[str, Any]]:
|
||||
schema = Schema()
|
||||
vector_config = draw(vector_index_config_strategy())
|
||||
schema.create_index(config=vector_config, key=None)
|
||||
return schema, vector_index_to_dict(vector_config)
|
||||
|
||||
|
||||
@st.composite
|
||||
def metadata_configuration_schema_strategy(
|
||||
draw: st.DrawFn,
|
||||
) -> CollectionInputCombination:
|
||||
"""
|
||||
Generate compatible combinations of metadata, configuration, and schema inputs.
|
||||
"""
|
||||
|
||||
choice = draw(
|
||||
st.sampled_from(
|
||||
[
|
||||
"none",
|
||||
"metadata",
|
||||
"configuration",
|
||||
"metadata_configuration",
|
||||
"schema",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
metadata: Optional[CollectionMetadata] = None
|
||||
configuration: Optional[CreateCollectionConfiguration] = None
|
||||
schema: Optional[Schema] = None
|
||||
schema_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
if choice in ("metadata", "metadata_configuration"):
|
||||
metadata = draw(
|
||||
metadata_with_hnsw_strategy().filter(
|
||||
lambda value: value is not None and len(value) > 0
|
||||
)
|
||||
)
|
||||
|
||||
if choice in ("configuration", "metadata_configuration"):
|
||||
configuration = draw(
|
||||
create_configuration_strategy().filter(
|
||||
lambda value: value is not None
|
||||
and (
|
||||
(value.get("hnsw") is not None and len(value["hnsw"]) > 0)
|
||||
or (value.get("spann") is not None and len(value["spann"]) > 0)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if choice == "schema":
|
||||
schema, schema_info = draw(_schema_input_strategy())
|
||||
|
||||
return CollectionInputCombination(
|
||||
metadata=metadata,
|
||||
configuration=configuration,
|
||||
schema=schema,
|
||||
schema_vector_info=schema_info,
|
||||
kind=choice,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Collection(ExternalCollection):
|
||||
"""
|
||||
@@ -344,7 +722,7 @@ def collections(
|
||||
spann_config: CreateSpannConfiguration = {
|
||||
"space": spann_space,
|
||||
"write_nprobe": 4,
|
||||
"reassign_neighbor_count": 4
|
||||
"reassign_neighbor_count": 4,
|
||||
}
|
||||
collection_config = {
|
||||
"spann": spann_config,
|
||||
@@ -395,7 +773,7 @@ def collections(
|
||||
known_document_keywords=known_document_keywords,
|
||||
has_embeddings=has_embeddings,
|
||||
embedding_function=embedding_function,
|
||||
collection_config=collection_config
|
||||
collection_config=collection_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -421,7 +799,9 @@ def metadata(
|
||||
del metadata[key] # type: ignore
|
||||
# Finally, add in some of the known keys for the collection
|
||||
sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = {
|
||||
k: st.just(v) for k, v in collection.known_metadata_keys.items()
|
||||
k: st.just(v)
|
||||
for k, v in collection.known_metadata_keys.items()
|
||||
if isinstance(v, (str, int, float))
|
||||
}
|
||||
metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore
|
||||
# We don't allow submitting empty metadata
|
||||
|
||||
@@ -31,7 +31,7 @@ collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="co
|
||||
normal=hypothesis.settings(max_examples=500),
|
||||
fast=hypothesis.settings(max_examples=200),
|
||||
),
|
||||
max_examples=2
|
||||
max_examples=2,
|
||||
)
|
||||
def test_add_miniscule(
|
||||
client: ClientAPI,
|
||||
@@ -332,7 +332,8 @@ def test_out_of_order_ids(client: ClientAPI) -> None:
|
||||
]
|
||||
|
||||
coll = client.create_collection(
|
||||
"test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore
|
||||
"test",
|
||||
embedding_function=lambda input: [[1, 2, 3] for _ in input], # type: ignore
|
||||
)
|
||||
embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids]
|
||||
coll.add(ids=ooo_ids, embeddings=embeddings)
|
||||
@@ -369,3 +370,155 @@ def test_add_partial(client: ClientAPI) -> None:
|
||||
assert results["ids"] == ["1", "2", "3"]
|
||||
assert results["metadatas"] == [{"a": 1}, None, {"a": 3}]
|
||||
assert results["documents"] == ["a", "b", None]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
NOT_CLUSTER_ONLY,
|
||||
reason="GroupBy is only supported in distributed mode",
|
||||
)
|
||||
def test_search_group_by(client: ClientAPI) -> None:
|
||||
"""Test GroupBy with single key, multiple keys, and multiple ranking keys."""
|
||||
from chromadb.execution.expression.operator import GroupBy, MinK, Key
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression import Knn
|
||||
|
||||
create_isolated_database(client)
|
||||
|
||||
coll = client.create_collection(name="test_group_by")
|
||||
|
||||
# Test data: 12 records across 3 categories and 2 years
|
||||
# Embeddings are designed so science docs are closest to query [1,0,0,0]
|
||||
ids = [
|
||||
"sci_2023_1",
|
||||
"sci_2023_2",
|
||||
"sci_2024_1",
|
||||
"sci_2024_2",
|
||||
"tech_2023_1",
|
||||
"tech_2023_2",
|
||||
"tech_2024_1",
|
||||
"tech_2024_2",
|
||||
"arts_2023_1",
|
||||
"arts_2023_2",
|
||||
"arts_2024_1",
|
||||
"arts_2024_2",
|
||||
]
|
||||
embeddings = cast(
|
||||
Embeddings,
|
||||
[
|
||||
# Science - closest to [1,0,0,0]
|
||||
[1.0, 0.0, 0.0, 0.0], # sci_2023_1: score ~0.0
|
||||
[0.9, 0.1, 0.0, 0.0], # sci_2023_2: score ~0.141
|
||||
[0.8, 0.2, 0.0, 0.0], # sci_2024_1: score ~0.283
|
||||
[0.7, 0.3, 0.0, 0.0], # sci_2024_2: score ~0.424
|
||||
# Tech - farther from [1,0,0,0]
|
||||
[0.0, 1.0, 0.0, 0.0], # tech_2023_1: score ~1.414
|
||||
[0.0, 0.9, 0.1, 0.0], # tech_2023_2: score ~1.345
|
||||
[0.0, 0.8, 0.2, 0.0], # tech_2024_1: score ~1.281
|
||||
[0.0, 0.7, 0.3, 0.0], # tech_2024_2: score ~1.221
|
||||
# Arts - farther from [1,0,0,0]
|
||||
[0.0, 0.0, 1.0, 0.0], # arts_2023_1: score ~1.414
|
||||
[0.0, 0.0, 0.9, 0.1], # arts_2023_2: score ~1.345
|
||||
[0.0, 0.0, 0.8, 0.2], # arts_2024_1: score ~1.281
|
||||
[0.0, 0.0, 0.7, 0.3], # arts_2024_2: score ~1.221
|
||||
],
|
||||
)
|
||||
metadatas: Metadatas = [
|
||||
{"category": "science", "year": 2023, "priority": 1},
|
||||
{"category": "science", "year": 2023, "priority": 2},
|
||||
{"category": "science", "year": 2024, "priority": 1},
|
||||
{"category": "science", "year": 2024, "priority": 3},
|
||||
{"category": "tech", "year": 2023, "priority": 2},
|
||||
{"category": "tech", "year": 2023, "priority": 1},
|
||||
{"category": "tech", "year": 2024, "priority": 1},
|
||||
{"category": "tech", "year": 2024, "priority": 2},
|
||||
{"category": "arts", "year": 2023, "priority": 3},
|
||||
{"category": "arts", "year": 2023, "priority": 1},
|
||||
{"category": "arts", "year": 2024, "priority": 2},
|
||||
{"category": "arts", "year": 2024, "priority": 1},
|
||||
]
|
||||
documents = [f"doc_{id}" for id in ids]
|
||||
|
||||
coll.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
query = [1.0, 0.0, 0.0, 0.0]
|
||||
|
||||
# Test 1: Single key grouping - top 2 per category by score
|
||||
# Expected: 2 best from each category (science, tech, arts)
|
||||
# - science: sci_2023_1 (0.0), sci_2023_2 (0.141)
|
||||
# - tech: tech_2024_2 (1.221), tech_2024_1 (1.281)
|
||||
# - arts: arts_2024_2 (1.221), arts_2024_1 (1.281)
|
||||
results1 = coll.search(
|
||||
Search()
|
||||
.rank(Knn(query=query, limit=12))
|
||||
.group_by(GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=2)))
|
||||
.limit(12)
|
||||
)
|
||||
assert results1["ids"] is not None
|
||||
result1_ids = results1["ids"][0]
|
||||
assert len(result1_ids) == 6
|
||||
expected1 = {
|
||||
"sci_2023_1",
|
||||
"sci_2023_2",
|
||||
"tech_2024_2",
|
||||
"tech_2024_1",
|
||||
"arts_2024_2",
|
||||
"arts_2024_1",
|
||||
}
|
||||
assert set(result1_ids) == expected1
|
||||
|
||||
# Test 2: Multiple key grouping - top 1 per (category, year) combination
|
||||
# 6 groups: (science,2023), (science,2024), (tech,2023), (tech,2024), (arts,2023), (arts,2024)
|
||||
results2 = coll.search(
|
||||
Search()
|
||||
.rank(Knn(query=query, limit=12))
|
||||
.group_by(
|
||||
GroupBy(
|
||||
keys=[Key("category"), Key("year")],
|
||||
aggregate=MinK(keys=Key.SCORE, k=1),
|
||||
)
|
||||
)
|
||||
.limit(12)
|
||||
)
|
||||
assert results2["ids"] is not None
|
||||
result2_ids = results2["ids"][0]
|
||||
assert len(result2_ids) == 6
|
||||
expected2 = {
|
||||
"sci_2023_1",
|
||||
"sci_2024_1",
|
||||
"tech_2023_2",
|
||||
"tech_2024_2",
|
||||
"arts_2023_2",
|
||||
"arts_2024_2",
|
||||
}
|
||||
assert set(result2_ids) == expected2
|
||||
|
||||
# Test 3: Multiple ranking keys - priority first, then score as tiebreaker
|
||||
# Top 2 per category, sorted by priority (ascending), then score (ascending)
|
||||
results3 = coll.search(
|
||||
Search()
|
||||
.rank(Knn(query=query, limit=12))
|
||||
.group_by(
|
||||
GroupBy(
|
||||
keys=Key("category"),
|
||||
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=2),
|
||||
)
|
||||
)
|
||||
.limit(12)
|
||||
)
|
||||
assert results3["ids"] is not None
|
||||
result3_ids = results3["ids"][0]
|
||||
assert len(result3_ids) == 6
|
||||
expected3 = {
|
||||
"sci_2023_1",
|
||||
"sci_2024_1",
|
||||
"tech_2024_1",
|
||||
"tech_2023_2",
|
||||
"arts_2024_2",
|
||||
"arts_2023_2",
|
||||
}
|
||||
assert set(result3_ids) == expected3
|
||||
|
||||
@@ -1983,7 +1983,9 @@ def test_sparse_vector_in_metadata_validation():
|
||||
with pytest.raises(ValueError, match="SparseVector values must be numbers"):
|
||||
invalid_metadata_4 = {
|
||||
"text": "non-numeric value",
|
||||
"sparse_embedding": SparseVector(indices=[0, 1], values=[0.1, "not_a_number"]), # type: ignore
|
||||
"sparse_embedding": SparseVector(
|
||||
indices=[0, 1], values=[0.1, "not_a_number"]
|
||||
), # type: ignore
|
||||
}
|
||||
|
||||
# Test 7: Multiple sparse vectors in metadata
|
||||
@@ -2683,6 +2685,59 @@ def test_rrf_to_dict() -> None:
|
||||
print("All RRF tests passed!")
|
||||
|
||||
|
||||
def test_group_by_serialization() -> None:
|
||||
"""Test GroupBy, MinK, and MaxK serialization and deserialization."""
|
||||
import pytest
|
||||
from chromadb.execution.expression.operator import (
|
||||
GroupBy,
|
||||
MinK,
|
||||
MaxK,
|
||||
Key,
|
||||
Aggregate,
|
||||
)
|
||||
|
||||
# to_dict with OneOrMany keys
|
||||
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
|
||||
assert group_by.to_dict() == {
|
||||
"keys": ["category"],
|
||||
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
|
||||
}
|
||||
|
||||
# to_dict with multiple keys and MaxK
|
||||
group_by = GroupBy(
|
||||
keys=[Key("year"), Key("category")],
|
||||
aggregate=MaxK(keys=[Key.SCORE, Key("priority")], k=5),
|
||||
)
|
||||
assert group_by.to_dict() == {
|
||||
"keys": ["year", "category"],
|
||||
"aggregate": {"$max_k": {"keys": ["#score", "priority"], "k": 5}},
|
||||
}
|
||||
|
||||
# Round-trip
|
||||
original = GroupBy(keys=[Key("category")], aggregate=MinK(keys=[Key.SCORE], k=3))
|
||||
assert GroupBy.from_dict(original.to_dict()).to_dict() == original.to_dict()
|
||||
|
||||
# Empty GroupBy serializes to {} and from_dict({}) returns default GroupBy
|
||||
empty_group_by = GroupBy()
|
||||
assert empty_group_by.to_dict() == {}
|
||||
assert GroupBy.from_dict({}).to_dict() == {}
|
||||
|
||||
# Error cases
|
||||
with pytest.raises(ValueError, match="requires 'keys' field"):
|
||||
GroupBy.from_dict({"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}})
|
||||
|
||||
with pytest.raises(ValueError, match="requires 'aggregate' field"):
|
||||
GroupBy.from_dict({"keys": ["category"]})
|
||||
|
||||
with pytest.raises(ValueError, match="keys cannot be empty"):
|
||||
GroupBy.from_dict(
|
||||
{"keys": [], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown aggregate operator"):
|
||||
Aggregate.from_dict({"$unknown": {"keys": ["#score"], "k": 3}})
|
||||
|
||||
|
||||
# Expression API Tests - Testing dict support and from_dict methods
|
||||
class TestSearchDictSupport:
|
||||
"""Test Search class dict input support."""
|
||||
@@ -2786,6 +2841,49 @@ class TestSearchDictSupport:
|
||||
with pytest.raises(TypeError, match="select must be"):
|
||||
Search(select=123)
|
||||
|
||||
def test_search_with_group_by(self):
|
||||
"""Test Search accepts group_by as dict, object, and builder method."""
|
||||
import pytest
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression.operator import GroupBy, MinK, Key
|
||||
|
||||
# Dict input
|
||||
search = Search(
|
||||
group_by={
|
||||
"keys": ["category"],
|
||||
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
|
||||
}
|
||||
)
|
||||
assert isinstance(search._group_by, GroupBy)
|
||||
|
||||
# Object input and builder method
|
||||
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
|
||||
assert Search(group_by=group_by)._group_by is group_by
|
||||
assert Search().group_by(group_by)._group_by.aggregate is not None
|
||||
|
||||
# Invalid inputs
|
||||
with pytest.raises(TypeError, match="group_by must be"):
|
||||
Search(group_by="invalid")
|
||||
with pytest.raises(ValueError, match="requires 'aggregate' field"):
|
||||
Search(group_by={"keys": ["category"]})
|
||||
|
||||
def test_search_group_by_serialization(self):
|
||||
"""Test Search serializes group_by correctly."""
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression.operator import GroupBy, MinK, Key, Knn
|
||||
|
||||
# Without group_by - empty dict
|
||||
search = Search().rank(Knn(query=[0.1, 0.2])).limit(10)
|
||||
assert search.to_dict()["group_by"] == {}
|
||||
|
||||
# With group_by - has keys and aggregate
|
||||
search = Search().group_by(
|
||||
GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
|
||||
)
|
||||
result = search.to_dict()["group_by"]
|
||||
assert result["keys"] == ["category"]
|
||||
assert result["aggregate"] == {"$min_k": {"keys": ["#score"], "k": 3}}
|
||||
|
||||
|
||||
class TestWhereFromDict:
|
||||
"""Test Where.from_dict() conversion."""
|
||||
@@ -3310,3 +3408,27 @@ class TestRoundTripConversion:
|
||||
return d1 == d2
|
||||
|
||||
assert compare_search_dicts(new_dict, search_dict)
|
||||
|
||||
def test_search_round_trip_with_group_by(self):
|
||||
"""Test Search round-trip with group_by."""
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression.operator import Key, GroupBy, MinK
|
||||
|
||||
original = Search(
|
||||
where=Key("status") == "active",
|
||||
group_by=GroupBy(
|
||||
keys=[Key("category")],
|
||||
aggregate=MinK(keys=[Key.SCORE], k=3),
|
||||
),
|
||||
)
|
||||
|
||||
# Verify group_by round-trip
|
||||
search_dict = original.to_dict()
|
||||
assert search_dict["group_by"]["keys"] == ["category"]
|
||||
assert search_dict["group_by"]["aggregate"] == {
|
||||
"$min_k": {"keys": ["#score"], "k": 3}
|
||||
}
|
||||
|
||||
# Reconstruct and compare group_by
|
||||
restored = Search(group_by=GroupBy.from_dict(search_dict["group_by"]))
|
||||
assert restored.to_dict()["group_by"] == search_dict["group_by"]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, Generator, cast
|
||||
from unittest.mock import patch
|
||||
from typing import Any, Callable, Generator, cast, Dict, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.config import Settings, System
|
||||
from chromadb.api import ClientAPI
|
||||
import chromadb.server.fastapi
|
||||
from chromadb.api.fastapi import FastAPI
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
@@ -110,3 +111,43 @@ def test_http_client_with_inconsistent_port_settings(
|
||||
str(e)
|
||||
== "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]"
|
||||
)
|
||||
|
||||
|
||||
def make_sync_client_factory() -> Tuple[Callable[..., Any], Dict[str, Any]]:
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
# takes any positional args to match httpx.Client
|
||||
def factory(*_: Any, **kwargs: Any) -> Any:
|
||||
captured.update(kwargs)
|
||||
session = MagicMock()
|
||||
session.headers = {}
|
||||
return session
|
||||
|
||||
return factory, captured
|
||||
|
||||
|
||||
def test_fastapi_uses_http_limits_from_settings() -> None:
|
||||
settings = Settings(
|
||||
chroma_api_impl="chromadb.api.fastapi.FastAPI",
|
||||
chroma_server_host="localhost",
|
||||
chroma_server_http_port=9000,
|
||||
chroma_server_ssl_verify=True,
|
||||
chroma_http_keepalive_secs=12.5,
|
||||
chroma_http_max_connections=64,
|
||||
chroma_http_max_keepalive_connections=16,
|
||||
)
|
||||
system = System(settings)
|
||||
|
||||
factory, captured = make_sync_client_factory()
|
||||
|
||||
with patch.object(FastAPI, "require", side_effect=[MagicMock(), MagicMock()]):
|
||||
with patch("chromadb.api.fastapi.httpx.Client", side_effect=factory):
|
||||
api = FastAPI(system)
|
||||
|
||||
api.stop()
|
||||
limits = captured["limits"]
|
||||
assert limits.keepalive_expiry == 12.5
|
||||
assert limits.max_connections == 64
|
||||
assert limits.max_keepalive_connections == 16
|
||||
assert captured["timeout"] is None
|
||||
assert captured["verify"] is True
|
||||
|
||||
@@ -189,3 +189,21 @@ def test_runtime_dependencies() -> None:
|
||||
assert data.starts == ["D", "C"]
|
||||
system.stop()
|
||||
assert data.stops == ["C", "D"]
|
||||
|
||||
|
||||
def test_http_client_setting_defaults() -> None:
|
||||
settings = Settings()
|
||||
assert settings.chroma_http_keepalive_secs == 40.0
|
||||
assert settings.chroma_http_max_connections is None
|
||||
assert settings.chroma_http_max_keepalive_connections is None
|
||||
|
||||
|
||||
def test_http_client_setting_overrides() -> None:
|
||||
settings = Settings(
|
||||
chroma_http_keepalive_secs=5.5,
|
||||
chroma_http_max_connections=123,
|
||||
chroma_http_max_keepalive_connections=17,
|
||||
)
|
||||
assert settings.chroma_http_keepalive_secs == 5.5
|
||||
assert settings.chroma_http_max_connections == 123
|
||||
assert settings.chroma_http_max_keepalive_connections == 17
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from typing import List, Any, Callable
|
||||
from typing import List, Any, Callable, Dict
|
||||
from jsonschema import ValidationError
|
||||
from unittest.mock import MagicMock, create_autospec
|
||||
from chromadb.utils.embedding_functions.schemas import (
|
||||
@@ -7,7 +7,10 @@ from chromadb.utils.embedding_functions.schemas import (
|
||||
load_schema,
|
||||
get_available_schemas,
|
||||
)
|
||||
from chromadb.utils.embedding_functions import known_embedding_functions
|
||||
from chromadb.utils.embedding_functions import (
|
||||
known_embedding_functions,
|
||||
sparse_known_embedding_functions,
|
||||
)
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
@@ -143,3 +146,306 @@ class TestEmbeddingFunctionSchemas:
|
||||
if schema.get("additionalProperties", True) is False:
|
||||
with pytest.raises(ValidationError):
|
||||
validate_config_schema(test_config, schema_name)
|
||||
|
||||
def _create_valid_config_from_schema(
|
||||
self, schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a valid config from a schema by filling in required fields"""
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
if "required" in schema and "properties" in schema:
|
||||
for field in schema["required"]:
|
||||
if field in schema["properties"]:
|
||||
field_schema = schema["properties"][field]
|
||||
config[field] = self._get_value_from_field_schema(field_schema)
|
||||
|
||||
return config
|
||||
|
||||
def _get_value_from_field_schema(self, field_schema: Dict[str, Any]) -> Any:
|
||||
"""Get a valid value from a field schema"""
|
||||
# Handle enums - use first enum value
|
||||
if "enum" in field_schema:
|
||||
return field_schema["enum"][0]
|
||||
|
||||
# Handle type (could be a list or single value)
|
||||
field_type = field_schema.get("type")
|
||||
if field_type is None:
|
||||
return "dummy" # Fallback if no type specified
|
||||
|
||||
if isinstance(field_type, list):
|
||||
# If null is in the type list, prefer non-null type
|
||||
non_null_types = [t for t in field_type if t != "null"]
|
||||
field_type = non_null_types[0] if non_null_types else field_type[0]
|
||||
|
||||
if field_type == "object":
|
||||
# Handle nested objects
|
||||
nested_config = {}
|
||||
if "properties" in field_schema:
|
||||
nested_required = field_schema.get("required", [])
|
||||
for prop in nested_required:
|
||||
if prop in field_schema["properties"]:
|
||||
nested_config[prop] = self._get_value_from_field_schema(
|
||||
field_schema["properties"][prop]
|
||||
)
|
||||
return nested_config if nested_config else {}
|
||||
|
||||
if field_type == "array":
|
||||
# Return empty array for arrays
|
||||
return []
|
||||
|
||||
# Use the existing dummy value method for primitive types
|
||||
return self._get_dummy_value(field_type)
|
||||
|
||||
def _has_custom_validation(self, ef_class: Any) -> bool:
|
||||
"""Check if validate_config actually validates (not just base implementation)"""
|
||||
try:
|
||||
# Try with an obviously invalid config - if it doesn't raise, it's base implementation
|
||||
invalid_config = {"__invalid_test_config__": True}
|
||||
try:
|
||||
ef_class.validate_config(invalid_config)
|
||||
# If we get here without exception, it's using base implementation
|
||||
return False
|
||||
except (ValidationError, ValueError, FileNotFoundError):
|
||||
# If it raises any validation-related error, it's actually validating
|
||||
return True
|
||||
except Exception:
|
||||
# Any other exception means it's trying to validate (e.g., schema not found)
|
||||
return True
|
||||
|
||||
def _setup_env_vars_for_ef(
|
||||
self, ef_name: str, mock_common_deps: MonkeyPatch
|
||||
) -> None:
|
||||
"""Set up environment variables needed for embedding function instantiation"""
|
||||
# Map of embedding function names to their default API key environment variable names
|
||||
api_key_env_vars = {
|
||||
"cohere": "CHROMA_COHERE_API_KEY",
|
||||
"openai": "CHROMA_OPENAI_API_KEY",
|
||||
"huggingface": "CHROMA_HUGGINGFACE_API_KEY",
|
||||
"huggingface_server": "CHROMA_HUGGINGFACE_API_KEY",
|
||||
"google_palm": "CHROMA_GOOGLE_PALM_API_KEY",
|
||||
"google_generative_ai": "CHROMA_GOOGLE_GENAI_API_KEY",
|
||||
"google_vertex": "CHROMA_GOOGLE_VERTEX_API_KEY",
|
||||
"jina": "CHROMA_JINA_API_KEY",
|
||||
"mistral": "MISTRAL_API_KEY",
|
||||
"morph": "MORPH_API_KEY",
|
||||
"voyageai": "CHROMA_VOYAGE_API_KEY",
|
||||
"cloudflare_workers_ai": "CHROMA_CLOUDFLARE_API_KEY",
|
||||
"together_ai": "CHROMA_TOGETHER_AI_API_KEY",
|
||||
"baseten": "CHROMA_BASETEN_API_KEY",
|
||||
"roboflow": "CHROMA_ROBOFLOW_API_KEY",
|
||||
"amazon_bedrock": "AWS_ACCESS_KEY_ID", # AWS uses different env vars
|
||||
"chroma-cloud-qwen": "CHROMA_API_KEY",
|
||||
# Sparse embedding functions
|
||||
"chroma-cloud-splade": "CHROMA_API_KEY",
|
||||
}
|
||||
|
||||
# Set API key environment variable if needed
|
||||
if ef_name in api_key_env_vars:
|
||||
mock_common_deps.setenv(api_key_env_vars[ef_name], "test-api-key")
|
||||
|
||||
# Special cases that need additional environment variables
|
||||
if ef_name == "amazon_bedrock":
|
||||
mock_common_deps.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
|
||||
mock_common_deps.setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
def _create_ef_instance(
|
||||
self, ef_name: str, ef_class: Any, mock_common_deps: MonkeyPatch
|
||||
) -> Any:
|
||||
"""Create an embedding function instance, handling special cases"""
|
||||
# Set up environment variables first
|
||||
self._setup_env_vars_for_ef(ef_name, mock_common_deps)
|
||||
|
||||
# Mock missing modules that are imported inside __init__ methods
|
||||
import sys
|
||||
|
||||
# Create mock modules
|
||||
mock_pil = MagicMock()
|
||||
mock_pil_image = MagicMock()
|
||||
mock_google_genai = MagicMock()
|
||||
mock_vertexai = MagicMock()
|
||||
mock_vertexai_lm = MagicMock()
|
||||
mock_boto3 = MagicMock()
|
||||
mock_jina = MagicMock()
|
||||
mock_mistralai = MagicMock()
|
||||
|
||||
# Mock boto3.Session for amazon_bedrock
|
||||
mock_boto3_session = MagicMock()
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.region_name = "us-east-1"
|
||||
mock_session_instance.profile_name = None
|
||||
mock_session_instance.client.return_value = MagicMock()
|
||||
mock_boto3_session.return_value = mock_session_instance
|
||||
mock_boto3.Session = mock_boto3_session
|
||||
|
||||
# Mock vertexai.init and TextEmbeddingModel
|
||||
mock_text_embedding_model = MagicMock()
|
||||
mock_text_embedding_model.from_pretrained.return_value = MagicMock()
|
||||
mock_vertexai_lm.TextEmbeddingModel = mock_text_embedding_model
|
||||
mock_vertexai.language_models = mock_vertexai_lm
|
||||
mock_vertexai.init = MagicMock()
|
||||
|
||||
# Mock google.generativeai - need to set up google module first
|
||||
mock_google = MagicMock()
|
||||
mock_google_genai.configure = MagicMock() # For palm.configure()
|
||||
mock_google_genai.GenerativeModel = MagicMock(return_value=MagicMock())
|
||||
mock_google.generativeai = mock_google_genai
|
||||
|
||||
# Mock jina Client
|
||||
mock_jina.Client = MagicMock()
|
||||
|
||||
# Mock mistralai
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_mistral_client.return_value.embeddings.create.return_value.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3])
|
||||
]
|
||||
mock_mistralai.Mistral = mock_mistral_client
|
||||
|
||||
# Add missing modules to sys.modules using monkeypatch
|
||||
modules_to_mock = {
|
||||
"PIL": mock_pil,
|
||||
"PIL.Image": mock_pil_image,
|
||||
"google": mock_google,
|
||||
"google.generativeai": mock_google_genai,
|
||||
"vertexai": mock_vertexai,
|
||||
"vertexai.language_models": mock_vertexai_lm,
|
||||
"boto3": mock_boto3,
|
||||
"jina": mock_jina,
|
||||
"mistralai": mock_mistralai,
|
||||
}
|
||||
|
||||
for module_name, mock_module in modules_to_mock.items():
|
||||
mock_common_deps.setitem(sys.modules, module_name, mock_module)
|
||||
|
||||
# Special cases that need additional arguments
|
||||
if ef_name == "cloudflare_workers_ai":
|
||||
return ef_class(
|
||||
model_name="test-model",
|
||||
account_id="test-account-id",
|
||||
)
|
||||
elif ef_name == "baseten":
|
||||
# Baseten needs api_key explicitly passed even with env var
|
||||
return ef_class(
|
||||
api_key="test-api-key",
|
||||
api_base="https://test.api.baseten.co",
|
||||
)
|
||||
elif ef_name == "amazon_bedrock":
|
||||
# Amazon Bedrock needs a boto3 session - create a mock session
|
||||
# boto3 is already mocked in sys.modules above
|
||||
mock_session = mock_boto3.Session(region_name="us-east-1")
|
||||
return ef_class(
|
||||
session=mock_session,
|
||||
model_name="amazon.titan-embed-text-v1",
|
||||
)
|
||||
elif ef_name == "huggingface_server":
|
||||
return ef_class(url="http://localhost:8080")
|
||||
elif ef_name == "google_vertex":
|
||||
return ef_class(project_id="test-project", region="us-central1")
|
||||
elif ef_name == "mistral":
|
||||
return ef_class(model="mistral-embed")
|
||||
elif ef_name == "roboflow":
|
||||
return ef_class() # No model_name needed
|
||||
elif ef_name == "chroma-cloud-qwen":
|
||||
from chromadb.utils.embedding_functions.chroma_cloud_qwen_embedding_function import (
|
||||
ChromaCloudQwenEmbeddingModel,
|
||||
)
|
||||
|
||||
return ef_class(
|
||||
model=ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
|
||||
task="nl_to_code",
|
||||
)
|
||||
else:
|
||||
# Try with no args first
|
||||
try:
|
||||
return ef_class()
|
||||
except Exception:
|
||||
# If that fails, try with common minimal args
|
||||
return ef_class(model_name="test-model")
|
||||
|
||||
@pytest.mark.parametrize("ef_name", get_embedding_function_names())
|
||||
def test_validate_config_with_schema(
|
||||
self,
|
||||
ef_name: str,
|
||||
mock_embeddings: Callable[[Documents], Embeddings],
|
||||
mock_common_deps: MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test that validate_config works correctly with actual configs from embedding functions"""
|
||||
ef_class = known_embedding_functions[ef_name]
|
||||
|
||||
# Skip if the embedding function doesn't have a validate_config method
|
||||
if not hasattr(ef_class, "validate_config"):
|
||||
pytest.skip(f"{ef_name} does not have validate_config method")
|
||||
|
||||
# Check if it's callable (static methods are callable on the class)
|
||||
if not callable(getattr(ef_class, "validate_config", None)):
|
||||
pytest.skip(f"{ef_name} validate_config is not callable")
|
||||
|
||||
# Skip if using base implementation (doesn't actually validate)
|
||||
if not self._has_custom_validation(ef_class):
|
||||
pytest.skip(
|
||||
f"{ef_name} uses base validate_config implementation (no validation)"
|
||||
)
|
||||
|
||||
# Create a real instance to get the actual config
|
||||
# We'll mock __call__ to avoid needing to actually generate embeddings
|
||||
try:
|
||||
ef_instance = self._create_ef_instance(ef_name, ef_class, mock_common_deps)
|
||||
except Exception as e:
|
||||
pytest.skip(
|
||||
f"{ef_name} requires arguments that we cannot provide without external deps: {e}"
|
||||
)
|
||||
|
||||
# Mock only __call__ to avoid needing to actually generate embeddings
|
||||
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
|
||||
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
|
||||
|
||||
# Get the actual config from the embedding function (this uses the real get_config method)
|
||||
config = ef_instance.get_config()
|
||||
|
||||
# Filter out None values - optional fields with None shouldn't be included in validation
|
||||
# This matches common JSON schema practice where optional fields are omitted rather than null
|
||||
config = {k: v for k, v in config.items() if v is not None}
|
||||
|
||||
# Validate the actual config using the embedding function's validate_config method
|
||||
ef_class.validate_config(config)
|
||||
|
||||
def test_validate_config_sparse_embedding_functions(
|
||||
self,
|
||||
mock_embeddings: Callable[[Documents], Embeddings],
|
||||
mock_common_deps: MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test validate_config for sparse embedding functions with actual configs"""
|
||||
for ef_name, ef_class in sparse_known_embedding_functions.items():
|
||||
# Skip if the embedding function doesn't have a validate_config method
|
||||
if not hasattr(ef_class, "validate_config"):
|
||||
continue
|
||||
|
||||
# Check if it's callable (static methods are callable on the class)
|
||||
if not callable(getattr(ef_class, "validate_config", None)):
|
||||
continue
|
||||
|
||||
# Skip if using base implementation (doesn't actually validate)
|
||||
if not self._has_custom_validation(ef_class):
|
||||
continue
|
||||
|
||||
# Create a real instance to get the actual config
|
||||
# We'll mock __call__ to avoid needing to actually generate embeddings
|
||||
try:
|
||||
ef_instance = self._create_ef_instance(
|
||||
ef_name, ef_class, mock_common_deps
|
||||
)
|
||||
except Exception:
|
||||
continue # Skip if we can't create instance
|
||||
|
||||
# Mock only __call__ to avoid needing to actually generate embeddings
|
||||
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
|
||||
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
|
||||
|
||||
# Get the actual config from the embedding function (this uses the real get_config method)
|
||||
config = ef_instance.get_config()
|
||||
|
||||
# Filter out None values - optional fields with None shouldn't be included in validation
|
||||
# This matches common JSON schema practice where optional fields are omitted rather than null
|
||||
config = {k: v for k, v in config.items() if v is not None}
|
||||
|
||||
# Validate the actual config using the embedding function's validate_config method
|
||||
ef_class.validate_config(config)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -23,6 +23,7 @@ from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
GoogleGenaiEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
@@ -68,6 +69,10 @@ from chromadb.utils.embedding_functions.mistral_embedding_function import (
|
||||
from chromadb.utils.embedding_functions.morph_embedding_function import (
|
||||
MorphEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.nomic_embedding_function import (
|
||||
NomicEmbeddingFunction,
|
||||
NomicQueryConfig,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_sparse_embedding_function import (
|
||||
HuggingFaceSparseEmbeddingFunction,
|
||||
)
|
||||
@@ -98,11 +103,13 @@ _all_classes: Set[str] = {
|
||||
"GooglePalmEmbeddingFunction",
|
||||
"GoogleGenerativeAiEmbeddingFunction",
|
||||
"GoogleVertexEmbeddingFunction",
|
||||
"GoogleGenaiEmbeddingFunction",
|
||||
"OllamaEmbeddingFunction",
|
||||
"InstructorEmbeddingFunction",
|
||||
"JinaEmbeddingFunction",
|
||||
"MistralEmbeddingFunction",
|
||||
"MorphEmbeddingFunction",
|
||||
"NomicEmbeddingFunction",
|
||||
"VoyageAIEmbeddingFunction",
|
||||
"ONNXMiniLM_L6_V2",
|
||||
"OpenCLIPEmbeddingFunction",
|
||||
@@ -137,11 +144,13 @@ known_embedding_functions: Dict[str, Type[EmbeddingFunction]] = { # type: ignor
|
||||
"google_palm": GooglePalmEmbeddingFunction,
|
||||
"google_generative_ai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google_vertex": GoogleVertexEmbeddingFunction,
|
||||
"google_genai": GoogleGenaiEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"mistral": MistralEmbeddingFunction,
|
||||
"morph": MorphEmbeddingFunction,
|
||||
"nomic": NomicEmbeddingFunction,
|
||||
"voyageai": VoyageAIEmbeddingFunction,
|
||||
"onnx_mini_lm_l6_v2": ONNXMiniLM_L6_V2,
|
||||
"open_clip": OpenCLIPEmbeddingFunction,
|
||||
@@ -259,12 +268,15 @@ __all__ = [
|
||||
"GooglePalmEmbeddingFunction",
|
||||
"GoogleGenerativeAiEmbeddingFunction",
|
||||
"GoogleVertexEmbeddingFunction",
|
||||
"GoogleGenaiEmbeddingFunction",
|
||||
"OllamaEmbeddingFunction",
|
||||
"InstructorEmbeddingFunction",
|
||||
"JinaEmbeddingFunction",
|
||||
"JinaQueryConfig",
|
||||
"MistralEmbeddingFunction",
|
||||
"MorphEmbeddingFunction",
|
||||
"NomicEmbeddingFunction",
|
||||
"NomicQueryConfig",
|
||||
"VoyageAIEmbeddingFunction",
|
||||
"ONNXMiniLM_L6_V2",
|
||||
"OpenCLIPEmbeddingFunction",
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,6 +2,7 @@ import os
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import Dict, Any, Optional, List
|
||||
from chromadb.api.types import Space
|
||||
import warnings
|
||||
@@ -35,12 +36,16 @@ class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
if os.getenv("BASETEN_API_KEY") is not None:
|
||||
self.api_key_env_var = "BASETEN_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
# Prioritize api_key argument, then environment variable
|
||||
resolved_api_key = api_key or os.getenv(api_key_env_var)
|
||||
resolved_api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not resolved_api_key:
|
||||
raise ValueError(
|
||||
f"API key not provided and {api_key_env_var} environment variable is not set."
|
||||
f"API key not provided and {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
self.api_key = resolved_api_key
|
||||
if not api_base:
|
||||
@@ -96,3 +101,16 @@ class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
|
||||
api_base=api_base,
|
||||
api_key_env_var=api_key_env_var,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "baseten")
|
||||
|
||||
@@ -231,4 +231,4 @@ class Bm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "bm25")
|
||||
validate_config_schema(config, "bm25")
|
||||
@@ -23,12 +23,32 @@ DEFAULT_TOKEN_MAX_LENGTH = 40
|
||||
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)
|
||||
|
||||
|
||||
class _HashedToken:
|
||||
__slots__ = ("hash", "label")
|
||||
|
||||
def __init__(self, hash: int, label: Optional[str]):
|
||||
self.hash = hash
|
||||
self.label = label
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.hash
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, _HashedToken):
|
||||
return NotImplemented
|
||||
return self.hash == other.hash
|
||||
|
||||
def __lt__(self, other: "_HashedToken") -> bool:
|
||||
return self.hash < other.hash
|
||||
|
||||
|
||||
class ChromaBm25Config(TypedDict, total=False):
|
||||
k: float
|
||||
b: float
|
||||
avg_doc_length: float
|
||||
token_max_length: int
|
||||
stopwords: List[str]
|
||||
include_tokens: bool
|
||||
|
||||
|
||||
class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
@@ -39,6 +59,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
|
||||
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
|
||||
stopwords: Optional[Iterable[str]] = None,
|
||||
include_tokens: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the BM25 sparse embedding function."""
|
||||
|
||||
@@ -46,38 +67,51 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
self.b = float(b)
|
||||
self.avg_doc_length = float(avg_doc_length)
|
||||
self.token_max_length = int(token_max_length)
|
||||
self.include_tokens = bool(include_tokens)
|
||||
|
||||
if stopwords is not None:
|
||||
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
|
||||
stopword_list: Iterable[str] = self.stopwords
|
||||
self._stopword_list: Iterable[str] = self.stopwords
|
||||
else:
|
||||
self.stopwords = None
|
||||
stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
|
||||
self._stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
|
||||
|
||||
stemmer = get_english_stemmer()
|
||||
self._tokenizer = Bm25Tokenizer(stemmer, stopword_list, self.token_max_length)
|
||||
self._hasher = Murmur3AbsHasher()
|
||||
|
||||
def _encode(self, text: str) -> SparseVector:
|
||||
tokens = self._tokenizer.tokenize(text)
|
||||
stemmer = get_english_stemmer()
|
||||
tokenizer = Bm25Tokenizer(stemmer, self._stopword_list, self.token_max_length)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
|
||||
if not tokens:
|
||||
return SparseVector(indices=[], values=[])
|
||||
|
||||
doc_len = float(len(tokens))
|
||||
counts = Counter(self._hasher.hash(token) for token in tokens)
|
||||
counts = Counter(
|
||||
_HashedToken(
|
||||
self._hasher.hash(token), token if self.include_tokens else None
|
||||
)
|
||||
for token in tokens
|
||||
)
|
||||
|
||||
indices = sorted(counts.keys())
|
||||
sorted_keys = sorted(counts.keys())
|
||||
indices: List[int] = []
|
||||
values: List[float] = []
|
||||
for idx in indices:
|
||||
tf = float(counts[idx])
|
||||
labels: Optional[List[str]] = [] if self.include_tokens else None
|
||||
|
||||
for key in sorted_keys:
|
||||
tf = float(counts[key])
|
||||
denominator = tf + self.k * (
|
||||
1 - self.b + (self.b * doc_len) / self.avg_doc_length
|
||||
)
|
||||
score = tf * (self.k + 1) / denominator
|
||||
values.append(score)
|
||||
|
||||
return SparseVector(indices=indices, values=values)
|
||||
indices.append(key.hash)
|
||||
values.append(score)
|
||||
if labels is not None and key.label is not None:
|
||||
labels.append(key.label)
|
||||
|
||||
return SparseVector(indices=indices, values=values, labels=labels)
|
||||
|
||||
def __call__(self, input: Documents) -> SparseVectors:
|
||||
sparse_vectors: SparseVectors = []
|
||||
@@ -99,7 +133,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
config: Dict[str, Any],
|
||||
) -> "SparseEmbeddingFunction[Documents]":
|
||||
return ChromaBm25EmbeddingFunction(
|
||||
k=config.get("k", DEFAULT_K),
|
||||
@@ -107,6 +141,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
|
||||
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
|
||||
stopwords=config.get("stopwords"),
|
||||
include_tokens=config.get("include_tokens", False),
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
@@ -115,6 +150,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
"b": self.b,
|
||||
"avg_doc_length": self.avg_doc_length,
|
||||
"token_max_length": self.token_max_length,
|
||||
"include_tokens": self.include_tokens,
|
||||
}
|
||||
|
||||
if self.stopwords is not None:
|
||||
@@ -125,7 +161,14 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords"}
|
||||
mutable_keys = {
|
||||
"k",
|
||||
"b",
|
||||
"avg_doc_length",
|
||||
"token_max_length",
|
||||
"stopwords",
|
||||
"include_tokens",
|
||||
}
|
||||
for key in new_config:
|
||||
if key not in mutable_keys:
|
||||
raise ValueError(f"Updating '{key}' is not supported for {NAME}")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from typing import List, Dict, Any, Union
|
||||
from typing import List, Dict, Any, Union, Optional
|
||||
import os
|
||||
import numpy as np
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.utils.embedding_functions.utils import _get_shared_system_client
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -32,7 +33,7 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
model: ChromaCloudQwenEmbeddingModel,
|
||||
task: str,
|
||||
task: Optional[str],
|
||||
instructions: ChromaCloudQwenEmbeddingInstructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
|
||||
api_key_env_var: str = "CHROMA_API_KEY",
|
||||
):
|
||||
@@ -41,7 +42,8 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
|
||||
Args:
|
||||
model (ChromaCloudQwenEmbeddingModel): The specific Qwen model to use for embeddings.
|
||||
task (str): The task for which embeddings are being generated.
|
||||
task (str, optional): The task for which embeddings are being generated. If None or empty,
|
||||
empty instructions will be used for both documents and queries.
|
||||
instructions (ChromaCloudQwenEmbeddingInstructions, optional): A dictionary containing
|
||||
custom instructions to use for the specified Qwen model. Defaults to CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS.
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key.
|
||||
@@ -55,9 +57,18 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
# First, try to get API key from environment variable
|
||||
self.api_key = os.getenv(api_key_env_var)
|
||||
# If not found in env var, try to get it from existing client instances
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
SharedSystemClient = _get_shared_system_client()
|
||||
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
|
||||
# Raise error if still no API key found
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
f"API key not found in environment variable {api_key_env_var} "
|
||||
f"or in any existing client instances"
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.task = task
|
||||
@@ -102,10 +113,14 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
if not input:
|
||||
return []
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": self.instructions[self.task][
|
||||
instruction = ""
|
||||
if self.task and self.task in self.instructions:
|
||||
instruction = self.instructions[self.task][
|
||||
ChromaCloudQwenEmbeddingTarget.DOCUMENTS
|
||||
],
|
||||
]
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": instruction,
|
||||
"texts": input,
|
||||
}
|
||||
|
||||
@@ -120,10 +135,14 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
if not input:
|
||||
return []
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": self.instructions[self.task][
|
||||
instruction = ""
|
||||
if self.task and self.task in self.instructions:
|
||||
instruction = self.instructions[self.task][
|
||||
ChromaCloudQwenEmbeddingTarget.QUERY
|
||||
],
|
||||
]
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": instruction,
|
||||
"texts": input,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from chromadb.api.types import (
|
||||
SparseEmbeddingFunction,
|
||||
SparseVector,
|
||||
SparseVectors,
|
||||
Documents,
|
||||
)
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, List, Optional
|
||||
from enum import Enum
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
|
||||
from chromadb.base_types import SparseVector
|
||||
import os
|
||||
from typing import Union
|
||||
from chromadb.utils.embedding_functions.utils import _get_shared_system_client
|
||||
|
||||
|
||||
class ChromaCloudSpladeEmbeddingModel(Enum):
|
||||
@@ -21,6 +22,7 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
self,
|
||||
api_key_env_var: str = "CHROMA_API_KEY",
|
||||
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
|
||||
include_tokens: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the ChromaCloudSpladeEmbeddingFunction.
|
||||
@@ -36,12 +38,20 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
"The httpx python package is not installed. Please install it with `pip install httpx`"
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
# First, try to get API key from environment variable
|
||||
self.api_key = os.getenv(self.api_key_env_var)
|
||||
# If not found in env var, try to get it from existing client instances
|
||||
if not self.api_key:
|
||||
SharedSystemClient = _get_shared_system_client()
|
||||
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
|
||||
# Raise error if still no API key found
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
f"API key not found in environment variable {self.api_key_env_var}"
|
||||
f"API key not found in environment variable {self.api_key_env_var} "
|
||||
f"or in any existing client instances"
|
||||
)
|
||||
self.model = model
|
||||
self.include_tokens = bool(include_tokens)
|
||||
self._api_url = "https://embed.trychroma.com/embed_sparse"
|
||||
self._session = httpx.Client()
|
||||
self._session.headers.update(
|
||||
@@ -80,6 +90,7 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
"texts": list(input),
|
||||
"task": "",
|
||||
"target": "",
|
||||
"fetch_tokens": "true" if self.include_tokens is True else "false",
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -113,13 +124,17 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
if isinstance(emb, dict):
|
||||
indices = emb.get("indices", [])
|
||||
values = emb.get("values", [])
|
||||
raw_labels = emb.get("labels") if self.include_tokens else None
|
||||
labels: Optional[List[str]] = raw_labels if raw_labels else None
|
||||
else:
|
||||
# Already a SparseVector, extract its data
|
||||
assert isinstance(emb, SparseVector)
|
||||
indices = emb.indices
|
||||
values = emb.values
|
||||
labels = emb.labels if self.include_tokens else None
|
||||
|
||||
normalized_vectors.append(
|
||||
normalize_sparse_vector(indices=indices, values=values)
|
||||
normalize_sparse_vector(indices=indices, values=values, labels=labels)
|
||||
)
|
||||
|
||||
return normalized_vectors
|
||||
@@ -141,18 +156,25 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
return ChromaCloudSpladeEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var,
|
||||
model=ChromaCloudSpladeEmbeddingModel(model),
|
||||
include_tokens=config.get("include_tokens", False),
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"api_key_env_var": self.api_key_env_var, "model": self.model.value}
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model": self.model.value,
|
||||
"include_tokens": self.include_tokens,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model" in new_config:
|
||||
raise ValueError(
|
||||
"model cannot be changed after the embedding function has been initialized"
|
||||
)
|
||||
immutable_keys = {"include_tokens", "model"}
|
||||
for key in immutable_keys:
|
||||
if key in new_config and new_config[key] != old_config.get(key):
|
||||
raise ValueError(
|
||||
f"Updating '{key}' is not supported for chroma-cloud-splade"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
|
||||
@@ -53,12 +53,19 @@ class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.account_id = account_id
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
|
||||
if os.getenv("CLOUDFLARE_API_KEY") is not None:
|
||||
self.api_key_env_var = "CLOUDFLARE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
self.gateway_id = gateway_id
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
if self.gateway_id:
|
||||
self._api_url = f"{GATEWAY_BASE_URL}/{self.account_id}/{self.gateway_id}/workers-ai/{self.model_name}"
|
||||
|
||||
@@ -43,10 +43,16 @@ class CohereEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("COHERE_API_KEY") is not None:
|
||||
self.api_key_env_var = "COHERE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
@@ -7,6 +7,150 @@ from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
import warnings
|
||||
|
||||
|
||||
class GoogleGenaiEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
vertexai: Optional[bool] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
api_key_env_var: str = "GOOGLE_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the GoogleGenaiEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to use for text embeddings.
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google GenAI API.
|
||||
Defaults to "GOOGLE_API_KEY".
|
||||
"""
|
||||
try:
|
||||
import google.genai as genai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The google-genai python package is not installed. Please install it with `pip install google-genai`"
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.vertexai = vertexai
|
||||
self.project = project
|
||||
self.location = location
|
||||
self.api_key = os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.client = genai.Client(
|
||||
api_key=self.api_key, vertexai=vertexai, project=project, location=location
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
if not input:
|
||||
raise ValueError("Input documents cannot be empty")
|
||||
if not isinstance(input, (list, tuple)):
|
||||
raise ValueError("Input must be a list or tuple of documents")
|
||||
if not all(isinstance(doc, str) for doc in input):
|
||||
raise ValueError("All input documents must be strings")
|
||||
|
||||
try:
|
||||
response = self.client.models.embed_content(
|
||||
model=self.model_name, contents=input
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate embeddings: {str(e)}") from e
|
||||
|
||||
# Validate response structure
|
||||
if not hasattr(response, "embeddings") or not response.embeddings:
|
||||
raise ValueError("No embeddings returned from the API")
|
||||
|
||||
embeddings_list = []
|
||||
for ce in response.embeddings:
|
||||
if not hasattr(ce, "values"):
|
||||
raise ValueError("Malformed embedding response: missing 'values'")
|
||||
embeddings_list.append(np.array(ce.values, dtype=np.float32))
|
||||
|
||||
return cast(Embeddings, embeddings_list)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "google_genai"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
model_name = config.get("model_name")
|
||||
vertexai = config.get("vertexai")
|
||||
project = config.get("project")
|
||||
location = config.get("location")
|
||||
|
||||
if model_name is None:
|
||||
raise ValueError("The model name is required.")
|
||||
|
||||
return GoogleGenaiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
vertexai=vertexai,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"vertexai": self.vertexai,
|
||||
"project": self.project,
|
||||
"location": self.location,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "vertexai" in new_config:
|
||||
raise ValueError(
|
||||
"The vertexai cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "project" in new_config:
|
||||
raise ValueError(
|
||||
"The project cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "location" in new_config:
|
||||
raise ValueError(
|
||||
"The location cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "google_genai")
|
||||
|
||||
|
||||
class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""
|
||||
|
||||
@@ -38,10 +182,16 @@ class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("GOOGLE_API_KEY") is not None:
|
||||
self.api_key_env_var = "GOOGLE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
@@ -154,10 +304,16 @@ class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("GOOGLE_API_KEY") is not None:
|
||||
self.api_key_env_var = "GOOGLE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.task_type = task_type
|
||||
@@ -289,10 +445,16 @@ class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("GOOGLE_API_KEY") is not None:
|
||||
self.api_key_env_var = "GOOGLE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.project_id = project_id
|
||||
|
||||
@@ -40,10 +40,16 @@ class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("HUGGINGFACE_API_KEY") is not None:
|
||||
self.api_key_env_var = "HUGGINGFACE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
@@ -160,6 +166,9 @@ class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]):
|
||||
self.url = url
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
if os.getenv("HUGGINGFACE_API_KEY") is not None:
|
||||
self.api_key_env_var = "HUGGINGFACE_API_KEY"
|
||||
|
||||
if self.api_key_env_var is not None:
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
else:
|
||||
|
||||
@@ -81,10 +81,16 @@ class JinaEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("JINA_API_KEY") is not None:
|
||||
self.api_key_env_var = "JINA_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user