chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,439 @@
from typing import Dict, Optional, Union
import logging
from chromadb.api.client import Client as ClientCreator
from chromadb.api.client import (
AdminClient as AdminClientCreator,
)
from chromadb.api.async_client import AsyncClient as AsyncClientCreator
from chromadb.auth.token_authn import TokenTransportHeader
import chromadb.config
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings
from chromadb.api import AdminAPI, AsyncClientAPI, ClientAPI
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
CollectionMetadata,
UpdateMetadata,
Documents,
EmbeddingFunction,
Embeddings,
URI,
URIs,
IDs,
Include,
Metadata,
Metadatas,
Where,
QueryResult,
GetResult,
WhereDocument,
UpdateCollectionMetadata,
SparseVector,
SparseVectors,
SparseEmbeddingFunction,
Schema,
VectorIndexConfig,
HnswIndexConfig,
SpannIndexConfig,
FtsIndexConfig,
SparseVectorIndexConfig,
StringInvertedIndexConfig,
IntInvertedIndexConfig,
FloatInvertedIndexConfig,
BoolInvertedIndexConfig,
)
# Import Search API components
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import (
# Key builder for where conditions and field selection
Key,
K, # Alias for Key
# KNN-based ranking for hybrid search
Knn,
# Reciprocal Rank Fusion for combining rankings
Rrf,
)
from pathlib import Path
import os
# Re-export types from chromadb.types
__all__ = [
"Collection",
"Metadata",
"Metadatas",
"Where",
"WhereDocument",
"Documents",
"IDs",
"URI",
"URIs",
"Embeddings",
"EmbeddingFunction",
"Include",
"CollectionMetadata",
"UpdateMetadata",
"UpdateCollectionMetadata",
"QueryResult",
"GetResult",
"TokenTransportHeader",
# Search API components
"Search",
"Key",
"K",
"Knn",
"Rrf",
# Sparse Vector Types
"SparseVector",
"SparseVectors",
"SparseEmbeddingFunction",
# Schema and Index Configuration
"Schema",
"VectorIndexConfig",
"HnswIndexConfig",
"SpannIndexConfig",
"FtsIndexConfig",
"SparseVectorIndexConfig",
"StringInvertedIndexConfig",
"IntInvertedIndexConfig",
"FloatInvertedIndexConfig",
"BoolInvertedIndexConfig",
]
from chromadb.types import CloudClientArg
logger = logging.getLogger(__name__)
__settings = Settings()
__version__ = "1.3.4"
# Workaround to deal with Colab's old sqlite3 version
def is_in_colab() -> bool:
try:
import google.colab # noqa: F401
return True
except ImportError:
return False
IN_COLAB = is_in_colab()
is_client = False
try:
from chromadb.is_thin_client import is_thin_client
is_client = is_thin_client
except ImportError:
is_client = False
if not is_client:
import sqlite3
if sqlite3.sqlite_version_info < (3, 35, 0):
if IN_COLAB:
# In Colab, hotswap to pysqlite-binary if it's too old
import subprocess
import sys
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "pysqlite3-binary"]
)
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
else:
raise RuntimeError(
"\033[91mYour system has an unsupported version of sqlite3. Chroma \
requires sqlite3 >= 3.35.0.\033[0m\n"
"\033[94mPlease visit \
https://docs.trychroma.com/troubleshooting#sqlite to learn how \
to upgrade.\033[0m"
)
def configure(**kwargs) -> None: # type: ignore
"""Override Chroma's default settings, environment variables or .env files"""
global __settings
__settings = chromadb.config.Settings(**kwargs)
def get_settings() -> Settings:
return __settings
def EphemeralClient(
settings: Optional[Settings] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> ClientAPI:
"""
Creates an in-memory instance of Chroma. This is useful for testing and
development, but not recommended for production use.
Args:
tenant: The tenant to use for this client. Defaults to the default tenant.
database: The database to use for this client. Defaults to the default database.
"""
if settings is None:
settings = Settings()
settings.is_persistent = False
# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)
return ClientCreator(settings=settings, tenant=tenant, database=database)
def PersistentClient(
path: Union[str, Path] = "./chroma",
settings: Optional[Settings] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> ClientAPI:
"""
Creates a persistent instance of Chroma that saves to disk. This is useful for
testing and development, but not recommended for production use.
Args:
path: The directory to save Chroma's data to. Defaults to "./chroma".
tenant: The tenant to use for this client. Defaults to the default tenant.
database: The database to use for this client. Defaults to the default database.
"""
if settings is None:
settings = Settings()
settings.persist_directory = str(path)
settings.is_persistent = True
# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)
return ClientCreator(tenant=tenant, database=database, settings=settings)
def RustClient(
path: Optional[str] = None,
settings: Optional[Settings] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> ClientAPI:
"""
Creates an ephemeral or persistance instance of Chroma that saves to disk.
This is useful for testing and development, but not recommended for production use.
Args:
path: An optional directory to save Chroma's data to. The client is ephemeral if a None value is provided. Defaults to None.
tenant: The tenant to use for this client. Defaults to the default tenant.
database: The database to use for this client. Defaults to the default database.
"""
if settings is None:
settings = Settings()
settings.chroma_api_impl = "chromadb.api.rust.RustBindingsAPI"
settings.is_persistent = path is not None
settings.persist_directory = path or ""
# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)
return ClientCreator(tenant=tenant, database=database, settings=settings)
def HttpClient(
host: str = "localhost",
port: int = 8000,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
settings: Optional[Settings] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> ClientAPI:
"""
Creates a client that connects to a remote Chroma server. This supports
many clients connecting to the same server, and is the recommended way to
use Chroma in production.
Args:
host: The hostname of the Chroma server. Defaults to "localhost".
port: The port of the Chroma server. Defaults to 8000.
ssl: Whether to use SSL to connect to the Chroma server. Defaults to False.
headers: A dictionary of headers to send to the Chroma server. Defaults to {}.
settings: A dictionary of settings to communicate with the chroma server.
tenant: The tenant to use for this client. Defaults to the default tenant.
database: The database to use for this client. Defaults to the default database.
"""
if settings is None:
settings = Settings()
# Make sure parameters are the correct types -- users can pass anything.
host = str(host)
port = int(port)
ssl = bool(ssl)
tenant = str(tenant)
database = str(database)
settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
if settings.chroma_server_host and settings.chroma_server_host != host:
raise ValueError(
f"Chroma server host provided in settings[{settings.chroma_server_host}] is different to the one provided in HttpClient: [{host}]"
)
settings.chroma_server_host = host
if settings.chroma_server_http_port and settings.chroma_server_http_port != port:
raise ValueError(
f"Chroma server http port provided in settings[{settings.chroma_server_http_port}] is different to the one provided in HttpClient: [{port}]"
)
settings.chroma_server_http_port = port
settings.chroma_server_ssl_enabled = ssl
settings.chroma_server_headers = headers
return ClientCreator(tenant=tenant, database=database, settings=settings)
async def AsyncHttpClient(
host: str = "localhost",
port: int = 8000,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
settings: Optional[Settings] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> AsyncClientAPI:
"""
Creates an async client that connects to a remote Chroma server. This supports
many clients connecting to the same server, and is the recommended way to
use Chroma in production.
Args:
host: The hostname of the Chroma server. Defaults to "localhost".
port: The port of the Chroma server. Defaults to 8000.
ssl: Whether to use SSL to connect to the Chroma server. Defaults to False.
headers: A dictionary of headers to send to the Chroma server. Defaults to {}.
settings: A dictionary of settings to communicate with the chroma server.
tenant: The tenant to use for this client. Defaults to the default tenant.
database: The database to use for this client. Defaults to the default database.
"""
if settings is None:
settings = Settings()
# Make sure parameters are the correct types -- users can pass anything.
host = str(host)
port = int(port)
ssl = bool(ssl)
tenant = str(tenant)
database = str(database)
settings.chroma_api_impl = "chromadb.api.async_fastapi.AsyncFastAPI"
if settings.chroma_server_host and settings.chroma_server_host != host:
raise ValueError(
f"Chroma server host provided in settings[{settings.chroma_server_host}] is different to the one provided in HttpClient: [{host}]"
)
settings.chroma_server_host = host
if settings.chroma_server_http_port and settings.chroma_server_http_port != port:
raise ValueError(
f"Chroma server http port provided in settings[{settings.chroma_server_http_port}] is different to the one provided in HttpClient: [{port}]"
)
settings.chroma_server_http_port = port
settings.chroma_server_ssl_enabled = ssl
settings.chroma_server_headers = headers
return await AsyncClientCreator.create(
tenant=tenant, database=database, settings=settings
)
def CloudClient(
tenant: Optional[str] = None,
database: Optional[str] = None,
api_key: Optional[str] = None,
settings: Optional[Settings] = None,
*, # Following arguments are keyword-only, intended for testing only.
cloud_host: str = "api.trychroma.com",
cloud_port: int = 443,
enable_ssl: bool = True,
) -> ClientAPI:
"""
Creates a client to connect to a tenant and database on Chroma cloud.
Args:
tenant: The tenant to use for this client. Optional. If not provided, it will be inferred from the API key if the key is scoped to a single tenant. If provided, it will be validated against the API key's scope.
database: The database to use for this client. Optional. If not provided, it will be inferred from the API key if the key is scoped to a single database. If provided, it will be validated against the API key's scope.
api_key: The api key to use for this client.
"""
required_args = [
CloudClientArg(name="api_key", env_var="CHROMA_API_KEY", value=api_key),
]
# If api_key is not provided, try to load it from the environment variable
if not all([arg.value for arg in required_args]):
for arg in required_args:
arg.value = arg.value or os.environ.get(arg.env_var)
missing_args = [arg for arg in required_args if arg.value is None]
if missing_args:
raise ValueError(
f"Missing required arguments: {', '.join([arg.name for arg in missing_args])}. "
f"Please provide them or set the environment variables: {', '.join([arg.env_var for arg in missing_args])}"
)
if settings is None:
settings = Settings()
# Make sure paramaters are the correct types -- users can pass anything.
tenant = tenant or os.environ.get("CHROMA_TENANT")
if tenant is not None:
tenant = str(tenant)
database = database or os.environ.get("CHROMA_DATABASE")
if database is not None:
database = str(database)
api_key = str(api_key)
cloud_host = str(cloud_host)
cloud_port = int(cloud_port)
enable_ssl = bool(enable_ssl)
settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
settings.chroma_server_host = cloud_host
settings.chroma_server_http_port = cloud_port
settings.chroma_server_ssl_enabled = enable_ssl
settings.chroma_client_auth_provider = (
"chromadb.auth.token_authn.TokenAuthClientProvider"
)
settings.chroma_client_auth_credentials = api_key
settings.chroma_auth_token_transport_header = TokenTransportHeader.X_CHROMA_TOKEN
settings.chroma_overwrite_singleton_tenant_database_access_from_auth = True
return ClientCreator(tenant=tenant, database=database, settings=settings)
def Client(
settings: Settings = __settings,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> ClientAPI:
"""
Return a running chroma.API instance
tenant: The tenant to use for this client. Defaults to the default tenant.
database: The database to use for this client. Defaults to the default database.
"""
# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)
return ClientCreator(tenant=tenant, database=database, settings=settings)
def AdminClient(settings: Settings = Settings()) -> AdminAPI:
"""
Creates an admin client that can be used to create tenants and databases.
"""
return AdminClientCreator(settings=settings)

View File

@@ -0,0 +1,863 @@
from chromadb.api.types import * # noqa: F401, F403
from chromadb.execution.expression import ( # noqa: F401, F403
Search,
Key,
K,
SearchWhere,
And,
Or,
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
In,
Nin,
Regex,
NotRegex,
Contains,
NotContains,
Limit,
Select,
Rank,
Abs,
Div,
Exp,
Log,
Max,
Min,
Mul,
Knn,
Rrf,
Sub,
Sum,
Val,
)
from abc import ABC, abstractmethod
from typing import Sequence, Optional, List, Dict, Any
from uuid import UUID
from overrides import override
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
from chromadb.api.types import (
CollectionMetadata,
Documents,
Embeddable,
EmbeddingFunction,
DataLoader,
Embeddings,
IDs,
Include,
IncludeMetadataDocumentsDistances,
IncludeMetadataDocuments,
Loadable,
Metadatas,
Schema,
URIs,
Where,
QueryResult,
GetResult,
WhereDocument,
SearchResult,
DefaultEmbeddingFunction,
)
from chromadb.auth import UserIdentity
from chromadb.config import Component, Settings
from chromadb.types import Database, Tenant, Collection as CollectionModel
from chromadb.api.models.Collection import Collection
from chromadb.api.models.AttachedFunction import AttachedFunction
# Re-export the async version
from chromadb.api.async_api import ( # noqa: F401
AsyncBaseAPI as AsyncBaseAPI,
AsyncClientAPI as AsyncClientAPI,
AsyncAdminAPI as AsyncAdminAPI,
AsyncServerAPI as AsyncServerAPI,
)
class BaseAPI(ABC):
@abstractmethod
def heartbeat(self) -> int:
"""Get the current time in nanoseconds since epoch.
Used to check if the server is alive.
Returns:
int: The current time in nanoseconds since epoch
"""
pass
#
# COLLECTION METHODS
#
@abstractmethod
def count_collections(self) -> int:
"""Count the number of collections.
Returns:
int: The number of collections.
Examples:
```python
client.count_collections()
# 1
```
"""
pass
def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration: Optional[UpdateCollectionConfiguration] = None,
) -> None:
"""[Internal] Modify a collection by UUID. Can update the name and/or metadata.
Args:
id: The internal UUID of the collection to modify.
new_name: The new name of the collection.
If None, the existing name will remain. Defaults to None.
new_metadata: The new metadata to associate with the collection.
Defaults to None.
new_configuration: The new configuration to associate with the collection.
Defaults to None.
"""
pass
@abstractmethod
def delete_collection(
self,
name: str,
) -> None:
"""Delete a collection with the given name.
Args:
name: The name of the collection to delete.
Raises:
ValueError: If the collection does not exist.
Examples:
```python
client.delete_collection("my_collection")
```
"""
pass
#
# ITEM METHODS
#
@abstractmethod
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""[Internal] Add embeddings to a collection specified by UUID.
If (some) ids already exist, only the new embeddings will be added.
Args:
ids: The ids to associate with the embeddings.
collection_id: The UUID of the collection to add the embeddings to.
embedding: The sequence of embeddings to add.
metadata: The metadata to associate with the embeddings. Defaults to None.
documents: The documents to associate with the embeddings. Defaults to None.
uris: URIs of data sources for each embedding. Defaults to None.
Returns:
True if the embeddings were added successfully.
"""
pass
@abstractmethod
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""[Internal] Update entries in a collection specified by UUID.
Args:
collection_id: The UUID of the collection to update the embeddings in.
ids: The IDs of the entries to update.
embeddings: The sequence of embeddings to update. Defaults to None.
metadatas: The metadata to associate with the embeddings. Defaults to None.
documents: The documents to associate with the embeddings. Defaults to None.
uris: URIs of data sources for each embedding. Defaults to None.
Returns:
True if the embeddings were updated successfully.
"""
pass
@abstractmethod
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""[Internal] Add or update entries in the a collection specified by UUID.
If an entry with the same id already exists, it will be updated,
otherwise it will be added.
Args:
collection_id: The collection to add the embeddings to
ids: The ids to associate with the embeddings. Defaults to None.
embeddings: The sequence of embeddings to add
metadatas: The metadata to associate with the embeddings. Defaults to None.
documents: The documents to associate with the embeddings. Defaults to None.
uris: URIs of data sources for each embedding. Defaults to None.
"""
pass
@abstractmethod
def _count(self, collection_id: UUID) -> int:
"""[Internal] Returns the number of entries in a collection specified by UUID.
Args:
collection_id: The UUID of the collection to count the embeddings in.
Returns:
int: The number of embeddings in the collection
"""
pass
@abstractmethod
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
"""[Internal] Returns the first n entries in a collection specified by UUID.
Args:
collection_id: The UUID of the collection to peek into.
n: The number of entries to peek. Defaults to 10.
Returns:
GetResult: The first n entries in the collection.
"""
pass
@abstractmethod
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocuments,
) -> GetResult:
"""[Internal] Returns entries from a collection specified by UUID.
Args:
ids: The IDs of the entries to get. Defaults to None.
where: Conditional filtering on metadata. Defaults to None.
limit: The maximum number of entries to return. Defaults to None.
offset: The number of entries to skip before returning. Defaults to None.
where_document: Conditional filtering on documents. Defaults to None.
include: The fields to include in the response.
Defaults to ["metadatas", "documents"].
Returns:
GetResult: The entries in the collection that match the query.
"""
pass
@abstractmethod
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs],
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
) -> None:
"""[Internal] Deletes entries from a collection specified by UUID.
Args:
collection_id: The UUID of the collection to delete the entries from.
ids: The IDs of the entries to delete. Defaults to None.
where: Conditional filtering on metadata. Defaults to None.
where_document: Conditional filtering on documents. Defaults to None.
Returns:
IDs: The list of IDs of the entries that were deleted.
"""
pass
@abstractmethod
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
) -> QueryResult:
"""[Internal] Performs a nearest neighbors query on a collection specified by UUID.
Args:
collection_id: The UUID of the collection to query.
query_embeddings: The embeddings to use as the query.
ids: The IDs to filter by during the query. Defaults to None.
n_results: The number of results to return. Defaults to 10.
where: Conditional filtering on metadata. Defaults to None.
where_document: Conditional filtering on documents. Defaults to None.
include: The fields to include in the response.
Defaults to ["metadatas", "documents", "distances"].
Returns:
QueryResult: The results of the query.
"""
pass
@abstractmethod
def reset(self) -> bool:
"""Resets the database. This will delete all collections and entries.
Returns:
bool: True if the database was reset successfully.
"""
pass
@abstractmethod
def get_version(self) -> str:
"""Get the version of Chroma.
Returns:
str: The version of Chroma
"""
pass
@abstractmethod
def get_settings(self) -> Settings:
"""Get the settings used to initialize.
Returns:
Settings: The settings used to initialize.
"""
pass
@abstractmethod
def get_max_batch_size(self) -> int:
"""Return the maximum number of records that can be created or mutated in a single call."""
pass
@abstractmethod
def get_user_identity(self) -> UserIdentity:
"""Resolve the tenant and databases for the client. Returns the default
values if can't be resolved.
"""
pass
class ClientAPI(BaseAPI, ABC):
tenant: str
database: str
@abstractmethod
def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Collection]:
"""List all collections.
Args:
limit: The maximum number of entries to return. Defaults to None.
offset: The number of entries to skip before returning. Defaults to None.
Returns:
Sequence[Collection]: A list of collections
Examples:
```python
client.list_collections()
# [collection(name="my_collection", metadata={})]
```
"""
pass
@abstractmethod
def create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> Collection:
"""Create a new collection with the given name and metadata.
Args:
name: The name of the collection to create.
metadata: Optional metadata to associate with the collection.
embedding_function: Optional function to use to embed documents.
Uses the default embedding function if not provided.
get_or_create: If True, return the existing collection if it exists.
data_loader: Optional function to use to load records (documents, images, etc.)
Returns:
Collection: The newly created collection.
Raises:
ValueError: If the collection already exists and get_or_create is False.
ValueError: If the collection name is invalid.
Examples:
```python
client.create_collection("my_collection")
# collection(name="my_collection", metadata={})
client.create_collection("my_collection", metadata={"foo": "bar"})
# collection(name="my_collection", metadata={"foo": "bar"})
```
"""
pass
@abstractmethod
def get_collection(
self,
name: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
"""Get a collection with the given name.
Args:
name: The name of the collection to get
embedding_function: Optional function to use to embed documents.
Uses the default embedding function if not provided.
data_loader: Optional function to use to load records (documents, images, etc.)
Returns:
Collection: The collection
Raises:
ValueError: If the collection does not exist
Examples:
```python
client.get_collection("my_collection")
# collection(name="my_collection", metadata={})
```
"""
pass
@abstractmethod
def get_or_create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
"""Get or create a collection with the given name and metadata.
Args:
name: The name of the collection to get or create
metadata: Optional metadata to associate with the collection. If
the collection already exists, the metadata provided is ignored.
If the collection does not exist, the new collection will be created
with the provided metadata.
embedding_function: Optional function to use to embed documents
data_loader: Optional function to use to load records (documents, images, etc.)
Returns:
The collection
Examples:
```python
client.get_or_create_collection("my_collection")
# collection(name="my_collection", metadata={})
```
"""
pass
@abstractmethod
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
"""Set the tenant and database for the client. Raises an error if the tenant or
database does not exist.
Args:
tenant: The tenant to set.
database: The database to set.
"""
pass
@abstractmethod
def set_database(self, database: str) -> None:
"""Set the database for the client. Raises an error if the database does not exist.
Args:
database: The database to set.
"""
pass
@staticmethod
@abstractmethod
def clear_system_cache() -> None:
"""Clear the system cache so that new systems can be created for an existing path.
This should only be used for testing purposes."""
pass
class AdminAPI(ABC):
@abstractmethod
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
"""Create a new database. Raises an error if the database already exists.
Args:
database: The name of the database to create.
"""
pass
@abstractmethod
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
"""Get a database. Raises an error if the database does not exist.
Args:
database: The name of the database to get.
tenant: The tenant of the database to get.
"""
pass
@abstractmethod
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
"""Delete a database. Raises an error if the database does not exist.
Args:
database: The name of the database to delete.
tenant: The tenant of the database to delete.
"""
pass
@abstractmethod
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
"""List all databases for a tenant. Raises an error if the tenant does not exist.
Args:
tenant: The tenant to list databases for.
"""
pass
@abstractmethod
def create_tenant(self, name: str) -> None:
"""Create a new tenant. Raises an error if the tenant already exists.
Args:
tenant: The name of the tenant to create.
"""
pass
@abstractmethod
def get_tenant(self, name: str) -> Tenant:
"""Get a tenant. Raises an error if the tenant does not exist.
Args:
tenant: The name of the tenant to get.
"""
pass
class ServerAPI(BaseAPI, AdminAPI, Component):
"""An API instance that extends the relevant Base API methods by passing
in a tenant and database. This is the root component of the Chroma System"""
@abstractmethod
@override
def count_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> int:
pass
@abstractmethod
def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Sequence[CollectionModel]:
pass
@abstractmethod
def create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass
@abstractmethod
def get_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass
@abstractmethod
def get_or_create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass
@abstractmethod
@override
def delete_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
pass
@abstractmethod
@override
def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration: Optional[UpdateCollectionConfiguration] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
pass
@abstractmethod
def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass
@abstractmethod
def _search(
self,
collection_id: UUID,
searches: List[Search],
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> SearchResult:
pass
@abstractmethod
@override
def _count(
self,
collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
pass
@abstractmethod
@override
def _peek(
self,
collection_id: UUID,
n: int = 10,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
pass
@abstractmethod
@override
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocuments,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
pass
@abstractmethod
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
pass
@abstractmethod
@override
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
pass
@abstractmethod
@override
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
pass
@abstractmethod
@override
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
pass
@abstractmethod
@override
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
pass
@abstractmethod
def attach_function(
self,
function_id: str,
name: str,
input_collection_id: UUID,
output_collection: str,
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Attach a function to a collection.
Args:
function_id: Built-in function identifier
name: Unique name for this attached function
input_collection_id: Source collection that triggers the function
output_collection: Target collection where function output is stored
params: Optional dictionary with function-specific parameters
tenant: The tenant name
database: The database name
Returns:
AttachedFunction: Object representing the attached function
"""
pass
@abstractmethod
def detach_function(
self,
attached_function_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""Detach a function and prevent any further runs.
Args:
attached_function_id: ID of the attached function to remove
delete_output: Whether to also delete the output collection
tenant: The tenant name
database: The database name
Returns:
bool: True if successful
"""
pass

View File

@@ -0,0 +1,770 @@
from abc import ABC, abstractmethod
from typing import Sequence, Optional, List
from uuid import UUID
from overrides import override
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
)
from chromadb.auth import UserIdentity
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
from chromadb.api.types import (
CollectionMetadata,
Documents,
Embeddable,
EmbeddingFunction,
DataLoader,
Embeddings,
IDs,
Include,
Loadable,
Metadatas,
Schema,
URIs,
Where,
QueryResult,
GetResult,
WhereDocument,
IncludeMetadataDocuments,
IncludeMetadataDocumentsDistances,
SearchResult,
DefaultEmbeddingFunction,
)
from chromadb.execution.expression.plan import Search
from chromadb.config import Component, Settings
from chromadb.types import Database, Tenant, Collection as CollectionModel
class AsyncBaseAPI(ABC):
@abstractmethod
async def heartbeat(self) -> int:
"""Get the current time in nanoseconds since epoch.
Used to check if the server is alive.
Returns:
int: The current time in nanoseconds since epoch
"""
pass
#
# COLLECTION METHODS
#
@abstractmethod
async def count_collections(self) -> int:
"""Count the number of collections.
Returns:
int: The number of collections.
Examples:
```python
await client.count_collections()
# 1
```
"""
pass
@abstractmethod
async def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration: Optional[UpdateCollectionConfiguration] = None,
) -> None:
"""[Internal] Modify a collection by UUID. Can update the name and/or metadata.
Args:
id: The internal UUID of the collection to modify.
new_name: The new name of the collection.
If None, the existing name will remain. Defaults to None.
new_metadata: The new metadata to associate with the collection.
Defaults to None.
new_configuration: The new configuration to associate with the collection.
Defaults to None.
"""
pass
@abstractmethod
async def delete_collection(
self,
name: str,
) -> None:
"""Delete a collection with the given name.
Args:
name: The name of the collection to delete.
Raises:
ValueError: If the collection does not exist.
Examples:
```python
await client.delete_collection("my_collection")
```
"""
pass
#
# ITEM METHODS
#
@abstractmethod
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""[Internal] Add embeddings to a collection specified by UUID.
If (some) ids already exist, only the new embeddings will be added.
Args:
ids: The ids to associate with the embeddings.
collection_id: The UUID of the collection to add the embeddings to.
embedding: The sequence of embeddings to add.
metadata: The metadata to associate with the embeddings. Defaults to None.
documents: The documents to associate with the embeddings. Defaults to None.
uris: URIs of data sources for each embedding. Defaults to None.
Returns:
True if the embeddings were added successfully.
"""
pass
@abstractmethod
async def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""[Internal] Update entries in a collection specified by UUID.
Args:
collection_id: The UUID of the collection to update the embeddings in.
ids: The IDs of the entries to update.
embeddings: The sequence of embeddings to update. Defaults to None.
metadatas: The metadata to associate with the embeddings. Defaults to None.
documents: The documents to associate with the embeddings. Defaults to None.
uris: URIs of data sources for each embedding. Defaults to None.
Returns:
True if the embeddings were updated successfully.
"""
pass
@abstractmethod
async def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
"""[Internal] Add or update entries in the a collection specified by UUID.
If an entry with the same id already exists, it will be updated,
otherwise it will be added.
Args:
collection_id: The collection to add the embeddings to
ids: The ids to associate with the embeddings. Defaults to None.
embeddings: The sequence of embeddings to add
metadatas: The metadata to associate with the embeddings. Defaults to None.
documents: The documents to associate with the embeddings. Defaults to None.
uris: URIs of data sources for each embedding. Defaults to None.
"""
pass
@abstractmethod
async def _count(self, collection_id: UUID) -> int:
"""[Internal] Returns the number of entries in a collection specified by UUID.
Args:
collection_id: The UUID of the collection to count the embeddings in.
Returns:
int: The number of embeddings in the collection
"""
pass
@abstractmethod
async def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
"""[Internal] Returns the first n entries in a collection specified by UUID.
Args:
collection_id: The UUID of the collection to peek into.
n: The number of entries to peek. Defaults to 10.
Returns:
GetResult: The first n entries in the collection.
"""
pass
@abstractmethod
async def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocuments,
) -> GetResult:
"""[Internal] Returns entries from a collection specified by UUID.
Args:
ids: The IDs of the entries to get. Defaults to None.
where: Conditional filtering on metadata. Defaults to None.
limit: The maximum number of entries to return. Defaults to None.
offset: The number of entries to skip before returning. Defaults to None.
where_document: Conditional filtering on documents. Defaults to None.
include: The fields to include in the response.
Defaults to ["embeddings", "metadatas", "documents"].
Returns:
GetResult: The entries in the collection that match the query.
"""
pass
@abstractmethod
async def _delete(
self,
collection_id: UUID,
ids: Optional[IDs],
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
) -> None:
"""[Internal] Deletes entries from a collection specified by UUID.
Args:
collection_id: The UUID of the collection to delete the entries from.
ids: The IDs of the entries to delete. Defaults to None.
where: Conditional filtering on metadata. Defaults to None.
where_document: Conditional filtering on documents. Defaults to None.
Returns:
IDs: The list of IDs of the entries that were deleted.
"""
pass
@abstractmethod
async def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
) -> QueryResult:
"""[Internal] Performs a nearest neighbors query on a collection specified by UUID.
Args:
collection_id: The UUID of the collection to query.
query_embeddings: The embeddings to use as the query.
n_results: The number of results to return. Defaults to 10.
where: Conditional filtering on metadata. Defaults to None.
where_document: Conditional filtering on documents. Defaults to None.
include: The fields to include in the response.
Defaults to ["embeddings", "metadatas", "documents", "distances"].
Returns:
QueryResult: The results of the query.
"""
pass
@abstractmethod
async def reset(self) -> bool:
"""Resets the database. This will delete all collections and entries.
Returns:
bool: True if the database was reset successfully.
"""
pass
@abstractmethod
async def get_version(self) -> str:
"""Get the version of Chroma.
Returns:
str: The version of Chroma
"""
pass
@abstractmethod
def get_settings(self) -> Settings:
"""Get the settings used to initialize.
Returns:
Settings: The settings used to initialize.
"""
pass
@abstractmethod
async def get_max_batch_size(self) -> int:
"""Return the maximum number of records that can be created or mutated in a single call."""
pass
@abstractmethod
async def get_user_identity(self) -> UserIdentity:
"""Resolve the tenant and databases for the client. Returns the default
values if can't be resolved.
"""
pass
class AsyncClientAPI(AsyncBaseAPI, ABC):
tenant: str
database: str
@abstractmethod
async def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[AsyncCollection]:
"""List all collections.
Args:
limit: The maximum number of entries to return. Defaults to None.
offset: The number of entries to skip before returning. Defaults to None.
Returns:
Sequence[AsyncCollection]: A list of collections.
Examples:
```python
await client.list_collections()
# [collection(name="my_collection", metadata={})]
```
"""
pass
@abstractmethod
async def create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> AsyncCollection:
"""Create a new collection with the given name and metadata.
Args:
name: The name of the collection to create.
metadata: Optional metadata to associate with the collection.
embedding_function: Optional function to use to embed documents.
Uses the default embedding function if not provided.
get_or_create: If True, return the existing collection if it exists.
data_loader: Optional function to use to load records (documents, images, etc.)
Returns:
Collection: The newly created collection.
Raises:
ValueError: If the collection already exists and get_or_create is False.
ValueError: If the collection name is invalid.
Examples:
```python
await client.create_collection("my_collection")
# collection(name="my_collection", metadata={})
await client.create_collection("my_collection", metadata={"foo": "bar"})
# collection(name="my_collection", metadata={"foo": "bar"})
```
"""
pass
@abstractmethod
async def get_collection(
self,
name: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> AsyncCollection:
"""Get a collection with the given name.
Args:
name: The name of the collection to get
embedding_function: Optional function to use to embed documents.
Uses the default embedding function if not provided.
data_loader: Optional function to use to load records (documents, images, etc.)
Returns:
Collection: The collection
Raises:
ValueError: If the collection does not exist
Examples:
```python
await client.get_collection("my_collection")
# collection(name="my_collection", metadata={})
```
"""
pass
@abstractmethod
async def get_or_create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> AsyncCollection:
"""Get or create a collection with the given name and metadata.
Args:
name: The name of the collection to get or create
metadata: Optional metadata to associate with the collection. If
the collection already exists, the metadata provided is ignored.
If the collection does not exist, the new collection will be created
with the provided metadata.
embedding_function: Optional function to use to embed documents
data_loader: Optional function to use to load records (documents, images, etc.)
Returns:
The collection
Examples:
```python
await client.get_or_create_collection("my_collection")
# collection(name="my_collection", metadata={})
```
"""
pass
@abstractmethod
async def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
"""Set the tenant and database for the client. Raises an error if the tenant or
database does not exist.
Args:
tenant: The tenant to set.
database: The database to set.
"""
pass
@abstractmethod
async def set_database(self, database: str) -> None:
"""Set the database for the client. Raises an error if the database does not exist.
Args:
database: The database to set.
"""
pass
@staticmethod
@abstractmethod
def clear_system_cache() -> None:
"""Clear the system cache so that new systems can be created for an existing path.
This should only be used for testing purposes."""
pass
class AsyncAdminAPI(ABC):
@abstractmethod
async def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
"""Create a new database. Raises an error if the database already exists.
Args:
database: The name of the database to create.
"""
pass
@abstractmethod
async def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
"""Get a database. Raises an error if the database does not exist.
Args:
database: The name of the database to get.
tenant: The tenant of the database to get.
"""
pass
@abstractmethod
async def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
"""Delete a database. Raises an error if the database does not exist.
Args:
database: The name of the database to delete.
tenant: The tenant of the database to delete.
"""
pass
@abstractmethod
async def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
"""List all databases for a tenant. Raises an error if the tenant does not exist.
Args:
tenant: The tenant to list databases for.
"""
pass
@abstractmethod
async def create_tenant(self, name: str) -> None:
"""Create a new tenant. Raises an error if the tenant already exists.
Args:
tenant: The name of the tenant to create.
"""
pass
@abstractmethod
async def get_tenant(self, name: str) -> Tenant:
"""Get a tenant. Raises an error if the tenant does not exist.
Args:
tenant: The name of the tenant to get.
"""
pass
class AsyncServerAPI(AsyncBaseAPI, AsyncAdminAPI, Component):
"""An API instance that extends the relevant Base API methods by passing
in a tenant and database. This is the root component of the Chroma System"""
@abstractmethod
async def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Sequence[CollectionModel]:
pass
@abstractmethod
@override
async def count_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> int:
pass
@abstractmethod
async def create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass
@abstractmethod
async def get_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass
@abstractmethod
async def get_or_create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass
@abstractmethod
@override
async def delete_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
pass
@abstractmethod
@override
async def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration: Optional[UpdateCollectionConfiguration] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
pass
@abstractmethod
async def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass
@abstractmethod
async def _search(
self,
collection_id: UUID,
searches: List[Search],
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> SearchResult:
pass
@abstractmethod
@override
async def _count(
self,
collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
pass
@abstractmethod
@override
async def _peek(
self,
collection_id: UUID,
n: int = 10,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
pass
@abstractmethod
@override
async def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocuments,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
pass
@abstractmethod
@override
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
pass
@abstractmethod
@override
async def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
pass
@abstractmethod
@override
async def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
pass
@abstractmethod
@override
async def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
pass
@abstractmethod
@override
async def _delete(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
pass

View File

@@ -0,0 +1,526 @@
import httpx
from typing import Optional, Sequence
from uuid import UUID
from overrides import override
from chromadb.auth import UserIdentity
from chromadb.auth.utils import maybe_set_tenant_and_database
from chromadb.api import AsyncAdminAPI, AsyncClientAPI, AsyncServerAPI
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
validate_embedding_function_conflict_on_create,
validate_embedding_function_conflict_on_get,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.shared_system_client import SharedSystemClient
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Documents,
Embeddable,
EmbeddingFunction,
Embeddings,
GetResult,
IDs,
Include,
IncludeMetadataDocuments,
IncludeMetadataDocumentsDistances,
Loadable,
Metadatas,
QueryResult,
Schema,
URIs,
DefaultEmbeddingFunction,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.errors import ChromaError
from chromadb.types import Database, Tenant, Where, WhereDocument
class AsyncClient(SharedSystemClient, AsyncClientAPI):
"""A client for Chroma. This is the main entrypoint for interacting with Chroma.
A client internally stores its tenant and database and proxies calls to a
Server API instance of Chroma. It treats the Server API and corresponding System
as a singleton, so multiple clients connecting to the same resource will share the
same API instance.
Client implementations should be implement their own API-caching strategies.
"""
# An internal admin client for verifying that databases and tenants exist
_admin_client: AsyncAdminAPI
tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE
_server: AsyncServerAPI
@classmethod
async def create(
cls,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
settings: Settings = Settings(),
) -> "AsyncClient":
# Create an admin client for verifying that databases and tenants exist
self = cls(settings=settings)
SharedSystemClient._populate_data_from_system(self._system)
self.tenant = tenant
self.database = database
# Get the root system component we want to interact with
self._server = self._system.instance(AsyncServerAPI)
user_identity = await self.get_user_identity()
maybe_tenant, maybe_database = maybe_set_tenant_and_database(
user_identity,
overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
user_provided_tenant=tenant,
user_provided_database=database,
)
if maybe_tenant:
self.tenant = maybe_tenant
if maybe_database:
self.database = maybe_database
self._admin_client = AsyncAdminClient.from_system(self._system)
await self._validate_tenant_database(tenant=self.tenant, database=self.database)
self._submit_client_start_event()
return self
@classmethod
# (we can't override and use from_system() because it's synchronous)
async def from_system_async(
cls,
system: System,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AsyncClient":
"""Create a client from an existing system. This is useful for testing and debugging."""
return await AsyncClient.create(tenant, database, system.settings)
@classmethod
@override
def from_system(
cls,
system: System,
) -> "SharedSystemClient":
"""AsyncClient cannot be created synchronously. Use .from_system_async() instead."""
raise NotImplementedError(
"AsyncClient cannot be created synchronously. Use .from_system_async() instead."
)
@override
async def get_user_identity(self) -> UserIdentity:
return await self._server.get_user_identity()
@override
async def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
await self._validate_tenant_database(tenant=tenant, database=database)
self.tenant = tenant
self.database = database
@override
async def set_database(self, database: str) -> None:
await self._validate_tenant_database(tenant=self.tenant, database=database)
self.database = database
async def _validate_tenant_database(self, tenant: str, database: str) -> None:
try:
await self._admin_client.get_tenant(name=tenant)
except httpx.ConnectError:
raise ValueError(
"Could not connect to a Chroma server. Are you sure it is running?"
)
# Propagate ChromaErrors
except ChromaError as e:
raise e
except Exception:
raise ValueError(
f"Could not connect to tenant {tenant}. Are you sure it exists?"
)
try:
await self._admin_client.get_database(name=database, tenant=tenant)
except httpx.ConnectError:
raise ValueError(
"Could not connect to a Chroma server. Are you sure it is running?"
)
# region BaseAPI Methods
# Note - we could do this in less verbose ways, but they break type checking
@override
async def heartbeat(self) -> int:
return await self._server.heartbeat()
@override
async def list_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> Sequence[AsyncCollection]:
models = await self._server.list_collections(
limit, offset, tenant=self.tenant, database=self.database
)
return [AsyncCollection(client=self._server, model=model) for model in models]
@override
async def count_collections(self) -> int:
return await self._server.count_collections(
tenant=self.tenant, database=self.database
)
@override
async def create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> AsyncCollection:
if configuration is None:
configuration = {}
configuration_ef = configuration.get("embedding_function")
validate_embedding_function_conflict_on_create(
embedding_function, configuration_ef
)
# If ef provided in function params and collection config ef is None,
# set the collection config ef to the function params
if embedding_function is not None and configuration_ef is None:
configuration["embedding_function"] = embedding_function
model = await self._server.create_collection(
name=name,
schema=schema,
configuration=configuration,
metadata=metadata,
tenant=self.tenant,
database=self.database,
get_or_create=get_or_create,
)
return AsyncCollection(
client=self._server,
model=model,
embedding_function=embedding_function,
data_loader=data_loader,
)
@override
async def get_collection(
self,
name: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> AsyncCollection:
model = await self._server.get_collection(
name=name,
tenant=self.tenant,
database=self.database,
)
persisted_ef_config = model.configuration_json.get("embedding_function")
validate_embedding_function_conflict_on_get(
embedding_function, persisted_ef_config
)
return AsyncCollection(
client=self._server,
model=model,
embedding_function=embedding_function,
data_loader=data_loader,
)
@override
async def get_or_create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> AsyncCollection:
if configuration is None:
configuration = {}
configuration_ef = configuration.get("embedding_function")
validate_embedding_function_conflict_on_create(
embedding_function, configuration_ef
)
if embedding_function is not None and configuration_ef is None:
configuration["embedding_function"] = embedding_function
model = await self._server.get_or_create_collection(
name=name,
schema=schema,
configuration=configuration,
metadata=metadata,
tenant=self.tenant,
database=self.database,
)
persisted_ef_config = model.configuration_json.get("embedding_function")
validate_embedding_function_conflict_on_get(
embedding_function, persisted_ef_config
)
return AsyncCollection(
client=self._server,
model=model,
embedding_function=embedding_function,
data_loader=data_loader,
)
@override
async def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration: Optional[UpdateCollectionConfiguration] = None,
) -> None:
return await self._server._modify(
id=id,
new_name=new_name,
new_metadata=new_metadata,
new_configuration=new_configuration,
tenant=self.tenant,
database=self.database,
)
@override
async def delete_collection(
self,
name: str,
) -> None:
return await self._server.delete_collection(
name=name,
tenant=self.tenant,
database=self.database,
)
#
# ITEM METHODS
#
@override
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return await self._server._add(
ids=ids,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
tenant=self.tenant,
database=self.database,
)
@override
async def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return await self._server._update(
collection_id=collection_id,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
tenant=self.tenant,
database=self.database,
)
@override
async def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return await self._server._upsert(
collection_id=collection_id,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
tenant=self.tenant,
database=self.database,
)
@override
async def _count(self, collection_id: UUID) -> int:
return await self._server._count(
collection_id=collection_id,
)
@override
async def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
return await self._server._peek(
collection_id=collection_id,
n=n,
)
@override
async def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocuments,
) -> GetResult:
return await self._server._get(
collection_id=collection_id,
ids=ids,
where=where,
limit=limit,
offset=offset,
where_document=where_document,
include=include,
tenant=self.tenant,
database=self.database,
)
async def _delete(
self,
collection_id: UUID,
ids: Optional[IDs],
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
) -> None:
await self._server._delete(
collection_id=collection_id,
ids=ids,
where=where,
where_document=where_document,
tenant=self.tenant,
database=self.database,
)
@override
async def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
) -> QueryResult:
return await self._server._query(
collection_id=collection_id,
query_embeddings=query_embeddings,
ids=ids,
n_results=n_results,
where=where,
where_document=where_document,
include=include,
tenant=self.tenant,
database=self.database,
)
@override
async def reset(self) -> bool:
return await self._server.reset()
@override
async def get_version(self) -> str:
return await self._server.get_version()
@override
def get_settings(self) -> Settings:
return self._server.get_settings()
@override
async def get_max_batch_size(self) -> int:
return await self._server.get_max_batch_size()
# endregion
class AsyncAdminClient(SharedSystemClient, AsyncAdminAPI):
_server: AsyncServerAPI
def __init__(self, settings: Settings = Settings()) -> None:
super().__init__(settings)
self._server = self._system.instance(AsyncServerAPI)
@override
async def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
return await self._server.create_database(name=name, tenant=tenant)
@override
async def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
return await self._server.get_database(name=name, tenant=tenant)
@override
async def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
return await self._server.delete_database(name=name, tenant=tenant)
@override
async def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
return await self._server.list_databases(
limit=limit, offset=offset, tenant=tenant
)
@override
async def create_tenant(self, name: str) -> None:
return await self._server.create_tenant(name=name)
@override
async def get_tenant(self, name: str) -> Tenant:
return await self._server.get_tenant(name=name)
@classmethod
@override
def from_system(
cls,
system: System,
) -> "AsyncAdminClient":
SharedSystemClient._populate_data_from_system(system)
instance = cls(settings=system.settings)
return instance

View File

@@ -0,0 +1,773 @@
import asyncio
from uuid import UUID
import urllib.parse
import orjson
from typing import Any, Optional, cast, Tuple, Sequence, Dict, List
import logging
import httpx
from overrides import override
from chromadb import __version__
from chromadb.auth import UserIdentity
from chromadb.api.async_api import AsyncServerAPI
from chromadb.api.base_http_client import BaseHTTPClient
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
create_collection_configuration_to_json,
update_collection_configuration_to_json,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, Settings
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.telemetry.product import ProductTelemetryClient
from chromadb.utils.async_to_sync import async_to_sync
from chromadb.types import Database, Tenant, Collection as CollectionModel
from chromadb.execution.expression.plan import Search
from chromadb.api.types import (
Documents,
Embeddings,
IDs,
Include,
Schema,
Metadatas,
URIs,
Where,
WhereDocument,
GetResult,
QueryResult,
SearchResult,
CollectionMetadata,
optional_embeddings_to_base64_strings,
validate_batch,
convert_np_embeddings_to_list,
IncludeMetadataDocuments,
IncludeMetadataDocumentsDistances,
)
from chromadb.api.types import (
IncludeMetadataDocumentsEmbeddings,
serialize_metadata,
deserialize_metadata,
)
logger = logging.getLogger(__name__)
class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
# We make one client per event loop to avoid unexpected issues if a client
# is shared between event loops.
# For example, if a client is constructed in the main thread, then passed
# (or a returned Collection is passed) to a new thread, the client would
# normally throw an obscure asyncio error.
# Mixing asyncio and threading in this manner usually discouraged, but
# this gives a better user experience with practically no downsides.
# https://github.com/encode/httpx/issues/2058
_clients: Dict[int, httpx.AsyncClient] = {}
def __init__(self, system: System):
super().__init__(system)
system.settings.require("chroma_server_host")
system.settings.require("chroma_server_http_port")
self._opentelemetry_client = self.require(OpenTelemetryClient)
self._product_telemetry_client = self.require(ProductTelemetryClient)
self._settings = system.settings
self._api_url = AsyncFastAPI.resolve_url(
chroma_server_host=str(system.settings.chroma_server_host),
chroma_server_http_port=system.settings.chroma_server_http_port,
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
default_api_path=system.settings.chroma_server_api_default_path,
)
async def __aenter__(self) -> "AsyncFastAPI":
self._get_client()
return self
async def _cleanup(self) -> None:
while len(self._clients) > 0:
(_, client) = self._clients.popitem()
await client.aclose()
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
await self._cleanup()
@override
def stop(self) -> None:
super().stop()
@async_to_sync
async def sync_cleanup() -> None:
await self._cleanup()
sync_cleanup()
def _get_client(self) -> httpx.AsyncClient:
# Ideally this would use anyio to be compatible with both
# asyncio and trio, but anyio does not expose any way to identify
# the current event loop.
# We attempt to get the loop assuming the environment is asyncio, and
# otherwise gracefully fall back to using a singleton client.
loop_hash = None
try:
loop = asyncio.get_event_loop()
loop_hash = loop.__hash__()
except RuntimeError:
loop_hash = 0
if loop_hash not in self._clients:
headers = (self._settings.chroma_server_headers or {}).copy()
headers["Content-Type"] = "application/json"
headers["User-Agent"] = (
"Chroma Python Client v"
+ __version__
+ " (https://github.com/chroma-core/chroma)"
)
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
self._clients[loop_hash] = httpx.AsyncClient(
timeout=None,
headers=headers,
verify=self._settings.chroma_server_ssl_verify or False,
limits=limits,
)
return self._clients[loop_hash]
async def _make_request(
self, method: str, path: str, **kwargs: Dict[str, Any]
) -> Any:
# If the request has json in kwargs, use orjson to serialize it,
# remove it from kwargs, and add it to the content parameter
# This is because httpx uses a slower json serializer
if "json" in kwargs:
data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY)
kwargs["content"] = data
# Unlike requests, httpx does not automatically escape the path
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
url = self._api_url + escaped_path
response = await self._get_client().request(method, url, **cast(Any, kwargs))
BaseHTTPClient._raise_chroma_error(response)
return orjson.loads(response.text)
@trace_method("AsyncFastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
@override
async def heartbeat(self) -> int:
response = await self._make_request("get", "")
return int(response["nanosecond heartbeat"])
@trace_method("AsyncFastAPI.create_database", OpenTelemetryGranularity.OPERATION)
@override
async def create_database(
self,
name: str,
tenant: str = DEFAULT_TENANT,
) -> None:
await self._make_request(
"post",
f"/tenants/{tenant}/databases",
json={"name": name},
)
@trace_method("AsyncFastAPI.get_database", OpenTelemetryGranularity.OPERATION)
@override
async def get_database(
self,
name: str,
tenant: str = DEFAULT_TENANT,
) -> Database:
response = await self._make_request(
"get",
f"/tenants/{tenant}/databases/{name}",
params={"tenant": tenant},
)
return Database(
id=response["id"], name=response["name"], tenant=response["tenant"]
)
@trace_method("AsyncFastAPI.delete_database", OpenTelemetryGranularity.OPERATION)
@override
async def delete_database(
self,
name: str,
tenant: str = DEFAULT_TENANT,
) -> None:
await self._make_request(
"delete",
f"/tenants/{tenant}/databases/{name}",
)
@trace_method("AsyncFastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
@override
async def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
response = await self._make_request(
"get",
f"/tenants/{tenant}/databases",
params=BaseHTTPClient._clean_params(
{
"limit": limit,
"offset": offset,
}
),
)
return [
Database(id=db["id"], name=db["name"], tenant=db["tenant"])
for db in response
]
@trace_method("AsyncFastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
async def create_tenant(self, name: str) -> None:
await self._make_request(
"post",
"/tenants",
json={"name": name},
)
@trace_method("AsyncFastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
async def get_tenant(self, name: str) -> Tenant:
resp_json = await self._make_request(
"get",
"/tenants/" + name,
)
return Tenant(name=resp_json["name"])
@trace_method("AsyncFastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
@override
async def get_user_identity(self) -> UserIdentity:
return UserIdentity(**(await self._make_request("get", "/auth/identity")))
@trace_method("AsyncFastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
@override
async def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Sequence[CollectionModel]:
resp_json = await self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections",
params=BaseHTTPClient._clean_params(
{
"limit": limit,
"offset": offset,
}
),
)
models = [
CollectionModel.from_json(json_collection) for json_collection in resp_json
]
return models
@trace_method("AsyncFastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
@override
async def count_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> int:
resp_json = await self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections_count",
)
return cast(int, resp_json)
@trace_method("AsyncFastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
@override
async def create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
"""Creates a collection"""
config_json = (
create_collection_configuration_to_json(configuration, metadata)
if configuration
else None
)
serialized_schema = schema.serialize_to_json() if schema else None
resp_json = await self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections",
json={
"name": name,
"metadata": metadata,
"configuration": config_json,
"schema": serialized_schema,
"get_or_create": get_or_create,
},
)
model = CollectionModel.from_json(resp_json)
return model
@trace_method("AsyncFastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
@override
async def get_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
resp_json = await self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections/{name}",
)
model = CollectionModel.from_json(resp_json)
return model
@trace_method(
"AsyncFastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
)
@override
async def get_or_create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
return await self.create_collection(
name=name,
schema=schema,
configuration=configuration,
metadata=metadata,
get_or_create=True,
tenant=tenant,
database=database,
)
@trace_method("AsyncFastAPI._modify", OpenTelemetryGranularity.OPERATION)
@override
async def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration: Optional[UpdateCollectionConfiguration] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
await self._make_request(
"put",
f"/tenants/{tenant}/databases/{database}/collections/{id}",
json={
"new_metadata": new_metadata,
"new_name": new_name,
"new_configuration": update_collection_configuration_to_json(
new_configuration
)
if new_configuration
else None,
},
)
@trace_method("AsyncFastAPI._fork", OpenTelemetryGranularity.OPERATION)
@override
async def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
resp_json = await self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
json={"new_name": new_name},
)
model = CollectionModel.from_json(resp_json)
return model
@trace_method("AsyncFastAPI._search", OpenTelemetryGranularity.OPERATION)
@override
async def _search(
self,
collection_id: UUID,
searches: List[Search],
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> SearchResult:
"""Performs hybrid search on a collection"""
payload = {"searches": [s.to_dict() for s in searches]}
resp_json = await self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/search",
json=payload,
)
metadata_batches = resp_json.get("metadatas", None)
if metadata_batches is not None:
resp_json["metadatas"] = [
[
deserialize_metadata(metadata) if metadata is not None else None
for metadata in metadatas
]
if metadatas is not None
else None
for metadatas in metadata_batches
]
return SearchResult(resp_json)
@trace_method("AsyncFastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
@override
async def delete_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
await self._make_request(
"delete",
f"/tenants/{tenant}/databases/{database}/collections/{name}",
)
@trace_method("AsyncFastAPI._count", OpenTelemetryGranularity.OPERATION)
@override
async def _count(
self,
collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
"""Returns the number of embeddings in the database"""
resp_json = await self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/count",
)
return cast(int, resp_json)
@trace_method("AsyncFastAPI._peek", OpenTelemetryGranularity.OPERATION)
@override
async def _peek(
self,
collection_id: UUID,
n: int = 10,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
resp = await self._get(
collection_id,
tenant=tenant,
database=database,
limit=n,
include=IncludeMetadataDocumentsEmbeddings,
)
return resp
@trace_method("AsyncFastAPI._get", OpenTelemetryGranularity.OPERATION)
@override
async def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocuments,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
# Servers do not support the "data" include, as that is hydrated on the client side
filtered_include = [i for i in include if i != "data"]
resp_json = await self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/get",
json={
"ids": ids,
"where": where,
"limit": limit,
"offset": offset,
"where_document": where_document,
"include": filtered_include,
},
)
metadatas = resp_json.get("metadatas", None)
if metadatas is not None:
metadatas = [
deserialize_metadata(metadata) if metadata is not None else None
for metadata in metadatas
]
return GetResult(
ids=resp_json["ids"],
embeddings=resp_json.get("embeddings", None),
metadatas=metadatas, # type: ignore
documents=resp_json.get("documents", None),
data=None,
uris=resp_json.get("uris", None),
included=include,
)
@trace_method("AsyncFastAPI._delete", OpenTelemetryGranularity.OPERATION)
@override
async def _delete(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
await self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete",
json={"where": where, "ids": ids, "where_document": where_document},
)
return None
@trace_method("AsyncFastAPI._submit_batch", OpenTelemetryGranularity.ALL)
async def _submit_batch(
self,
batch: Tuple[
IDs,
Optional[Embeddings],
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
],
url: str,
) -> Any:
"""
Submits a batch of embeddings to the database
"""
supports_base64_encoding = await self.supports_base64_encoding()
serialized_metadatas = None
if batch[2] is not None:
serialized_metadatas = [
serialize_metadata(metadata) if metadata is not None else None
for metadata in batch[2]
]
data = {
"ids": batch[0],
"embeddings": optional_embeddings_to_base64_strings(batch[1])
if supports_base64_encoding
else batch[1],
"metadatas": serialized_metadatas,
"documents": batch[3],
"uris": batch[4],
}
return await self._make_request(
"post",
url,
json=data,
)
@trace_method("AsyncFastAPI._add", OpenTelemetryGranularity.ALL)
@override
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
batch = (
ids,
embeddings,
metadatas,
documents,
uris,
)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
await self._submit_batch(
batch,
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/add",
)
return True
@trace_method("AsyncFastAPI._update", OpenTelemetryGranularity.ALL)
@override
async def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
batch = (
ids,
embeddings if embeddings is not None else None,
metadatas,
documents,
uris,
)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
await self._submit_batch(
batch,
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/update",
)
return True
@trace_method("AsyncFastAPI._upsert", OpenTelemetryGranularity.ALL)
@override
async def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
batch = (
ids,
embeddings,
metadatas,
documents,
uris,
)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
await self._submit_batch(
batch,
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/upsert",
)
return True
@trace_method("AsyncFastAPI._query", OpenTelemetryGranularity.ALL)
@override
async def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
# Servers do not support the "data" include, as that is hydrated on the client side
filtered_include = [i for i in include if i != "data"]
resp_json = await self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
json={
"ids": ids,
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
if query_embeddings is not None
else None,
"n_results": n_results,
"where": where,
"where_document": where_document,
"include": filtered_include,
},
)
metadata_batches = resp_json.get("metadatas", None)
if metadata_batches is not None:
metadata_batches = [
[
deserialize_metadata(metadata) if metadata is not None else None
for metadata in metadatas
]
if metadatas is not None
else None
for metadatas in metadata_batches
]
return QueryResult(
ids=resp_json["ids"],
distances=resp_json.get("distances", None),
embeddings=resp_json.get("embeddings", None),
metadatas=metadata_batches, # type: ignore
documents=resp_json.get("documents", None),
uris=resp_json.get("uris", None),
data=None,
included=include,
)
@trace_method("AsyncFastAPI.reset", OpenTelemetryGranularity.ALL)
@override
async def reset(self) -> bool:
resp_json = await self._make_request("post", "/reset")
return cast(bool, resp_json)
@trace_method("AsyncFastAPI.get_version", OpenTelemetryGranularity.OPERATION)
@override
async def get_version(self) -> str:
resp_json = await self._make_request("get", "/version")
return cast(str, resp_json)
@override
def get_settings(self) -> Settings:
return self._settings
@trace_method(
"AsyncFastAPI.get_pre_flight_checks", OpenTelemetryGranularity.OPERATION
)
async def get_pre_flight_checks(self) -> Any:
if self.pre_flight_checks is None:
resp_json = await self._make_request("get", "/pre-flight-checks")
self.pre_flight_checks = resp_json
return self.pre_flight_checks
@trace_method(
"AsyncFastAPI.supports_base64_encoding", OpenTelemetryGranularity.OPERATION
)
async def supports_base64_encoding(self) -> bool:
pre_flight_checks = await self.get_pre_flight_checks()
b64_encoding_enabled = cast(
bool, pre_flight_checks.get("supports_base64_encoding", False)
)
return b64_encoding_enabled
@trace_method("AsyncFastAPI.get_max_batch_size", OpenTelemetryGranularity.OPERATION)
@override
async def get_max_batch_size(self) -> int:
pre_flight_checks = await self.get_pre_flight_checks()
max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1))
return max_batch_size

View File

@@ -0,0 +1,105 @@
from typing import Any, Dict, Optional, TypeVar
from urllib.parse import quote, urlparse, urlunparse
import logging
import orjson as json
import httpx
import chromadb.errors as errors
from chromadb.config import Settings
logger = logging.getLogger(__name__)
class BaseHTTPClient:
_settings: Settings
pre_flight_checks: Any = None
keepalive_secs: int = 40
@staticmethod
def _validate_host(host: str) -> None:
parsed = urlparse(host)
if "/" in host and parsed.scheme not in {"http", "https"}:
raise ValueError(
"Invalid URL. " f"Unrecognized protocol - {parsed.scheme}."
)
if "/" in host and (not host.startswith("http")):
raise ValueError(
"Invalid URL. "
"Seems that you are trying to pass URL as a host but without \
specifying the protocol. "
"Please add http:// or https:// to the host."
)
@staticmethod
def resolve_url(
chroma_server_host: str,
chroma_server_ssl_enabled: Optional[bool] = False,
default_api_path: Optional[str] = "",
chroma_server_http_port: Optional[int] = 8000,
) -> str:
_skip_port = False
_chroma_server_host = chroma_server_host
BaseHTTPClient._validate_host(_chroma_server_host)
if _chroma_server_host.startswith("http"):
logger.debug("Skipping port as the user is passing a full URL")
_skip_port = True
parsed = urlparse(_chroma_server_host)
scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
port = (
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
)
path = parsed.path or default_api_path
if not path or path == net_loc:
path = default_api_path if default_api_path else ""
if not path.endswith(default_api_path or ""):
path = path + default_api_path if default_api_path else ""
full_url = urlunparse(
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
)
return full_url
# requests removes None values from the built query string, but httpx includes it as an empty value
T = TypeVar("T", bound=Dict[Any, Any])
@staticmethod
def _clean_params(params: T) -> T:
"""Remove None values from provided dict."""
return {k: v for k, v in params.items() if v is not None} # type: ignore
@staticmethod
def _raise_chroma_error(resp: httpx.Response) -> None:
"""Raises an error if the response is not ok, using a ChromaError if possible."""
try:
resp.raise_for_status()
return
except httpx.HTTPStatusError:
pass
chroma_error = None
try:
body = json.loads(resp.text)
if "error" in body:
if body["error"] in errors.error_types:
chroma_error = errors.error_types[body["error"]](body["message"])
trace_id = resp.headers.get("chroma-trace-id")
if trace_id:
chroma_error.trace_id = trace_id
except BaseException:
pass
if chroma_error:
raise chroma_error
try:
resp.raise_for_status()
except httpx.HTTPStatusError:
trace_id = resp.headers.get("chroma-trace-id")
if trace_id:
raise Exception(f"{resp.text} (trace ID: {trace_id})")
raise (Exception(resp.text))

View File

@@ -0,0 +1,545 @@
from typing import Optional, Sequence
from uuid import UUID
from overrides import override
import httpx
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
validate_embedding_function_conflict_on_create,
validate_embedding_function_conflict_on_get,
)
from chromadb.api.shared_system_client import SharedSystemClient
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Documents,
Embeddable,
EmbeddingFunction,
Embeddings,
GetResult,
IDs,
Include,
Loadable,
Metadatas,
QueryResult,
Schema,
URIs,
IncludeMetadataDocuments,
IncludeMetadataDocumentsDistances,
DefaultEmbeddingFunction,
)
from chromadb.auth import UserIdentity
from chromadb.auth.utils import maybe_set_tenant_and_database
from chromadb.config import Settings, System
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
from chromadb.api.models.Collection import Collection
from chromadb.errors import ChromaAuthError, ChromaError
from chromadb.types import Database, Tenant, Where, WhereDocument
class Client(SharedSystemClient, ClientAPI):
"""A client for Chroma. This is the main entrypoint for interacting with Chroma.
A client internally stores its tenant and database and proxies calls to a
Server API instance of Chroma. It treats the Server API and corresponding System
as a singleton, so multiple clients connecting to the same resource will share the
same API instance.
Client implementations should be implement their own API-caching strategies.
"""
tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE
_server: ServerAPI
# An internal admin client for verifying that databases and tenants exist
_admin_client: AdminAPI
# region Initialization
def __init__(
self,
tenant: Optional[str] = DEFAULT_TENANT,
database: Optional[str] = DEFAULT_DATABASE,
settings: Settings = Settings(),
) -> None:
super().__init__(settings=settings)
if tenant is not None:
self.tenant = tenant
if database is not None:
self.database = database
# Get the root system component we want to interact with
self._server = self._system.instance(ServerAPI)
user_identity = self.get_user_identity()
maybe_tenant, maybe_database = maybe_set_tenant_and_database(
user_identity,
overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
user_provided_tenant=tenant,
user_provided_database=database,
)
# this should not happen unless types are invalidated
if maybe_tenant is None and tenant is None:
raise ChromaAuthError(
"Could not determine a tenant from the current authentication method. Please provide a tenant."
)
if maybe_database is None and database is None:
raise ChromaAuthError(
"Could not determine a database name from the current authentication method. Please provide a database name."
)
if maybe_tenant:
self.tenant = maybe_tenant
if maybe_database:
self.database = maybe_database
# Create an admin client for verifying that databases and tenants exist
self._admin_client = AdminClient.from_system(self._system)
self._validate_tenant_database(tenant=self.tenant, database=self.database)
self._submit_client_start_event()
@classmethod
@override
def from_system(
cls,
system: System,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "Client":
SharedSystemClient._populate_data_from_system(system)
instance = cls(tenant=tenant, database=database, settings=system.settings)
return instance
# endregion
@override
def get_user_identity(self) -> UserIdentity:
try:
return self._server.get_user_identity()
except httpx.ConnectError:
raise ValueError(
"Could not connect to a Chroma server. Are you sure it is running?"
)
# Propagate ChromaErrors
except ChromaError as e:
raise e
except Exception as e:
raise ValueError(str(e))
# region BaseAPI Methods
# Note - we could do this in less verbose ways, but they break type checking
@override
def heartbeat(self) -> int:
return self._server.heartbeat()
@override
def list_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> Sequence[Collection]:
return [
Collection(client=self._server, model=model)
for model in self._server.list_collections(
limit, offset, tenant=self.tenant, database=self.database
)
]
@override
def count_collections(self) -> int:
return self._server.count_collections(
tenant=self.tenant, database=self.database
)
@override
def create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> Collection:
if configuration is None:
configuration = {}
configuration_ef = configuration.get("embedding_function")
validate_embedding_function_conflict_on_create(
embedding_function, configuration_ef
)
# If ef provided in function params and collection config ef is None,
# set the collection config ef to the function params
if embedding_function is not None and configuration_ef is None:
configuration["embedding_function"] = embedding_function
model = self._server.create_collection(
name=name,
schema=schema,
metadata=metadata,
tenant=self.tenant,
database=self.database,
get_or_create=get_or_create,
configuration=configuration,
)
return Collection(
client=self._server,
model=model,
embedding_function=embedding_function,
data_loader=data_loader,
)
@override
def get_collection(
self,
name: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
model = self._server.get_collection(
name=name,
tenant=self.tenant,
database=self.database,
)
persisted_ef_config = model.configuration_json.get("embedding_function")
validate_embedding_function_conflict_on_get(
embedding_function, persisted_ef_config
)
return Collection(
client=self._server,
model=model,
embedding_function=embedding_function,
data_loader=data_loader,
)
@override
def get_or_create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
if configuration is None:
configuration = {}
configuration_ef = configuration.get("embedding_function")
validate_embedding_function_conflict_on_create(
embedding_function, configuration_ef
)
if embedding_function is not None and configuration_ef is None:
configuration["embedding_function"] = embedding_function
model = self._server.get_or_create_collection(
name=name,
schema=schema,
metadata=metadata,
tenant=self.tenant,
database=self.database,
configuration=configuration,
)
persisted_ef_config = model.configuration_json.get("embedding_function")
validate_embedding_function_conflict_on_get(
embedding_function, persisted_ef_config
)
return Collection(
client=self._server,
model=model,
embedding_function=embedding_function,
data_loader=data_loader,
)
@override
def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration: Optional[UpdateCollectionConfiguration] = None,
) -> None:
return self._server._modify(
id=id,
tenant=self.tenant,
database=self.database,
new_name=new_name,
new_metadata=new_metadata,
new_configuration=new_configuration,
)
@override
def delete_collection(
self,
name: str,
) -> None:
return self._server.delete_collection(
name=name,
tenant=self.tenant,
database=self.database,
)
#
# ITEM METHODS
#
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return self._server._add(
ids=ids,
tenant=self.tenant,
database=self.database,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
@override
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return self._server._update(
collection_id=collection_id,
tenant=self.tenant,
database=self.database,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
@override
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return self._server._upsert(
collection_id=collection_id,
tenant=self.tenant,
database=self.database,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
@override
def _count(self, collection_id: UUID) -> int:
return self._server._count(
collection_id=collection_id,
tenant=self.tenant,
database=self.database,
)
@override
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
return self._server._peek(
collection_id=collection_id,
n=n,
tenant=self.tenant,
database=self.database,
)
@override
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocuments,
) -> GetResult:
return self._server._get(
collection_id=collection_id,
tenant=self.tenant,
database=self.database,
ids=ids,
where=where,
limit=limit,
offset=offset,
where_document=where_document,
include=include,
)
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs],
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
) -> None:
self._server._delete(
collection_id=collection_id,
tenant=self.tenant,
database=self.database,
ids=ids,
where=where,
where_document=where_document,
)
@override
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
) -> QueryResult:
return self._server._query(
collection_id=collection_id,
ids=ids,
tenant=self.tenant,
database=self.database,
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
where_document=where_document,
include=include,
)
@override
def reset(self) -> bool:
return self._server.reset()
@override
def get_version(self) -> str:
return self._server.get_version()
@override
def get_settings(self) -> Settings:
return self._server.get_settings()
@override
def get_max_batch_size(self) -> int:
return self._server.get_max_batch_size()
# endregion
# region ClientAPI Methods
@override
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
self._validate_tenant_database(tenant=tenant, database=database)
self.tenant = tenant
self.database = database
@override
def set_database(self, database: str) -> None:
self._validate_tenant_database(tenant=self.tenant, database=database)
self.database = database
def _validate_tenant_database(self, tenant: str, database: str) -> None:
try:
self._admin_client.get_tenant(name=tenant)
except httpx.ConnectError:
raise ValueError(
"Could not connect to a Chroma server. Are you sure it is running?"
)
# Propagate ChromaErrors
except ChromaError as e:
raise e
except Exception:
raise ValueError(
f"Could not connect to tenant {tenant}. Are you sure it exists?"
)
try:
self._admin_client.get_database(name=database, tenant=tenant)
except httpx.ConnectError:
raise ValueError(
"Could not connect to a Chroma server. Are you sure it is running?"
)
# endregion
class AdminClient(SharedSystemClient, AdminAPI):
_server: ServerAPI
def __init__(self, settings: Settings = Settings()) -> None:
super().__init__(settings)
self._server = self._system.instance(ServerAPI)
@override
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
return self._server.create_database(name=name, tenant=tenant)
@override
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
return self._server.get_database(name=name, tenant=tenant)
@override
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
return self._server.delete_database(name=name, tenant=tenant)
@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
return self._server.list_databases(limit, offset, tenant=tenant)
@override
def create_tenant(self, name: str) -> None:
return self._server.create_tenant(name=name)
@override
def get_tenant(self, name: str) -> Tenant:
return self._server.get_tenant(name=name)
@classmethod
@override
def from_system(
cls,
system: System,
) -> "AdminClient":
SharedSystemClient._populate_data_from_system(system)
instance = cls(settings=system.settings)
return instance

View File

@@ -0,0 +1,882 @@
from typing import TypedDict, Dict, Any, Optional, cast, get_args
import json
from chromadb.api.types import (
Space,
CollectionMetadata,
UpdateMetadata,
EmbeddingFunction,
)
from chromadb.utils.embedding_functions import (
known_embedding_functions,
register_embedding_function,
)
from multiprocessing import cpu_count
import warnings
from chromadb.api.types import Schema
class HNSWConfiguration(TypedDict, total=False):
space: Space
ef_construction: int
max_neighbors: int
ef_search: int
num_threads: int
batch_size: int
sync_threshold: int
resize_factor: float
class SpannConfiguration(TypedDict, total=False):
search_nprobe: int
write_nprobe: int
space: Space
ef_construction: int
ef_search: int
max_neighbors: int
reassign_neighbor_count: int
split_threshold: int
merge_threshold: int
class CollectionConfiguration(TypedDict, total=True):
hnsw: Optional[HNSWConfiguration]
spann: Optional[SpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
def load_collection_configuration_from_json_str(
config_json_str: str,
) -> CollectionConfiguration:
config_json_map = json.loads(config_json_str)
return load_collection_configuration_from_json(config_json_map)
# TODO: make warnings prettier and add link to migration docs
def load_collection_configuration_from_json(
config_json_map: Dict[str, Any]
) -> CollectionConfiguration:
if (
config_json_map.get("spann") is not None
and config_json_map.get("hnsw") is not None
):
raise ValueError("hnsw and spann cannot both be provided")
hnsw_config = None
spann_config = None
ef_config = None
# Process vector index configuration (HNSW or SPANN)
if config_json_map.get("hnsw") is not None:
hnsw_config = cast(HNSWConfiguration, config_json_map["hnsw"])
if config_json_map.get("spann") is not None:
spann_config = cast(SpannConfiguration, config_json_map["spann"])
# Process embedding function configuration
if config_json_map.get("embedding_function") is not None:
ef_config = config_json_map["embedding_function"]
if ef_config["type"] == "legacy":
warnings.warn(
"legacy embedding function config",
DeprecationWarning,
stacklevel=2,
)
ef = None
else:
try:
ef_name = ef_config["name"]
except KeyError:
raise ValueError(
f"Embedding function name not found in config: {ef_config}"
)
try:
ef = known_embedding_functions[ef_name]
except KeyError:
raise ValueError(
f"Embedding function {ef_name} not found. Add @register_embedding_function decorator to the class definition."
)
try:
ef = ef.build_from_config(ef_config["config"]) # type: ignore
except Exception as e:
raise ValueError(
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
)
else:
ef = None
return CollectionConfiguration(
hnsw=hnsw_config,
spann=spann_config,
embedding_function=ef, # type: ignore
)
def collection_configuration_to_json_str(config: CollectionConfiguration) -> str:
return json.dumps(collection_configuration_to_json(config))
def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[str, Any]:
if isinstance(config, dict):
hnsw_config = config.get("hnsw")
spann_config = config.get("spann")
ef = config.get("embedding_function")
else:
try:
hnsw_config = config.get_parameter("hnsw").value
except ValueError:
hnsw_config = None
try:
spann_config = config.get_parameter("spann").value
except ValueError:
spann_config = None
try:
ef = config.get_parameter("embedding_function").value
except ValueError:
ef = None
ef_config: Dict[str, Any] | None = None
if hnsw_config is not None:
try:
hnsw_config = cast(HNSWConfiguration, hnsw_config)
except Exception as e:
raise ValueError(f"not a valid hnsw config: {e}")
if spann_config is not None:
try:
spann_config = cast(SpannConfiguration, spann_config)
except Exception as e:
raise ValueError(f"not a valid spann config: {e}")
if ef is None:
ef = None
ef_config = {"type": "legacy"}
if ef is not None:
try:
if ef.is_legacy():
ef_config = {"type": "legacy"}
else:
ef_config = {
"name": ef.name(),
"type": "known",
"config": ef.get_config(),
}
register_embedding_function(type(ef)) # type: ignore
except Exception as e:
warnings.warn(
f"legacy embedding function config: {e}",
DeprecationWarning,
stacklevel=2,
)
ef = None
ef_config = {"type": "legacy"}
return {
"hnsw": hnsw_config,
"spann": spann_config,
"embedding_function": ef_config,
}
class CreateHNSWConfiguration(TypedDict, total=False):
space: Space
ef_construction: int
max_neighbors: int
ef_search: int
num_threads: int
batch_size: int
sync_threshold: int
resize_factor: float
def json_to_create_hnsw_configuration(
json_map: Dict[str, Any]
) -> CreateHNSWConfiguration:
config: CreateHNSWConfiguration = {}
if "space" in json_map:
space_value = json_map["space"]
if space_value in get_args(Space):
config["space"] = space_value
else:
raise ValueError(f"not a valid space: {space_value}")
if "ef_construction" in json_map:
config["ef_construction"] = json_map["ef_construction"]
if "max_neighbors" in json_map:
config["max_neighbors"] = json_map["max_neighbors"]
if "ef_search" in json_map:
config["ef_search"] = json_map["ef_search"]
if "num_threads" in json_map:
config["num_threads"] = json_map["num_threads"]
if "batch_size" in json_map:
config["batch_size"] = json_map["batch_size"]
if "sync_threshold" in json_map:
config["sync_threshold"] = json_map["sync_threshold"]
if "resize_factor" in json_map:
config["resize_factor"] = json_map["resize_factor"]
return config
class CreateSpannConfiguration(TypedDict, total=False):
search_nprobe: int
write_nprobe: int
space: Space
ef_construction: int
ef_search: int
max_neighbors: int
reassign_neighbor_count: int
split_threshold: int
merge_threshold: int
def json_to_create_spann_configuration(
json_map: Dict[str, Any]
) -> CreateSpannConfiguration:
config: CreateSpannConfiguration = {}
if "search_nprobe" in json_map:
config["search_nprobe"] = json_map["search_nprobe"]
if "write_nprobe" in json_map:
config["write_nprobe"] = json_map["write_nprobe"]
if "space" in json_map:
space_value = json_map["space"]
if space_value in get_args(Space):
config["space"] = space_value
else:
raise ValueError(f"not a valid space: {space_value}")
if "ef_construction" in json_map:
config["ef_construction"] = json_map["ef_construction"]
if "ef_search" in json_map:
config["ef_search"] = json_map["ef_search"]
if "max_neighbors" in json_map:
config["max_neighbors"] = json_map["max_neighbors"]
return config
class CreateCollectionConfiguration(TypedDict, total=False):
hnsw: Optional[CreateHNSWConfiguration]
spann: Optional[CreateSpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
def create_collection_configuration_from_legacy_collection_metadata(
metadata: CollectionMetadata,
) -> CreateCollectionConfiguration:
"""Create a CreateCollectionConfiguration from legacy collection metadata"""
return create_collection_configuration_from_legacy_metadata_dict(metadata)
def create_collection_configuration_from_legacy_metadata_dict(
metadata: Dict[str, Any],
) -> CreateCollectionConfiguration:
"""Create a CreateCollectionConfiguration from legacy collection metadata"""
old_to_new = {
"hnsw:space": "space",
"hnsw:construction_ef": "ef_construction",
"hnsw:M": "max_neighbors",
"hnsw:search_ef": "ef_search",
"hnsw:num_threads": "num_threads",
"hnsw:batch_size": "batch_size",
"hnsw:sync_threshold": "sync_threshold",
"hnsw:resize_factor": "resize_factor",
}
json_map = {}
for name, value in metadata.items():
if name in old_to_new:
json_map[old_to_new[name]] = value
hnsw_config = json_to_create_hnsw_configuration(json_map)
hnsw_config = populate_create_hnsw_defaults(hnsw_config)
return CreateCollectionConfiguration(hnsw=hnsw_config)
# TODO: make warnings prettier and add link to migration docs
def load_create_collection_configuration_from_json(
json_map: Dict[str, Any]
) -> CreateCollectionConfiguration:
if json_map.get("hnsw") is not None and json_map.get("spann") is not None:
raise ValueError("hnsw and spann cannot both be provided")
result = CreateCollectionConfiguration()
# Handle vector index configuration
if json_map.get("hnsw") is not None:
result["hnsw"] = json_to_create_hnsw_configuration(json_map["hnsw"])
if json_map.get("spann") is not None:
result["spann"] = json_to_create_spann_configuration(json_map["spann"])
# Handle embedding function configuration
if json_map.get("embedding_function") is not None:
ef_config = json_map["embedding_function"]
if ef_config["type"] == "legacy":
warnings.warn(
"legacy embedding function config",
DeprecationWarning,
stacklevel=2,
)
else:
ef = known_embedding_functions[ef_config["name"]]
result["embedding_function"] = ef.build_from_config(ef_config["config"])
return result
def create_collection_configuration_to_json_str(
config: CreateCollectionConfiguration,
metadata: Optional[CollectionMetadata] = None,
) -> str:
"""Convert a CreateCollection configuration to a JSON-serializable string"""
return json.dumps(create_collection_configuration_to_json(config, metadata))
# TODO: make warnings prettier and add link to migration docs
def create_collection_configuration_to_json(
config: CreateCollectionConfiguration,
metadata: Optional[CollectionMetadata] = None,
) -> Dict[str, Any]:
"""Convert a CreateCollection configuration to a JSON-serializable dict"""
ef_config: Dict[str, Any] | None = None
hnsw_config = config.get("hnsw")
spann_config = config.get("spann")
if hnsw_config is not None:
try:
hnsw_config = cast(CreateHNSWConfiguration, hnsw_config)
except Exception as e:
raise ValueError(f"not a valid hnsw config: {e}")
if spann_config is not None:
try:
spann_config = cast(CreateSpannConfiguration, spann_config)
except Exception as e:
raise ValueError(f"not a valid spann config: {e}")
if hnsw_config is not None and spann_config is not None:
raise ValueError("hnsw and spann cannot both be provided")
if config.get("embedding_function") is None:
ef = None
ef_config = {"type": "legacy"}
return {
"hnsw": hnsw_config,
"spann": spann_config,
"embedding_function": ef_config,
}
try:
ef = cast(EmbeddingFunction, config.get("embedding_function")) # type: ignore
if ef.is_legacy():
ef_config = {"type": "legacy"}
else:
# default space logic: if neither hnsw nor spann config is provided and metadata doesn't have space,
# then populate space from ef
# otherwise dont use default space from ef
# then validate the space afterwards based on the supported spaces of the embedding function,
# warn if space is not supported
if hnsw_config is None and spann_config is None:
if metadata is None or metadata.get("hnsw:space") is None:
# this populates space from ef if not provided in either config
hnsw_config = CreateHNSWConfiguration(space=ef.default_space())
# if hnsw config or spann config exists but space is not provided, populate it from ef
if hnsw_config is not None and hnsw_config.get("space") is None:
hnsw_config["space"] = ef.default_space()
if spann_config is not None and spann_config.get("space") is None:
spann_config["space"] = ef.default_space()
# Validate space compatibility with embedding function
if hnsw_config is not None:
if hnsw_config.get("space") not in ef.supported_spaces():
warnings.warn(
f"space {hnsw_config.get('space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}",
UserWarning,
stacklevel=2,
)
if spann_config is not None:
if spann_config.get("space") not in ef.supported_spaces():
warnings.warn(
f"space {spann_config.get('space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}",
UserWarning,
stacklevel=2,
)
# only validate space from metadata if config is not provided
if (
hnsw_config is None
and spann_config is None
and metadata is not None
and metadata.get("hnsw:space") is not None
):
if metadata.get("hnsw:space") not in ef.supported_spaces():
warnings.warn(
f"space {metadata.get('hnsw:space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}",
UserWarning,
stacklevel=2,
)
ef_config = {
"name": ef.name(),
"type": "known",
"config": ef.get_config(),
}
register_embedding_function(type(ef)) # type: ignore
except Exception as e:
warnings.warn(
f"legacy embedding function config: {e}",
DeprecationWarning,
stacklevel=2,
)
ef = None
ef_config = {"type": "legacy"}
return {
"hnsw": hnsw_config,
"spann": spann_config,
"embedding_function": ef_config,
}
def populate_create_hnsw_defaults(
config: CreateHNSWConfiguration, ef: Optional[EmbeddingFunction] = None # type: ignore
) -> CreateHNSWConfiguration:
"""Populate a CreateHNSW configuration with default values"""
if config.get("space") is None:
config["space"] = ef.default_space() if ef else "l2"
if config.get("ef_construction") is None:
config["ef_construction"] = 100
if config.get("max_neighbors") is None:
config["max_neighbors"] = 16
if config.get("ef_search") is None:
config["ef_search"] = 100
if config.get("num_threads") is None:
config["num_threads"] = cpu_count()
if config.get("batch_size") is None:
config["batch_size"] = 100
if config.get("sync_threshold") is None:
config["sync_threshold"] = 1000
if config.get("resize_factor") is None:
config["resize_factor"] = 1.2
return config
class UpdateHNSWConfiguration(TypedDict, total=False):
ef_search: int
num_threads: int
batch_size: int
sync_threshold: int
resize_factor: float
def json_to_update_hnsw_configuration(
json_map: Dict[str, Any]
) -> UpdateHNSWConfiguration:
config: UpdateHNSWConfiguration = {}
if "ef_search" in json_map:
config["ef_search"] = json_map["ef_search"]
if "num_threads" in json_map:
config["num_threads"] = json_map["num_threads"]
if "batch_size" in json_map:
config["batch_size"] = json_map["batch_size"]
if "sync_threshold" in json_map:
config["sync_threshold"] = json_map["sync_threshold"]
if "resize_factor" in json_map:
config["resize_factor"] = json_map["resize_factor"]
return config
class UpdateSpannConfiguration(TypedDict, total=False):
search_nprobe: int
ef_search: int
def json_to_update_spann_configuration(
json_map: Dict[str, Any]
) -> UpdateSpannConfiguration:
config: UpdateSpannConfiguration = {}
if "search_nprobe" in json_map:
config["search_nprobe"] = json_map["search_nprobe"]
if "ef_search" in json_map:
config["ef_search"] = json_map["ef_search"]
return config
class UpdateCollectionConfiguration(TypedDict, total=False):
hnsw: Optional[UpdateHNSWConfiguration]
spann: Optional[UpdateSpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
def update_collection_configuration_from_legacy_collection_metadata(
metadata: CollectionMetadata,
) -> UpdateCollectionConfiguration:
"""Create an UpdateCollectionConfiguration from legacy collection metadata"""
old_to_new = {
"hnsw:search_ef": "ef_search",
"hnsw:num_threads": "num_threads",
"hnsw:batch_size": "batch_size",
"hnsw:sync_threshold": "sync_threshold",
"hnsw:resize_factor": "resize_factor",
}
json_map = {}
for name, value in metadata.items():
if name in old_to_new:
json_map[old_to_new[name]] = value
hnsw_config = json_to_update_hnsw_configuration(json_map)
return UpdateCollectionConfiguration(hnsw=hnsw_config)
def update_collection_configuration_from_legacy_update_metadata(
metadata: UpdateMetadata,
) -> UpdateCollectionConfiguration:
"""Create an UpdateCollectionConfiguration from legacy update metadata"""
old_to_new = {
"hnsw:search_ef": "ef_search",
"hnsw:num_threads": "num_threads",
"hnsw:batch_size": "batch_size",
"hnsw:sync_threshold": "sync_threshold",
"hnsw:resize_factor": "resize_factor",
}
json_map = {}
for name, value in metadata.items():
if name in old_to_new:
json_map[old_to_new[name]] = value
hnsw_config = json_to_update_hnsw_configuration(json_map)
return UpdateCollectionConfiguration(hnsw=hnsw_config)
def update_collection_configuration_to_json_str(
config: UpdateCollectionConfiguration,
) -> str:
"""Convert an UpdateCollectionConfiguration to a JSON-serializable string"""
json_dict = update_collection_configuration_to_json(config)
return json.dumps(json_dict)
def update_collection_configuration_to_json(
config: UpdateCollectionConfiguration,
) -> Dict[str, Any]:
"""Convert an UpdateCollectionConfiguration to a JSON-serializable dict"""
hnsw_config = config.get("hnsw")
spann_config = config.get("spann")
ef = config.get("embedding_function")
if hnsw_config is None and spann_config is None and ef is None:
return {}
if hnsw_config is not None:
try:
hnsw_config = cast(UpdateHNSWConfiguration, hnsw_config)
except Exception as e:
raise ValueError(f"not a valid hnsw config: {e}")
if spann_config is not None:
try:
spann_config = cast(UpdateSpannConfiguration, spann_config)
except Exception as e:
raise ValueError(f"not a valid spann config: {e}")
ef_config: Dict[str, Any] | None = None
if ef is not None:
if ef.is_legacy():
ef_config = {"type": "legacy"}
else:
ef.validate_config(ef.get_config())
ef_config = {
"name": ef.name(),
"type": "known",
"config": ef.get_config(),
}
register_embedding_function(type(ef)) # type: ignore
else:
ef_config = None
return {
"hnsw": hnsw_config,
"spann": spann_config,
"embedding_function": ef_config,
}
def load_update_collection_configuration_from_json_str(
json_str: str,
) -> UpdateCollectionConfiguration:
json_map = json.loads(json_str)
return load_update_collection_configuration_from_json(json_map)
# TODO: make warnings prettier and add link to migration docs
def load_update_collection_configuration_from_json(
json_map: Dict[str, Any]
) -> UpdateCollectionConfiguration:
"""Convert a JSON dict to an UpdateCollectionConfiguration"""
if json_map.get("hnsw") is not None and json_map.get("spann") is not None:
raise ValueError("hnsw and spann cannot both be provided")
result = UpdateCollectionConfiguration()
# Handle vector index configurations
if json_map.get("hnsw") is not None:
result["hnsw"] = json_to_update_hnsw_configuration(json_map["hnsw"])
if json_map.get("spann") is not None:
result["spann"] = json_to_update_spann_configuration(json_map["spann"])
# Handle embedding function
if json_map.get("embedding_function") is not None:
if json_map["embedding_function"]["type"] == "legacy":
warnings.warn(
"legacy embedding function config",
DeprecationWarning,
stacklevel=2,
)
else:
ef = known_embedding_functions[json_map["embedding_function"]["name"]]
result["embedding_function"] = ef.build_from_config(
json_map["embedding_function"]["config"]
)
return result
def overwrite_hnsw_configuration(
existing_hnsw_config: HNSWConfiguration, update_hnsw_config: UpdateHNSWConfiguration
) -> HNSWConfiguration:
"""Overwrite a HNSWConfiguration with a new configuration"""
# Create a copy of the existing config and update with new values
result = dict(existing_hnsw_config)
update_fields = [
"ef_search",
"num_threads",
"batch_size",
"sync_threshold",
"resize_factor",
]
for field in update_fields:
if field in update_hnsw_config:
result[field] = update_hnsw_config[field] # type: ignore
return cast(HNSWConfiguration, result)
def overwrite_spann_configuration(
existing_spann_config: SpannConfiguration,
update_spann_config: UpdateSpannConfiguration,
) -> SpannConfiguration:
"""Overwrite a SpannConfiguration with a new configuration"""
result = dict(existing_spann_config)
update_fields = [
"search_nprobe",
"ef_search",
]
for field in update_fields:
if field in update_spann_config:
result[field] = update_spann_config[field] # type: ignore
return cast(SpannConfiguration, result)
# TODO: make warnings prettier and add link to migration docs
def overwrite_embedding_function(
existing_embedding_function: EmbeddingFunction, # type: ignore
update_embedding_function: EmbeddingFunction, # type: ignore
) -> EmbeddingFunction: # type: ignore
"""Overwrite an EmbeddingFunction with a new configuration"""
# Check for legacy embedding functions
if existing_embedding_function.is_legacy() or update_embedding_function.is_legacy():
warnings.warn(
"cannot update legacy embedding function config",
DeprecationWarning,
stacklevel=2,
)
return existing_embedding_function
# Validate function compatibility
if existing_embedding_function.name() != update_embedding_function.name():
raise ValueError(
f"Cannot update embedding function: incompatible types "
f"({existing_embedding_function.name()} vs {update_embedding_function.name()})"
)
# Validate and apply the configuration update
update_embedding_function.validate_config_update(
existing_embedding_function.get_config(), update_embedding_function.get_config()
)
return update_embedding_function
def overwrite_collection_configuration(
existing_config: CollectionConfiguration,
update_config: UpdateCollectionConfiguration,
) -> CollectionConfiguration:
"""Overwrite a CollectionConfiguration with a new configuration"""
update_spann = update_config.get("spann")
update_hnsw = update_config.get("hnsw")
if update_spann is not None and update_hnsw is not None:
raise ValueError("hnsw and spann cannot both be provided")
# Handle HNSW configuration update
updated_hnsw_config = existing_config.get("hnsw")
if updated_hnsw_config is not None and update_hnsw is not None:
updated_hnsw_config = overwrite_hnsw_configuration(
updated_hnsw_config, update_hnsw
)
# Handle SPANN configuration update
updated_spann_config = existing_config.get("spann")
if updated_spann_config is not None and update_spann is not None:
updated_spann_config = overwrite_spann_configuration(
updated_spann_config, update_spann
)
# Handle embedding function update
updated_embedding_function = existing_config.get("embedding_function")
update_ef = update_config.get("embedding_function")
if update_ef is not None:
if updated_embedding_function is not None:
updated_embedding_function = overwrite_embedding_function(
updated_embedding_function, update_ef
)
else:
updated_embedding_function = update_ef
return CollectionConfiguration(
hnsw=updated_hnsw_config,
spann=updated_spann_config,
embedding_function=updated_embedding_function,
)
def validate_embedding_function_conflict_on_create(
embedding_function: Optional[EmbeddingFunction], # type: ignore
configuration_ef: Optional[EmbeddingFunction], # type: ignore
) -> None:
"""
Validates that there are no conflicting embedding functions between function parameter
and collection configuration.
Args:
embedding_function: The embedding function provided as a parameter
configuration_ef: The embedding function from collection configuration
Returns:
bool: True if there is a conflict, False otherwise
"""
# If ef provided in function params and collection config, check if they are the same
# If not, there's a conflict
# ef is by default "default" if not provided, so ignore that case.
if embedding_function is not None and configuration_ef is not None:
if (
embedding_function.name() != "default"
and embedding_function.name() != configuration_ef.name()
):
raise ValueError(
f"Multiple embedding functions provided. Please provide only one. Embedding function conflict: {embedding_function.name()} vs {configuration_ef.name()}"
)
return None
# The reason to use the config on get, rather than build the ef is because
# if there is an issue with deserializing the config, an error shouldn't be raised
# at get time. CollectionCommon.py will raise an error at _embed time if there is an issue deserializing.
def validate_embedding_function_conflict_on_get(
embedding_function: Optional[EmbeddingFunction], # type: ignore
persisted_ef_config: Optional[Dict[str, Any]],
) -> None:
"""
Validates that there are no conflicting embedding functions between function parameter
and collection configuration.
"""
if persisted_ef_config is not None and embedding_function is not None:
if (
embedding_function.name() != "default"
and persisted_ef_config.get("name") is not None
and persisted_ef_config.get("name") != embedding_function.name()
):
raise ValueError(
f"An embedding function already exists in the collection configuration, and a new one is provided. If this is intentional, please embed documents separately. Embedding function conflict: new: {embedding_function.name()} vs persisted: {persisted_ef_config.get('name')}"
)
return None
def update_schema_from_collection_configuration(
schema: "Schema", configuration: "UpdateCollectionConfiguration"
) -> "Schema":
"""
Updates a schema with configuration changes.
Only updates fields that are present in the configuration update.
Args:
schema: The existing Schema object
configuration: The configuration updates to apply
Returns:
Updated Schema object
"""
# Get the vector index from defaults and #embedding key
if (
schema.defaults.float_list is None
or schema.defaults.float_list.vector_index is None
):
raise ValueError("Schema is missing defaults.float_list.vector_index")
embedding_key = "#embedding"
if embedding_key not in schema.keys:
raise ValueError(f"Schema is missing keys[{embedding_key}]")
embedding_value_types = schema.keys[embedding_key]
if (
embedding_value_types.float_list is None
or embedding_value_types.float_list.vector_index is None
):
raise ValueError(
f"Schema is missing keys[{embedding_key}].float_list.vector_index"
)
# Update vector index config in both locations
for vector_index in [
schema.defaults.float_list.vector_index,
embedding_value_types.float_list.vector_index,
]:
if "hnsw" in configuration and configuration["hnsw"] is not None:
# Update HNSW config
if vector_index.config.hnsw is None:
raise ValueError("Trying to update HNSW config but schema has SPANN")
hnsw_config = vector_index.config.hnsw
update_hnsw = configuration["hnsw"]
# Only update fields that are present in the update
if "ef_search" in update_hnsw:
hnsw_config.ef_search = update_hnsw["ef_search"]
if "num_threads" in update_hnsw:
hnsw_config.num_threads = update_hnsw["num_threads"]
if "batch_size" in update_hnsw:
hnsw_config.batch_size = update_hnsw["batch_size"]
if "sync_threshold" in update_hnsw:
hnsw_config.sync_threshold = update_hnsw["sync_threshold"]
if "resize_factor" in update_hnsw:
hnsw_config.resize_factor = update_hnsw["resize_factor"]
elif "spann" in configuration and configuration["spann"] is not None:
# Update SPANN config
if vector_index.config.spann is None:
raise ValueError("Trying to update SPANN config but schema has HNSW")
spann_config = vector_index.config.spann
update_spann = configuration["spann"]
# Only update fields that are present in the update
if "search_nprobe" in update_spann:
spann_config.search_nprobe = update_spann["search_nprobe"]
if "ef_search" in update_spann:
spann_config.ef_search = update_spann["ef_search"]
# Update embedding function if present
if (
"embedding_function" in configuration
and configuration["embedding_function"] is not None
):
vector_index.config.embedding_function = configuration["embedding_function"]
return schema

View File

@@ -0,0 +1,410 @@
from abc import abstractmethod
import json
from overrides import override
from typing import (
Any,
ClassVar,
Dict,
List,
Optional,
Protocol,
Union,
TypeVar,
cast,
)
from typing_extensions import Self
from multiprocessing import cpu_count
from chromadb.serde import JSONSerializable
# TODO: move out of API
class StaticParameterError(Exception):
"""Represents an error that occurs when a static parameter is set."""
pass
class InvalidConfigurationError(ValueError):
"""Represents an error that occurs when a configuration is invalid."""
pass
ParameterValue = Union[str, int, float, bool, "ConfigurationInternal"]
class ParameterValidator(Protocol):
"""Represents an abstract parameter validator."""
@abstractmethod
def __call__(self, value: ParameterValue) -> bool:
"""Returns whether the given value is valid."""
raise NotImplementedError()
class ConfigurationDefinition:
"""Represents the definition of a configuration."""
name: str
validator: ParameterValidator
is_static: bool
default_value: ParameterValue
def __init__(
self,
name: str,
validator: ParameterValidator,
is_static: bool,
default_value: ParameterValue,
):
self.name = name
self.validator = validator
self.is_static = is_static
self.default_value = default_value
class ConfigurationParameter:
"""Represents a parameter of a configuration."""
name: str
value: ParameterValue
def __init__(self, name: str, value: ParameterValue):
self.name = name
self.value = value
def __repr__(self) -> str:
return f"ConfigurationParameter({self.name}, {self.value})"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, ConfigurationParameter):
return NotImplemented
return self.name == __value.name and self.value == __value.value
T = TypeVar("T", bound="ConfigurationInternal")
class ConfigurationInternal(JSONSerializable["ConfigurationInternal"]):
"""Represents an abstract configuration, used internally by Chroma."""
# The internal data structure used to store the parameters
# All expected parameters must be present with defaults or None values at initialization
parameter_map: Dict[str, ConfigurationParameter]
definitions: ClassVar[Dict[str, ConfigurationDefinition]]
def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None):
"""Initializes a new instance of the Configuration class. Respecting defaults and
validators."""
self.parameter_map = {}
if parameters is not None:
for parameter in parameters:
if parameter.name not in self.definitions:
raise ValueError(f"Invalid parameter name: {parameter.name}")
definition = self.definitions[parameter.name]
# Handle the case where we have a recursive configuration definition
if isinstance(parameter.value, dict):
child_type = globals().get(parameter.value.get("_type", None))
if child_type is None:
raise ValueError(
f"Invalid configuration type: {parameter.value}"
)
parameter.value = child_type.from_json(parameter.value)
if not isinstance(parameter.value, type(definition.default_value)):
raise ValueError(f"Invalid parameter value: {parameter.value}")
parameter_validator = definition.validator
if not parameter_validator(parameter.value):
raise ValueError(f"Invalid parameter value: {parameter.value}")
self.parameter_map[parameter.name] = parameter
# Apply the defaults for any missing parameters
for name, definition in self.definitions.items():
if name not in self.parameter_map:
self.parameter_map[name] = ConfigurationParameter(
name=name, value=definition.default_value
)
self.configuration_validator()
def __repr__(self) -> str:
return f"Configuration({self.parameter_map.values()})"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, ConfigurationInternal):
return NotImplemented
return self.parameter_map == __value.parameter_map
@abstractmethod
def configuration_validator(self) -> None:
"""Perform custom validation when parameters are dependent on each other.
Raises an InvalidConfigurationError if the configuration is invalid.
"""
pass
def get_parameters(self) -> List[ConfigurationParameter]:
"""Returns the parameters of the configuration."""
return list(self.parameter_map.values())
def get_parameter(self, name: str) -> ConfigurationParameter:
"""Returns the parameter with the given name, or except if it doesn't exist."""
if name not in self.parameter_map:
raise ValueError(
f"Invalid parameter name: {name} for configuration {self.__class__.__name__}"
)
param_value = cast(ConfigurationParameter, self.parameter_map.get(name))
return param_value
def set_parameter(self, name: str, value: Union[str, int, float, bool]) -> None:
"""Sets the parameter with the given name to the given value."""
if name not in self.definitions:
raise ValueError(f"Invalid parameter name: {name}")
definition = self.definitions[name]
parameter = self.parameter_map[name]
if definition.is_static:
raise StaticParameterError(f"Cannot set static parameter: {name}")
if not definition.validator(value):
raise ValueError(f"Invalid value for parameter {name}: {value}")
parameter.value = value
@override
def to_json_str(self) -> str:
"""Returns the JSON representation of the configuration."""
return json.dumps(self.to_json())
@classmethod
@override
def from_json_str(cls, json_str: str) -> Self:
"""Returns a configuration from the given JSON string."""
try:
config_json = json.loads(json_str)
except json.JSONDecodeError:
raise ValueError(
f"Unable to decode configuration from JSON string: {json_str}"
)
return cls.from_json(config_json) if config_json else cls()
@override
def to_json(self) -> Dict[str, Any]:
"""Returns the JSON compatible dictionary representation of the configuration."""
json_dict = {
name: parameter.value.to_json()
if isinstance(parameter.value, ConfigurationInternal)
else parameter.value
for name, parameter in self.parameter_map.items()
}
# What kind of configuration is this?
json_dict["_type"] = self.__class__.__name__
return json_dict
@classmethod
@override
def from_json(cls, json_map: Dict[str, Any]) -> Self:
"""Returns a configuration from the given JSON string."""
if cls.__name__ != json_map.get("_type", None):
raise ValueError(
f"Trying to instantiate configuration of type {cls.__name__} from JSON with type {json_map['_type']}"
)
parameters = []
for name, value in json_map.items():
# Type value is only for storage
if name == "_type":
continue
parameters.append(ConfigurationParameter(name=name, value=value))
return cls(parameters=parameters)
class HNSWConfigurationInternal(ConfigurationInternal):
"""Internal representation of the HNSW configuration.
Used for validation, defaults, serialization and deserialization."""
definitions = {
"space": ConfigurationDefinition(
name="space",
validator=lambda value: isinstance(value, str)
and value in ["l2", "ip", "cosine"],
is_static=True,
default_value="l2",
),
"ef_construction": ConfigurationDefinition(
name="ef_construction",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=100,
),
"ef_search": ConfigurationDefinition(
name="ef_search",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=False,
default_value=100,
),
"num_threads": ConfigurationDefinition(
name="num_threads",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=False,
default_value=cpu_count(), # By default use all cores available
),
"M": ConfigurationDefinition(
name="M",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=16,
),
"resize_factor": ConfigurationDefinition(
name="resize_factor",
validator=lambda value: isinstance(value, float) and value >= 1,
is_static=True,
default_value=1.2,
),
"batch_size": ConfigurationDefinition(
name="batch_size",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=100,
),
"sync_threshold": ConfigurationDefinition(
name="sync_threshold",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=1000,
),
}
@override
def configuration_validator(self) -> None:
batch_size = self.parameter_map.get("batch_size")
sync_threshold = self.parameter_map.get("sync_threshold")
if (
batch_size
and sync_threshold
and cast(int, batch_size.value) > cast(int, sync_threshold.value)
):
raise InvalidConfigurationError(
"batch_size must be less than or equal to sync_threshold"
)
@classmethod
def from_legacy_params(cls, params: Dict[str, Any]) -> Self:
"""Returns an HNSWConfiguration from a metadata dict containing legacy HNSW parameters. Used for migration."""
# We maintain this map to avoid a circular import with HnswParams, and
# because then names won't change since we intend to deprecate HNSWParams
# in favor of this type of configuration.
old_to_new = {
"hnsw:space": "space",
"hnsw:construction_ef": "ef_construction",
"hnsw:search_ef": "ef_search",
"hnsw:M": "M",
"hnsw:num_threads": "num_threads",
"hnsw:resize_factor": "resize_factor",
"hnsw:batch_size": "batch_size",
"hnsw:sync_threshold": "sync_threshold",
}
parameters = []
for name, value in params.items():
if name not in old_to_new:
raise ValueError(f"Invalid legacy HNSW parameter name: {name}")
parameters.append(
ConfigurationParameter(name=old_to_new[name], value=value)
)
return cls(parameters)
# This is the user-facing interface for HNSW index configuration parameters.
# Internally, we pass around HNSWConfigurationInternal objects, which perform
# validation, serialization and deserialization. Users don't need to know
# about that and instead get a clean constructor with default arguments.
class HNSWConfigurationInterface(HNSWConfigurationInternal):
"""HNSW index configuration parameters.
See https://docs.trychroma.com/guides#changing-the-distance-function for more information.
"""
def __init__(
self,
space: str = "l2",
ef_construction: int = 100,
ef_search: int = 100,
num_threads: int = cpu_count(),
M: int = 16,
resize_factor: float = 1.2,
batch_size: int = 100,
sync_threshold: int = 1000,
):
parameters = [
ConfigurationParameter(name="space", value=space),
ConfigurationParameter(name="ef_construction", value=ef_construction),
ConfigurationParameter(name="ef_search", value=ef_search),
ConfigurationParameter(name="num_threads", value=num_threads),
ConfigurationParameter(name="M", value=M),
ConfigurationParameter(name="resize_factor", value=resize_factor),
ConfigurationParameter(name="batch_size", value=batch_size),
ConfigurationParameter(name="sync_threshold", value=sync_threshold),
]
super().__init__(parameters=parameters)
# Alias for user convenience - the user doesn't need to know this is an 'Interface'
HNSWConfiguration = HNSWConfigurationInterface
class CollectionConfigurationInternal(ConfigurationInternal):
"""Internal representation of the collection configuration.
Used for validation, defaults, and serialization / deserialization."""
definitions = {
"hnsw_configuration": ConfigurationDefinition(
name="hnsw_configuration",
validator=lambda value: isinstance(value, HNSWConfigurationInternal),
is_static=True,
default_value=HNSWConfigurationInternal(),
),
}
@override
def configuration_validator(self) -> None:
pass
# This is the user-facing interface for HNSW index configuration parameters.
# Internally, we pass around HNSWConfigurationInternal objects, which perform
# validation, serialization and deserialization. Users don't need to know
# about that and instead get a clean constructor with default arguments.
class CollectionConfigurationInterface(CollectionConfigurationInternal):
"""Configuration parameters for creating a collection."""
def __init__(self, hnsw_configuration: Optional[HNSWConfigurationInternal]):
"""Initializes a new instance of the CollectionConfiguration class.
Args:
hnsw_configuration: The HNSW configuration to use for the collection.
"""
if hnsw_configuration is None:
hnsw_configuration = HNSWConfigurationInternal()
parameters = [
ConfigurationParameter(name="hnsw_configuration", value=hnsw_configuration)
]
super().__init__(parameters=parameters)
# Alias for user convenience - the user doesn't need to know this is an 'Interface'.
CollectionConfiguration = CollectionConfigurationInterface
class EmbeddingsQueueConfigurationInternal(ConfigurationInternal):
definitions = {
"automatically_purge": ConfigurationDefinition(
name="automatically_purge",
validator=lambda value: isinstance(value, bool),
is_static=False,
default_value=True,
),
}
@override
def configuration_validator(self) -> None:
pass

View File

@@ -0,0 +1,806 @@
import orjson
import logging
from typing import Any, Dict, Optional, cast, Tuple, List
from typing import Sequence
from uuid import UUID
import httpx
import urllib.parse
from overrides import override
from chromadb.api.models.AttachedFunction import AttachedFunction
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
update_collection_configuration_to_json,
create_collection_configuration_to_json,
)
from chromadb import __version__
from chromadb.api.base_http_client import BaseHTTPClient
from chromadb.types import Database, Tenant, Collection as CollectionModel
from chromadb.api import ServerAPI
from chromadb.execution.expression.plan import Search
from chromadb.api.types import (
Documents,
Embeddings,
IDs,
Include,
Schema,
Metadatas,
URIs,
Where,
WhereDocument,
GetResult,
QueryResult,
SearchResult,
CollectionMetadata,
validate_batch,
convert_np_embeddings_to_list,
IncludeMetadataDocuments,
IncludeMetadataDocumentsDistances,
)
from chromadb.api.types import (
IncludeMetadataDocumentsEmbeddings,
optional_embeddings_to_base64_strings,
serialize_metadata,
deserialize_metadata,
)
from chromadb.auth import UserIdentity
from chromadb.auth import (
ClientAuthProvider,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.telemetry.product import ProductTelemetryClient
logger = logging.getLogger(__name__)
class FastAPI(BaseHTTPClient, ServerAPI):
def __init__(self, system: System):
super().__init__(system)
system.settings.require("chroma_server_host")
system.settings.require("chroma_server_http_port")
self._opentelemetry_client = self.require(OpenTelemetryClient)
self._product_telemetry_client = self.require(ProductTelemetryClient)
self._settings = system.settings
self._api_url = FastAPI.resolve_url(
chroma_server_host=str(system.settings.chroma_server_host),
chroma_server_http_port=system.settings.chroma_server_http_port,
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
default_api_path=system.settings.chroma_server_api_default_path,
)
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
self._session = httpx.Client(timeout=None, limits=limits)
self._header = system.settings.chroma_server_headers or {}
self._header["Content-Type"] = "application/json"
self._header["User-Agent"] = (
"Chroma Python Client v"
+ __version__
+ " (https://github.com/chroma-core/chroma)"
)
if self._settings.chroma_server_ssl_verify is not None:
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
if self._header is not None:
self._session.headers.update(self._header)
if system.settings.chroma_client_auth_provider:
self._auth_provider = self.require(ClientAuthProvider)
_headers = self._auth_provider.authenticate()
for header, value in _headers.items():
self._session.headers[header] = value.get_secret_value()
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
# If the request has json in kwargs, use orjson to serialize it,
# remove it from kwargs, and add it to the content parameter
# This is because httpx uses a slower json serializer
if "json" in kwargs:
data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY)
kwargs["content"] = data
# Unlike requests, httpx does not automatically escape the path
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
url = self._api_url + escaped_path
response = self._session.request(method, url, **cast(Any, kwargs))
BaseHTTPClient._raise_chroma_error(response)
return orjson.loads(response.text)
@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
@override
def heartbeat(self) -> int:
"""Returns the current server time in nanoseconds to check if the server is alive"""
resp_json = self._make_request("get", "/heartbeat")
return int(resp_json["nanosecond heartbeat"])
# Migrated to rust in distributed.
@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
@override
def create_database(
self,
name: str,
tenant: str = DEFAULT_TENANT,
) -> None:
"""Creates a database"""
self._make_request(
"post",
f"/tenants/{tenant}/databases",
json={"name": name},
)
# Migrated to rust in distributed.
@trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION)
@override
def get_database(
self,
name: str,
tenant: str = DEFAULT_TENANT,
) -> Database:
"""Returns a database"""
resp_json = self._make_request(
"get",
f"/tenants/{tenant}/databases/{name}",
)
return Database(
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
)
@trace_method("FastAPI.delete_database", OpenTelemetryGranularity.OPERATION)
@override
def delete_database(
self,
name: str,
tenant: str = DEFAULT_TENANT,
) -> None:
"""Deletes a database"""
self._make_request(
"delete",
f"/tenants/{tenant}/databases/{name}",
)
@trace_method("FastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
"""Returns a list of all databases"""
json_databases = self._make_request(
"get",
f"/tenants/{tenant}/databases",
params=BaseHTTPClient._clean_params(
{
"limit": limit,
"offset": offset,
}
),
)
databases = [
Database(id=db["id"], name=db["name"], tenant=db["tenant"])
for db in json_databases
]
return databases
@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
self._make_request("post", "/tenants", json={"name": name})
@trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
def get_tenant(self, name: str) -> Tenant:
resp_json = self._make_request("get", "/tenants/" + name)
return Tenant(name=resp_json["name"])
@trace_method("FastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
@override
def get_user_identity(self) -> UserIdentity:
return UserIdentity(**self._make_request("get", "/auth/identity"))
@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
@override
def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Sequence[CollectionModel]:
"""Returns a list of all collections"""
json_collections = self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections",
params=BaseHTTPClient._clean_params(
{
"limit": limit,
"offset": offset,
}
),
)
collection_models = [
CollectionModel.from_json(json_collection)
for json_collection in json_collections
]
return collection_models
@trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
@override
def count_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> int:
"""Returns a count of collections"""
resp_json = self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections_count",
)
return cast(int, resp_json)
@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
@override
def create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
"""Creates a collection"""
config_json = (
create_collection_configuration_to_json(configuration, metadata)
if configuration
else None
)
serialized_schema = schema.serialize_to_json() if schema else None
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections",
json={
"name": name,
"metadata": metadata,
"configuration": config_json,
"schema": serialized_schema,
"get_or_create": get_or_create,
},
)
model = CollectionModel.from_json(resp_json)
return model
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
@override
def get_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
"""Returns a collection"""
resp_json = self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections/{name}",
)
model = CollectionModel.from_json(resp_json)
return model
@trace_method(
"FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
)
@override
def get_or_create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
return self.create_collection(
name=name,
metadata=metadata,
configuration=configuration,
schema=schema,
get_or_create=True,
tenant=tenant,
database=database,
)
@trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION)
@override
def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration: Optional[UpdateCollectionConfiguration] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
"""Updates a collection"""
self._make_request(
"put",
f"/tenants/{tenant}/databases/{database}/collections/{id}",
json={
"new_metadata": new_metadata,
"new_name": new_name,
"new_configuration": update_collection_configuration_to_json(
new_configuration
)
if new_configuration
else None,
},
)
@trace_method("FastAPI._fork", OpenTelemetryGranularity.OPERATION)
@override
def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
"""Forks a collection"""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
json={"new_name": new_name},
)
model = CollectionModel.from_json(resp_json)
return model
@trace_method("FastAPI._search", OpenTelemetryGranularity.OPERATION)
@override
def _search(
self,
collection_id: UUID,
searches: List[Search],
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> SearchResult:
"""Performs hybrid search on a collection"""
# Convert Search objects to dictionaries
payload = {"searches": [s.to_dict() for s in searches]}
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/search",
json=payload,
)
# Deserialize metadatas: convert transport format to SparseVector instances
metadata_batches = resp_json.get("metadatas", None)
if metadata_batches is not None:
# SearchResult has nested structure: List[Optional[List[Optional[Metadata]]]]
resp_json["metadatas"] = [
[
deserialize_metadata(metadata) if metadata is not None else None
for metadata in metadatas
]
if metadatas is not None
else None
for metadatas in metadata_batches
]
return SearchResult(resp_json)
@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
@override
def delete_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
"""Deletes a collection"""
self._make_request(
"delete",
f"/tenants/{tenant}/databases/{database}/collections/{name}",
)
@trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION)
@override
def _count(
self,
collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
"""Returns the number of embeddings in the database"""
resp_json = self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/count",
)
return cast(int, resp_json)
@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
@override
def _peek(
self,
collection_id: UUID,
n: int = 10,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
return cast(
GetResult,
self._get(
collection_id,
tenant=tenant,
database=database,
limit=n,
include=IncludeMetadataDocumentsEmbeddings,
),
)
@trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION)
@override
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocuments,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
# Servers do not support receiving "data", as that is hydrated by the client as a loadable
filtered_include = [i for i in include if i != "data"]
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/get",
json={
"ids": ids,
"where": where,
"limit": limit,
"offset": offset,
"where_document": where_document,
"include": filtered_include,
},
)
# Deserialize metadatas: convert transport format to SparseVector instances
metadatas = resp_json.get("metadatas", None)
if metadatas is not None:
metadatas = [
deserialize_metadata(metadata) if metadata is not None else None
for metadata in metadatas
]
return GetResult(
ids=resp_json["ids"],
embeddings=resp_json.get("embeddings", None),
metadatas=metadatas, # type: ignore
documents=resp_json.get("documents", None),
data=None,
uris=resp_json.get("uris", None),
included=include,
)
@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
@override
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
"""Deletes embeddings from the database"""
self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete",
json={
"ids": ids,
"where": where,
"where_document": where_document,
},
)
return None
@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
def _submit_batch(
self,
batch: Tuple[
IDs,
Optional[Embeddings],
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
],
url: str,
) -> None:
"""
Submits a batch of embeddings to the database
"""
# Serialize metadatas: convert SparseVector instances to transport format
serialized_metadatas = None
if batch[2] is not None:
serialized_metadatas = [
serialize_metadata(metadata) if metadata is not None else None
for metadata in batch[2]
]
data = {
"ids": batch[0],
"embeddings": optional_embeddings_to_base64_strings(batch[1])
if self.supports_base64_encoding()
else batch[1],
"metadatas": serialized_metadatas,
"documents": batch[3],
"uris": batch[4],
}
self._make_request("post", url, json=data)
@trace_method("FastAPI._add", OpenTelemetryGranularity.ALL)
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""
Adds a batch of embeddings to the database
- pass in column oriented data lists
"""
batch = (
ids,
embeddings,
metadatas,
documents,
uris,
)
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
self._submit_batch(
batch,
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/add",
)
return True
@trace_method("FastAPI._update", OpenTelemetryGranularity.ALL)
@override
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""
Updates a batch of embeddings in the database
- pass in column oriented data lists
"""
batch = (
ids,
embeddings if embeddings is not None else None,
metadatas,
documents,
uris,
)
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
self._submit_batch(
batch,
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/update",
)
return True
@trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL)
@override
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""
Upserts a batch of embeddings in the database
- pass in column oriented data lists
"""
batch = (
ids,
embeddings,
metadatas,
documents,
uris,
)
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
self._submit_batch(
batch,
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/upsert",
)
return True
@trace_method("FastAPI._query", OpenTelemetryGranularity.ALL)
@override
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
# Clients do not support receiving "data", as that is hydrated by the client as a loadable
filtered_include = [i for i in include if i != "data"]
"""Gets the nearest neighbors of a single embedding"""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
json={
"ids": ids,
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
if query_embeddings is not None
else None,
"n_results": n_results,
"where": where,
"where_document": where_document,
"include": filtered_include,
},
)
# Deserialize metadatas: convert transport format to SparseVector instances
metadata_batches = resp_json.get("metadatas", None)
if metadata_batches is not None:
metadata_batches = [
[
deserialize_metadata(metadata) if metadata is not None else None
for metadata in metadatas
]
if metadatas is not None
else None
for metadatas in metadata_batches
]
return QueryResult(
ids=resp_json["ids"],
distances=resp_json.get("distances", None),
embeddings=resp_json.get("embeddings", None),
metadatas=metadata_batches, # type: ignore
documents=resp_json.get("documents", None),
uris=resp_json.get("uris", None),
data=None,
included=include,
)
@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
@override
def reset(self) -> bool:
"""Resets the database"""
resp_json = self._make_request("post", "/reset")
return cast(bool, resp_json)
@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
@override
def get_version(self) -> str:
"""Returns the version of the server"""
resp_json = self._make_request("get", "/version")
return cast(str, resp_json)
@override
def get_settings(self) -> Settings:
"""Returns the settings of the client"""
return self._settings
@trace_method("FastAPI.get_pre_flight_checks", OpenTelemetryGranularity.OPERATION)
def get_pre_flight_checks(self) -> Any:
if self.pre_flight_checks is None:
resp_json = self._make_request("get", "/pre-flight-checks")
self.pre_flight_checks = resp_json
return self.pre_flight_checks
@trace_method(
"FastAPI.supports_base64_encoding", OpenTelemetryGranularity.OPERATION
)
def supports_base64_encoding(self) -> bool:
pre_flight_checks = self.get_pre_flight_checks()
b64_encoding_enabled = cast(
bool, pre_flight_checks.get("supports_base64_encoding", False)
)
return b64_encoding_enabled
@trace_method("FastAPI.get_max_batch_size", OpenTelemetryGranularity.OPERATION)
@override
def get_max_batch_size(self) -> int:
pre_flight_checks = self.get_pre_flight_checks()
max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1))
return max_batch_size
@trace_method("FastAPI.attach_function", OpenTelemetryGranularity.ALL)
@override
def attach_function(
self,
function_id: str,
name: str,
input_collection_id: UUID,
output_collection: str,
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Attach a function to a collection."""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/attach",
json={
"name": name,
"function_id": function_id,
"output_collection": output_collection,
"params": params,
},
)
return AttachedFunction(
client=self,
id=UUID(resp_json["attached_function"]["id"]),
name=resp_json["attached_function"]["name"],
function_id=resp_json["attached_function"]["function_id"],
input_collection_id=input_collection_id,
output_collection=output_collection,
params=params,
tenant=tenant,
database=database,
)
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
@override
def detach_function(
self,
attached_function_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""Detach a function and prevent any further runs."""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
json={
"delete_output": delete_output,
},
)
return cast(bool, resp_json["success"])

View File

@@ -0,0 +1,492 @@
from typing import TYPE_CHECKING, Optional, Union, List, cast
from chromadb.api.types import (
URI,
CollectionMetadata,
Embedding,
PyEmbedding,
Include,
Metadata,
Document,
Image,
Where,
IDs,
GetResult,
QueryResult,
ID,
OneOrMany,
WhereDocument,
SearchResult,
maybe_cast_one_to_many,
)
from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.api.collection_configuration import UpdateCollectionConfiguration
from chromadb.execution.expression.plan import Search
if TYPE_CHECKING:
from chromadb.api import AsyncServerAPI # noqa: F401
class AsyncCollection(CollectionCommon["AsyncServerAPI"]):
async def add(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
images: The images to associate with the embeddings. Optional.
uris: The uris of the images to associate with the embeddings. Optional.
Returns:
None
Raises:
ValueError: If you don't provide either embeddings or documents
ValueError: If the length of ids, embeddings, metadatas, or documents don't match
ValueError: If you don't provide an embedding function and don't provide embeddings
ValueError: If you provide both embeddings and documents
ValueError: If you provide an id that already exists
"""
add_request = self._validate_and_prepare_add_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)
await self._client._add(
collection_id=self.id,
ids=add_request["ids"],
embeddings=add_request["embeddings"],
metadatas=add_request["metadatas"],
documents=add_request["documents"],
uris=add_request["uris"],
tenant=self.tenant,
database=self.database,
)
async def count(self) -> int:
"""The total number of embeddings added to the database
Returns:
int: The total number of embeddings added to the database
"""
return await self._client._count(
collection_id=self.id,
tenant=self.tenant,
database=self.database,
)
async def get(
self,
ids: Optional[OneOrMany[ID]] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"],
) -> GetResult:
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
all embeddings up to limit starting at offset.
Args:
ids: The ids of the embeddings to get. Optional.
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
limit: The number of documents to return. Optional.
offset: The offset to start returning results from. Useful for paging results with limit. Optional.
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`. Ids are always included. Defaults to `["metadatas", "documents"]`. Optional.
Returns:
GetResult: A GetResult object containing the results.
"""
get_request = self._validate_and_prepare_get_request(
ids=ids,
where=where,
where_document=where_document,
include=include,
)
get_results = await self._client._get(
collection_id=self.id,
ids=get_request["ids"],
where=get_request["where"],
where_document=get_request["where_document"],
include=get_request["include"],
limit=limit,
offset=offset,
tenant=self.tenant,
database=self.database,
)
return self._transform_get_response(
response=get_results, include=get_request["include"]
)
async def peek(self, limit: int = 10) -> GetResult:
"""Get the first few results in the database up to limit
Args:
limit: The number of results to return.
Returns:
GetResult: A GetResult object containing the results.
"""
return self._transform_peek_response(
await self._client._peek(
collection_id=self.id,
n=limit,
tenant=self.tenant,
database=self.database,
)
)
async def query(
self,
query_embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
] = None,
query_texts: Optional[OneOrMany[Document]] = None,
query_images: Optional[OneOrMany[Image]] = None,
query_uris: Optional[OneOrMany[URI]] = None,
ids: Optional[OneOrMany[ID]] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = [
"metadatas",
"documents",
"distances",
],
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
Args:
query_embeddings: The embeddings to get the closes neighbors of. Optional.
query_texts: The document texts to get the closes neighbors of. Optional.
query_images: The images to get the closes neighbors of. Optional.
ids: A subset of ids to search within. Optional.
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`, `"distances"`. Ids are always included. Defaults to `["metadatas", "documents", "distances"]`. Optional.
Returns:
QueryResult: A QueryResult object containing the results.
Raises:
ValueError: If you don't provide either query_embeddings, query_texts, or query_images
ValueError: If you provide both query_embeddings and query_texts
ValueError: If you provide both query_embeddings and query_images
ValueError: If you provide both query_texts and query_images
"""
query_request = self._validate_and_prepare_query_request(
query_embeddings=query_embeddings,
query_texts=query_texts,
query_images=query_images,
query_uris=query_uris,
ids=ids,
n_results=n_results,
where=where,
where_document=where_document,
include=include,
)
query_results = await self._client._query(
collection_id=self.id,
ids=query_request["ids"],
query_embeddings=query_request["embeddings"],
n_results=query_request["n_results"],
where=query_request["where"],
where_document=query_request["where_document"],
include=query_request["include"],
tenant=self.tenant,
database=self.database,
)
return self._transform_query_response(
response=query_results, include=query_request["include"]
)
async def modify(
self,
name: Optional[str] = None,
metadata: Optional[CollectionMetadata] = None,
configuration: Optional[UpdateCollectionConfiguration] = None,
) -> None:
"""Modify the collection name or metadata
Args:
name: The updated name for the collection. Optional.
metadata: The updated metadata for the collection. Optional.
Returns:
None
"""
self._validate_modify_request(metadata)
# Note there is a race condition here where the metadata can be updated
# but another thread sees the cached local metadata.
# TODO: fixme
await self._client._modify(
id=self.id,
new_name=name,
new_metadata=metadata,
new_configuration=configuration,
tenant=self.tenant,
database=self.database,
)
self._update_model_after_modify_success(name, metadata, configuration)
async def fork(
self,
new_name: str,
) -> "AsyncCollection":
"""Fork the current collection under a new name. The returning collection should contain identical data to the current collection.
This is an experimental API that only works for Hosted Chroma for now.
Args:
new_name: The name of the new collection.
Returns:
Collection: A new collection with the specified name and containing identical data to the current collection.
"""
model = await self._client._fork(
collection_id=self.id,
new_name=new_name,
tenant=self.tenant,
database=self.database,
)
return AsyncCollection(
client=self._client,
model=model,
embedding_function=self._embedding_function,
data_loader=self._data_loader,
)
async def search(
self,
searches: OneOrMany[Search],
) -> SearchResult:
"""Perform hybrid search on the collection.
This is an experimental API that only works for Hosted Chroma for now.
Args:
searches: A single Search object or a list of Search objects, each containing:
- where: Where expression for filtering
- rank: Ranking expression for hybrid search (defaults to Val(0.0))
- limit: Limit configuration for pagination (defaults to no limit)
- select: Select configuration for keys to return (defaults to empty)
Returns:
SearchResult: Column-major format response with:
- ids: List of result IDs for each search payload
- documents: Optional documents for each payload
- embeddings: Optional embeddings for each payload
- metadatas: Optional metadata for each payload
- scores: Optional scores for each payload
- select: List of selected keys for each payload
Raises:
NotImplementedError: For local/segment API implementations
Examples:
# Using builder pattern with Key constants
from chromadb.execution.expression import (
Search, Key, K, Knn, Val
)
# Note: K is an alias for Key, so K.DOCUMENT == Key.DOCUMENT
search = (Search()
.where((K("category") == "science") & (K("score") > 0.5))
.rank(Knn(query=[0.1, 0.2, 0.3]) * 0.8 + Val(0.5) * 0.2)
.limit(10, offset=0)
.select(K.DOCUMENT, K.SCORE, "title"))
# Direct construction
from chromadb.execution.expression import (
Search, Eq, And, Gt, Knn, Limit, Select, Key
)
search = Search(
where=And([Eq("category", "science"), Gt("score", 0.5)]),
rank=Knn(query=[0.1, 0.2, 0.3]),
limit=Limit(offset=0, limit=10),
select=Select(keys={Key.DOCUMENT, Key.SCORE, "title"})
)
# Single search
result = await collection.search(search)
# Multiple searches at once
searches = [
Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])),
Search().where(K("type") == "paper").rank(Knn(query=[0.3, 0.4]))
]
results = await collection.search(searches)
"""
# Convert single search to list for consistent handling
searches_list = maybe_cast_one_to_many(searches)
if searches_list is None:
searches_list = []
# Embed any string queries in Knn objects
embedded_searches = [
self._embed_search_string_queries(search) for search in searches_list
]
return await self._client._search(
collection_id=self.id,
searches=cast(List[Search], embedded_searches),
tenant=self.tenant,
database=self.database,
)
async def update(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
"""Update the embeddings, metadatas or documents for provided ids.
Args:
ids: The ids of the embeddings to update
embeddings: The embeddings to update. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
images: The images to associate with the embeddings. Optional.
Returns:
None
"""
update_request = self._validate_and_prepare_update_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)
await self._client._update(
collection_id=self.id,
ids=update_request["ids"],
embeddings=update_request["embeddings"],
metadatas=update_request["metadatas"],
documents=update_request["documents"],
uris=update_request["uris"],
tenant=self.tenant,
database=self.database,
)
async def upsert(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
"""Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.
Args:
ids: The ids of the embeddings to update
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
Returns:
None
"""
upsert_request = self._validate_and_prepare_upsert_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)
await self._client._upsert(
collection_id=self.id,
ids=upsert_request["ids"],
embeddings=upsert_request["embeddings"],
metadatas=upsert_request["metadatas"],
documents=upsert_request["documents"],
uris=upsert_request["uris"],
tenant=self.tenant,
database=self.database,
)
async def delete(
self,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
) -> None:
"""Delete the embeddings based on ids and/or a where filter
Args:
ids: The ids of the embeddings to delete
where: A Where type dict used to filter the delection by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. `{"$contains": "hello"}`. Optional.
Returns:
None
Raises:
ValueError: If you don't provide either ids, where, or where_document
"""
delete_request = self._validate_and_prepare_delete_request(
ids, where, where_document
)
await self._client._delete(
collection_id=self.id,
ids=delete_request["ids"],
where=delete_request["where"],
where_document=delete_request["where_document"],
tenant=self.tenant,
database=self.database,
)

View File

@@ -0,0 +1,101 @@
from typing import TYPE_CHECKING, Optional, Dict, Any
from uuid import UUID
if TYPE_CHECKING:
from chromadb.api import ServerAPI # noqa: F401
class AttachedFunction:
"""Represents a function attached to a collection."""
def __init__(
self,
client: "ServerAPI",
id: UUID,
name: str,
function_id: str,
input_collection_id: UUID,
output_collection: str,
params: Optional[Dict[str, Any]],
tenant: str,
database: str,
):
"""Initialize an AttachedFunction.
Args:
client: The API client
id: Unique identifier for this attached function
name: Name of this attached function instance
function_id: The function identifier (e.g., "record_counter")
input_collection_id: ID of the input collection
output_collection: Name of the output collection
params: Function-specific parameters
tenant: The tenant name
database: The database name
"""
self._client = client
self._id = id
self._name = name
self._function_id = function_id
self._input_collection_id = input_collection_id
self._output_collection = output_collection
self._params = params
self._tenant = tenant
self._database = database
@property
def id(self) -> UUID:
"""The unique identifier of this attached function."""
return self._id
@property
def name(self) -> str:
"""The name of this attached function instance."""
return self._name
@property
def function_id(self) -> str:
"""The function identifier."""
return self._function_id
@property
def input_collection_id(self) -> UUID:
"""The ID of the input collection."""
return self._input_collection_id
@property
def output_collection(self) -> str:
"""The name of the output collection."""
return self._output_collection
@property
def params(self) -> Optional[Dict[str, Any]]:
"""The function parameters."""
return self._params
def detach(self, delete_output_collection: bool = False) -> bool:
"""Detach this function and prevent any further runs.
Args:
delete_output_collection: Whether to also delete the output collection. Defaults to False.
Returns:
bool: True if successful
Example:
>>> success = attached_fn.detach(delete_output_collection=True)
"""
return self._client.detach_function(
attached_function_id=self._id,
delete_output=delete_output_collection,
tenant=self._tenant,
database=self._database,
)
def __repr__(self) -> str:
return (
f"AttachedFunction(id={self._id}, name='{self._name}', "
f"function_id='{self._function_id}', "
f"input_collection_id={self._input_collection_id}, "
f"output_collection='{self._output_collection}')"
)

View File

@@ -0,0 +1,535 @@
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any
from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.api.types import (
URI,
CollectionMetadata,
Embedding,
PyEmbedding,
Include,
Metadata,
Document,
Image,
Where,
IDs,
GetResult,
QueryResult,
ID,
OneOrMany,
WhereDocument,
SearchResult,
maybe_cast_one_to_many,
)
from chromadb.api.collection_configuration import UpdateCollectionConfiguration
from chromadb.execution.expression.plan import Search
import logging
if TYPE_CHECKING:
from chromadb.api.models.AttachedFunction import AttachedFunction
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from chromadb.api import ServerAPI # noqa: F401
class Collection(CollectionCommon["ServerAPI"]):
def count(self) -> int:
"""The total number of embeddings added to the database
Returns:
int: The total number of embeddings added to the database
"""
return self._client._count(
collection_id=self.id,
tenant=self.tenant,
database=self.database,
)
def add(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
images: The images to associate with the embeddings. Optional.
uris: The uris of the images to associate with the embeddings. Optional.
Returns:
None
Raises:
ValueError: If you don't provide either embeddings or documents
ValueError: If the length of ids, embeddings, metadatas, or documents don't match
ValueError: If you don't provide an embedding function and don't provide embeddings
ValueError: If you provide both embeddings and documents
ValueError: If you provide an id that already exists
"""
add_request = self._validate_and_prepare_add_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)
self._client._add(
collection_id=self.id,
ids=add_request["ids"],
embeddings=add_request["embeddings"],
metadatas=add_request["metadatas"],
documents=add_request["documents"],
uris=add_request["uris"],
tenant=self.tenant,
database=self.database,
)
def get(
self,
ids: Optional[OneOrMany[ID]] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"],
) -> GetResult:
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
all embeddings up to limit starting at offset.
Args:
ids: The ids of the embeddings to get. Optional.
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
limit: The number of documents to return. Optional.
offset: The offset to start returning results from. Useful for paging results with limit. Optional.
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`. Ids are always included. Defaults to `["metadatas", "documents"]`. Optional.
Returns:
GetResult: A GetResult object containing the results.
"""
get_request = self._validate_and_prepare_get_request(
ids=ids,
where=where,
where_document=where_document,
include=include,
)
get_results = self._client._get(
collection_id=self.id,
ids=get_request["ids"],
where=get_request["where"],
where_document=get_request["where_document"],
include=get_request["include"],
limit=limit,
offset=offset,
tenant=self.tenant,
database=self.database,
)
return self._transform_get_response(
response=get_results, include=get_request["include"]
)
def peek(self, limit: int = 10) -> GetResult:
"""Get the first few results in the database up to limit
Args:
limit: The number of results to return.
Returns:
GetResult: A GetResult object containing the results.
"""
return self._transform_peek_response(
self._client._peek(
collection_id=self.id,
n=limit,
tenant=self.tenant,
database=self.database,
)
)
def query(
self,
query_embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
] = None,
query_texts: Optional[OneOrMany[Document]] = None,
query_images: Optional[OneOrMany[Image]] = None,
query_uris: Optional[OneOrMany[URI]] = None,
ids: Optional[OneOrMany[ID]] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = [
"metadatas",
"documents",
"distances",
],
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
Args:
query_embeddings: The embeddings to get the closes neighbors of. Optional.
query_texts: The document texts to get the closes neighbors of. Optional.
query_images: The images to get the closes neighbors of. Optional.
query_uris: The URIs to be used with data loader. Optional.
ids: A subset of ids to search within. Optional.
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`, `"distances"`. Ids are always included. Defaults to `["metadatas", "documents", "distances"]`. Optional.
Returns:
QueryResult: A QueryResult object containing the results.
Raises:
ValueError: If you don't provide either query_embeddings, query_texts, or query_images
ValueError: If you provide both query_embeddings and query_texts
ValueError: If you provide both query_embeddings and query_images
ValueError: If you provide both query_texts and query_images
"""
query_request = self._validate_and_prepare_query_request(
query_embeddings=query_embeddings,
query_texts=query_texts,
query_images=query_images,
query_uris=query_uris,
ids=ids,
n_results=n_results,
where=where,
where_document=where_document,
include=include,
)
query_results = self._client._query(
collection_id=self.id,
ids=query_request["ids"],
query_embeddings=query_request["embeddings"],
n_results=query_request["n_results"],
where=query_request["where"],
where_document=query_request["where_document"],
include=query_request["include"],
tenant=self.tenant,
database=self.database,
)
return self._transform_query_response(
response=query_results, include=query_request["include"]
)
def modify(
self,
name: Optional[str] = None,
metadata: Optional[CollectionMetadata] = None,
configuration: Optional[UpdateCollectionConfiguration] = None,
) -> None:
"""Modify the collection name or metadata
Args:
name: The updated name for the collection. Optional.
metadata: The updated metadata for the collection. Optional.
Returns:
None
"""
self._validate_modify_request(metadata)
# Note there is a race condition here where the metadata can be updated
# but another thread sees the cached local metadata.
# TODO: fixme
self._client._modify(
id=self.id,
new_name=name,
new_metadata=metadata,
new_configuration=configuration,
tenant=self.tenant,
database=self.database,
)
self._update_model_after_modify_success(name, metadata, configuration)
def fork(
self,
new_name: str,
) -> "Collection":
"""Fork the current collection under a new name. The returning collection should contain identical data to the current collection.
This is an experimental API that only works for Hosted Chroma for now.
Args:
new_name: The name of the new collection.
Returns:
Collection: A new collection with the specified name and containing identical data to the current collection.
"""
model = self._client._fork(
collection_id=self.id,
new_name=new_name,
tenant=self.tenant,
database=self.database,
)
return Collection(
client=self._client,
model=model,
embedding_function=self._embedding_function,
data_loader=self._data_loader,
)
def search(
self,
searches: OneOrMany[Search],
) -> SearchResult:
"""Perform hybrid search on the collection.
This is an experimental API that only works for Hosted Chroma for now.
Args:
searches: A single Search object or a list of Search objects, each containing:
- where: Where expression for filtering
- rank: Ranking expression for hybrid search (defaults to Val(0.0))
- limit: Limit configuration for pagination (defaults to no limit)
- select: Select configuration for keys to return (defaults to empty)
Returns:
SearchResult: Column-major format response with:
- ids: List of result IDs for each search payload
- documents: Optional documents for each payload
- embeddings: Optional embeddings for each payload
- metadatas: Optional metadata for each payload
- scores: Optional scores for each payload
- select: List of selected keys for each payload
Raises:
NotImplementedError: For local/segment API implementations
Examples:
# Using builder pattern with Key constants
from chromadb.execution.expression import (
Search, Key, K, Knn, Val
)
# Note: K is an alias for Key, so K.DOCUMENT == Key.DOCUMENT
search = (Search()
.where((K("category") == "science") & (K("score") > 0.5))
.rank(Knn(query=[0.1, 0.2, 0.3]) * 0.8 + Val(0.5) * 0.2)
.limit(10, offset=0)
.select(K.DOCUMENT, K.SCORE, "title"))
# Direct construction
from chromadb.execution.expression import (
Search, Eq, And, Gt, Knn, Limit, Select, Key
)
search = Search(
where=And([Eq("category", "science"), Gt("score", 0.5)]),
rank=Knn(query=[0.1, 0.2, 0.3]),
limit=Limit(offset=0, limit=10),
select=Select(keys={Key.DOCUMENT, Key.SCORE, "title"})
)
# Single search
result = collection.search(search)
# Multiple searches at once
searches = [
Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])),
Search().where(K("type") == "paper").rank(Knn(query=[0.3, 0.4]))
]
results = collection.search(searches)
"""
# Convert single search to list for consistent handling
searches_list = maybe_cast_one_to_many(searches)
if searches_list is None:
searches_list = []
# Embed any string queries in Knn objects
embedded_searches = [
self._embed_search_string_queries(search) for search in searches_list
]
return self._client._search(
collection_id=self.id,
searches=cast(List[Search], embedded_searches),
tenant=self.tenant,
database=self.database,
)
def update(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
"""Update the embeddings, metadatas or documents for provided ids.
Args:
ids: The ids of the embeddings to update
embeddings: The embeddings to update. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
images: The images to associate with the embeddings. Optional.
Returns:
None
"""
update_request = self._validate_and_prepare_update_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)
self._client._update(
collection_id=self.id,
ids=update_request["ids"],
embeddings=update_request["embeddings"],
metadatas=update_request["metadatas"],
documents=update_request["documents"],
uris=update_request["uris"],
tenant=self.tenant,
database=self.database,
)
def upsert(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
"""Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.
Args:
ids: The ids of the embeddings to update
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
Returns:
None
"""
upsert_request = self._validate_and_prepare_upsert_request(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)
self._client._upsert(
collection_id=self.id,
ids=upsert_request["ids"],
embeddings=upsert_request["embeddings"],
metadatas=upsert_request["metadatas"],
documents=upsert_request["documents"],
uris=upsert_request["uris"],
tenant=self.tenant,
database=self.database,
)
def delete(
self,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
) -> None:
"""Delete the embeddings based on ids and/or a where filter
Args:
ids: The ids of the embeddings to delete
where: A Where type dict used to filter the delection by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}]}}`. Optional.
where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. `{"$contains": "hello"}`. Optional.
Returns:
None
Raises:
ValueError: If you don't provide either ids, where, or where_document
"""
delete_request = self._validate_and_prepare_delete_request(
ids, where, where_document
)
self._client._delete(
collection_id=self.id,
ids=delete_request["ids"],
where=delete_request["where"],
where_document=delete_request["where_document"],
tenant=self.tenant,
database=self.database,
)
def attach_function(
self,
function_id: str,
name: str,
output_collection: str,
params: Optional[Dict[str, Any]] = None,
) -> "AttachedFunction":
"""Attach a function to this collection.
Args:
function_id: Built-in function identifier (e.g., "record_counter")
name: Unique name for this attached function
output_collection: Name of the collection where function output will be stored
params: Optional dictionary with function-specific parameters
Returns:
AttachedFunction: Object representing the attached function
Example:
>>> attached_fn = collection.attach_function(
... function_id="record_counter",
... name="mycoll_stats_fn",
... output_collection="mycoll_stats",
... params={"threshold": 100}
... )
"""
return self._client.attach_function(
function_id=function_id,
name=name,
input_collection_id=self.id,
output_collection=output_collection,
params=params,
tenant=self.tenant,
database=self.database,
)

View File

@@ -0,0 +1,644 @@
from typing import TYPE_CHECKING
from chromadb import (
CollectionMetadata,
Embeddings,
GetResult,
IDs,
Where,
WhereDocument,
Include,
Documents,
Metadatas,
QueryResult,
URIs,
)
from chromadb.api import ServerAPI
if TYPE_CHECKING:
from chromadb.api.models.AttachedFunction import AttachedFunction
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
create_collection_configuration_to_json_str,
update_collection_configuration_to_json_str,
)
from chromadb.auth import UserIdentity
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.telemetry.product import ProductTelemetryClient
from chromadb.telemetry.product.events import (
CollectionAddEvent,
CollectionDeleteEvent,
CollectionGetEvent,
CollectionUpdateEvent,
CollectionQueryEvent,
ClientCreateCollectionEvent,
)
from chromadb.api.types import (
IncludeMetadataDocuments,
IncludeMetadataDocumentsDistances,
IncludeMetadataDocumentsEmbeddings,
Schema,
SearchResult,
)
# TODO(hammadb): Unify imports across types vs root __init__.py
from chromadb.types import Database, Tenant, Collection as CollectionModel
from chromadb.execution.expression.plan import Search
import chromadb_rust_bindings
from typing import Optional, Sequence, List, Dict, Any
from overrides import override
from uuid import UUID
import json
import platform
if platform.system() != "Windows":
import resource
elif platform.system() == "Windows":
import ctypes
# RustBindingsAPI is an implementation of ServerAPI which shims
# the Rust bindings to the Python API, providing a full implementation
# of the API. It could be that bindings was a direct implementation of
# ServerAPI, but in order to prevent propagating the bindings types
# into the Python API, we have to shim it here so we can convert into
# the legacy Python types.
# TODO(hammadb): Propagate the types from the bindings into the Python API
# and remove the python-level types entirely.
class RustBindingsAPI(ServerAPI):
bindings: chromadb_rust_bindings.Bindings
hnsw_cache_size: int
product_telemetry_client: ProductTelemetryClient
def __init__(self, system: System):
super().__init__(system)
self.product_telemetry_client = self.require(ProductTelemetryClient)
if platform.system() != "Windows":
max_file_handles = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
else:
max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore
self.hnsw_cache_size = (
max_file_handles
# This is integer division in Python 3, and not a comment.
# Each HNSW index has 4 data files and 1 metadata file
// 5
)
@override
def start(self) -> None:
# Construct the SqliteConfig
# TOOD: We should add a "config converter"
if self._system.settings.require("is_persistent"):
persist_path = self._system.settings.require("persist_directory")
sqlite_persist_path = persist_path + "/chroma.sqlite3"
else:
persist_path = None
sqlite_persist_path = None
hash_type = self._system.settings.require("migrations_hash_algorithm")
hash_type_bindings = (
chromadb_rust_bindings.MigrationHash.MD5
if hash_type == "md5"
else chromadb_rust_bindings.MigrationHash.SHA256
)
migration_mode = self._system.settings.require("migrations")
migration_mode_bindings = (
chromadb_rust_bindings.MigrationMode.Apply
if migration_mode == "apply"
else chromadb_rust_bindings.MigrationMode.Validate
)
sqlite_config = chromadb_rust_bindings.SqliteDBConfig(
hash_type=hash_type_bindings,
migration_mode=migration_mode_bindings,
url=sqlite_persist_path,
)
self.bindings = chromadb_rust_bindings.Bindings(
allow_reset=self._system.settings.require("allow_reset"),
sqlite_db_config=sqlite_config,
persist_path=persist_path,
hnsw_cache_size=self.hnsw_cache_size,
)
@override
def stop(self) -> None:
del self.bindings
# ////////////////////////////// Admin API //////////////////////////////
@override
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
return self.bindings.create_database(name, tenant)
@override
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
database = self.bindings.get_database(name, tenant)
return {
"id": database.id,
"name": database.name,
"tenant": database.tenant,
}
@override
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
return self.bindings.delete_database(name, tenant)
@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
databases = self.bindings.list_databases(limit, offset, tenant)
return [
{
"id": database.id,
"name": database.name,
"tenant": database.tenant,
}
for database in databases
]
@override
def create_tenant(self, name: str) -> None:
return self.bindings.create_tenant(name)
@override
def get_tenant(self, name: str) -> Tenant:
tenant = self.bindings.get_tenant(name)
return Tenant(name=tenant.name)
# ////////////////////////////// Base API //////////////////////////////
@override
def heartbeat(self) -> int:
return self.bindings.heartbeat()
@override
def count_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> int:
return self.bindings.count_collections(tenant, database)
@override
def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Sequence[CollectionModel]:
collections = self.bindings.list_collections(limit, offset, tenant, database)
return [
CollectionModel(
id=collection.id,
name=collection.name,
serialized_schema=collection.schema,
configuration_json=collection.configuration,
metadata=collection.metadata,
dimension=collection.dimension,
tenant=collection.tenant,
database=collection.database,
)
for collection in collections
]
@override
def create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
# TODO: This event doesn't capture the get_or_create case appropriately
# TODO: Re-enable embedding function tracking in create_collection
self.product_telemetry_client.capture(
ClientCreateCollectionEvent(
collection_uuid=str(id),
# embedding_function=embedding_function.__class__.__name__,
)
)
if configuration:
configuration_json_str = create_collection_configuration_to_json_str(
configuration, metadata
)
else:
configuration_json_str = None
if schema:
schema_str = json.dumps(schema.serialize_to_json())
else:
schema_str = None
collection = self.bindings.create_collection(
name,
configuration_json_str,
schema_str,
metadata,
get_or_create,
tenant,
database,
)
collection_model = CollectionModel(
id=collection.id,
name=collection.name,
configuration_json=collection.configuration,
serialized_schema=collection.schema,
metadata=collection.metadata,
dimension=collection.dimension,
tenant=collection.tenant,
database=collection.database,
)
return collection_model
@override
def get_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
collection = self.bindings.get_collection(name, tenant, database)
return CollectionModel(
id=collection.id,
name=collection.name,
configuration_json=collection.configuration,
serialized_schema=collection.schema,
metadata=collection.metadata,
dimension=collection.dimension,
tenant=collection.tenant,
database=collection.database,
)
@override
def get_or_create_collection(
self,
name: str,
schema: Optional[Schema] = None,
configuration: Optional[CreateCollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
return self.create_collection(
name, schema, configuration, metadata, True, tenant, database
)
@override
def delete_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
self.bindings.delete_collection(name, tenant, database)
@override
def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration: Optional[UpdateCollectionConfiguration] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
if new_configuration:
new_configuration_json_str = update_collection_configuration_to_json_str(
new_configuration
)
else:
new_configuration_json_str = None
self.bindings.update_collection(
str(id), new_name, new_metadata, new_configuration_json_str
)
@override
def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
raise NotImplementedError(
"Collection forking is not implemented for Local Chroma"
)
@override
def _search(
self,
collection_id: UUID,
searches: List[Search],
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> SearchResult:
raise NotImplementedError("Search is not implemented for Local Chroma")
@override
def _count(
self,
collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
return self.bindings.count(str(collection_id), tenant, database)
@override
def _peek(
self,
collection_id: UUID,
n: int = 10,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
return self._get(
collection_id,
limit=n,
tenant=tenant,
database=database,
include=IncludeMetadataDocumentsEmbeddings,
)
@override
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocuments,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
ids_amount = len(ids) if ids else 0
self.product_telemetry_client.capture(
CollectionGetEvent(
collection_uuid=str(collection_id),
ids_count=ids_amount,
limit=limit if limit else 0,
include_metadata=ids_amount if "metadatas" in include else 0,
include_documents=ids_amount if "documents" in include else 0,
include_uris=ids_amount if "uris" in include else 0,
)
)
rust_response = self.bindings.get(
str(collection_id),
ids,
json.dumps(where) if where else None,
limit,
offset or 0,
json.dumps(where_document) if where_document else None,
include,
tenant,
database,
)
return GetResult(
ids=rust_response.ids,
embeddings=rust_response.embeddings,
documents=rust_response.documents,
uris=rust_response.uris,
included=include,
data=None,
metadatas=rust_response.metadatas,
)
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
self.product_telemetry_client.capture(
CollectionAddEvent(
collection_uuid=str(collection_id),
add_amount=len(ids),
with_metadata=len(ids) if metadatas is not None else 0,
with_documents=len(ids) if documents is not None else 0,
with_uris=len(ids) if uris is not None else 0,
)
)
return self.bindings.add(
ids,
str(collection_id),
embeddings,
metadatas,
documents,
uris,
tenant,
database,
)
@override
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
self.product_telemetry_client.capture(
CollectionUpdateEvent(
collection_uuid=str(collection_id),
update_amount=len(ids),
with_embeddings=len(embeddings) if embeddings else 0,
with_metadata=len(metadatas) if metadatas else 0,
with_documents=len(documents) if documents else 0,
with_uris=len(uris) if uris else 0,
)
)
return self.bindings.update(
str(collection_id),
ids,
embeddings,
metadatas,
documents,
uris,
tenant,
database,
)
@override
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
return self.bindings.upsert(
str(collection_id),
ids,
embeddings,
metadatas,
documents,
uris,
tenant,
database,
)
@override
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
query_amount = len(query_embeddings)
filtered_ids_amount = len(ids) if ids else 0
self.product_telemetry_client.capture(
CollectionQueryEvent(
collection_uuid=str(collection_id),
query_amount=query_amount,
filtered_ids_amount=filtered_ids_amount,
n_results=n_results,
with_metadata_filter=query_amount if where is not None else 0,
with_document_filter=query_amount if where_document is not None else 0,
include_metadatas=query_amount if "metadatas" in include else 0,
include_documents=query_amount if "documents" in include else 0,
include_uris=query_amount if "uris" in include else 0,
include_distances=query_amount if "distances" in include else 0,
)
)
rust_response = self.bindings.query(
str(collection_id),
ids,
query_embeddings,
n_results,
json.dumps(where) if where else None,
json.dumps(where_document) if where_document else None,
include,
tenant,
database,
)
return QueryResult(
ids=rust_response.ids,
embeddings=rust_response.embeddings,
documents=rust_response.documents,
uris=rust_response.uris,
included=include,
data=None,
metadatas=rust_response.metadatas,
distances=rust_response.distances,
)
@override
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
self.product_telemetry_client.capture(
CollectionDeleteEvent(
# NOTE: the delete amount is not observable from python
# TODO: Fix this when posthog is pushed into Rust frontend
collection_uuid=str(collection_id),
delete_amount=0,
)
)
return self.bindings.delete(
str(collection_id),
ids,
json.dumps(where) if where else None,
json.dumps(where_document) if where_document else None,
tenant,
database,
)
@override
def reset(self) -> bool:
return self.bindings.reset()
@override
def get_version(self) -> str:
return self.bindings.get_version()
@override
def get_settings(self) -> Settings:
return self._system.settings
@override
def get_max_batch_size(self) -> int:
return self.bindings.get_max_batch_size()
@override
def attach_function(
self,
function_id: str,
name: str,
input_collection_id: UUID,
output_collection: str,
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
raise NotImplementedError(
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
"The Rust bindings (embedded mode) do not support attached function operations."
)
@override
def detach_function(
self,
attached_function_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
raise NotImplementedError(
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
"The Rust bindings (embedded mode) do not support attached function operations."
)
# TODO: Remove this if it's not planned to be used
@override
def get_user_identity(self) -> UserIdentity:
return UserIdentity(
user_id="",
tenant=DEFAULT_TENANT,
databases=[DEFAULT_DATABASE],
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
from typing import ClassVar, Dict
import uuid
from chromadb.api import ServerAPI
from chromadb.config import Settings, System
from chromadb.telemetry.product import ProductTelemetryClient
from chromadb.telemetry.product.events import ClientStartEvent
class SharedSystemClient:
_identifier_to_system: ClassVar[Dict[str, System]] = {}
_identifier: str
def __init__(
self,
settings: Settings = Settings(),
) -> None:
self._identifier = SharedSystemClient._get_identifier_from_settings(settings)
SharedSystemClient._create_system_if_not_exists(self._identifier, settings)
@classmethod
def _create_system_if_not_exists(
cls, identifier: str, settings: Settings
) -> System:
if identifier not in cls._identifier_to_system:
new_system = System(settings)
cls._identifier_to_system[identifier] = new_system
new_system.instance(ProductTelemetryClient)
new_system.instance(ServerAPI)
new_system.start()
else:
previous_system = cls._identifier_to_system[identifier]
# For now, the settings must match
if previous_system.settings != settings:
raise ValueError(
f"An instance of Chroma already exists for {identifier} with different settings"
)
return cls._identifier_to_system[identifier]
@staticmethod
def _get_identifier_from_settings(settings: Settings) -> str:
identifier = ""
api_impl = settings.chroma_api_impl
if api_impl is None:
raise ValueError("Chroma API implementation must be set in settings")
elif api_impl in [
"chromadb.api.segment.SegmentAPI",
"chromadb.api.rust.RustBindingsAPI",
]:
if settings.is_persistent:
identifier = settings.persist_directory
else:
identifier = (
"ephemeral" # TODO: support pathing and multiple ephemeral clients
)
elif api_impl in [
"chromadb.api.fastapi.FastAPI",
"chromadb.api.async_fastapi.AsyncFastAPI",
]:
# FastAPI clients can all use unique system identifiers since their configurations can be independent, e.g. different auth tokens
identifier = str(uuid.uuid4())
else:
raise ValueError(f"Unsupported Chroma API implementation {api_impl}")
return identifier
@staticmethod
def _populate_data_from_system(system: System) -> str:
identifier = SharedSystemClient._get_identifier_from_settings(system.settings)
SharedSystemClient._identifier_to_system[identifier] = system
return identifier
@classmethod
def from_system(cls, system: System) -> "SharedSystemClient":
"""Create a client from an existing system. This is useful for testing and debugging."""
SharedSystemClient._populate_data_from_system(system)
instance = cls(system.settings)
return instance
@staticmethod
def clear_system_cache() -> None:
SharedSystemClient._identifier_to_system = {}
@property
def _system(self) -> System:
return SharedSystemClient._identifier_to_system[self._identifier]
def _submit_client_start_event(self) -> None:
telemetry_client = self._system.instance(ProductTelemetryClient)
telemetry_client.capture(ClientStartEvent())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
import chromadb
import chromadb.config
from chromadb.server.fastapi import FastAPI
settings = chromadb.config.Settings()
server = FastAPI(settings)
app = server.app()

View File

@@ -0,0 +1,237 @@
from __future__ import annotations
from abc import abstractmethod
from enum import Enum
from typing import (
Any,
List,
Optional,
Dict,
Tuple,
TypeVar,
)
from dataclasses import dataclass
from pydantic import SecretStr
from chromadb.config import (
Component,
System,
)
T = TypeVar("T")
S = TypeVar("S")
class AuthError(Exception):
pass
ClientAuthHeaders = Dict[str, SecretStr]
class ClientAuthProvider(Component):
"""
ClientAuthProvider is responsible for providing authentication headers for
client requests. Client implementations (in our case, just the FastAPI
client) must inject these headers into their requests.
"""
def __init__(self, system: System) -> None:
super().__init__(system)
@abstractmethod
def authenticate(self) -> ClientAuthHeaders:
pass
@dataclass
class UserIdentity:
"""
UserIdentity represents the identity of a user. In general, not all fields
will be populated, and the fields that are populated will depend on the
authentication provider.
The idea is that the AuthenticationProvider is responsible for populating
_all_ information known about the user, and the AuthorizationProvider is
responsible for making decisions based on that information.
"""
user_id: str
tenant: Optional[str] = None
databases: Optional[List[str]] = None
# This can be used for any additional auth context which needs to be
# propagated from the authentication provider to the authorization
# provider.
attributes: Optional[Dict[str, Any]] = None
class ServerAuthenticationProvider(Component):
"""
ServerAuthenticationProvider is responsible for authenticating requests. If
a ServerAuthenticationProvider is configured, it will be called by the
server to authenticate requests. If no ServerAuthenticationProvider is
configured, all requests will be authenticated.
The ServerAuthenticationProvider should return a UserIdentity object if the
request is authenticated for use by the ServerAuthorizationProvider.
"""
def __init__(self, system: System) -> None:
super().__init__(system)
self._ignore_auth_paths: Dict[
str, List[str]
] = system.settings.chroma_server_auth_ignore_paths
self.overwrite_singleton_tenant_database_access_from_auth = (
system.settings.chroma_overwrite_singleton_tenant_database_access_from_auth
)
@abstractmethod
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
pass
def ignore_operation(self, verb: str, path: str) -> bool:
if (
path in self._ignore_auth_paths.keys()
and verb.upper() in self._ignore_auth_paths[path]
):
return True
return False
def read_creds_or_creds_file(self) -> List[str]:
_creds_file = None
_creds = None
if self._system.settings.chroma_server_authn_credentials_file:
_creds_file = str(
self._system.settings["chroma_server_authn_credentials_file"]
)
if self._system.settings.chroma_server_authn_credentials:
_creds = str(self._system.settings["chroma_server_authn_credentials"])
if not _creds_file and not _creds:
raise ValueError(
"No credentials file or credentials found in "
"[chroma_server_authn_credentials]."
)
if _creds_file and _creds:
raise ValueError(
"Both credentials file and credentials found."
"Please provide only one."
)
if _creds:
return [c for c in _creds.split("\n") if c]
elif _creds_file:
with open(_creds_file, "r") as f:
return f.readlines()
raise ValueError("Should never happen")
def singleton_tenant_database_if_applicable(
self, user: Optional[UserIdentity]
) -> Tuple[Optional[str], Optional[str]]:
"""
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
is False, this function always returns (None, None).
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
is True, follows the following logic:
- If the user only has access to a single tenant, this function will
return that tenant as its first return value.
- If the user only has access to a single database, this function will
return that database as its second return value. If the user has
access to multiple tenants and/or databases, including "*", this
function will return None for the corresponding value(s).
- If the user has access to multiple tenants and/or databases this
function will return None for the corresponding value(s).
"""
if not self.overwrite_singleton_tenant_database_access_from_auth or not user:
return None, None
tenant = None
database = None
if user.tenant and user.tenant != "*":
tenant = user.tenant
if user.databases and len(user.databases) == 1 and user.databases[0] != "*":
database = user.databases[0]
return tenant, database
class AuthzAction(str, Enum):
"""
The set of actions that can be authorized by the authorization provider.
"""
RESET = "system:reset"
CREATE_TENANT = "tenant:create_tenant"
GET_TENANT = "tenant:get_tenant"
CREATE_DATABASE = "db:create_database"
GET_DATABASE = "db:get_database"
DELETE_DATABASE = "db:delete_database"
LIST_DATABASES = "db:list_databases"
LIST_COLLECTIONS = "db:list_collections"
COUNT_COLLECTIONS = "db:count_collections"
CREATE_COLLECTION = "db:create_collection"
GET_OR_CREATE_COLLECTION = "db:get_or_create_collection"
GET_COLLECTION = "collection:get_collection"
DELETE_COLLECTION = "collection:delete_collection"
UPDATE_COLLECTION = "collection:update_collection"
ADD = "collection:add"
DELETE = "collection:delete"
GET = "collection:get"
QUERY = "collection:query"
COUNT = "collection:count"
UPDATE = "collection:update"
UPSERT = "collection:upsert"
@dataclass
class AuthzResource:
"""
The resource being accessed in an authorization request.
"""
tenant: Optional[str]
database: Optional[str]
collection: Optional[str]
class ServerAuthorizationProvider(Component):
"""
ServerAuthorizationProvider is responsible for authorizing requests. If a
ServerAuthorizationProvider is configured, it will be called by the server
to authorize requests. If no ServerAuthorizationProvider is configured, all
requests will be authorized.
ServerAuthorizationProvider should raise an exception if the request is not
authorized.
"""
def __init__(self, system: System) -> None:
super().__init__(system)
@abstractmethod
def authorize_or_raise(
self, user: UserIdentity, action: AuthzAction, resource: AuthzResource
) -> None:
pass
def read_config_or_config_file(self) -> List[str]:
_config_file = None
_config = None
if self._system.settings.chroma_server_authz_config_file:
_config_file = self._system.settings["chroma_server_authz_config_file"]
if self._system.settings.chroma_server_authz_config:
_config = str(self._system.settings["chroma_server_authz_config"])
if not _config_file and not _config:
raise ValueError(
"No authz configuration file or authz configuration found."
)
if _config_file and _config:
raise ValueError(
"Both authz configuration file and authz configuration found."
"Please provide only one."
)
if _config:
return [c for c in _config.split("\n") if c]
elif _config_file:
with open(_config_file, "r") as f:
return f.readlines()
raise ValueError("Should never happen")

View File

@@ -0,0 +1,146 @@
import base64
import random
import re
import time
import traceback
import bcrypt
import logging
from overrides import override
from pydantic import SecretStr
from chromadb.auth import (
UserIdentity,
ServerAuthenticationProvider,
ClientAuthProvider,
ClientAuthHeaders,
AuthError,
)
from chromadb.config import System
from chromadb.errors import ChromaAuthError
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
from typing import Dict
logger = logging.getLogger(__name__)
__all__ = ["BasicAuthenticationServerProvider", "BasicAuthClientProvider"]
AUTHORIZATION_HEADER = "Authorization"
class BasicAuthClientProvider(ClientAuthProvider):
"""
Client auth provider for basic auth. The credentials are passed as a
base64-encoded string in the Authorization header prepended with "Basic ".
"""
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_client_auth_credentials")
self._creds = SecretStr(str(system.settings.chroma_client_auth_credentials))
@override
def authenticate(self) -> ClientAuthHeaders:
encoded = base64.b64encode(
f"{self._creds.get_secret_value()}".encode("utf-8")
).decode("utf-8")
return {
AUTHORIZATION_HEADER: SecretStr(f"Basic {encoded}"),
}
class BasicAuthenticationServerProvider(ServerAuthenticationProvider):
"""
Server auth provider for basic auth. The credentials are read from
`chroma_server_authn_credentials_file` and each line must be in the format
<username>:<bcrypt passwd>.
Expects tokens to be passed as a base64-encoded string in the Authorization
header prepended with "Basic".
"""
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
self._creds: Dict[str, SecretStr] = {}
creds = self.read_creds_or_creds_file()
for line in creds:
if not line.strip():
continue
_raw_creds = [v for v in line.strip().split(":")]
if (
_raw_creds
and _raw_creds[0]
and len(_raw_creds) != 2
or not all(_raw_creds)
):
raise ValueError(
f"Invalid htpasswd credentials found: {_raw_creds}. "
"Lines must be exactly <username>:<bcrypt passwd>."
)
username = _raw_creds[0]
password = _raw_creds[1]
if username in self._creds:
raise ValueError(
"Duplicate username found in "
"[chroma_server_authn_credentials]. "
"Usernames must be unique."
)
self._creds[username] = SecretStr(password)
@trace_method(
"BasicAuthenticationServerProvider.authenticate", OpenTelemetryGranularity.ALL
)
@override
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
try:
if AUTHORIZATION_HEADER.lower() not in headers.keys():
raise AuthError(AUTHORIZATION_HEADER + " header not found")
_auth_header = headers[AUTHORIZATION_HEADER.lower()]
_auth_header = re.sub(r"^Basic ", "", _auth_header)
_auth_header = _auth_header.strip()
base64_decoded = base64.b64decode(_auth_header).decode("utf-8")
if ":" not in base64_decoded:
raise AuthError("Invalid Authorization header format")
username, password = base64_decoded.split(":", 1)
username = str(username) # convert to string to prevent header injection
password = str(password) # convert to string to prevent header injection
if username not in self._creds:
raise AuthError("Invalid username or password")
_pwd_check = bcrypt.checkpw(
password.encode("utf-8"),
self._creds[username].get_secret_value().encode("utf-8"),
)
if not _pwd_check:
raise AuthError("Invalid username or password")
return UserIdentity(user_id=username)
except AuthError as e:
logger.error(
f"BasicAuthenticationServerProvider.authenticate failed: {repr(e)}"
)
except Exception as e:
tb = traceback.extract_tb(e.__traceback__)
# Get the last call stack
last_call_stack = tb[-1]
line_number = last_call_stack.lineno
filename = last_call_stack.filename
logger.error(
"BasicAuthenticationServerProvider.authenticate failed: "
f"Failed to authenticate {type(e).__name__} at {filename}:{line_number}"
)
time.sleep(
random.uniform(0.001, 0.005)
) # add some jitter to avoid timing attacks
raise ChromaAuthError()

View File

@@ -0,0 +1,75 @@
import logging
from typing import Dict, Set
from overrides import override
import yaml
from chromadb.auth import (
AuthzAction,
AuthzResource,
UserIdentity,
ServerAuthorizationProvider,
)
from chromadb.config import System
from fastapi import HTTPException
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
logger = logging.getLogger(__name__)
class SimpleRBACAuthorizationProvider(ServerAuthorizationProvider):
"""
A simple Role-Based Access Control (RBAC) authorization provider. This
provider reads a configuration file that maps users to roles, and roles to
actions. The provider then checks if the user has the action they are
attempting to perform.
For an example of an RBAC configuration file, see
examples/basic_functionality/authz/authz.yaml.
"""
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
self._config = yaml.safe_load("\n".join(self.read_config_or_config_file()))
# We favor preprocessing here to avoid having to parse the config file
# on every request. This AuthorizationProvider does not support
# per-resource authorization so we just map the user ID to the
# permissions they have. We're not worried about the size of this dict
# since users are all specified in the file -- anyone with a gigantic
# number of users can roll their own AuthorizationProvider.
self._permissions: Dict[str, Set[str]] = {}
for user in self._config["users"]:
_actions = self._config["roles_mapping"][user["role"]]["actions"]
self._permissions[user["id"]] = set(_actions)
logger.info(
"Authorization Provider SimpleRBACAuthorizationProvider " "initialized"
)
@trace_method(
"SimpleRBACAuthorizationProvider.authorize",
OpenTelemetryGranularity.ALL,
)
@override
def authorize_or_raise(
self, user: UserIdentity, action: AuthzAction, resource: AuthzResource
) -> None:
policy_decision = False
if (
user.user_id in self._permissions
and action in self._permissions[user.user_id]
):
policy_decision = True
logger.debug(
f"Authorization decision: Access "
f"{'granted' if policy_decision else 'denied'} for "
f"user [{user.user_id}] attempting to "
f"[{action}] [{resource}]"
)
if not policy_decision:
raise HTTPException(status_code=403, detail="Forbidden")

View File

@@ -0,0 +1,235 @@
import logging
import random
import re
import string
import time
import traceback
from enum import Enum
from typing import cast, Dict, List, Optional, TypedDict, TypeVar
from overrides import override
from pydantic import SecretStr
import yaml
from chromadb.auth import (
ServerAuthenticationProvider,
ClientAuthProvider,
ClientAuthHeaders,
UserIdentity,
AuthError,
)
from chromadb.config import System
from chromadb.errors import ChromaAuthError
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
T = TypeVar("T")
logger = logging.getLogger(__name__)
__all__ = [
"TokenAuthenticationServerProvider",
"TokenAuthClientProvider",
"TokenTransportHeader",
]
class TokenTransportHeader(str, Enum):
"""
Accceptable token transport headers.
"""
# I don't love having this enum here -- it's weird to have an enum
# for just two values and it's weird to have users pass X_CHROMA_TOKEN
# to configure "x-chroma-token". But I also like having a single source
# of truth, so 🤷🏻‍♂️
AUTHORIZATION = "Authorization"
X_CHROMA_TOKEN = "X-Chroma-Token"
valid_token_chars = set(string.digits + string.ascii_letters + string.punctuation)
def _check_token(token: str) -> None:
token_str = str(token)
if not all(c in valid_token_chars for c in token_str):
raise ValueError(
"Invalid token. Must contain only ASCII letters, digits, and punctuation."
)
allowed_token_headers = [
TokenTransportHeader.AUTHORIZATION.value,
TokenTransportHeader.X_CHROMA_TOKEN.value,
]
def _check_allowed_token_headers(token_header: str) -> None:
if token_header not in allowed_token_headers:
raise ValueError(
f"Invalid token transport header: {token_header}. "
f"Must be one of {allowed_token_headers}"
)
class TokenAuthClientProvider(ClientAuthProvider):
"""
Client auth provider for token-based auth. Header key will be either
"Authorization" or "X-Chroma-Token" depending on
`chroma_auth_token_transport_header`. If the header is "Authorization",
the token is passed as a bearer token.
"""
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_client_auth_credentials")
self._token = SecretStr(str(system.settings.chroma_client_auth_credentials))
_check_token(self._token.get_secret_value())
if system.settings.chroma_auth_token_transport_header:
_check_allowed_token_headers(
system.settings.chroma_auth_token_transport_header
)
self._token_transport_header = TokenTransportHeader(
system.settings.chroma_auth_token_transport_header
)
else:
self._token_transport_header = TokenTransportHeader.AUTHORIZATION
@override
def authenticate(self) -> ClientAuthHeaders:
val = self._token.get_secret_value()
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION:
val = f"Bearer {val}"
return {
self._token_transport_header.value: SecretStr(val),
}
class User(TypedDict):
"""
A simple User class for use in this module only. If you need a generic
way to represent a User, please use UserIdentity as this class keeps
track of sensitive tokens.
"""
id: str
role: str
tenant: Optional[str]
databases: Optional[List[str]]
tokens: List[str]
class TokenAuthenticationServerProvider(ServerAuthenticationProvider):
"""
Server authentication provider for token-based auth. The provider will
- On initialization, read the users from the file specified in
`chroma_server_authn_credentials_file`. This file must be a well-formed
YAML file with a top-level array called `users`. Each user must have
an `id` field and a `tokens` (string array) field.
- On each request, check the token in the header specified by
`chroma_auth_token_transport_header`. If the configured header is
"Authorization", the token is expected to be a bearer token.
- If the token is valid, the server will return the user identity
associated with the token.
"""
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
if system.settings.chroma_auth_token_transport_header:
_check_allowed_token_headers(
system.settings.chroma_auth_token_transport_header
)
self._token_transport_header = TokenTransportHeader(
system.settings.chroma_auth_token_transport_header
)
else:
self._token_transport_header = TokenTransportHeader.AUTHORIZATION
self._token_user_mapping: Dict[str, User] = {}
creds = self.read_creds_or_creds_file()
# If we only get one cred, assume it's just a valid token.
if len(creds) == 1:
self._token_user_mapping[creds[0]] = User(
id="anonymous",
tenant="*",
databases=["*"],
role="anonymous",
tokens=[creds[0]],
)
return
self._users = cast(List[User], yaml.safe_load("\n".join(creds))["users"])
for user in self._users:
if "tokens" not in user:
raise ValueError("User missing tokens")
if "tenant" not in user:
user["tenant"] = "*"
if "databases" not in user:
user["databases"] = ["*"]
for token in user["tokens"]:
_check_token(token)
if (
token in self._token_user_mapping
and self._token_user_mapping[token] != user
):
raise ValueError(
f"Token {token} already in use: wanted to use it for "
f"user {user['id']} but it's already in use by "
f"user {self._token_user_mapping[token]}"
)
self._token_user_mapping[token] = user
@trace_method(
"TokenAuthenticationServerProvider.authenticate", OpenTelemetryGranularity.ALL
)
@override
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
try:
if self._token_transport_header.value.lower() not in headers.keys():
raise AuthError(
f"Authorization header '{self._token_transport_header.value}' not found"
)
token = headers[self._token_transport_header.value.lower()]
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION:
if not token.startswith("Bearer "):
raise AuthError("Bearer not found in Authorization header")
token = re.sub(r"^Bearer ", "", token)
token = token.strip()
_check_token(token)
if token not in self._token_user_mapping:
raise AuthError("Invalid credentials: Token not found}")
user_identity = UserIdentity(
user_id=self._token_user_mapping[token]["id"],
tenant=self._token_user_mapping[token]["tenant"],
databases=self._token_user_mapping[token]["databases"],
)
return user_identity
except AuthError as e:
logger.debug(
f"TokenAuthenticationServerProvider.authenticate failed: {repr(e)}"
)
except Exception as e:
tb = traceback.extract_tb(e.__traceback__)
# Get the last call stack
last_call_stack = tb[-1]
line_number = last_call_stack.lineno
filename = last_call_stack.filename
logger.debug(
"TokenAuthenticationServerProvider.authenticate failed: "
f"Failed to authenticate {type(e).__name__} at {filename}:{line_number}"
)
time.sleep(
random.uniform(0.001, 0.005)
) # add some jitter to avoid timing attacks
raise ChromaAuthError()

View File

@@ -0,0 +1,86 @@
from typing import Optional, Tuple
from chromadb.auth import UserIdentity
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
from chromadb.errors import ChromaAuthError
def _singleton_tenant_database_if_applicable(
user_identity: UserIdentity,
overwrite_singleton_tenant_database_access_from_auth: bool,
) -> Tuple[Optional[str], Optional[str]]:
"""
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
is False, this function always returns (None, None).
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
is True, follows the following logic:
- If the user only has access to a single tenant, this function will
return that tenant as its first return value.
- If the user only has access to a single database, this function will
return that database as its second return value. If the user has
access to multiple tenants and/or databases, including "*", this
function will return None for the corresponding value(s).
- If the user has access to multiple tenants and/or databases this
function will return None for the corresponding value(s).
"""
if not overwrite_singleton_tenant_database_access_from_auth:
return None, None
tenant = None
database = None
user_tenant = user_identity.tenant
user_databases = user_identity.databases
if user_tenant and user_tenant != "*":
tenant = user_tenant
if user_databases:
user_databases_set = set(user_databases)
if len(user_databases_set) == 1 and "*" not in user_databases_set:
database = list(user_databases_set)[0]
return tenant, database
def maybe_set_tenant_and_database(
user_identity: UserIdentity,
overwrite_singleton_tenant_database_access_from_auth: bool,
user_provided_tenant: Optional[str] = None,
user_provided_database: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
(
new_tenant,
new_database,
) = _singleton_tenant_database_if_applicable(
user_identity=user_identity,
overwrite_singleton_tenant_database_access_from_auth=overwrite_singleton_tenant_database_access_from_auth,
)
# The only error case is if the user provides a tenant and database that
# don't match what we resolved from auth. This can incorrectly happen when
# there is no auth provider set, but overwrite_singleton_tenant_database_access_from_auth
# is set to True. In this case, we'll resolve tenant/database to the default
# values, which might not match the provided values. Thus, it's important
# to ensure that the flag is set to True only when there is an auth provider.
if (
user_provided_tenant
and user_provided_tenant != DEFAULT_TENANT
and new_tenant
and new_tenant != user_provided_tenant
):
raise ChromaAuthError(f"Tenant {user_provided_tenant} does not match {new_tenant} from the server. Are you sure the tenant is correct?")
if (
user_provided_database
and user_provided_database != DEFAULT_DATABASE
and new_database
and new_database != user_provided_database
):
raise ChromaAuthError(f"Database {user_provided_database} does not match {new_database} from the server. Are you sure the database is correct?")
if (
not user_provided_tenant or user_provided_tenant == DEFAULT_TENANT
) and new_tenant:
user_provided_tenant = new_tenant
if (
not user_provided_database or user_provided_database == DEFAULT_DATABASE
) and new_database:
user_provided_database = new_database
return user_provided_tenant, user_provided_database

View File

@@ -0,0 +1,122 @@
from typing import Dict, List, Mapping, Optional, Sequence, Union, Any
from typing_extensions import Literal, Final
from dataclasses import dataclass
import numpy as np
from numpy.typing import NDArray
# Type tag constants
TYPE_KEY: Final[str] = "#type"
SPARSE_VECTOR_TYPE_VALUE: Final[str] = "sparse_vector"
@dataclass
class SparseVector:
"""Represents a sparse vector using parallel arrays for indices and values.
Attributes:
indices: List of dimension indices (must be non-negative integers, sorted in strictly ascending order)
values: List of values corresponding to each index (floats)
Note:
- Indices must be sorted in strictly ascending order (no duplicates)
- Indices and values must have the same length
- All validations are performed in __post_init__
"""
indices: List[int]
values: List[float]
def __post_init__(self) -> None:
"""Validate the sparse vector structure."""
if not isinstance(self.indices, list):
raise ValueError(
f"Expected SparseVector indices to be a list, got {type(self.indices).__name__}"
)
if not isinstance(self.values, list):
raise ValueError(
f"Expected SparseVector values to be a list, got {type(self.values).__name__}"
)
if len(self.indices) != len(self.values):
raise ValueError(
f"SparseVector indices and values must have the same length, "
f"got {len(self.indices)} indices and {len(self.values)} values"
)
for i, idx in enumerate(self.indices):
if not isinstance(idx, int):
raise ValueError(
f"SparseVector indices must be integers, got {type(idx).__name__} at position {i}"
)
if idx < 0:
raise ValueError(
f"SparseVector indices must be non-negative, got {idx} at position {i}"
)
for i, val in enumerate(self.values):
if not isinstance(val, (int, float)):
raise ValueError(
f"SparseVector values must be numbers, got {type(val).__name__} at position {i}"
)
# Validate indices are sorted in strictly ascending order
if len(self.indices) > 1:
for i in range(1, len(self.indices)):
if self.indices[i] <= self.indices[i - 1]:
raise ValueError(
f"SparseVector indices must be sorted in strictly ascending order, "
f"found indices[{i}]={self.indices[i]} <= indices[{i-1}]={self.indices[i-1]}"
)
def to_dict(self) -> Dict[str, Any]:
"""Serialize to transport format with type tag."""
return {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": self.indices,
"values": self.values,
}
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "SparseVector":
"""Deserialize from transport format (strict - requires #type field)."""
if d.get(TYPE_KEY) != SPARSE_VECTOR_TYPE_VALUE:
raise ValueError(
f"Expected {TYPE_KEY}='{SPARSE_VECTOR_TYPE_VALUE}', got {d.get(TYPE_KEY)}"
)
return cls(indices=d["indices"], values=d["values"])
Metadata = Mapping[str, Optional[Union[str, int, float, bool, SparseVector]]]
UpdateMetadata = Mapping[str, Union[int, float, str, bool, SparseVector, None]]
PyVector = Union[Sequence[float], Sequence[int]]
Vector = NDArray[Union[np.int32, np.float32]] # TODO: Specify that the vector is 1D
# Metadata Query Grammar
LiteralValue = Union[str, int, float, bool]
LogicalOperator = Union[Literal["$and"], Literal["$or"]]
WhereOperator = Union[
Literal["$gt"],
Literal["$gte"],
Literal["$lt"],
Literal["$lte"],
Literal["$ne"],
Literal["$eq"],
]
InclusionExclusionOperator = Union[Literal["$in"], Literal["$nin"]]
OperatorExpression = Union[
Dict[Union[WhereOperator, LogicalOperator], LiteralValue],
Dict[InclusionExclusionOperator, List[LiteralValue]],
]
Where = Dict[
Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]]
]
WhereDocumentOperator = Union[
Literal["$contains"],
Literal["$not_contains"],
Literal["$regex"],
Literal["$not_regex"],
LogicalOperator,
]
WhereDocument = Dict[WhereDocumentOperator, Union[str, List["WhereDocument"]]]

View File

@@ -0,0 +1,195 @@
from typing import List, Optional, Sequence
from uuid import UUID
from chromadb import CollectionMetadata, Embeddings, IDs
from chromadb.api.types import (
CollectionMetadata,
Documents,
Embeddings,
IDs,
Metadatas,
URIs,
Include,
)
from chromadb.types import Tenant, Collection as CollectionModel
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
from enum import Enum
class DatabaseFromBindings:
id: UUID
name: str
tenant: str
# Result Types
class GetResponse:
ids: IDs
embeddings: Embeddings
documents: Documents
uris: URIs
metadatas: Metadatas
include: Include
class QueryResponse:
ids: List[IDs]
embeddings: Optional[List[Embeddings]]
documents: Optional[List[Documents]]
uris: Optional[List[URIs]]
metadatas: Optional[List[Metadatas]]
distances: Optional[List[List[float]]]
include: Include
class GetTenantResponse:
name: str
# SqliteDBConfig types
class MigrationMode(Enum):
Apply = 0
Validate = 1
class MigrationHash(Enum):
SHA256 = 0
MD5 = 1
class SqliteDBConfig:
url: str
hash_type: MigrationHash
migration_mode: MigrationMode
def __init__(
self, url: str, hash_type: MigrationHash, migration_mode: MigrationMode
) -> None: ...
class Bindings:
def __init__(
self,
allow_reset: bool,
sqlite_db_config: SqliteDBConfig,
persist_path: str,
hnsw_cache_size: int,
) -> None: ...
def heartbeat(self) -> int: ...
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: ...
def get_database(
self, name: str, tenant: str = DEFAULT_TENANT
) -> DatabaseFromBindings: ...
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: ...
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[DatabaseFromBindings]: ...
def create_tenant(self, name: str) -> None: ...
def get_tenant(self, name: str) -> GetTenantResponse: ...
def count_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> int: ...
def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Sequence[CollectionModel]: ...
def create_collection(
self,
name: str,
configuration_json_str: Optional[str] = None,
schema_str: Optional[str] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel: ...
def get_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel: ...
def update_collection(
self,
id: str,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
new_configuration_json_str: Optional[str] = None,
) -> None: ...
def delete_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None: ...
def add(
self,
ids: IDs,
collection_id: str,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool: ...
def update(
self,
collection_id: str,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool: ...
def upsert(
self,
collection_id: str,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool: ...
def delete(
self,
collection_id: str,
ids: Optional[IDs] = None,
where: Optional[str] = None,
where_document: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None: ...
def count(
self,
collection_id: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int: ...
def get(
self,
collection_id: str,
ids: Optional[IDs] = None,
where: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[str] = None,
include: Include = ["metadatas", "documents"], # type: ignore[list-item]
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResponse: ...
def query(
self,
collection_id: str,
query_embeddings: Embeddings,
n_results: int = 10,
where: Optional[str] = None,
where_document: Optional[str] = None,
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResponse: ...
def reset(self) -> bool: ...
def get_version(self) -> str: ...

View File

@@ -0,0 +1,56 @@
import re
import sys
import chromadb_rust_bindings
import requests
from packaging.version import parse
import chromadb
def build_cli_args(**kwargs):
args = []
for key, value in kwargs.items():
if isinstance(value, bool):
if value:
args.append(f"--{key}")
elif value is not None:
args.extend([f"--{key}", str(value)])
return args
def update():
try:
url = f"https://api.github.com/repos/chroma-core/chroma/releases"
response = requests.get(url)
response.raise_for_status()
releases = response.json()
version_pattern = re.compile(r'^\d+\.\d+\.\d+$')
numeric_releases = [r["tag_name"] for r in releases if version_pattern.fullmatch(r["tag_name"])]
if not numeric_releases:
print("Couldn't fetch the latest Chroma version")
return
latest = max(numeric_releases, key=parse)
if latest == chromadb.__version__:
print("Your Chroma version is up-to-date")
return
print(
f"A new version of Chroma is available!\nIf you're using pip, run 'pip install --upgrade chromadb' to upgrade to version {latest}")
except Exception as e:
print("Couldn't fetch the latest Chroma version")
def app():
args = sys.argv
if ["chroma", "update"] in args:
update()
return
try:
chromadb_rust_bindings.cli(args)
except KeyboardInterrupt:
pass

View File

@@ -0,0 +1,40 @@
from typing import Any, Dict
import os
import yaml
def set_log_file_path(
log_config_path: str, new_filename: str = "chroma.log"
) -> Dict[str, Any]:
"""This works with the standard log_config.yml file.
It will not work with custom log configs that may use different handlers"""
with open(f"{log_config_path}", "r") as file:
log_config = yaml.safe_load(file)
for handler in log_config["handlers"].values():
if handler.get("class") == "logging.handlers.RotatingFileHandler":
handler["filename"] = new_filename
return log_config
def get_directory_size(directory: str) -> int:
"""Get the size of a directory in bytes"""
total = 0
with os.scandir(directory) as it:
for entry in it:
if entry.is_file():
total += entry.stat().st_size
elif entry.is_dir():
total += get_directory_size(entry.path)
return total
# https://stackoverflow.com/a/1094933
def sizeof_fmt(num: int, suffix: str = "B") -> str:
n: float = float(num)
for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
if abs(n) < 1024.0:
return f"{n:3.1f}{unit}{suffix}"
n /= 1024.0
return f"{n:.1f}Yi{suffix}"

View File

@@ -0,0 +1,503 @@
import importlib
import inspect
import logging
from abc import ABC
from enum import Enum
from graphlib import TopologicalSorter
from typing import Optional, List, Any, Dict, Set, Iterable, Union
from typing import Type, TypeVar, cast
from overrides import EnforceOverrides
from overrides import override
from typing_extensions import Literal
import platform
in_pydantic_v2 = False
try:
from pydantic import BaseSettings
except ImportError:
in_pydantic_v2 = True
from pydantic.v1 import BaseSettings
from pydantic.v1 import validator
if not in_pydantic_v2:
from pydantic import validator # type: ignore # noqa
# The thin client will have a flag to control which implementations to use
is_thin_client = False
try:
from chromadb.is_thin_client import is_thin_client # type: ignore
except ImportError:
is_thin_client = False
logger = logging.getLogger(__name__)
LEGACY_ERROR = """\033[91mYou are using a deprecated configuration of Chroma.
\033[94mIf you do not have data you wish to migrate, you only need to change how you construct
your Chroma client. Please see the "New Clients" section of https://docs.trychroma.com/deployment/migration.
________________________________________________________________________________________________
If you do have data you wish to migrate, we have a migration tool you can use in order to
migrate your data to the new Chroma architecture.
Please `pip install chroma-migrate` and run `chroma-migrate` to migrate your data and then
change how you construct your Chroma client.
See https://docs.trychroma.com/deployment/migration for more information or join our discord at https://discord.gg/MMeYNTmh3x for help!\033[0m"""
_legacy_config_keys = {
"chroma_db_impl",
}
_legacy_config_values = {
"duckdb",
"duckdb+parquet",
"clickhouse",
"local",
"rest",
"chromadb.db.duckdb.DuckDB",
"chromadb.db.duckdb.PersistentDuckDB",
"chromadb.db.clickhouse.Clickhouse",
"chromadb.api.local.LocalAPI",
}
# Map specific abstract types to the setting which specifies which
# concrete implementation to use.
# Please keep these sorted. We're civilized people.
_abstract_type_keys: Dict[str, str] = {
# TODO: Don't use concrete types here to avoid circular deps. Strings are
# fine for right here!
# NOTE: this is to support legacy api construction. Use ServerAPI instead
"chromadb.api.API": "chroma_api_impl",
"chromadb.api.ServerAPI": "chroma_api_impl",
"chromadb.api.async_api.AsyncServerAPI": "chroma_api_impl",
"chromadb.auth.ClientAuthProvider": "chroma_client_auth_provider",
"chromadb.auth.ServerAuthenticationProvider": "chroma_server_authn_provider",
"chromadb.auth.ServerAuthorizationProvider": "chroma_server_authz_provider",
"chromadb.db.system.SysDB": "chroma_sysdb_impl",
"chromadb.execution.executor.abstract.Executor": "chroma_executor_impl",
"chromadb.ingest.Consumer": "chroma_consumer_impl",
"chromadb.ingest.Producer": "chroma_producer_impl",
"chromadb.quota.QuotaEnforcer": "chroma_quota_enforcer_impl",
"chromadb.rate_limit.RateLimitEnforcer": "chroma_rate_limit_enforcer_impl",
"chromadb.rate_limit.AsyncRateLimitEnforcer": "chroma_async_rate_limit_enforcer_impl",
"chromadb.segment.SegmentManager": "chroma_segment_manager_impl",
"chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl",
"chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl",
"chromadb.telemetry.product.ProductTelemetryClient": "chroma_product_telemetry_impl",
}
DEFAULT_TENANT = "default_tenant"
DEFAULT_DATABASE = "default_database"
class APIVersion(str, Enum):
V1 = "/api/v1"
V2 = "/api/v2"
# NOTE(hammadb) 1/13/2024 - This has to be in config.py instead of being localized to the module
# that uses it because of a circular import issue. This is a temporary solution until we can
# refactor the code to remove the circular import.
class RoutingMode(Enum):
"""
Routing mode for the segment directory
node - Assign based on the node name, used in production with multi-node settings with the assumption that
there is one query service pod per node. This is useful for when there is a disk based cache on the
node that we want to route to.
id - Assign based on the member id, used in development and testing environments where the node name is not
guaranteed to be unique. (I.e a local development kubernetes env). Or when there are multiple query service
pods per node.
"""
NODE = "node"
ID = "id"
class Settings(BaseSettings): # type: ignore
# ==============
# Generic config
# ==============
environment: str = ""
# Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" or "chromadb.api.rust.RustBindingsAPI"
chroma_api_impl: str = "chromadb.api.rust.RustBindingsAPI"
@validator("chroma_server_nofile", pre=True, always=True, allow_reuse=True)
def empty_str_to_none(cls, v: str) -> Optional[str]:
if type(v) is str and v.strip() == "":
return None
return v
chroma_server_nofile: Optional[int] = None
# the number of maximum threads to handle synchronous tasks in the FastAPI server
chroma_server_thread_pool_size: int = 40
# ==================
# Client-mode config
# ==================
tenant_id: str = "default"
topic_namespace: str = "default"
chroma_server_host: Optional[str] = None
chroma_server_headers: Optional[Dict[str, str]] = None
chroma_server_http_port: Optional[int] = None
chroma_server_ssl_enabled: Optional[bool] = False
chroma_server_ssl_verify: Optional[Union[bool, str]] = None
chroma_server_api_default_path: Optional[APIVersion] = APIVersion.V2
# eg ["http://localhost:8000"]
chroma_server_cors_allow_origins: List[str] = []
# ==================
# Server config
# ==================
is_persistent: bool = False
persist_directory: str = "./chroma"
chroma_memory_limit_bytes: int = 0
chroma_segment_cache_policy: Optional[str] = None
allow_reset: bool = False
# ===========================
# {Client, Server} auth{n, z}
# ===========================
# The header to use for the token. Defaults to "Authorization".
chroma_auth_token_transport_header: Optional[str] = None
# ================
# Client auth{n,z}
# ================
# The provider for client auth. See chromadb/auth/__init__.py
chroma_client_auth_provider: Optional[str] = None
# If needed by the provider (e.g. BasicAuthClientProvider),
# the credentials to use.
chroma_client_auth_credentials: Optional[str] = None
# ================
# Server auth{n,z}
# ================
chroma_server_auth_ignore_paths: Dict[str, List[str]] = {
f"{APIVersion.V2}": ["GET"],
f"{APIVersion.V2}/heartbeat": ["GET"],
f"{APIVersion.V2}/version": ["GET"],
f"{APIVersion.V1}": ["GET"],
f"{APIVersion.V1}/heartbeat": ["GET"],
f"{APIVersion.V1}/version": ["GET"],
}
# Overwrite singleton tenant and database access from the auth provider
# if applicable. See chromadb/auth/utils/__init__.py's
# authenticate_and_authorize_or_raise method.
chroma_overwrite_singleton_tenant_database_access_from_auth: bool = False
# ============
# Server authn
# ============
chroma_server_authn_provider: Optional[str] = None
# Only one of the below may be specified.
chroma_server_authn_credentials: Optional[str] = None
chroma_server_authn_credentials_file: Optional[str] = None
# ============
# Server authz
# ============
chroma_server_authz_provider: Optional[str] = None
# Only one of the below may be specified.
chroma_server_authz_config: Optional[str] = None
chroma_server_authz_config_file: Optional[str] = None
# =========
# Telemetry
# =========
chroma_product_telemetry_impl: str = "chromadb.telemetry.product.posthog.Posthog"
# Required for backwards compatibility
chroma_telemetry_impl: str = chroma_product_telemetry_impl
anonymized_telemetry: bool = True
chroma_otel_collection_endpoint: Optional[str] = ""
chroma_otel_service_name: Optional[str] = "chromadb"
chroma_otel_collection_headers: Dict[str, str] = {}
chroma_otel_granularity: Optional[str] = None
# ==========
# Migrations
# ==========
migrations: Literal["none", "validate", "apply"] = "apply"
# you cannot change the hash_algorithm after migrations have already
# been applied once this is intended to be a first-time setup configuration
migrations_hash_algorithm: Literal["md5", "sha256"] = "md5"
# ==================
# Distributed Chroma
# ==================
chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory"
chroma_segment_directory_routing_mode: RoutingMode = RoutingMode.ID
chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider"
worker_memberlist_name: str = "query-service-memberlist"
chroma_coordinator_host = "localhost"
# TODO this is the sysdb port. Should probably rename it.
chroma_server_grpc_port: Optional[int] = None
chroma_sysdb_impl: str = "chromadb.db.impl.sqlite.SqliteDB"
chroma_producer_impl: str = "chromadb.db.impl.sqlite.SqliteDB"
chroma_consumer_impl: str = "chromadb.db.impl.sqlite.SqliteDB"
chroma_segment_manager_impl: str = (
"chromadb.segment.impl.manager.local.LocalSegmentManager"
)
chroma_executor_impl: str = "chromadb.execution.executor.local.LocalExecutor"
chroma_query_replication_factor: int = 2
chroma_logservice_host = "localhost"
chroma_logservice_port = 50052
chroma_quota_provider_impl: Optional[str] = None
chroma_rate_limiting_provider_impl: Optional[str] = None
chroma_quota_enforcer_impl: str = (
"chromadb.quota.simple_quota_enforcer.SimpleQuotaEnforcer"
)
chroma_rate_limit_enforcer_impl: str = (
"chromadb.rate_limit.simple_rate_limit.SimpleRateLimitEnforcer"
)
chroma_async_rate_limit_enforcer_impl: str = (
"chromadb.rate_limit.simple_rate_limit.SimpleAsyncRateLimitEnforcer"
)
# ==========
# gRPC service config
# ==========
chroma_logservice_request_timeout_seconds: int = 3
chroma_sysdb_request_timeout_seconds: int = 3
chroma_query_request_timeout_seconds: int = 60
# ======
# Legacy
# ======
chroma_db_impl: Optional[str] = None
chroma_collection_assignment_policy_impl: str = (
"chromadb.ingest.impl.simple_policy.SimpleAssignmentPolicy"
)
# =======
# Methods
# =======
def require(self, key: str) -> Any:
"""Return the value of a required config key, or raise an exception if it is not
set"""
val = self[key]
if val is None:
raise ValueError(f"Missing required config value '{key}'")
return val
def __getitem__(self, key: str) -> Any:
val = getattr(self, key)
# Error on legacy config values
if isinstance(val, str) and val in _legacy_config_values:
raise ValueError(LEGACY_ERROR)
return val
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
T = TypeVar("T", bound="Component")
class Component(ABC, EnforceOverrides):
_dependencies: Set["Component"]
_system: "System"
_running: bool
def __init__(self, system: "System"):
self._dependencies = set()
self._system = system
self._running = False
def require(self, type: Type[T]) -> T:
"""Get a Component instance of the given type, and register as a dependency of
that instance."""
inst = self._system.instance(type)
self._dependencies.add(inst)
return inst
def dependencies(self) -> Set["Component"]:
"""Return the full set of components this component depends on."""
return self._dependencies
def stop(self) -> None:
"""Idempotently stop this component's execution and free all associated
resources."""
logger.debug(f"Stopping component {self.__class__.__name__}")
self._running = False
def start(self) -> None:
"""Idempotently start this component's execution"""
logger.debug(f"Starting component {self.__class__.__name__}")
self._running = True
def reset_state(self) -> None:
"""Reset this component's state to its initial blank state. Only intended to be
called from tests."""
logger.debug(f"Resetting component {self.__class__.__name__}")
class System(Component):
settings: Settings
_instances: Dict[Type[Component], Component]
def __init__(self, settings: Settings):
if is_thin_client:
# The thin client is a system with only the API component
if settings["chroma_api_impl"] not in [
"chromadb.api.fastapi.FastAPI",
"chromadb.api.async_fastapi.AsyncFastAPI",
]:
raise RuntimeError(
"Chroma is running in http-only client mode, and can only be run with 'chromadb.api.fastapi.FastAPI' or 'chromadb.api.async_fastapi.AsyncFastAPI' as the chroma_api_impl. \
see https://docs.trychroma.com/guides#using-the-python-http-only-client for more information."
)
# Validate settings don't contain any legacy config values
for key in _legacy_config_keys:
if settings[key] is not None:
raise ValueError(LEGACY_ERROR)
if (
settings["chroma_segment_cache_policy"] is not None
and settings["chroma_segment_cache_policy"] != "LRU"
):
logger.error(
"Failed to set chroma_segment_cache_policy: Only LRU is available."
)
if settings["chroma_memory_limit_bytes"] == 0:
logger.error(
"Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require."
)
# Apply the nofile limit if set
if settings["chroma_server_nofile"] is not None:
if platform.system() != "Windows":
import resource
curr_soft, curr_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
desired_soft = settings["chroma_server_nofile"]
# Validate
if desired_soft > curr_hard:
logging.warning(
f"chroma_server_nofile cannot be set to a value greater than the current hard limit of {curr_hard}. Keeping soft limit at {curr_soft}"
)
# Apply
elif desired_soft > curr_soft:
try:
resource.setrlimit(
resource.RLIMIT_NOFILE, (desired_soft, curr_hard)
)
logger.info(f"Set chroma_server_nofile to {desired_soft}")
except Exception as e:
logger.error(
f"Failed to set chroma_server_nofile to {desired_soft}: {e} nofile soft limit will remain at {curr_soft}"
)
# Don't apply if reducing the limit
elif desired_soft < curr_soft:
logger.warning(
f"chroma_server_nofile is set to {desired_soft}, but this is less than current soft limit of {curr_soft}. chroma_server_nofile will not be set."
)
else:
logger.warning(
"chroma_server_nofile is not supported on Windows. chroma_server_nofile will not be set."
)
self.settings = settings
self._instances = {}
super().__init__(self)
def instance(self, type: Type[T]) -> T:
"""Return an instance of the component type specified. If the system is running,
the component will be started as well."""
if inspect.isabstract(type):
type_fqn = get_fqn(type)
if type_fqn not in _abstract_type_keys:
raise ValueError(f"Cannot instantiate abstract type: {type}")
key = _abstract_type_keys[type_fqn]
fqn = self.settings.require(key)
type = get_class(fqn, type)
if type not in self._instances:
impl = type(self)
self._instances[type] = impl
if self._running:
impl.start()
inst = self._instances[type]
return cast(T, inst)
def components(self) -> Iterable[Component]:
"""Return the full set of all components and their dependencies in dependency
order."""
sorter: TopologicalSorter[Component] = TopologicalSorter()
for component in self._instances.values():
sorter.add(component, *component.dependencies())
return sorter.static_order()
@override
def start(self) -> None:
super().start()
for component in self.components():
component.start()
@override
def stop(self) -> None:
super().stop()
for component in reversed(list(self.components())):
component.stop()
@override
def reset_state(self) -> None:
"""Reset the state of this system and all constituents in reverse dependency order"""
if not self.settings.allow_reset:
raise ValueError(
"Resetting is not allowed by this configuration (to enable it, set `allow_reset` to `True` in your Settings() or include `ALLOW_RESET=TRUE` in your environment variables)"
)
for component in reversed(list(self.components())):
component.reset_state()
C = TypeVar("C")
def get_class(fqn: str, type: Type[C]) -> Type[C]:
"""Given a fully qualifed class name, import the module and return the class"""
module_name, class_name = fqn.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
return cast(Type[C], cls)
def get_fqn(cls: Type[object]) -> str:
"""Given a class, return its fully qualified name"""
return f"{cls.__module__}.{cls.__name__}"

View File

@@ -0,0 +1,122 @@
from abc import abstractmethod
from typing import List, Sequence, Optional, Tuple
from uuid import UUID
from chromadb.api.types import (
Embeddings,
Documents,
IDs,
Metadatas,
Metadata,
Where,
WhereDocument,
)
from chromadb.config import Component
class DB(Component):
@abstractmethod
def create_collection(
self,
name: str,
metadata: Optional[Metadata] = None,
get_or_create: bool = False,
) -> Sequence: # type: ignore
pass
@abstractmethod
def get_collection(self, name: str) -> Sequence: # type: ignore
pass
@abstractmethod
def list_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> Sequence: # type: ignore
pass
@abstractmethod
def count_collections(self) -> int:
pass
@abstractmethod
def update_collection(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[Metadata] = None,
) -> None:
pass
@abstractmethod
def delete_collection(self, name: str) -> None:
pass
@abstractmethod
def get_collection_uuid_from_name(self, collection_name: str) -> UUID:
pass
@abstractmethod
def add(
self,
collection_uuid: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas],
documents: Optional[Documents],
ids: List[str],
) -> List[UUID]:
pass
@abstractmethod
def get(
self,
where: Optional[Where] = None,
collection_name: Optional[str] = None,
collection_uuid: Optional[UUID] = None,
ids: Optional[IDs] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
columns: Optional[List[str]] = None,
) -> Sequence: # type: ignore
pass
@abstractmethod
def update(
self,
collection_uuid: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
) -> bool:
pass
@abstractmethod
def count(self, collection_id: UUID) -> int:
pass
@abstractmethod
def delete(
self,
where: Optional[Where] = None,
collection_uuid: Optional[UUID] = None,
ids: Optional[IDs] = None,
where_document: Optional[WhereDocument] = None,
) -> None:
pass
@abstractmethod
def get_nearest_neighbors(
self,
collection_uuid: UUID,
where: Optional[Where] = None,
embeddings: Optional[Embeddings] = None,
n_results: int = 10,
where_document: Optional[WhereDocument] = None,
) -> Tuple[List[List[UUID]], List[List[float]]]:
pass
@abstractmethod
def get_by_ids(
self, uuids: List[UUID], columns: Optional[List[str]] = None
) -> Sequence: # type: ignore
pass

View File

@@ -0,0 +1,180 @@
from typing import Any, Optional, Sequence, Tuple, Type
from types import TracebackType
from typing_extensions import Protocol, Self, Literal
from abc import ABC, abstractmethod
from threading import local
from overrides import override, EnforceOverrides
import pypika
import pypika.queries
from chromadb.config import System, Component
from uuid import UUID
from itertools import islice, count
class Cursor(Protocol):
"""Reifies methods we use from a DBAPI2 Cursor since DBAPI2 is not typed."""
def execute(self, sql: str, params: Optional[Tuple[Any, ...]] = None) -> Self:
...
def executescript(self, script: str) -> Self:
...
def executemany(
self, sql: str, params: Optional[Sequence[Tuple[Any, ...]]] = None
) -> Self:
...
def fetchone(self) -> Tuple[Any, ...]:
...
def fetchall(self) -> Sequence[Tuple[Any, ...]]:
...
class TxWrapper(ABC, EnforceOverrides):
"""Wrapper class for DBAPI 2.0 Connection objects, with which clients can implement transactions.
Makes two guarantees that basic DBAPI 2.0 connections do not:
- __enter__ returns a Cursor object consistently (instead of a Connection like some do)
- Always re-raises an exception if one was thrown from the body
"""
@abstractmethod
def __enter__(self) -> Cursor:
pass
@abstractmethod
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Literal[False]:
pass
class SqlDB(Component):
"""DBAPI 2.0 interface wrapper to ensure consistent behavior between implementations"""
def __init__(self, system: System):
super().__init__(system)
@abstractmethod
def tx(self) -> TxWrapper:
"""Return a transaction wrapper"""
pass
@staticmethod
@abstractmethod
def querybuilder() -> Type[pypika.Query]:
"""Return a PyPika Query builder of an appropriate subtype for this database
implementation (see
https://pypika.readthedocs.io/en/latest/3_advanced.html#handling-different-database-platforms)
"""
pass
@staticmethod
@abstractmethod
def parameter_format() -> str:
"""Return the appropriate parameter format for this database implementation.
Will be called with str.format(i) where i is the numeric index of the parameter.
"""
pass
@staticmethod
@abstractmethod
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]:
"""Convert a UUID to a value that can be passed to the DB driver"""
pass
@staticmethod
@abstractmethod
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]:
"""Convert a value from the DB driver to a UUID"""
pass
@staticmethod
@abstractmethod
def unique_constraint_error() -> Type[BaseException]:
"""Return the exception type that the DB raises when a unique constraint is
violated"""
pass
def param(self, idx: int) -> pypika.Parameter:
"""Return a PyPika Parameter object for the given index"""
return pypika.Parameter(self.parameter_format().format(idx))
_context = local()
class ParameterValue(pypika.Parameter): # type: ignore
"""
Wrapper class for PyPika paramters that allows the values for Parameters
to be expressed inline while building a query. See get_sql() for
detailed usage information.
"""
def __init__(self, value: Any):
self.value = value
@override
def get_sql(self, **kwargs: Any) -> str:
if isinstance(self.value, (list, tuple)):
_context.values.extend(self.value)
indexes = islice(_context.generator, len(self.value))
placeholders = ", ".join(_context.formatstr.format(i) for i in indexes)
val = f"({placeholders})"
else:
_context.values.append(self.value)
val = _context.formatstr.format(next(_context.generator))
return str(val)
def get_sql(
query: pypika.queries.QueryBuilder, formatstr: str = "?"
) -> Tuple[str, Tuple[Any, ...]]:
"""
Wrapper for pypika's get_sql method that allows the values for Parameters
to be expressed inline while building a query, and that returns a tuple of the
SQL string and parameters. This makes it easier to construct complex queries
programmatically and automatically matches up the generated SQL with the required
parameter vector.
Doing so requires using the ParameterValue class defined in this module instead
of the base pypika.Parameter class.
Usage Example:
q = (
pypika.Query().from_("table")
.select("col1")
.where("col2"==ParameterValue("foo"))
.where("col3"==ParameterValue("bar"))
)
sql, params = get_sql(q)
cursor.execute(sql, params)
Note how it is not necessary to construct the parameter vector manually... it
will always be generated with the parameter values in the same order as emitted
SQL string.
The format string should match the parameter format for the database being used.
It will be called with str.format(i) where i is the numeric index of the parameter.
For example, Postgres requires parameters like `:1`, `:2`, etc. so the format string
should be `":{}"`.
See https://pypika.readthedocs.io/en/latest/2_tutorial.html#parametrized-queries for more
information on parameterized queries in PyPika.
"""
_context.values = []
_context.generator = count(1)
_context.formatstr = formatstr
sql = query.get_sql()
params = tuple(_context.values)
return sql, params

View File

@@ -0,0 +1,558 @@
from typing import List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID
from overrides import overrides
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
create_collection_configuration_to_json_str,
UpdateCollectionConfiguration,
update_collection_configuration_to_json_str,
CollectionMetadata,
)
from chromadb.api.types import Schema
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, logger
from chromadb.db.system import SysDB
from chromadb.errors import NotFoundError, UniqueConstraintError, InternalError
from chromadb.proto.convert import (
from_proto_collection,
from_proto_segment,
to_proto_update_metadata,
to_proto_segment,
to_proto_segment_scope,
)
from chromadb.proto.coordinator_pb2 import (
CreateCollectionRequest,
CreateDatabaseRequest,
CreateSegmentRequest,
CreateTenantRequest,
CountCollectionsRequest,
CountCollectionsResponse,
DeleteCollectionRequest,
DeleteDatabaseRequest,
DeleteSegmentRequest,
GetCollectionsRequest,
GetCollectionsResponse,
GetCollectionSizeRequest,
GetCollectionSizeResponse,
GetCollectionWithSegmentsRequest,
GetCollectionWithSegmentsResponse,
GetDatabaseRequest,
GetSegmentsRequest,
GetTenantRequest,
ListDatabasesRequest,
UpdateCollectionRequest,
UpdateSegmentRequest,
)
from chromadb.proto.coordinator_pb2_grpc import SysDBStub
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import (
Collection,
CollectionAndSegments,
Database,
Metadata,
OptionalArgument,
Segment,
SegmentScope,
Tenant,
Unspecified,
UpdateMetadata,
)
from google.protobuf.empty_pb2 import Empty
import grpc
class GrpcSysDB(SysDB):
"""A gRPC implementation of the SysDB. In the distributed system, the SysDB is also
called the 'Coordinator'. This implementation is used by Chroma frontend servers
to call a remote SysDB (Coordinator) service."""
_sys_db_stub: SysDBStub
_channel: grpc.Channel
_coordinator_url: str
_coordinator_port: int
_request_timeout_seconds: int
def __init__(self, system: System):
self._coordinator_url = system.settings.require("chroma_coordinator_host")
# TODO: break out coordinator_port into a separate setting?
self._coordinator_port = system.settings.require("chroma_server_grpc_port")
self._request_timeout_seconds = system.settings.require(
"chroma_sysdb_request_timeout_seconds"
)
return super().__init__(system)
@overrides
def start(self) -> None:
self._channel = grpc.insecure_channel(
f"{self._coordinator_url}:{self._coordinator_port}",
options=[("grpc.max_concurrent_streams", 1000)],
)
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
self._channel = grpc.intercept_channel(self._channel, *interceptors)
self._sys_db_stub = SysDBStub(self._channel) # type: ignore
return super().start()
@overrides
def stop(self) -> None:
self._channel.close()
return super().stop()
@overrides
def reset_state(self) -> None:
self._sys_db_stub.ResetState(Empty())
return super().reset_state()
@overrides
def create_database(
self, id: UUID, name: str, tenant: str = DEFAULT_TENANT
) -> None:
try:
request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant)
response = self._sys_db_stub.CreateDatabase(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(
f"Failed to create database name {name} and database id {id} for tenant {tenant} due to error: {e}"
)
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
@overrides
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
try:
request = GetDatabaseRequest(name=name, tenant=tenant)
response = self._sys_db_stub.GetDatabase(
request, timeout=self._request_timeout_seconds
)
return Database(
id=UUID(hex=response.database.id),
name=response.database.name,
tenant=response.database.tenant,
)
except grpc.RpcError as e:
logger.info(
f"Failed to get database {name} for tenant {tenant} due to error: {e}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError()
@overrides
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
try:
request = DeleteDatabaseRequest(name=name, tenant=tenant)
self._sys_db_stub.DeleteDatabase(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(
f"Failed to delete database {name} for tenant {tenant} due to error: {e}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError
@overrides
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
try:
request = ListDatabasesRequest(limit=limit, offset=offset, tenant=tenant)
response = self._sys_db_stub.ListDatabases(
request, timeout=self._request_timeout_seconds
)
results: List[Database] = []
for proto_database in response.databases:
results.append(
Database(
id=UUID(hex=proto_database.id),
name=proto_database.name,
tenant=proto_database.tenant,
)
)
return results
except grpc.RpcError as e:
logger.info(
f"Failed to list databases for tenant {tenant} due to error: {e}"
)
raise InternalError()
@overrides
def create_tenant(self, name: str) -> None:
try:
request = CreateTenantRequest(name=name)
response = self._sys_db_stub.CreateTenant(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(f"Failed to create tenant {name} due to error: {e}")
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
@overrides
def get_tenant(self, name: str) -> Tenant:
try:
request = GetTenantRequest(name=name)
response = self._sys_db_stub.GetTenant(
request, timeout=self._request_timeout_seconds
)
return Tenant(
name=response.tenant.name,
)
except grpc.RpcError as e:
logger.info(f"Failed to get tenant {name} due to error: {e}")
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError()
@overrides
def create_segment(self, segment: Segment) -> None:
try:
proto_segment = to_proto_segment(segment)
request = CreateSegmentRequest(
segment=proto_segment,
)
response = self._sys_db_stub.CreateSegment(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(f"Failed to create segment {segment}, error: {e}")
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
@overrides
def delete_segment(self, collection: UUID, id: UUID) -> None:
try:
request = DeleteSegmentRequest(
id=id.hex,
collection=collection.hex,
)
response = self._sys_db_stub.DeleteSegment(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(
f"Failed to delete segment with id {id} for collection {collection} due to error: {e}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError()
@overrides
def get_segments(
self,
collection: UUID,
id: Optional[UUID] = None,
type: Optional[str] = None,
scope: Optional[SegmentScope] = None,
) -> Sequence[Segment]:
try:
request = GetSegmentsRequest(
id=id.hex if id else None,
type=type,
scope=to_proto_segment_scope(scope) if scope else None,
collection=collection.hex,
)
response = self._sys_db_stub.GetSegments(
request, timeout=self._request_timeout_seconds
)
results: List[Segment] = []
for proto_segment in response.segments:
segment = from_proto_segment(proto_segment)
results.append(segment)
return results
except grpc.RpcError as e:
logger.info(
f"Failed to get segment id {id}, type {type}, scope {scope} for collection {collection} due to error: {e}"
)
raise InternalError()
@overrides
def update_segment(
self,
collection: UUID,
id: UUID,
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
) -> None:
try:
write_metadata = None
if metadata != Unspecified():
write_metadata = cast(Union[UpdateMetadata, None], metadata)
request = UpdateSegmentRequest(
id=id.hex,
collection=collection.hex,
metadata=to_proto_update_metadata(write_metadata)
if write_metadata
else None,
)
if metadata is None:
request.ClearField("metadata")
request.reset_metadata = True
self._sys_db_stub.UpdateSegment(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(
f"Failed to update segment with id {id} for collection {collection}, error: {e}"
)
raise InternalError()
@overrides
def create_collection(
self,
id: UUID,
name: str,
schema: Optional[Schema],
configuration: CreateCollectionConfiguration,
segments: Sequence[Segment],
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Tuple[Collection, bool]:
try:
request = CreateCollectionRequest(
id=id.hex,
name=name,
configuration_json_str=create_collection_configuration_to_json_str(
configuration, cast(CollectionMetadata, metadata)
),
metadata=to_proto_update_metadata(metadata) if metadata else None,
dimension=dimension,
get_or_create=get_or_create,
tenant=tenant,
database=database,
segments=[to_proto_segment(segment) for segment in segments],
)
response = self._sys_db_stub.CreateCollection(
request, timeout=self._request_timeout_seconds
)
collection = from_proto_collection(response.collection)
return collection, response.created
except grpc.RpcError as e:
logger.error(
f"Failed to create collection id {id}, name {name} for database {database} and tenant {tenant} due to error: {e}"
)
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
@overrides
def delete_collection(
self,
id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
try:
request = DeleteCollectionRequest(
id=id.hex,
tenant=tenant,
database=database,
)
response = self._sys_db_stub.DeleteCollection(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.error(
f"Failed to delete collection id {id} for database {database} and tenant {tenant} due to error: {e}"
)
e = cast(grpc.Call, e)
logger.error(
f"Error code: {e.code()}, NotFoundError: {grpc.StatusCode.NOT_FOUND}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError()
@overrides
def get_collections(
self,
id: Optional[UUID] = None,
name: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Collection]:
try:
# TODO: implement limit and offset in the gRPC service
request = None
if id is not None:
request = GetCollectionsRequest(
id=id.hex,
limit=limit,
offset=offset,
)
if name is not None:
if tenant is None and database is None:
raise ValueError(
"If name is specified, tenant and database must also be specified in order to uniquely identify the collection"
)
request = GetCollectionsRequest(
name=name,
tenant=tenant,
database=database,
limit=limit,
offset=offset,
)
if id is None and name is None:
request = GetCollectionsRequest(
tenant=tenant,
database=database,
limit=limit,
offset=offset,
)
response: GetCollectionsResponse = self._sys_db_stub.GetCollections(
request, timeout=self._request_timeout_seconds
)
results: List[Collection] = []
for collection in response.collections:
results.append(from_proto_collection(collection))
return results
except grpc.RpcError as e:
logger.error(
f"Failed to get collections with id {id}, name {name}, tenant {tenant}, database {database} due to error: {e}"
)
raise InternalError()
@overrides
def count_collections(
self,
tenant: str = DEFAULT_TENANT,
database: Optional[str] = None,
) -> int:
try:
if database is None or database == "":
request = CountCollectionsRequest(tenant=tenant)
response: CountCollectionsResponse = self._sys_db_stub.CountCollections(
request
)
return response.count
else:
request = CountCollectionsRequest(
tenant=tenant,
database=database,
)
response: CountCollectionsResponse = self._sys_db_stub.CountCollections(
request
)
return response.count
except grpc.RpcError as e:
logger.error(f"Failed to count collections due to error: {e}")
raise InternalError()
@overrides
def get_collection_size(self, id: UUID) -> int:
try:
request = GetCollectionSizeRequest(id=id.hex)
response: GetCollectionSizeResponse = self._sys_db_stub.GetCollectionSize(
request
)
return response.total_records_post_compaction
except grpc.RpcError as e:
logger.error(f"Failed to get collection {id} size due to error: {e}")
raise InternalError()
@trace_method(
"SysDB.get_collection_with_segments", OpenTelemetryGranularity.OPERATION
)
@overrides
def get_collection_with_segments(
self, collection_id: UUID
) -> CollectionAndSegments:
try:
request = GetCollectionWithSegmentsRequest(id=collection_id.hex)
response: GetCollectionWithSegmentsResponse = (
self._sys_db_stub.GetCollectionWithSegments(request)
)
return CollectionAndSegments(
collection=from_proto_collection(response.collection),
segments=[from_proto_segment(segment) for segment in response.segments],
)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
logger.error(
f"Failed to get collection {collection_id} and its segments due to error: {e}"
)
raise InternalError()
@overrides
def update_collection(
self,
id: UUID,
name: OptionalArgument[str] = Unspecified(),
dimension: OptionalArgument[Optional[int]] = Unspecified(),
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
configuration: OptionalArgument[
Optional[UpdateCollectionConfiguration]
] = Unspecified(),
) -> None:
try:
write_name = None
if name != Unspecified():
write_name = cast(str, name)
write_dimension = None
if dimension != Unspecified():
write_dimension = cast(Union[int, None], dimension)
write_metadata = None
if metadata != Unspecified():
write_metadata = cast(Union[UpdateMetadata, None], metadata)
write_configuration = None
if configuration != Unspecified():
write_configuration = cast(
Union[UpdateCollectionConfiguration, None], configuration
)
request = UpdateCollectionRequest(
id=id.hex,
name=write_name,
dimension=write_dimension,
metadata=to_proto_update_metadata(write_metadata)
if write_metadata
else None,
configuration_json_str=update_collection_configuration_to_json_str(
write_configuration
)
if write_configuration
else None,
)
if metadata is None:
request.ClearField("metadata")
request.reset_metadata = True
response = self._sys_db_stub.UpdateCollection(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
e = cast(grpc.Call, e)
logger.error(
f"Failed to update collection id {id}, name {name} due to error: {e}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
def reset_and_wait_for_ready(self) -> None:
self._sys_db_stub.ResetState(Empty(), wait_for_ready=True)

View File

@@ -0,0 +1,497 @@
from concurrent import futures
from typing import Any, Dict, List, cast
from uuid import UUID
from overrides import overrides
import json
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System
from chromadb.proto.convert import (
from_proto_metadata,
from_proto_update_metadata,
from_proto_segment,
from_proto_segment_scope,
to_proto_collection,
to_proto_segment,
)
import chromadb.proto.chroma_pb2 as proto
from chromadb.proto.coordinator_pb2 import (
CreateCollectionRequest,
CreateCollectionResponse,
CreateDatabaseRequest,
CreateDatabaseResponse,
CreateSegmentRequest,
CreateSegmentResponse,
CreateTenantRequest,
CreateTenantResponse,
CountCollectionsRequest,
CountCollectionsResponse,
DeleteCollectionRequest,
DeleteCollectionResponse,
DeleteSegmentRequest,
DeleteSegmentResponse,
GetCollectionsRequest,
GetCollectionsResponse,
GetCollectionSizeRequest,
GetCollectionSizeResponse,
GetCollectionWithSegmentsRequest,
GetCollectionWithSegmentsResponse,
GetDatabaseRequest,
GetDatabaseResponse,
GetSegmentsRequest,
GetSegmentsResponse,
GetTenantRequest,
GetTenantResponse,
ResetStateResponse,
UpdateCollectionRequest,
UpdateCollectionResponse,
UpdateSegmentRequest,
UpdateSegmentResponse,
)
from chromadb.proto.coordinator_pb2_grpc import (
SysDBServicer,
add_SysDBServicer_to_server,
)
import grpc
from google.protobuf.empty_pb2 import Empty
from chromadb.types import Collection, Metadata, Segment, SegmentScope
class GrpcMockSysDB(SysDBServicer, Component):
"""A mock sysdb implementation that can be used for testing the grpc client. It stores
state in simple python data structures instead of a database."""
_server: grpc.Server
_server_port: int
_segments: Dict[str, Segment] = {}
_collection_to_segments: Dict[str, List[str]] = {}
_tenants_to_databases_to_collections: Dict[
str, Dict[str, Dict[str, Collection]]
] = {}
_tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {}
def __init__(self, system: System):
self._server_port = system.settings.require("chroma_server_grpc_port")
return super().__init__(system)
@overrides
def start(self) -> None:
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
add_SysDBServicer_to_server(self, self._server) # type: ignore
self._server.add_insecure_port(f"[::]:{self._server_port}")
self._server.start()
return super().start()
@overrides
def stop(self) -> None:
self._server.stop(None)
return super().stop()
@overrides
def reset_state(self) -> None:
self._segments = {}
self._tenants_to_databases_to_collections = {}
# Create defaults
self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {}
self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {}
self._tenants_to_database_to_id[DEFAULT_TENANT] = {}
self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0)
return super().reset_state()
@overrides(check_signature=False)
def CreateDatabase(
self, request: CreateDatabaseRequest, context: grpc.ServicerContext
) -> CreateDatabaseResponse:
tenant = request.tenant
database = request.name
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
if database in self._tenants_to_databases_to_collections[tenant]:
context.abort(
grpc.StatusCode.ALREADY_EXISTS, f"Database {database} already exists"
)
self._tenants_to_databases_to_collections[tenant][database] = {}
self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id)
return CreateDatabaseResponse()
@overrides(check_signature=False)
def GetDatabase(
self, request: GetDatabaseRequest, context: grpc.ServicerContext
) -> GetDatabaseResponse:
tenant = request.tenant
database = request.name
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
if database not in self._tenants_to_databases_to_collections[tenant]:
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
id = self._tenants_to_database_to_id[tenant][database]
return GetDatabaseResponse(
database=proto.Database(id=id.hex, name=database, tenant=tenant),
)
@overrides(check_signature=False)
def CreateTenant(
self, request: CreateTenantRequest, context: grpc.ServicerContext
) -> CreateTenantResponse:
tenant = request.name
if tenant in self._tenants_to_databases_to_collections:
context.abort(
grpc.StatusCode.ALREADY_EXISTS, f"Tenant {tenant} already exists"
)
self._tenants_to_databases_to_collections[tenant] = {}
self._tenants_to_database_to_id[tenant] = {}
return CreateTenantResponse()
@overrides(check_signature=False)
def GetTenant(
self, request: GetTenantRequest, context: grpc.ServicerContext
) -> GetTenantResponse:
tenant = request.name
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
return GetTenantResponse(
tenant=proto.Tenant(name=tenant),
)
# We are forced to use check_signature=False because the generated proto code
# does not have type annotations for the request and response objects.
# TODO: investigate generating types for the request and response objects
@overrides(check_signature=False)
def CreateSegment(
self, request: CreateSegmentRequest, context: grpc.ServicerContext
) -> CreateSegmentResponse:
segment = from_proto_segment(request.segment)
return self.CreateSegmentHelper(segment, context)
def CreateSegmentHelper(
self, segment: Segment, context: grpc.ServicerContext
) -> CreateSegmentResponse:
if segment["id"].hex in self._segments:
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Segment {segment['id']} already exists",
)
self._segments[segment["id"].hex] = segment
return CreateSegmentResponse()
@overrides(check_signature=False)
def DeleteSegment(
self, request: DeleteSegmentRequest, context: grpc.ServicerContext
) -> DeleteSegmentResponse:
id_to_delete = request.id
if id_to_delete in self._segments:
del self._segments[id_to_delete]
return DeleteSegmentResponse()
else:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Segment {id_to_delete} not found"
)
@overrides(check_signature=False)
def GetSegments(
self, request: GetSegmentsRequest, context: grpc.ServicerContext
) -> GetSegmentsResponse:
target_id = UUID(hex=request.id) if request.HasField("id") else None
target_type = request.type if request.HasField("type") else None
target_scope = (
from_proto_segment_scope(request.scope)
if request.HasField("scope")
else None
)
target_collection = UUID(hex=request.collection)
found_segments = []
for segment in self._segments.values():
if target_id and segment["id"] != target_id:
continue
if target_type and segment["type"] != target_type:
continue
if target_scope and segment["scope"] != target_scope:
continue
if target_collection and segment["collection"] != target_collection:
continue
found_segments.append(segment)
return GetSegmentsResponse(
segments=[to_proto_segment(segment) for segment in found_segments]
)
@overrides(check_signature=False)
def UpdateSegment(
self, request: UpdateSegmentRequest, context: grpc.ServicerContext
) -> UpdateSegmentResponse:
id_to_update = UUID(request.id)
if id_to_update.hex not in self._segments:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Segment {id_to_update} not found"
)
else:
segment = self._segments[id_to_update.hex]
if request.HasField("metadata"):
target = cast(Dict[str, Any], segment["metadata"])
if segment["metadata"] is None:
segment["metadata"] = {}
self._merge_metadata(target, request.metadata)
if request.HasField("reset_metadata") and request.reset_metadata:
segment["metadata"] = {}
return UpdateSegmentResponse()
@overrides(check_signature=False)
def CreateCollection(
self, request: CreateCollectionRequest, context: grpc.ServicerContext
) -> CreateCollectionResponse:
collection_name = request.name
tenant = request.tenant
database = request.database
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
if database not in self._tenants_to_databases_to_collections[tenant]:
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
# Check if the collection already exists globally by id
for (
search_tenant,
databases,
) in self._tenants_to_databases_to_collections.items():
for search_database, search_collections in databases.items():
if request.id in search_collections:
if (
search_tenant != request.tenant
or search_database != request.database
):
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
)
elif not request.get_or_create:
# If the id exists for this tenant and database, and we are not doing a get_or_create, then
# we should return an already exists error
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
)
# Check if the collection already exists in this database by name
collections = self._tenants_to_databases_to_collections[tenant][database]
matches = [c for c in collections.values() if c["name"] == collection_name]
assert len(matches) <= 1
if len(matches) > 0:
if request.get_or_create:
existing_collection = matches[0]
return CreateCollectionResponse(
collection=to_proto_collection(existing_collection),
created=False,
)
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Collection {collection_name} already exists",
)
configuration_json = json.loads(request.configuration_json_str)
id = UUID(hex=request.id)
new_collection = Collection(
id=id,
name=request.name,
configuration_json=configuration_json,
serialized_schema=None,
metadata=from_proto_metadata(request.metadata),
dimension=request.dimension,
database=database,
tenant=tenant,
version=0,
)
# Check that segments are unique and do not already exist
# Keep a track of the segments that are being added
segments_added = []
# Create segments for the collection
for segment_proto in request.segments:
segment = from_proto_segment(segment_proto)
if segment["id"].hex in self._segments:
# Remove the already added segment since we need to roll back
for s in segments_added:
self.DeleteSegment(DeleteSegmentRequest(id=s), context)
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Segment {segment['id']} already exists",
)
self.CreateSegmentHelper(segment, context)
segments_added.append(segment["id"].hex)
collections[request.id] = new_collection
collection_unique_key = f"{tenant}:{database}:{request.id}"
self._collection_to_segments[collection_unique_key] = segments_added
return CreateCollectionResponse(
collection=to_proto_collection(new_collection),
created=True,
)
@overrides(check_signature=False)
def DeleteCollection(
self, request: DeleteCollectionRequest, context: grpc.ServicerContext
) -> DeleteCollectionResponse:
collection_id = request.id
tenant = request.tenant
database = request.database
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
if database not in self._tenants_to_databases_to_collections[tenant]:
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
collections = self._tenants_to_databases_to_collections[tenant][database]
if collection_id in collections:
del collections[collection_id]
collection_unique_key = f"{tenant}:{database}:{collection_id}"
segment_ids = self._collection_to_segments[collection_unique_key]
if segment_ids: # Delete segments if provided.
for segment_id in segment_ids:
del self._segments[segment_id]
return DeleteCollectionResponse()
else:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Collection {collection_id} not found"
)
@overrides(check_signature=False)
def GetCollections(
self, request: GetCollectionsRequest, context: grpc.ServicerContext
) -> GetCollectionsResponse:
target_id = UUID(hex=request.id) if request.HasField("id") else None
target_name = request.name if request.HasField("name") else None
allCollections = {}
for tenant, databases in self._tenants_to_databases_to_collections.items():
for database, collections in databases.items():
if request.tenant != "" and tenant != request.tenant:
continue
if request.database != "" and database != request.database:
continue
allCollections.update(collections)
print(
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
)
found_collections = []
for collection in allCollections.values():
if target_id and collection["id"] != target_id:
continue
if target_name and collection["name"] != target_name:
continue
found_collections.append(collection)
return GetCollectionsResponse(
collections=[
to_proto_collection(collection) for collection in found_collections
]
)
@overrides(check_signature=False)
def CountCollections(
self, request: CountCollectionsRequest, context: grpc.ServicerContext
) -> CountCollectionsResponse:
request = GetCollectionsRequest(
tenant=request.tenant,
database=request.database,
)
collections = self.GetCollections(request, context)
return CountCollectionsResponse(count=len(collections.collections))
@overrides(check_signature=False)
def GetCollectionSize(
self, request: GetCollectionSizeRequest, context: grpc.ServicerContext
) -> GetCollectionSizeResponse:
return GetCollectionSizeResponse(
total_records_post_compaction=0,
)
@overrides(check_signature=False)
def GetCollectionWithSegments(
self, request: GetCollectionWithSegmentsRequest, context: grpc.ServicerContext
) -> GetCollectionWithSegmentsResponse:
allCollections = {}
for tenant, databases in self._tenants_to_databases_to_collections.items():
for database, collections in databases.items():
allCollections.update(collections)
print(
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
)
collection = allCollections.get(request.id, None)
if collection is None:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Collection with id {request.id} not found"
)
collection_unique_key = (
f"{collection.tenant}:{collection.database}:{request.id}"
)
segments = [
self._segments[id]
for id in self._collection_to_segments[collection_unique_key]
]
if {segment["scope"] for segment in segments} != {
SegmentScope.METADATA,
SegmentScope.RECORD,
SegmentScope.VECTOR,
}:
context.abort(
grpc.StatusCode.INTERNAL,
f"Incomplete segments for collection {collection}: {segments}",
)
return GetCollectionWithSegmentsResponse(
collection=to_proto_collection(collection),
segments=[to_proto_segment(segment) for segment in segments],
)
@overrides(check_signature=False)
def UpdateCollection(
self, request: UpdateCollectionRequest, context: grpc.ServicerContext
) -> UpdateCollectionResponse:
id_to_update = UUID(request.id)
# Find the collection with this id
collections = {}
for tenant, databases in self._tenants_to_databases_to_collections.items():
for database, maybe_collections in databases.items():
if id_to_update.hex in maybe_collections:
collections = maybe_collections
if id_to_update.hex not in collections:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Collection {id_to_update} not found"
)
else:
collection = collections[id_to_update.hex]
if request.HasField("name"):
collection["name"] = request.name
if request.HasField("dimension"):
collection["dimension"] = request.dimension
if request.HasField("metadata"):
# TODO: IN SysDB SQlite we have technical debt where we
# replace the entire metadata dict with the new one. We should
# fix that by merging it. For now we just do the same thing here
update_metadata = from_proto_update_metadata(request.metadata)
cleaned_metadata = None
if update_metadata is not None:
cleaned_metadata = {}
for key, value in update_metadata.items():
if value is not None:
cleaned_metadata[key] = value
collection["metadata"] = cleaned_metadata
elif request.HasField("reset_metadata"):
if request.reset_metadata:
collection["metadata"] = {}
return UpdateCollectionResponse()
@overrides(check_signature=False)
def ResetState(
self, request: Empty, context: grpc.ServicerContext
) -> ResetStateResponse:
self.reset_state()
return ResetStateResponse()
def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None:
target_metadata = cast(Dict[str, Any], target)
source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source))
target_metadata.update(source_metadata)
# If a key has a None value, remove it from the metadata
for key, value in source_metadata.items():
if value is None and key in target:
del target_metadata[key]

View File

@@ -0,0 +1,273 @@
import logging
from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool
from chromadb.db.migrations import MigratableDB, Migration
from chromadb.config import System, Settings
import chromadb.db.base as base
from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue
from chromadb.db.mixins.sysdb import SqlSysDB
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
import sqlite3
from overrides import override
import pypika
from typing import Sequence, cast, Optional, Type, Any
from typing_extensions import Literal
from types import TracebackType
import os
from uuid import UUID
from threading import local
from importlib_resources import files
from importlib_resources.abc import Traversable
logger = logging.getLogger(__name__)
class TxWrapper(base.TxWrapper):
_conn: Connection
_pool: Pool
def __init__(self, conn_pool: Pool, stack: local):
self._tx_stack = stack
self._conn = conn_pool.connect()
self._pool = conn_pool
@override
def __enter__(self) -> base.Cursor:
if len(self._tx_stack.stack) == 0:
self._conn.execute("PRAGMA case_sensitive_like = ON")
self._conn.execute("BEGIN;")
self._tx_stack.stack.append(self)
return self._conn.cursor() # type: ignore
@override
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Literal[False]:
self._tx_stack.stack.pop()
if len(self._tx_stack.stack) == 0:
if exc_type is None:
self._conn.commit()
else:
self._conn.rollback()
self._conn.cursor().close()
self._pool.return_to_pool(self._conn)
return False
class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB):
_conn_pool: Pool
_settings: Settings
_migration_imports: Sequence[Traversable]
_db_file: str
_tx_stack: local
_is_persistent: bool
def __init__(self, system: System):
self._settings = system.settings
self._migration_imports = [
files("chromadb.migrations.embeddings_queue"),
files("chromadb.migrations.sysdb"),
files("chromadb.migrations.metadb"),
]
self._is_persistent = self._settings.require("is_persistent")
self._opentelemetry_client = system.require(OpenTelemetryClient)
if not self._is_persistent:
# In order to allow sqlite to be shared between multiple threads, we need to use a
# URI connection string with shared cache.
# See https://www.sqlite.org/sharedcache.html
# https://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa
self._db_file = "file::memory:?cache=shared"
self._conn_pool = LockPool(self._db_file, is_uri=True)
else:
self._db_file = (
self._settings.require("persist_directory") + "/chroma.sqlite3"
)
if not os.path.exists(self._db_file):
os.makedirs(os.path.dirname(self._db_file), exist_ok=True)
self._conn_pool = PerThreadPool(self._db_file)
self._tx_stack = local()
super().__init__(system)
@trace_method("SqliteDB.start", OpenTelemetryGranularity.ALL)
@override
def start(self) -> None:
super().start()
with self.tx() as cur:
cur.execute("PRAGMA foreign_keys = ON")
cur.execute("PRAGMA case_sensitive_like = ON")
self.initialize_migrations()
if (
# (don't attempt to access .config if migrations haven't been run)
self._settings.require("migrations") == "apply"
and self.config.get_parameter("automatically_purge").value is False
):
logger.warning(
"⚠️ It looks like you upgraded from a version below 0.5.6 and could benefit from vacuuming your database. Run chromadb utils vacuum --help for more information."
)
@trace_method("SqliteDB.stop", OpenTelemetryGranularity.ALL)
@override
def stop(self) -> None:
super().stop()
self._conn_pool.close()
@staticmethod
@override
def querybuilder() -> Type[pypika.Query]:
return pypika.Query # type: ignore
@staticmethod
@override
def parameter_format() -> str:
return "?"
@staticmethod
@override
def migration_scope() -> str:
return "sqlite"
@override
def migration_dirs(self) -> Sequence[Traversable]:
return self._migration_imports
@override
def tx(self) -> TxWrapper:
if not hasattr(self._tx_stack, "stack"):
self._tx_stack.stack = []
return TxWrapper(self._conn_pool, stack=self._tx_stack)
@trace_method("SqliteDB.reset_state", OpenTelemetryGranularity.ALL)
@override
def reset_state(self) -> None:
if not self._settings.require("allow_reset"):
raise ValueError(
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
)
with self.tx() as cur:
# Drop all tables
cur.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table'
"""
)
for row in cur.fetchall():
cur.execute(f"DROP TABLE IF EXISTS {row[0]}")
self._conn_pool.close()
self.start()
super().reset_state()
@trace_method("SqliteDB.setup_migrations", OpenTelemetryGranularity.ALL)
@override
def setup_migrations(self) -> None:
with self.tx() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS migrations (
dir TEXT NOT NULL,
version INTEGER NOT NULL,
filename TEXT NOT NULL,
sql TEXT NOT NULL,
hash TEXT NOT NULL,
PRIMARY KEY (dir, version)
)
"""
)
@trace_method("SqliteDB.migrations_initialized", OpenTelemetryGranularity.ALL)
@override
def migrations_initialized(self) -> bool:
with self.tx() as cur:
cur.execute(
"""SELECT count(*) FROM sqlite_master
WHERE type='table' AND name='migrations'"""
)
if cur.fetchone()[0] == 0:
return False
else:
return True
@trace_method("SqliteDB.db_migrations", OpenTelemetryGranularity.ALL)
@override
def db_migrations(self, dir: Traversable) -> Sequence[Migration]:
with self.tx() as cur:
cur.execute(
"""
SELECT dir, version, filename, sql, hash
FROM migrations
WHERE dir = ?
ORDER BY version ASC
""",
(dir.name,),
)
migrations = []
for row in cur.fetchall():
found_dir = cast(str, row[0])
found_version = cast(int, row[1])
found_filename = cast(str, row[2])
found_sql = cast(str, row[3])
found_hash = cast(str, row[4])
migrations.append(
Migration(
dir=found_dir,
version=found_version,
filename=found_filename,
sql=found_sql,
hash=found_hash,
scope=self.migration_scope(),
)
)
return migrations
@override
def apply_migration(self, cur: base.Cursor, migration: Migration) -> None:
cur.executescript(migration["sql"])
cur.execute(
"""
INSERT INTO migrations (dir, version, filename, sql, hash)
VALUES (?, ?, ?, ?, ?)
""",
(
migration["dir"],
migration["version"],
migration["filename"],
migration["sql"],
migration["hash"],
),
)
@staticmethod
@override
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]:
return UUID(value) if value is not None else None
@staticmethod
@override
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]:
return str(uuid) if uuid is not None else None
@staticmethod
@override
def unique_constraint_error() -> Type[BaseException]:
return sqlite3.IntegrityError
def vacuum(self, timeout: int = 5) -> None:
"""Runs VACUUM on the database. `timeout` is the maximum time to wait for an exclusive lock in seconds."""
conn = self._conn_pool.connect()
conn.execute(f"PRAGMA busy_timeout = {int(timeout) * 1000}")
conn.execute("VACUUM")
conn.execute(
"""
INSERT INTO maintenance_log (operation, timestamp)
VALUES ('vacuum', CURRENT_TIMESTAMP)
"""
)

View File

@@ -0,0 +1,163 @@
import sqlite3
import weakref
from abc import ABC, abstractmethod
from typing import Any, Set
import threading
from overrides import override
from typing_extensions import Annotated
class Connection:
"""A threadpool connection that returns itself to the pool on close()"""
_pool: "Pool"
_db_file: str
_conn: sqlite3.Connection
def __init__(
self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any
):
self._pool = pool
self._db_file = db_file
self._conn = sqlite3.connect(
db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs
) # type: ignore
self._conn.isolation_level = None # Handle commits explicitly
def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: # type: ignore
if parameters is ...:
return self._conn.execute(sql)
return self._conn.execute(sql, parameters)
def commit(self) -> None:
self._conn.commit()
def rollback(self) -> None:
self._conn.rollback()
def cursor(self) -> sqlite3.Cursor:
return self._conn.cursor()
def close_actual(self) -> None:
"""Actually closes the connection to the db"""
self._conn.close()
class Pool(ABC):
"""Abstract base class for a pool of connections to a sqlite database."""
@abstractmethod
def __init__(self, db_file: str, is_uri: bool) -> None:
pass
@abstractmethod
def connect(self, *args: Any, **kwargs: Any) -> Connection:
"""Return a connection from the pool."""
pass
@abstractmethod
def close(self) -> None:
"""Close all connections in the pool."""
pass
@abstractmethod
def return_to_pool(self, conn: Connection) -> None:
"""Return a connection to the pool."""
pass
class LockPool(Pool):
"""A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time.
This is used because sqlite does not support multithreaded access with connection timeouts when using the
shared cache mode. We use the shared cache mode to allow multiple threads to share a database.
"""
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.RLock
_connection: threading.local
_db_file: str
_is_uri: bool
def __init__(self, db_file: str, is_uri: bool = False):
self._connections = set()
self._connection = threading.local()
self._lock = threading.RLock()
self._db_file = db_file
self._is_uri = is_uri
@override
def connect(self, *args: Any, **kwargs: Any) -> Connection:
self._lock.acquire()
if hasattr(self._connection, "conn") and self._connection.conn is not None:
return self._connection.conn # type: ignore # cast doesn't work here for some reason
else:
new_connection = Connection(
self, self._db_file, self._is_uri, *args, **kwargs
)
self._connection.conn = new_connection
self._connections.add(weakref.ref(new_connection))
return new_connection
@override
def return_to_pool(self, conn: Connection) -> None:
try:
self._lock.release()
except RuntimeError:
pass
@override
def close(self) -> None:
for conn in self._connections:
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()
try:
self._lock.release()
except RuntimeError:
pass
class PerThreadPool(Pool):
"""Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be
extended to do so and block on connect() if the cap is reached.
"""
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.Lock
_connection: threading.local
_db_file: str
_is_uri_: bool
def __init__(self, db_file: str, is_uri: bool = False):
self._connections = set()
self._connection = threading.local()
self._lock = threading.Lock()
self._db_file = db_file
self._is_uri = is_uri
@override
def connect(self, *args: Any, **kwargs: Any) -> Connection:
if hasattr(self._connection, "conn") and self._connection.conn is not None:
return self._connection.conn # type: ignore # cast doesn't work here for some reason
else:
new_connection = Connection(
self, self._db_file, self._is_uri, *args, **kwargs
)
self._connection.conn = new_connection
with self._lock:
self._connections.add(weakref.ref(new_connection))
return new_connection
@override
def close(self) -> None:
with self._lock:
for conn in self._connections:
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()
@override
def return_to_pool(self, conn: Connection) -> None:
pass # Each thread gets its own connection, so we don't need to return it to the pool

View File

@@ -0,0 +1,276 @@
import sys
from typing import Sequence
from typing_extensions import TypedDict, NotRequired
from importlib_resources.abc import Traversable
import re
import hashlib
from chromadb.db.base import SqlDB, Cursor
from abc import abstractmethod
from chromadb.config import System, Settings
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
class MigrationFile(TypedDict):
path: NotRequired[Traversable]
dir: str
filename: str
version: int
scope: str
class Migration(MigrationFile):
hash: str
sql: str
class UninitializedMigrationsError(Exception):
def __init__(self) -> None:
super().__init__("Migrations have not been initialized")
class UnappliedMigrationsError(Exception):
def __init__(self, dir: str, version: int):
self.dir = dir
self.version = version
super().__init__(
f"Unapplied migrations in {dir}, starting with version {version}"
)
class InconsistentVersionError(Exception):
def __init__(self, dir: str, db_version: int, source_version: int):
super().__init__(
f"Inconsistent migration versions in {dir}:"
+ f"db version was {db_version}, source version was {source_version}."
+ " Has the migration sequence been modified since being applied to the DB?"
)
class InconsistentHashError(Exception):
def __init__(self, path: str, db_hash: str, source_hash: str):
super().__init__(
f"Inconsistent hashes in {path}:"
+ f"db hash was {db_hash}, source has was {source_hash}."
+ " Was the migration file modified after being applied to the DB?"
)
class InvalidHashError(Exception):
def __init__(self, alg: str):
super().__init__(f"Invalid hash algorithm specified: {alg}")
class InvalidMigrationFilename(Exception):
pass
class MigratableDB(SqlDB):
"""Simple base class for databases which support basic migrations.
Migrations are SQL files stored as package resources and accessed via
importlib_resources.
All migrations in the same directory are assumed to be dependent on previous
migrations in the same directory, where "previous" is defined on lexographical
ordering of filenames.
Migrations have a ascending numeric version number and a hash of the file contents.
When migrations are applied, the hashes of previous migrations are checked to ensure
that the database is consistent with the source repository. If they are not, an
error is thrown and no migrations will be applied.
Migration files must follow the naming convention:
<version>.<description>.<scope>.sql, where <version> is a 5-digit zero-padded
integer, <description> is a short textual description, and <scope> is a short string
identifying the database implementation.
"""
_settings: Settings
def __init__(self, system: System) -> None:
self._settings = system.settings
self._opentelemetry_client = system.require(OpenTelemetryClient)
super().__init__(system)
@staticmethod
@abstractmethod
def migration_scope() -> str:
"""The database implementation to use for migrations (e.g, sqlite, pgsql)"""
pass
@abstractmethod
def migration_dirs(self) -> Sequence[Traversable]:
"""Directories containing the migration sequences that should be applied to this
DB."""
pass
@abstractmethod
def setup_migrations(self) -> None:
"""Idempotently creates the migrations table"""
pass
@abstractmethod
def migrations_initialized(self) -> bool:
"""Return true if the migrations table exists"""
pass
@abstractmethod
def db_migrations(self, dir: Traversable) -> Sequence[Migration]:
"""Return a list of all migrations already applied to this database, from the
given source directory, in ascending order."""
pass
@abstractmethod
def apply_migration(self, cur: Cursor, migration: Migration) -> None:
"""Apply a single migration to the database"""
pass
def initialize_migrations(self) -> None:
"""Initialize migrations for this DB"""
migrate = self._settings.require("migrations")
if migrate == "validate":
self.validate_migrations()
if migrate == "apply":
self.apply_migrations()
@trace_method("MigratableDB.validate_migrations", OpenTelemetryGranularity.ALL)
def validate_migrations(self) -> None:
"""Validate all migrations and throw an exception if there are any unapplied
migrations in the source repo."""
if not self.migrations_initialized():
raise UninitializedMigrationsError()
for dir in self.migration_dirs():
db_migrations = self.db_migrations(dir)
source_migrations = find_migrations(
dir,
self.migration_scope(),
self._settings.require("migrations_hash_algorithm"),
)
unapplied_migrations = verify_migration_sequence(
db_migrations, source_migrations
)
if len(unapplied_migrations) > 0:
version = unapplied_migrations[0]["version"]
raise UnappliedMigrationsError(dir=dir.name, version=version)
@trace_method("MigratableDB.apply_migrations", OpenTelemetryGranularity.ALL)
def apply_migrations(self) -> None:
"""Validate existing migrations, and apply all new ones."""
self.setup_migrations()
for dir in self.migration_dirs():
db_migrations = self.db_migrations(dir)
source_migrations = find_migrations(
dir,
self.migration_scope(),
self._settings.require("migrations_hash_algorithm"),
)
unapplied_migrations = verify_migration_sequence(
db_migrations, source_migrations
)
with self.tx() as cur:
for migration in unapplied_migrations:
self.apply_migration(cur, migration)
# Format is <version>-<name>.<scope>.sql
# e.g, 00001-users.sqlite.sql
filename_regex = re.compile(r"(\d+)-(.+)\.(.+)\.sql")
def _parse_migration_filename(
dir: str, filename: str, path: Traversable
) -> MigrationFile:
"""Parse a migration filename into a MigrationFile object"""
match = filename_regex.match(filename)
if match is None:
raise InvalidMigrationFilename("Invalid migration filename: " + filename)
version, _, scope = match.groups()
return {
"path": path,
"dir": dir,
"filename": filename,
"version": int(version),
"scope": scope,
}
def verify_migration_sequence(
db_migrations: Sequence[Migration],
source_migrations: Sequence[Migration],
) -> Sequence[Migration]:
"""Given a list of migrations already applied to a database, and a list of
migrations from the source code, validate that the applied migrations are correct
and match the expected migrations.
Throws an exception if any migrations are missing, out of order, or if the source
hash does not match.
Returns a list of all unapplied migrations, or an empty list if all migrations are
applied and the database is up to date."""
for db_migration, source_migration in zip(db_migrations, source_migrations):
if db_migration["version"] != source_migration["version"]:
raise InconsistentVersionError(
dir=db_migration["dir"],
db_version=db_migration["version"],
source_version=source_migration["version"],
)
if db_migration["hash"] != source_migration["hash"]:
raise InconsistentHashError(
path=db_migration["dir"] + "/" + db_migration["filename"],
db_hash=db_migration["hash"],
source_hash=source_migration["hash"],
)
return source_migrations[len(db_migrations) :]
def find_migrations(
dir: Traversable, scope: str, hash_alg: str = "md5"
) -> Sequence[Migration]:
"""Return a list of all migration present in the given directory, in ascending
order. Filter by scope."""
files = [
_parse_migration_filename(dir.name, t.name, t)
for t in dir.iterdir()
if t.name.endswith(".sql")
]
files = list(filter(lambda f: f["scope"] == scope, files))
files = sorted(files, key=lambda f: f["version"])
return [_read_migration_file(f, hash_alg) for f in files]
def _read_migration_file(file: MigrationFile, hash_alg: str) -> Migration:
"""Read a migration file"""
if "path" not in file or not file["path"].is_file():
raise FileNotFoundError(
f"No migration file found for dir {file['dir']} with filename {file['filename']} and scope {file['scope']} at version {file['version']}"
)
sql = file["path"].read_text()
if hash_alg == "md5":
hash = (
hashlib.md5(sql.encode("utf-8"), usedforsecurity=False).hexdigest()
if sys.version_info >= (3, 9)
else hashlib.md5(sql.encode("utf-8")).hexdigest()
)
elif hash_alg == "sha256":
hash = hashlib.sha256(sql.encode("utf-8")).hexdigest()
else:
raise InvalidHashError(alg=hash_alg)
return {
"hash": hash,
"sql": sql,
"dir": file["dir"],
"filename": file["filename"],
"version": file["version"],
"scope": file["scope"],
}

View File

@@ -0,0 +1,507 @@
from functools import cached_property
import json
from chromadb.api.configuration import (
ConfigurationParameter,
EmbeddingsQueueConfigurationInternal,
)
from chromadb.db.base import SqlDB, ParameterValue, get_sql
from chromadb.errors import BatchSizeExceededError
from chromadb.ingest import (
Producer,
Consumer,
ConsumerCallbackFn,
decode_vector,
encode_vector,
)
from chromadb.types import (
OperationRecord,
LogRecord,
ScalarEncoding,
SeqId,
Operation,
)
from chromadb.config import System
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from overrides import override
from collections import defaultdict
from typing import Sequence, Optional, Dict, Set, Tuple, cast
from uuid import UUID
from pypika import Table, functions
import uuid
import logging
from chromadb.ingest.impl.utils import create_topic_name
logger = logging.getLogger(__name__)
_operation_codes = {
Operation.ADD: 0,
Operation.UPDATE: 1,
Operation.UPSERT: 2,
Operation.DELETE: 3,
}
_operation_codes_inv = {v: k for k, v in _operation_codes.items()}
# Set in conftest.py to rethrow errors in the "async" path during testing
# https://doc.pytest.org/en/latest/example/simple.html#detect-if-running-from-within-a-pytest-run
_called_from_test = False
class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
"""A SQL database that stores embeddings, allowing a traditional RDBMS to be used as
the primary ingest queue and satisfying the top level Producer/Consumer interfaces.
Note that this class is only suitable for use cases where the producer and consumer
are in the same process.
This is because notification of new embeddings happens solely in-process: this
implementation does not actively listen to the the database for new records added by
other processes.
"""
class Subscription:
id: UUID
topic_name: str
start: int
end: int
callback: ConsumerCallbackFn
def __init__(
self,
id: UUID,
topic_name: str,
start: int,
end: int,
callback: ConsumerCallbackFn,
):
self.id = id
self.topic_name = topic_name
self.start = start
self.end = end
self.callback = callback
_subscriptions: Dict[str, Set[Subscription]]
_max_batch_size: Optional[int]
_tenant: str
_topic_namespace: str
# How many variables are in the insert statement for a single record
VARIABLES_PER_RECORD = 6
def __init__(self, system: System):
self._subscriptions = defaultdict(set)
self._max_batch_size = None
self._opentelemetry_client = system.require(OpenTelemetryClient)
self._tenant = system.settings.require("tenant_id")
self._topic_namespace = system.settings.require("topic_namespace")
super().__init__(system)
@trace_method("SqlEmbeddingsQueue.reset_state", OpenTelemetryGranularity.ALL)
@override
def reset_state(self) -> None:
super().reset_state()
self._subscriptions = defaultdict(set)
# Invalidate the cached property
try:
del self.config
except AttributeError:
# Cached property hasn't been accessed yet
pass
@trace_method("SqlEmbeddingsQueue.delete_topic", OpenTelemetryGranularity.ALL)
@override
def delete_log(self, collection_id: UUID) -> None:
topic_name = create_topic_name(
self._tenant, self._topic_namespace, collection_id
)
t = Table("embeddings_queue")
q = (
self.querybuilder()
.from_(t)
.where(t.topic == ParameterValue(topic_name))
.delete()
)
with self.tx() as cur:
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)
@trace_method("SqlEmbeddingsQueue.purge_log", OpenTelemetryGranularity.ALL)
@override
def purge_log(self, collection_id: UUID) -> None:
# (We need to purge on a per topic/collection basis, because the maximum sequence ID is tracked on a per topic/collection basis.)
segments_t = Table("segments")
segment_ids_q = (
self.querybuilder()
.from_(segments_t)
# This coalesce prevents a correctness bug when > 1 segments exist and:
# - > 1 has written to the max_seq_id table
# - > 1 has not never written to the max_seq_id table
# In that case, we should not delete any WAL entries as we can't be sure that the all segments are caught up.
.select(functions.Coalesce(Table("max_seq_id").seq_id, -1))
.where(
segments_t.collection == ParameterValue(self.uuid_to_db(collection_id))
)
.left_join(Table("max_seq_id"))
.on(segments_t.id == Table("max_seq_id").segment_id)
)
topic_name = create_topic_name(
self._tenant, self._topic_namespace, collection_id
)
with self.tx() as cur:
sql, params = get_sql(segment_ids_q, self.parameter_format())
cur.execute(sql, params)
results = cur.fetchall()
if results:
min_seq_id = min(row[0] for row in results)
else:
return
t = Table("embeddings_queue")
q = (
self.querybuilder()
.from_(t)
.where(t.seq_id < ParameterValue(min_seq_id))
.where(t.topic == ParameterValue(topic_name))
.delete()
)
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)
@trace_method("SqlEmbeddingsQueue.submit_embedding", OpenTelemetryGranularity.ALL)
@override
def submit_embedding(
self, collection_id: UUID, embedding: OperationRecord
) -> SeqId:
if not self._running:
raise RuntimeError("Component not running")
return self.submit_embeddings(collection_id, [embedding])[0]
@trace_method("SqlEmbeddingsQueue.submit_embeddings", OpenTelemetryGranularity.ALL)
@override
def submit_embeddings(
self, collection_id: UUID, embeddings: Sequence[OperationRecord]
) -> Sequence[SeqId]:
if not self._running:
raise RuntimeError("Component not running")
if len(embeddings) == 0:
return []
if len(embeddings) > self.max_batch_size:
raise BatchSizeExceededError(
f"""
Cannot submit more than {self.max_batch_size:,} embeddings at once.
Please submit your embeddings in batches of size
{self.max_batch_size:,} or less.
"""
)
# This creates the persisted configuration if it doesn't exist.
# It should be run as soon as possible (before any WAL mutations) since the default configuration depends on the WAL size.
# (We can't run this in __init__()/start() because the migrations have not been run at that point and the table may not be available.)
_ = self.config
topic_name = create_topic_name(
self._tenant, self._topic_namespace, collection_id
)
t = Table("embeddings_queue")
insert = (
self.querybuilder()
.into(t)
.columns(t.operation, t.topic, t.id, t.vector, t.encoding, t.metadata)
)
id_to_idx: Dict[str, int] = {}
for embedding in embeddings:
(
embedding_bytes,
encoding,
metadata,
) = self._prepare_vector_encoding_metadata(embedding)
insert = insert.insert(
ParameterValue(_operation_codes[embedding["operation"]]),
ParameterValue(topic_name),
ParameterValue(embedding["id"]),
ParameterValue(embedding_bytes),
ParameterValue(encoding),
ParameterValue(metadata),
)
id_to_idx[embedding["id"]] = len(id_to_idx)
with self.tx() as cur:
sql, params = get_sql(insert, self.parameter_format())
# The returning clause does not guarantee order, so we need to do reorder
# the results. https://www.sqlite.org/lang_returning.html
sql = f"{sql} RETURNING seq_id, id" # Pypika doesn't support RETURNING
results = cur.execute(sql, params).fetchall()
# Reorder the results
seq_ids = [cast(SeqId, None)] * len(
results
) # Lie to mypy: https://stackoverflow.com/questions/76694215/python-type-casting-when-preallocating-list
embedding_records = []
for seq_id, id in results:
seq_ids[id_to_idx[id]] = seq_id
submit_embedding_record = embeddings[id_to_idx[id]]
# We allow notifying consumers out of order relative to one call to
# submit_embeddings so we do not reorder the records before submitting them
embedding_record = LogRecord(
log_offset=seq_id,
record=OperationRecord(
id=id,
embedding=submit_embedding_record["embedding"],
encoding=submit_embedding_record["encoding"],
metadata=submit_embedding_record["metadata"],
operation=submit_embedding_record["operation"],
),
)
embedding_records.append(embedding_record)
self._notify_all(topic_name, embedding_records)
if self.config.get_parameter("automatically_purge").value:
self.purge_log(collection_id)
return seq_ids
@trace_method("SqlEmbeddingsQueue.subscribe", OpenTelemetryGranularity.ALL)
@override
def subscribe(
self,
collection_id: UUID,
consume_fn: ConsumerCallbackFn,
start: Optional[SeqId] = None,
end: Optional[SeqId] = None,
id: Optional[UUID] = None,
) -> UUID:
if not self._running:
raise RuntimeError("Component not running")
topic_name = create_topic_name(
self._tenant, self._topic_namespace, collection_id
)
subscription_id = id or uuid.uuid4()
start, end = self._validate_range(start, end)
subscription = self.Subscription(
subscription_id, topic_name, start, end, consume_fn
)
# Backfill first, so if it errors we do not add the subscription
self._backfill(subscription)
self._subscriptions[topic_name].add(subscription)
return subscription_id
@trace_method("SqlEmbeddingsQueue.unsubscribe", OpenTelemetryGranularity.ALL)
@override
def unsubscribe(self, subscription_id: UUID) -> None:
for topic_name, subscriptions in self._subscriptions.items():
for subscription in subscriptions:
if subscription.id == subscription_id:
subscriptions.remove(subscription)
if len(subscriptions) == 0:
del self._subscriptions[topic_name]
return
@override
def min_seqid(self) -> SeqId:
return -1
@override
def max_seqid(self) -> SeqId:
return 2**63 - 1
@property
@trace_method("SqlEmbeddingsQueue.max_batch_size", OpenTelemetryGranularity.ALL)
@override
def max_batch_size(self) -> int:
if self._max_batch_size is None:
with self.tx() as cur:
cur.execute("PRAGMA compile_options;")
compile_options = cur.fetchall()
for option in compile_options:
if "MAX_VARIABLE_NUMBER" in option[0]:
# The pragma returns a string like 'MAX_VARIABLE_NUMBER=999'
self._max_batch_size = int(option[0].split("=")[1]) // (
self.VARIABLES_PER_RECORD
)
if self._max_batch_size is None:
# This value is the default for sqlite3 versions < 3.32.0
# It is the safest value to use if we can't find the pragma for some
# reason
self._max_batch_size = 999 // self.VARIABLES_PER_RECORD
return self._max_batch_size
@trace_method(
"SqlEmbeddingsQueue._prepare_vector_encoding_metadata",
OpenTelemetryGranularity.ALL,
)
def _prepare_vector_encoding_metadata(
self, embedding: OperationRecord
) -> Tuple[Optional[bytes], Optional[str], Optional[str]]:
if embedding["embedding"] is not None:
encoding_type = cast(ScalarEncoding, embedding["encoding"])
encoding = encoding_type.value
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
else:
embedding_bytes = None
encoding = None
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None
return embedding_bytes, encoding, metadata
@trace_method("SqlEmbeddingsQueue._backfill", OpenTelemetryGranularity.ALL)
def _backfill(self, subscription: Subscription) -> None:
"""Backfill the given subscription with any currently matching records in the
DB"""
t = Table("embeddings_queue")
q = (
self.querybuilder()
.from_(t)
.where(t.topic == ParameterValue(subscription.topic_name))
.where(t.seq_id > ParameterValue(subscription.start))
.where(t.seq_id <= ParameterValue(subscription.end))
.select(t.seq_id, t.operation, t.id, t.vector, t.encoding, t.metadata)
.orderby(t.seq_id)
)
with self.tx() as cur:
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)
rows = cur.fetchall()
for row in rows:
if row[3]:
encoding = ScalarEncoding(row[4])
vector = decode_vector(row[3], encoding)
else:
encoding = None
vector = None
self._notify_one(
subscription,
[
LogRecord(
log_offset=row[0],
record=OperationRecord(
operation=_operation_codes_inv[row[1]],
id=row[2],
embedding=vector,
encoding=encoding,
metadata=json.loads(row[5]) if row[5] else None,
),
)
],
)
@trace_method("SqlEmbeddingsQueue._validate_range", OpenTelemetryGranularity.ALL)
def _validate_range(
self, start: Optional[SeqId], end: Optional[SeqId]
) -> Tuple[int, int]:
"""Validate and normalize the start and end SeqIDs for a subscription using this
impl."""
start = start or self._next_seq_id()
end = end or self.max_seqid()
if not isinstance(start, int) or not isinstance(end, int):
raise TypeError("SeqIDs must be integers for sql-based EmbeddingsDB")
if start >= end:
raise ValueError(f"Invalid SeqID range: {start} to {end}")
return start, end
@trace_method("SqlEmbeddingsQueue._next_seq_id", OpenTelemetryGranularity.ALL)
def _next_seq_id(self) -> int:
"""Get the next SeqID for this database."""
t = Table("embeddings_queue")
q = self.querybuilder().from_(t).select(functions.Max(t.seq_id))
with self.tx() as cur:
cur.execute(q.get_sql())
return int(cur.fetchone()[0]) + 1
@trace_method("SqlEmbeddingsQueue._notify_all", OpenTelemetryGranularity.ALL)
def _notify_all(self, topic: str, embeddings: Sequence[LogRecord]) -> None:
"""Send a notification to each subscriber of the given topic."""
if self._running:
for sub in self._subscriptions[topic]:
self._notify_one(sub, embeddings)
@trace_method("SqlEmbeddingsQueue._notify_one", OpenTelemetryGranularity.ALL)
def _notify_one(self, sub: Subscription, embeddings: Sequence[LogRecord]) -> None:
"""Send a notification to a single subscriber."""
# Filter out any embeddings that are not in the subscription range
should_unsubscribe = False
filtered_embeddings = []
for embedding in embeddings:
if embedding["log_offset"] <= sub.start:
continue
if embedding["log_offset"] > sub.end:
should_unsubscribe = True
break
filtered_embeddings.append(embedding)
# Log errors instead of throwing them to preserve async semantics
# for consistency between local and distributed configurations
try:
if len(filtered_embeddings) > 0:
sub.callback(filtered_embeddings)
if should_unsubscribe:
self.unsubscribe(sub.id)
except BaseException as e:
logger.error(
f"Exception occurred invoking consumer for subscription {sub.id.hex}"
+ f"to topic {sub.topic_name} %s",
str(e),
)
if _called_from_test:
raise e
@cached_property
def config(self) -> EmbeddingsQueueConfigurationInternal:
t = Table("embeddings_queue_config")
q = self.querybuilder().from_(t).select(t.config_json_str).limit(1)
with self.tx() as cur:
cur.execute(q.get_sql())
result = cur.fetchone()
if result is None:
is_fresh_system = self._get_wal_size() == 0
config = EmbeddingsQueueConfigurationInternal(
[ConfigurationParameter("automatically_purge", is_fresh_system)]
)
self.set_config(config)
return config
return EmbeddingsQueueConfigurationInternal.from_json_str(result[0])
def set_config(self, config: EmbeddingsQueueConfigurationInternal) -> None:
with self.tx() as cur:
cur.execute(
"""
INSERT OR REPLACE INTO embeddings_queue_config (id, config_json_str)
VALUES (?, ?)
""",
(
1,
config.to_json_str(),
),
)
# Invalidate the cached property
try:
del self.config
except AttributeError:
# Cached property hasn't been accessed yet
pass
def _get_wal_size(self) -> int:
t = Table("embeddings_queue")
q = self.querybuilder().from_(t).select(functions.Count("*"))
with self.tx() as cur:
cur.execute(q.get_sql())
return int(cur.fetchone()[0])

View File

@@ -0,0 +1,986 @@
import logging
import sys
from typing import Optional, Sequence, Any, Tuple, cast, Dict, Union, Set
from uuid import UUID
from overrides import override
from pypika import Table, Column
from itertools import groupby
from chromadb.api.types import Schema
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System
from chromadb.db.base import Cursor, SqlDB, ParameterValue, get_sql
from chromadb.db.system import SysDB
from chromadb.errors import (
NotFoundError,
UniqueConstraintError,
)
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.ingest import Producer
from chromadb.types import (
CollectionAndSegments,
Database,
OptionalArgument,
Segment,
Metadata,
Collection,
SegmentScope,
Tenant,
Unspecified,
UpdateMetadata,
)
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
create_collection_configuration_to_json_str,
load_collection_configuration_from_json_str,
CollectionConfiguration,
create_collection_configuration_to_json,
collection_configuration_to_json,
collection_configuration_to_json_str,
overwrite_collection_configuration,
update_collection_configuration_from_legacy_update_metadata,
CollectionMetadata,
)
logger = logging.getLogger(__name__)
class SqlSysDB(SqlDB, SysDB):
# Used only to delete log streams on collection deletion.
# TODO: refactor to remove this dependency into a separate interface
_producer: Producer
def __init__(self, system: System):
super().__init__(system)
self._opentelemetry_client = system.require(OpenTelemetryClient)
@trace_method("SqlSysDB.create_segment", OpenTelemetryGranularity.ALL)
@override
def start(self) -> None:
super().start()
self._producer = self._system.instance(Producer)
@override
def create_database(
self, id: UUID, name: str, tenant: str = DEFAULT_TENANT
) -> None:
with self.tx() as cur:
# Get the tenant id for the tenant name and then insert the database with the id, name and tenant id
databases = Table("databases")
tenants = Table("tenants")
insert_database = (
self.querybuilder()
.into(databases)
.columns(databases.id, databases.name, databases.tenant_id)
.insert(
ParameterValue(self.uuid_to_db(id)),
ParameterValue(name),
self.querybuilder()
.select(tenants.id)
.from_(tenants)
.where(tenants.id == ParameterValue(tenant)),
)
)
sql, params = get_sql(insert_database, self.parameter_format())
try:
cur.execute(sql, params)
except self.unique_constraint_error() as e:
raise UniqueConstraintError(
f"Database {name} already exists for tenant {tenant}"
) from e
@override
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
with self.tx() as cur:
databases = Table("databases")
q = (
self.querybuilder()
.from_(databases)
.select(databases.id, databases.name)
.where(databases.name == ParameterValue(name))
.where(databases.tenant_id == ParameterValue(tenant))
)
sql, params = get_sql(q, self.parameter_format())
row = cur.execute(sql, params).fetchone()
if not row:
raise NotFoundError(
f"Database {name} not found for tenant {tenant}. Are you sure it exists?"
)
if row[0] is None:
raise NotFoundError(
f"Database {name} not found for tenant {tenant}. Are you sure it exists?"
)
id: UUID = cast(UUID, self.uuid_from_db(row[0]))
return Database(
id=id,
name=row[1],
tenant=tenant,
)
@override
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
with self.tx() as cur:
databases = Table("databases")
q = (
self.querybuilder()
.from_(databases)
.where(databases.name == ParameterValue(name))
.where(databases.tenant_id == ParameterValue(tenant))
.delete()
)
sql, params = get_sql(q, self.parameter_format())
sql = sql + " RETURNING id"
result = cur.execute(sql, params).fetchone()
if not result:
raise NotFoundError(f"Database {name} not found for tenant {tenant}")
# As of 01/09/2025, cascading deletes don't work because foreign keys are not enabled.
# See https://github.com/chroma-core/chroma/issues/3456.
collections = Table("collections")
q = (
self.querybuilder()
.from_(collections)
.where(collections.database_id == ParameterValue(result[0]))
.delete()
)
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)
@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
with self.tx() as cur:
databases = Table("databases")
q = (
self.querybuilder()
.from_(databases)
.select(databases.id, databases.name)
.where(databases.tenant_id == ParameterValue(tenant))
.offset(offset)
.limit(
sys.maxsize if limit is None else limit
) # SQLite requires that a limit is provided to use offset
.orderby(databases.created_at)
)
sql, params = get_sql(q, self.parameter_format())
rows = cur.execute(sql, params).fetchall()
return [
Database(
id=cast(UUID, self.uuid_from_db(row[0])),
name=row[1],
tenant=tenant,
)
for row in rows
]
@override
def create_tenant(self, name: str) -> None:
with self.tx() as cur:
tenants = Table("tenants")
insert_tenant = (
self.querybuilder()
.into(tenants)
.columns(tenants.id)
.insert(ParameterValue(name))
)
sql, params = get_sql(insert_tenant, self.parameter_format())
try:
cur.execute(sql, params)
except self.unique_constraint_error() as e:
raise UniqueConstraintError(f"Tenant {name} already exists") from e
@override
def get_tenant(self, name: str) -> Tenant:
with self.tx() as cur:
tenants = Table("tenants")
q = (
self.querybuilder()
.from_(tenants)
.select(tenants.id)
.where(tenants.id == ParameterValue(name))
)
sql, params = get_sql(q, self.parameter_format())
row = cur.execute(sql, params).fetchone()
if not row:
raise NotFoundError(f"Tenant {name} not found")
return Tenant(name=name)
# Create a segment using the passed cursor, so that the other changes
# can be in the same transaction.
def create_segment_with_tx(self, cur: Cursor, segment: Segment) -> None:
add_attributes_to_current_span(
{
"segment_id": str(segment["id"]),
"segment_type": segment["type"],
"segment_scope": segment["scope"].value,
"collection": str(segment["collection"]),
}
)
segments = Table("segments")
insert_segment = (
self.querybuilder()
.into(segments)
.columns(
segments.id,
segments.type,
segments.scope,
segments.collection,
)
.insert(
ParameterValue(self.uuid_to_db(segment["id"])),
ParameterValue(segment["type"]),
ParameterValue(segment["scope"].value),
ParameterValue(self.uuid_to_db(segment["collection"])),
)
)
sql, params = get_sql(insert_segment, self.parameter_format())
try:
cur.execute(sql, params)
except self.unique_constraint_error() as e:
raise UniqueConstraintError(
f"Segment {segment['id']} already exists"
) from e
# Insert segment metadata if it exists
metadata_t = Table("segment_metadata")
if segment["metadata"]:
try:
self._insert_metadata(
cur,
metadata_t,
metadata_t.segment_id,
segment["id"],
segment["metadata"],
)
except Exception as e:
logger.error(f"Error inserting segment metadata: {e}")
raise
# TODO(rohit): Investigate and remove this method completely.
@trace_method("SqlSysDB.create_segment", OpenTelemetryGranularity.ALL)
@override
def create_segment(self, segment: Segment) -> None:
with self.tx() as cur:
self.create_segment_with_tx(cur, segment)
@trace_method("SqlSysDB.create_collection", OpenTelemetryGranularity.ALL)
@override
def create_collection(
self,
id: UUID,
name: str,
schema: Optional[Schema],
configuration: CreateCollectionConfiguration,
segments: Sequence[Segment],
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Tuple[Collection, bool]:
if id is None and not get_or_create:
raise ValueError("id must be specified if get_or_create is False")
add_attributes_to_current_span(
{
"collection_id": str(id),
"collection_name": name,
}
)
existing = self.get_collections(name=name, tenant=tenant, database=database)
if existing:
if get_or_create:
collection = existing[0]
return (
self.get_collections(
id=collection.id, tenant=tenant, database=database
)[0],
False,
)
else:
raise UniqueConstraintError(f"Collection {name} already exists")
collection = Collection(
id=id,
name=name,
configuration_json=create_collection_configuration_to_json(
configuration, cast(CollectionMetadata, metadata)
),
serialized_schema=None,
metadata=metadata,
dimension=dimension,
tenant=tenant,
database=database,
version=0,
)
with self.tx() as cur:
collections = Table("collections")
databases = Table("databases")
insert_collection = (
self.querybuilder()
.into(collections)
.columns(
collections.id,
collections.name,
collections.config_json_str,
collections.dimension,
collections.database_id,
)
.insert(
ParameterValue(self.uuid_to_db(collection["id"])),
ParameterValue(collection["name"]),
ParameterValue(
create_collection_configuration_to_json_str(
configuration, cast(CollectionMetadata, metadata)
)
),
ParameterValue(collection["dimension"]),
# Get the database id for the database with the given name and tenant
self.querybuilder()
.select(databases.id)
.from_(databases)
.where(databases.name == ParameterValue(database))
.where(databases.tenant_id == ParameterValue(tenant)),
)
)
sql, params = get_sql(insert_collection, self.parameter_format())
try:
cur.execute(sql, params)
except self.unique_constraint_error() as e:
raise UniqueConstraintError(
f"Collection {collection['id']} already exists"
) from e
metadata_t = Table("collection_metadata")
if collection["metadata"]:
self._insert_metadata(
cur,
metadata_t,
metadata_t.collection_id,
collection.id,
collection["metadata"],
)
for segment in segments:
self.create_segment_with_tx(cur, segment)
return collection, True
@trace_method("SqlSysDB.get_segments", OpenTelemetryGranularity.ALL)
@override
def get_segments(
self,
collection: UUID,
id: Optional[UUID] = None,
type: Optional[str] = None,
scope: Optional[SegmentScope] = None,
) -> Sequence[Segment]:
add_attributes_to_current_span(
{
"segment_id": str(id),
"segment_type": type if type else "",
"segment_scope": scope.value if scope else "",
"collection": str(collection),
}
)
segments_t = Table("segments")
metadata_t = Table("segment_metadata")
q = (
self.querybuilder()
.from_(segments_t)
.select(
segments_t.id,
segments_t.type,
segments_t.scope,
segments_t.collection,
metadata_t.key,
metadata_t.str_value,
metadata_t.int_value,
metadata_t.float_value,
metadata_t.bool_value,
)
.left_join(metadata_t)
.on(segments_t.id == metadata_t.segment_id)
.orderby(segments_t.id)
)
if id:
q = q.where(segments_t.id == ParameterValue(self.uuid_to_db(id)))
if type:
q = q.where(segments_t.type == ParameterValue(type))
if scope:
q = q.where(segments_t.scope == ParameterValue(scope.value))
if collection:
q = q.where(
segments_t.collection == ParameterValue(self.uuid_to_db(collection))
)
with self.tx() as cur:
sql, params = get_sql(q, self.parameter_format())
rows = cur.execute(sql, params).fetchall()
by_segment = groupby(rows, lambda r: cast(object, r[0]))
segments = []
for segment_id, segment_rows in by_segment:
id = self.uuid_from_db(str(segment_id))
rows = list(segment_rows)
type = str(rows[0][1])
scope = SegmentScope(str(rows[0][2]))
collection = self.uuid_from_db(rows[0][3]) # type: ignore[assignment]
metadata = self._metadata_from_rows(rows)
segments.append(
Segment(
id=cast(UUID, id),
type=type,
scope=scope,
collection=collection,
metadata=metadata,
file_paths={},
)
)
return segments
@trace_method("SqlSysDB.get_collections", OpenTelemetryGranularity.ALL)
@override
def get_collections(
self,
id: Optional[UUID] = None,
name: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Collection]:
"""Get collections by name, embedding function and/or metadata"""
if name is not None and (tenant is None or database is None):
raise ValueError(
"If name is specified, tenant and database must also be specified in order to uniquely identify the collection"
)
add_attributes_to_current_span(
{
"collection_id": str(id),
"collection_name": name if name else "",
}
)
collections_t = Table("collections")
metadata_t = Table("collection_metadata")
databases_t = Table("databases")
q = (
self.querybuilder()
.from_(collections_t)
.select(
collections_t.id,
collections_t.name,
collections_t.config_json_str,
collections_t.dimension,
databases_t.name,
databases_t.tenant_id,
metadata_t.key,
metadata_t.str_value,
metadata_t.int_value,
metadata_t.float_value,
metadata_t.bool_value,
)
.left_join(metadata_t)
.on(collections_t.id == metadata_t.collection_id)
.left_join(databases_t)
.on(collections_t.database_id == databases_t.id)
.orderby(collections_t.id)
)
if id:
q = q.where(collections_t.id == ParameterValue(self.uuid_to_db(id)))
if name:
q = q.where(collections_t.name == ParameterValue(name))
# Only if we have a name, tenant and database do we need to filter databases
# Given an id, we can uniquely identify the collection so we don't need to filter databases
if id is None and tenant and database:
databases_t = Table("databases")
q = q.where(
collections_t.database_id
== self.querybuilder()
.select(databases_t.id)
.from_(databases_t)
.where(databases_t.name == ParameterValue(database))
.where(databases_t.tenant_id == ParameterValue(tenant))
)
# cant set limit and offset here because this is metadata and we havent reduced yet
with self.tx() as cur:
sql, params = get_sql(q, self.parameter_format())
rows = cur.execute(sql, params).fetchall()
by_collection = groupby(rows, lambda r: cast(object, r[0]))
collections = []
for collection_id, collection_rows in by_collection:
id = self.uuid_from_db(str(collection_id))
rows = list(collection_rows)
name = str(rows[0][1])
metadata = self._metadata_from_rows(rows)
dimension = int(rows[0][3]) if rows[0][3] else None
if rows[0][2] is not None:
configuration = load_collection_configuration_from_json_str(
rows[0][2]
)
else:
# 07/2024: This is a legacy case where we don't have a collection
# configuration stored in the database. This non-destructively migrates
# the collection to have a configuration, and takes into account any
# HNSW params that might be in the existing metadata.
configuration = self._insert_config_from_legacy_params(
collection_id, metadata
)
collections.append(
Collection(
id=cast(UUID, id),
name=name,
configuration_json=collection_configuration_to_json(
configuration
),
serialized_schema=None,
metadata=metadata,
dimension=dimension,
tenant=str(rows[0][5]),
database=str(rows[0][4]),
version=0,
)
)
# apply limit and offset
if limit is not None:
if offset is None:
offset = 0
collections = collections[offset : offset + limit]
else:
collections = collections[offset:]
return collections
@override
def get_collection_with_segments(
self, collection_id: UUID
) -> CollectionAndSegments:
collections = self.get_collections(id=collection_id)
if len(collections) == 0:
raise NotFoundError(f"Collection {collection_id} does not exist.")
return CollectionAndSegments(
collection=collections[0],
segments=self.get_segments(collection=collection_id),
)
@trace_method("SqlSysDB.delete_segment", OpenTelemetryGranularity.ALL)
@override
def delete_segment(self, collection: UUID, id: UUID) -> None:
"""Delete a segment from the SysDB"""
add_attributes_to_current_span(
{
"segment_id": str(id),
}
)
t = Table("segments")
q = (
self.querybuilder()
.from_(t)
.where(t.id == ParameterValue(self.uuid_to_db(id)))
.delete()
)
with self.tx() as cur:
# no need for explicit del from metadata table because of ON DELETE CASCADE
sql, params = get_sql(q, self.parameter_format())
sql = sql + " RETURNING id"
result = cur.execute(sql, params).fetchone()
if not result:
raise NotFoundError(f"Segment {id} not found")
# Used by delete_collection to delete all segments for a collection along with
# the collection itself in a single transaction.
def delete_segments_for_collection(self, cur: Cursor, collection: UUID) -> None:
segments_t = Table("segments")
q = (
self.querybuilder()
.from_(segments_t)
.where(segments_t.collection == ParameterValue(self.uuid_to_db(collection)))
.delete()
)
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)
@trace_method("SqlSysDB.delete_collection", OpenTelemetryGranularity.ALL)
@override
def delete_collection(
self,
id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
"""Delete a collection and all associated segments from the SysDB. Deletes
the log stream for this collection as well."""
add_attributes_to_current_span(
{
"collection_id": str(id),
}
)
t = Table("collections")
databases_t = Table("databases")
q = (
self.querybuilder()
.from_(t)
.where(t.id == ParameterValue(self.uuid_to_db(id)))
.where(
t.database_id
== self.querybuilder()
.select(databases_t.id)
.from_(databases_t)
.where(databases_t.name == ParameterValue(database))
.where(databases_t.tenant_id == ParameterValue(tenant))
)
.delete()
)
with self.tx() as cur:
# no need for explicit del from metadata table because of ON DELETE CASCADE
sql, params = get_sql(q, self.parameter_format())
sql = sql + " RETURNING id"
result = cur.execute(sql, params).fetchone()
if not result:
raise NotFoundError(f"Collection {id} not found")
# Delete segments.
self.delete_segments_for_collection(cur, id)
self._producer.delete_log(result[0])
@trace_method("SqlSysDB.update_segment", OpenTelemetryGranularity.ALL)
@override
def update_segment(
self,
collection: UUID,
id: UUID,
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
) -> None:
add_attributes_to_current_span(
{
"segment_id": str(id),
"collection": str(collection),
}
)
segments_t = Table("segments")
metadata_t = Table("segment_metadata")
q = (
self.querybuilder()
.update(segments_t)
.where(segments_t.id == ParameterValue(self.uuid_to_db(id)))
.set(segments_t.collection, ParameterValue(self.uuid_to_db(collection)))
)
with self.tx() as cur:
sql, params = get_sql(q, self.parameter_format())
if sql: # pypika emits a blank string if nothing to do
cur.execute(sql, params)
if metadata is None:
q = (
self.querybuilder()
.from_(metadata_t)
.where(metadata_t.segment_id == ParameterValue(self.uuid_to_db(id)))
.delete()
)
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)
elif metadata != Unspecified():
metadata = cast(UpdateMetadata, metadata)
metadata = cast(UpdateMetadata, metadata)
self._insert_metadata(
cur,
metadata_t,
metadata_t.segment_id,
id,
metadata,
set(metadata.keys()),
)
@trace_method("SqlSysDB.update_collection", OpenTelemetryGranularity.ALL)
@override
def update_collection(
self,
id: UUID,
name: OptionalArgument[str] = Unspecified(),
dimension: OptionalArgument[Optional[int]] = Unspecified(),
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
configuration: OptionalArgument[
Optional[UpdateCollectionConfiguration]
] = Unspecified(),
) -> None:
add_attributes_to_current_span(
{
"collection_id": str(id),
}
)
collections_t = Table("collections")
metadata_t = Table("collection_metadata")
q = (
self.querybuilder()
.update(collections_t)
.where(collections_t.id == ParameterValue(self.uuid_to_db(id)))
)
if not name == Unspecified():
q = q.set(collections_t.name, ParameterValue(name))
if not dimension == Unspecified():
q = q.set(collections_t.dimension, ParameterValue(dimension))
with self.tx() as cur:
sql, params = get_sql(q, self.parameter_format())
if sql: # pypika emits a blank string if nothing to do
sql = sql + " RETURNING id"
result = cur.execute(sql, params)
if not result.fetchone():
raise NotFoundError(f"Collection {id} not found")
# TODO: Update to use better semantics where it's possible to update
# individual keys without wiping all the existing metadata.
# For now, follow current legancy semantics where metadata is fully reset
if metadata != Unspecified():
q = (
self.querybuilder()
.from_(metadata_t)
.where(
metadata_t.collection_id == ParameterValue(self.uuid_to_db(id))
)
.delete()
)
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)
if metadata is not None:
metadata = cast(UpdateMetadata, metadata)
self._insert_metadata(
cur,
metadata_t,
metadata_t.collection_id,
id,
metadata,
set(metadata.keys()),
)
if configuration != Unspecified():
update_configuration = cast(
UpdateCollectionConfiguration, configuration
)
self._update_config_json_str(cur, update_configuration, id)
else:
if metadata != Unspecified():
metadata = cast(UpdateMetadata, metadata)
if metadata is not None:
update_configuration = (
update_collection_configuration_from_legacy_update_metadata(
metadata
)
)
self._update_config_json_str(cur, update_configuration, id)
def _update_config_json_str(
self, cur: Cursor, update_configuration: UpdateCollectionConfiguration, id: UUID
) -> None:
collections_t = Table("collections")
q = (
self.querybuilder()
.from_(collections_t)
.select(collections_t.config_json_str)
.where(collections_t.id == ParameterValue(self.uuid_to_db(id)))
)
sql, params = get_sql(q, self.parameter_format())
row = cur.execute(sql, params).fetchone()
if not row:
raise NotFoundError(f"Collection {id} not found")
config_json_str = row[0]
existing_config = load_collection_configuration_from_json_str(config_json_str)
new_config = overwrite_collection_configuration(
existing_config, update_configuration
)
q = (
self.querybuilder()
.update(collections_t)
.set(
collections_t.config_json_str,
ParameterValue(collection_configuration_to_json_str(new_config)),
)
.where(collections_t.id == ParameterValue(self.uuid_to_db(id)))
)
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)
@trace_method("SqlSysDB._metadata_from_rows", OpenTelemetryGranularity.ALL)
def _metadata_from_rows(
self, rows: Sequence[Tuple[Any, ...]]
) -> Optional[Metadata]:
"""Given SQL rows, return a metadata map (assuming that the last four columns
are the key, str_value, int_value & float_value)"""
add_attributes_to_current_span(
{
"num_rows": len(rows),
}
)
metadata: Dict[str, Union[str, int, float, bool]] = {}
for row in rows:
key = str(row[-5])
if row[-4] is not None:
metadata[key] = str(row[-4])
elif row[-3] is not None:
metadata[key] = int(row[-3])
elif row[-2] is not None:
metadata[key] = float(row[-2])
elif row[-1] is not None:
metadata[key] = bool(row[-1])
return metadata or None
@trace_method("SqlSysDB._insert_metadata", OpenTelemetryGranularity.ALL)
def _insert_metadata(
self,
cur: Cursor,
table: Table,
id_col: Column,
id: UUID,
metadata: UpdateMetadata,
clear_keys: Optional[Set[str]] = None,
) -> None:
# It would be cleaner to use something like ON CONFLICT UPDATE here But that is
# very difficult to do in a portable way (e.g sqlite and postgres have
# completely different sytnax)
add_attributes_to_current_span(
{
"num_keys": len(metadata),
}
)
if clear_keys:
q = (
self.querybuilder()
.from_(table)
.where(id_col == ParameterValue(self.uuid_to_db(id)))
.where(table.key.isin([ParameterValue(k) for k in clear_keys]))
.delete()
)
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)
q = (
self.querybuilder()
.into(table)
.columns(
id_col,
table.key,
table.str_value,
table.int_value,
table.float_value,
table.bool_value,
)
)
sql_id = self.uuid_to_db(id)
for k, v in metadata.items():
# Note: The order is important here because isinstance(v, bool)
# and isinstance(v, int) both are true for v of bool type.
if isinstance(v, bool):
q = q.insert(
ParameterValue(sql_id),
ParameterValue(k),
None,
None,
None,
ParameterValue(int(v)),
)
elif isinstance(v, str):
q = q.insert(
ParameterValue(sql_id),
ParameterValue(k),
ParameterValue(v),
None,
None,
None,
)
elif isinstance(v, int):
q = q.insert(
ParameterValue(sql_id),
ParameterValue(k),
None,
ParameterValue(v),
None,
None,
)
elif isinstance(v, float):
q = q.insert(
ParameterValue(sql_id),
ParameterValue(k),
None,
None,
ParameterValue(v),
None,
)
elif v is None:
continue
sql, params = get_sql(q, self.parameter_format())
if sql:
cur.execute(sql, params)
def _insert_config_from_legacy_params(
self, collection_id: Any, metadata: Optional[Metadata]
) -> CollectionConfiguration:
"""Insert the configuration from legacy metadata params into the collections table, and return the configuration object."""
# This is a legacy case where we don't have configuration stored in the database
# This is non-destructive, we don't delete or overwrite any keys in the metadata
collections_t = Table("collections")
create_collection_config = CreateCollectionConfiguration()
# Write the configuration into the database
configuration_json_str = create_collection_configuration_to_json_str(
create_collection_config, cast(CollectionMetadata, metadata)
)
q = (
self.querybuilder()
.update(collections_t)
.set(
collections_t.config_json_str,
ParameterValue(configuration_json_str),
)
.where(collections_t.id == ParameterValue(collection_id))
)
sql, params = get_sql(q, self.parameter_format())
with self.tx() as cur:
cur.execute(sql, params)
return load_collection_configuration_from_json_str(configuration_json_str)
@override
def get_collection_size(self, id: UUID) -> int:
raise NotImplementedError
@override
def count_collections(
self,
tenant: str = DEFAULT_TENANT,
database: Optional[str] = None,
) -> int:
"""Gets the number of collections for the (tenant, database) combination."""
# TODO(Sanket): Implement this efficiently using a count query.
# Note, the underlying get_collections api always requires a database
# to be specified. In the sysdb implementation in go code, it does not
# filter on database if it is set to "". This is a bad API and
# should be fixed. For now, we will replicate the behavior.
request_database: str = "" if database is None or database == "" else database
return len(self.get_collections(tenant=tenant, database=request_database))

View File

@@ -0,0 +1,189 @@
from abc import abstractmethod
from typing import Optional, Sequence, Tuple
from uuid import UUID
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
)
from chromadb.api.types import Schema
from chromadb.types import (
Collection,
CollectionAndSegments,
Database,
Tenant,
Metadata,
Segment,
SegmentScope,
OptionalArgument,
Unspecified,
UpdateMetadata,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component
class SysDB(Component):
"""Data interface for Chroma's System database"""
@abstractmethod
def create_database(
self, id: UUID, name: str, tenant: str = DEFAULT_TENANT
) -> None:
"""Create a new database in the System database. Raises an Error if the Database
already exists."""
pass
@abstractmethod
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
"""Get a database by name and tenant. Raises an Error if the Database does not
exist."""
pass
@abstractmethod
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
"""Delete a database."""
pass
@abstractmethod
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
"""List all databases for a tenant."""
pass
@abstractmethod
def create_tenant(self, name: str) -> None:
"""Create a new tenant in the System database. The name must be unique.
Raises an Error if the Tenant already exists."""
pass
@abstractmethod
def get_tenant(self, name: str) -> Tenant:
"""Get a tenant by name. Raises an Error if the Tenant does not exist."""
pass
# TODO: Investigate and remove this method, as segment creation is done as
# part of collection creation.
@abstractmethod
def create_segment(self, segment: Segment) -> None:
"""Create a new segment in the System database. Raises an Error if the ID
already exists."""
pass
@abstractmethod
def delete_segment(self, collection: UUID, id: UUID) -> None:
"""Delete a segment from the System database."""
pass
@abstractmethod
def get_segments(
self,
collection: UUID,
id: Optional[UUID] = None,
type: Optional[str] = None,
scope: Optional[SegmentScope] = None,
) -> Sequence[Segment]:
"""Find segments by id, type, scope or collection."""
pass
@abstractmethod
def update_segment(
self,
collection: UUID,
id: UUID,
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
) -> None:
"""Update a segment. Unspecified fields will be left unchanged. For the
metadata, keys with None values will be removed and keys not present in the
UpdateMetadata dict will be left unchanged."""
pass
@abstractmethod
def create_collection(
self,
id: UUID,
name: str,
schema: Optional[Schema],
configuration: CreateCollectionConfiguration,
segments: Sequence[Segment],
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Tuple[Collection, bool]:
"""Create a new collection and associated resources
in the SysDB. If get_or_create is True, the
collection will be created if one with the same name does not exist.
The metadata will be updated using the same protocol as update_collection. If get_or_create
is False and the collection already exists, an error will be raised.
Returns a tuple of the created collection and a boolean indicating whether the
collection was created or not.
"""
pass
@abstractmethod
def delete_collection(
self,
id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
"""Delete a collection, all associated segments and any associate resources (log stream)
from the SysDB and the system at large."""
pass
@abstractmethod
def get_collections(
self,
id: Optional[UUID] = None,
name: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Collection]:
"""Find collections by id or name. If name is provided, tenant and database must also be provided."""
pass
@abstractmethod
def count_collections(
self,
tenant: str = DEFAULT_TENANT,
database: Optional[str] = None,
) -> int:
"""Gets the number of collections for the (tenant, database) combination."""
pass
@abstractmethod
def get_collection_with_segments(
self, collection_id: UUID
) -> CollectionAndSegments:
"""Get a consistent snapshot of a collection by id. This will return a collection with segment
information that matches the collection version and log position.
"""
pass
@abstractmethod
def update_collection(
self,
id: UUID,
name: OptionalArgument[str] = Unspecified(),
dimension: OptionalArgument[Optional[int]] = Unspecified(),
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
configuration: OptionalArgument[
Optional[UpdateCollectionConfiguration]
] = Unspecified(),
) -> None:
"""Update a collection. Unspecified fields will be left unchanged. For metadata,
keys with None values will be removed and keys not present in the UpdateMetadata
dict will be left unchanged."""
pass
@abstractmethod
def get_collection_size(self, id: UUID) -> int:
"""Returns the number of records in a collection."""
pass

View File

@@ -0,0 +1,194 @@
from abc import abstractmethod
from typing import Dict, Optional, Type
from overrides import overrides, EnforceOverrides
class ChromaError(Exception, EnforceOverrides):
trace_id: Optional[str] = None
def code(self) -> int:
"""Return an appropriate HTTP response code for this error"""
return 400 # Bad Request
def message(self) -> str:
return ", ".join(self.args)
@classmethod
@abstractmethod
def name(cls) -> str:
"""Return the error name"""
pass
class InvalidDimensionException(ChromaError):
@classmethod
@overrides
def name(cls) -> str:
return "InvalidDimension"
class IDAlreadyExistsError(ChromaError):
@overrides
def code(self) -> int:
return 409 # Conflict
@classmethod
@overrides
def name(cls) -> str:
return "IDAlreadyExists"
class ChromaAuthError(ChromaError):
@overrides
def code(self) -> int:
return 403
@classmethod
@overrides
def name(cls) -> str:
return "AuthError"
@overrides
def message(self) -> str:
return "Forbidden"
class DuplicateIDError(ChromaError):
@classmethod
@overrides
def name(cls) -> str:
return "DuplicateID"
class InvalidArgumentError(ChromaError):
@overrides
def code(self) -> int:
return 400
@classmethod
@overrides
def name(cls) -> str:
return "InvalidArgument"
class InvalidUUIDError(ChromaError):
@classmethod
@overrides
def name(cls) -> str:
return "InvalidUUID"
class InvalidHTTPVersion(ChromaError):
@classmethod
@overrides
def name(cls) -> str:
return "InvalidHTTPVersion"
class AuthorizationError(ChromaError):
@overrides
def code(self) -> int:
return 401
@classmethod
@overrides
def name(cls) -> str:
return "AuthorizationError"
class NotFoundError(ChromaError):
@overrides
def code(self) -> int:
return 404
@classmethod
@overrides
def name(cls) -> str:
return "NotFoundError"
class UniqueConstraintError(ChromaError):
@overrides
def code(self) -> int:
return 409
@classmethod
@overrides
def name(cls) -> str:
return "UniqueConstraintError"
class BatchSizeExceededError(ChromaError):
@overrides
def code(self) -> int:
return 413
@classmethod
@overrides
def name(cls) -> str:
return "BatchSizeExceededError"
class VersionMismatchError(ChromaError):
@overrides
def code(self) -> int:
return 500
@classmethod
@overrides
def name(cls) -> str:
return "VersionMismatchError"
class InternalError(ChromaError):
@overrides
def code(self) -> int:
return 500
@classmethod
@overrides
def name(cls) -> str:
return "InternalError"
class RateLimitError(ChromaError):
@overrides
def code(self) -> int:
return 429
@classmethod
@overrides
def name(cls) -> str:
return "RateLimitError"
class QuotaError(ChromaError):
@overrides
def code(self) -> int:
return 400
@classmethod
@overrides
def name(cls) -> str:
return "QuotaError"
error_types: Dict[str, Type[ChromaError]] = {
"InvalidDimension": InvalidDimensionException,
"InvalidArgumentError": InvalidArgumentError,
"IDAlreadyExists": IDAlreadyExistsError,
"DuplicateID": DuplicateIDError,
"InvalidUUID": InvalidUUIDError,
"InvalidHTTPVersion": InvalidHTTPVersion,
"AuthorizationError": AuthorizationError,
"NotFoundError": NotFoundError,
"BatchSizeExceededError": BatchSizeExceededError,
"VersionMismatchError": VersionMismatchError,
"RateLimitError": RateLimitError,
"AuthError": ChromaAuthError,
"UniqueConstraintError": UniqueConstraintError,
"QuotaError": QuotaError,
"InternalError": InternalError,
# Catch-all for any other errors
"ChromaError": ChromaError,
}

View File

@@ -0,0 +1,19 @@
from abc import abstractmethod
from chromadb.api.types import GetResult, QueryResult
from chromadb.config import Component
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
class Executor(Component):
@abstractmethod
def count(self, plan: CountPlan) -> int:
pass
@abstractmethod
def get(self, plan: GetPlan) -> GetResult:
pass
@abstractmethod
def knn(self, plan: KNNPlan) -> QueryResult:
pass

View File

@@ -0,0 +1,242 @@
import threading
import random
from typing import Callable, Dict, List, Optional, TypeVar
import grpc
from overrides import overrides
from chromadb.api.types import GetResult, Metadata, QueryResult
from chromadb.config import System
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.proto import convert
from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from tenacity import (
RetryCallState,
Retrying,
stop_after_attempt,
wait_exponential_jitter,
retry_if_exception,
)
from opentelemetry.trace import Span
def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map."""
if not metadata:
return None
result = {}
for k, v in metadata.items():
if not k.startswith("chroma:"):
result[k] = v
if len(result) == 0:
return None
return result
def _uri(metadata: Optional[Metadata]) -> Optional[str]:
"""Retrieve the uri (if any) from a Metadata map"""
if metadata and "chroma:uri" in metadata:
return str(metadata["chroma:uri"])
return None
# Type variables for input and output types of the round-robin retry function
I = TypeVar("I") # noqa: E741
O = TypeVar("O") # noqa: E741
class DistributedExecutor(Executor):
_mtx: threading.Lock
_grpc_stub_pool: Dict[str, QueryExecutorStub]
_manager: DistributedSegmentManager
_request_timeout_seconds: int
_query_replication_factor: int
def __init__(self, system: System):
super().__init__(system)
self._mtx = threading.Lock()
self._grpc_stub_pool = {}
self._manager = self.require(DistributedSegmentManager)
self._request_timeout_seconds = system.settings.require(
"chroma_query_request_timeout_seconds"
)
self._query_replication_factor = system.settings.require(
"chroma_query_replication_factor"
)
def _round_robin_retry(self, funcs: List[Callable[[I], O]], args: I) -> O:
"""
Retry a list of functions in a round-robin fashion until one of them succeeds.
funcs: List of functions to retry
args: Arguments to pass to each function
"""
attempt_count = 0
sleep_span: Optional[Span] = None
def before_sleep(_: RetryCallState) -> None:
# HACK(hammadb) 1/14/2024 - this is a hack to avoid the fact that tracer is not yet available and there are boot order issues
# This should really use our component system to get the tracer. Since our grpc utils use this pattern
# we are copying it here. This should be removed once we have a better way to get the tracer
from chromadb.telemetry.opentelemetry import tracer
nonlocal sleep_span
if tracer is not None:
sleep_span = tracer.start_span("Waiting to retry RPC")
for attempt in Retrying(
stop=stop_after_attempt(5),
wait=wait_exponential_jitter(0.1, jitter=0.1),
reraise=True,
retry=retry_if_exception(
lambda x: isinstance(x, grpc.RpcError)
and x.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]
),
before_sleep=before_sleep,
):
if sleep_span is not None:
sleep_span.end()
sleep_span = None
with attempt:
return funcs[attempt_count % len(funcs)](args)
attempt_count += 1
# NOTE(hammadb) because Retrying() will always either return or raise an exception, this line should never be reached
raise Exception("Unreachable code error - should never reach here")
@overrides
def count(self, plan: CountPlan) -> int:
endpoints = self._get_grpc_endpoints(plan.scan)
count_funcs = [self._get_stub(endpoint).Count for endpoint in endpoints]
count_result = self._round_robin_retry(
count_funcs, convert.to_proto_count_plan(plan)
)
return convert.from_proto_count_result(count_result)
@overrides
def get(self, plan: GetPlan) -> GetResult:
endpoints = self._get_grpc_endpoints(plan.scan)
get_funcs = [self._get_stub(endpoint).Get for endpoint in endpoints]
get_result = self._round_robin_retry(get_funcs, convert.to_proto_get_plan(plan))
records = convert.from_proto_get_result(get_result)
ids = [record["id"] for record in records]
embeddings = (
[record["embedding"] for record in records]
if plan.projection.embedding
else None
)
documents = (
[record["document"] for record in records]
if plan.projection.document
else None
)
uris = (
[_uri(record["metadata"]) for record in records]
if plan.projection.uri
else None
)
metadatas = (
[_clean_metadata(record["metadata"]) for record in records]
if plan.projection.metadata
else None
)
# TODO: Fix typing
return GetResult(
ids=ids,
embeddings=embeddings, # type: ignore[typeddict-item]
documents=documents, # type: ignore[typeddict-item]
uris=uris, # type: ignore[typeddict-item]
data=None,
metadatas=metadatas, # type: ignore[typeddict-item]
included=plan.projection.included,
)
@overrides
def knn(self, plan: KNNPlan) -> QueryResult:
endpoints = self._get_grpc_endpoints(plan.scan)
knn_funcs = [self._get_stub(endpoint).KNN for endpoint in endpoints]
knn_result = self._round_robin_retry(knn_funcs, convert.to_proto_knn_plan(plan))
results = convert.from_proto_knn_batch_result(knn_result)
ids = [[record["record"]["id"] for record in records] for records in results]
embeddings = (
[
[record["record"]["embedding"] for record in records]
for records in results
]
if plan.projection.embedding
else None
)
documents = (
[
[record["record"]["document"] for record in records]
for records in results
]
if plan.projection.document
else None
)
uris = (
[
[_uri(record["record"]["metadata"]) for record in records]
for records in results
]
if plan.projection.uri
else None
)
metadatas = (
[
[_clean_metadata(record["record"]["metadata"]) for record in records]
for records in results
]
if plan.projection.metadata
else None
)
distances = (
[[record["distance"] for record in records] for records in results]
if plan.projection.rank
else None
)
# TODO: Fix typing
return QueryResult(
ids=ids,
embeddings=embeddings, # type: ignore[typeddict-item]
documents=documents, # type: ignore[typeddict-item]
uris=uris, # type: ignore[typeddict-item]
data=None,
metadatas=metadatas, # type: ignore[typeddict-item]
distances=distances, # type: ignore[typeddict-item]
included=plan.projection.included,
)
def _get_grpc_endpoints(self, scan: Scan) -> List[str]:
# Since grpc endpoint is endpoint is determined by collection uuid,
# the endpoint should be the same for all segments of the same collection
grpc_urls = self._manager.get_endpoints(
scan.record, self._query_replication_factor
)
# Shuffle the grpc urls to distribute the load evenly
random.shuffle(grpc_urls)
return grpc_urls
def _get_stub(self, grpc_url: str) -> QueryExecutorStub:
with self._mtx:
if grpc_url not in self._grpc_stub_pool:
channel = grpc.insecure_channel(
grpc_url,
options=[
("grpc.max_concurrent_streams", 1000),
("grpc.max_receive_message_length", 32000000), # 32 MB
],
)
interceptors = [OtelInterceptor()]
channel = grpc.intercept_channel(channel, *interceptors)
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel)
return self._grpc_stub_pool[grpc_url]

View File

@@ -0,0 +1,205 @@
from typing import Optional, Sequence
from overrides import overrides
from chromadb.api.types import GetResult, Metadata, QueryResult
from chromadb.config import System
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.segment import MetadataReader, VectorReader
from chromadb.segment.impl.manager.local import LocalSegmentManager
from chromadb.types import Collection, VectorQuery, VectorQueryResult
def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map."""
if not metadata:
return None
result = {}
for k, v in metadata.items():
if not k.startswith("chroma:"):
result[k] = v
if len(result) == 0:
return None
return result
def _doc(metadata: Optional[Metadata]) -> Optional[str]:
"""Retrieve the document (if any) from a Metadata map"""
if metadata and "chroma:document" in metadata:
return str(metadata["chroma:document"])
return None
def _uri(metadata: Optional[Metadata]) -> Optional[str]:
"""Retrieve the uri (if any) from a Metadata map"""
if metadata and "chroma:uri" in metadata:
return str(metadata["chroma:uri"])
return None
class LocalExecutor(Executor):
_manager: LocalSegmentManager
def __init__(self, system: System):
super().__init__(system)
self._manager = self.require(LocalSegmentManager)
@overrides
def count(self, plan: CountPlan) -> int:
return self._metadata_segment(plan.scan.collection).count(plan.scan.version)
@overrides
def get(self, plan: GetPlan) -> GetResult:
records = self._metadata_segment(plan.scan.collection).get_metadata(
request_version_context=plan.scan.version,
where=plan.filter.where,
where_document=plan.filter.where_document,
ids=plan.filter.user_ids,
limit=plan.limit.limit,
offset=plan.limit.offset,
include_metadata=True,
)
ids = [r["id"] for r in records]
embeddings = None
documents = None
uris = None
metadatas = None
included = list()
if plan.projection.embedding:
if len(records) > 0:
vectors = self._vector_segment(plan.scan.collection).get_vectors(
ids=ids, request_version_context=plan.scan.version
)
embeddings = [v["embedding"] for v in vectors]
else:
embeddings = list()
included.append("embeddings")
if plan.projection.document:
documents = [_doc(r["metadata"]) for r in records]
included.append("documents")
if plan.projection.uri:
uris = [_uri(r["metadata"]) for r in records]
included.append("uris")
if plan.projection.metadata:
metadatas = [_clean_metadata(r["metadata"]) for r in records]
included.append("metadatas")
# TODO: Fix typing
return GetResult(
ids=ids,
embeddings=embeddings,
documents=documents, # type: ignore[typeddict-item]
uris=uris, # type: ignore[typeddict-item]
data=None,
metadatas=metadatas, # type: ignore[typeddict-item]
included=included,
)
@overrides
def knn(self, plan: KNNPlan) -> QueryResult:
prefiltered_ids = None
if plan.filter.user_ids or plan.filter.where or plan.filter.where_document:
records = self._metadata_segment(plan.scan.collection).get_metadata(
request_version_context=plan.scan.version,
where=plan.filter.where,
where_document=plan.filter.where_document,
ids=plan.filter.user_ids,
limit=None,
offset=0,
include_metadata=False,
)
prefiltered_ids = [r["id"] for r in records]
knns: Sequence[Sequence[VectorQueryResult]] = [[]] * len(plan.knn.embeddings)
# Query vectors only when the user did not specify a filter or when the filter
# yields non-empty ids. Otherwise, the user specified a filter but it yields
# no matching ids, in which case we can return an empty result.
if prefiltered_ids is None or len(prefiltered_ids) > 0:
query = VectorQuery(
vectors=plan.knn.embeddings,
k=plan.knn.fetch,
allowed_ids=prefiltered_ids,
include_embeddings=plan.projection.embedding,
options=None,
request_version_context=plan.scan.version,
)
knns = self._vector_segment(plan.scan.collection).query_vectors(query)
ids = [[r["id"] for r in result] for result in knns]
embeddings = None
documents = None
uris = None
metadatas = None
distances = None
included = list()
if plan.projection.embedding:
embeddings = [[r["embedding"] for r in result] for result in knns]
included.append("embeddings")
if plan.projection.rank:
distances = [[r["distance"] for r in result] for result in knns]
included.append("distances")
if plan.projection.document or plan.projection.metadata or plan.projection.uri:
merged_ids = list(set([id for result in ids for id in result]))
hydrated_records = self._metadata_segment(
plan.scan.collection
).get_metadata(
request_version_context=plan.scan.version,
where=None,
where_document=None,
ids=merged_ids,
limit=None,
offset=0,
include_metadata=True,
)
metadata_by_id = {r["id"]: r["metadata"] for r in hydrated_records}
if plan.projection.document:
documents = [
[_doc(metadata_by_id.get(id, None)) for id in result]
for result in ids
]
included.append("documents")
if plan.projection.uri:
uris = [
[_uri(metadata_by_id.get(id, None)) for id in result]
for result in ids
]
included.append("uris")
if plan.projection.metadata:
metadatas = [
[_clean_metadata(metadata_by_id.get(id, None)) for id in result]
for result in ids
]
included.append("metadatas")
# TODO: Fix typing
return QueryResult(
ids=ids,
embeddings=embeddings, # type: ignore[typeddict-item]
documents=documents, # type: ignore[typeddict-item]
uris=uris, # type: ignore[typeddict-item]
data=None,
metadatas=metadatas, # type: ignore[typeddict-item]
distances=distances,
included=included,
)
def _metadata_segment(self, collection: Collection) -> MetadataReader:
return self._manager.get_segment(collection.id, MetadataReader)
def _vector_segment(self, collection: Collection) -> VectorReader:
return self._manager.get_segment(collection.id, VectorReader)

View File

@@ -0,0 +1,90 @@
"""
Chromadb execution expression module for search operations.
"""
from chromadb.execution.expression.operator import (
# Field proxy for building Where conditions
Key,
K,
# Where expressions
Where,
And,
Or,
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
In,
Nin,
Regex,
NotRegex,
Contains,
NotContains,
# Search configuration
Limit,
Select,
# Rank expressions
Rank,
Abs,
Div,
Exp,
Log,
Max,
Min,
Mul,
Knn,
Rrf,
Sub,
Sum,
Val,
)
from chromadb.execution.expression.plan import (
Search,
)
SearchWhere = Where
__all__ = [
# Main search class
"Search",
# Field proxy
"Key",
"K",
# Where expressions
"SearchWhere",
"Where",
"And",
"Or",
"Eq",
"Ne",
"Gt",
"Gte",
"Lt",
"Lte",
"In",
"Nin",
"Regex",
"NotRegex",
"Contains",
"NotContains",
# Search configuration
"Limit",
"Select",
# Rank expressions
"Rank",
"Abs",
"Div",
"Exp",
"Log",
"Max",
"Min",
"Mul",
"Knn",
"Rrf",
"Sub",
"Sum",
"Val",
]

View File

@@ -0,0 +1,263 @@
from dataclasses import dataclass, field
from typing import List, Dict, Any, Union, Set, Optional
from chromadb.execution.expression.operator import (
KNN,
Filter,
Limit,
Projection,
Scan,
Rank,
Select,
Val,
Where,
Key,
)
@dataclass
class CountPlan:
scan: Scan
@dataclass
class GetPlan:
scan: Scan
filter: Filter = field(default_factory=Filter)
limit: Limit = field(default_factory=Limit)
projection: Projection = field(default_factory=Projection)
@dataclass
class KNNPlan:
scan: Scan
knn: KNN
filter: Filter = field(default_factory=Filter)
projection: Projection = field(default_factory=Projection)
class Search:
"""Payload for hybrid search operations.
Can be constructed directly or using builder pattern:
Direct construction with expressions:
Search(
where=Key("status") == "active",
rank=Knn(query=[0.1, 0.2]),
limit=Limit(limit=10),
select=Select(keys={Key.DOCUMENT})
)
Direct construction with dicts:
Search(
where={"status": "active"},
rank={"$knn": {"query": [0.1, 0.2]}},
limit=10, # Creates Limit(limit=10, offset=0)
select=["#document", "#score"]
)
Builder pattern:
(Search()
.where(Key("status") == "active")
.rank(Knn(query=[0.1, 0.2]))
.limit(10)
.select(Key.DOCUMENT))
Builder pattern with dicts:
(Search()
.where({"status": "active"})
.rank({"$knn": {"query": [0.1, 0.2]}})
.limit(10)
.select(Key.DOCUMENT))
Filter by IDs:
Search().where(Key.ID.is_in(["id1", "id2", "id3"]))
Combined with metadata filtering:
Search().where((Key.ID.is_in(["id1", "id2"])) & (Key("status") == "active"))
Empty Search() is valid and will use defaults:
- where: None (no filtering)
- rank: None (no ranking - results ordered by default order)
- limit: No limit
- select: Empty selection
"""
def __init__(
self,
where: Optional[Union[Where, Dict[str, Any]]] = None,
rank: Optional[Union[Rank, Dict[str, Any]]] = None,
limit: Optional[Union[Limit, Dict[str, Any], int]] = None,
select: Optional[Union[Select, Dict[str, Any], List[str], Set[str]]] = None,
):
"""Initialize a Search with optional parameters.
Args:
where: Where expression or dict for filtering results (defaults to None - no filtering)
Dict will be converted using Where.from_dict()
rank: Rank expression or dict for scoring (defaults to None - no ranking)
Dict will be converted using Rank.from_dict()
Note: Primitive numbers are not accepted - use {"$val": number} for constant ranks
limit: Limit configuration for pagination (defaults to no limit)
Can be a Limit object, a dict for Limit.from_dict(), or an int
When passing an int, it creates Limit(limit=value, offset=0)
select: Select configuration for keys (defaults to empty selection)
Can be a Select object, a dict for Select.from_dict(),
or a list/set of strings (e.g., ["#document", "#score"])
"""
# Handle where parameter
if where is None:
self._where = None
elif isinstance(where, Where):
self._where = where
elif isinstance(where, dict):
self._where = Where.from_dict(where)
else:
raise TypeError(
f"where must be a Where object, dict, or None, got {type(where).__name__}"
)
# Handle rank parameter
if rank is None:
self._rank = None
elif isinstance(rank, Rank):
self._rank = rank
elif isinstance(rank, dict):
self._rank = Rank.from_dict(rank)
else:
raise TypeError(
f"rank must be a Rank object, dict, or None, got {type(rank).__name__}"
)
# Handle limit parameter
if limit is None:
self._limit = Limit()
elif isinstance(limit, Limit):
self._limit = limit
elif isinstance(limit, int):
self._limit = Limit.from_dict({"limit": limit, "offset": 0})
elif isinstance(limit, dict):
self._limit = Limit.from_dict(limit)
else:
raise TypeError(
f"limit must be a Limit object, dict, int, or None, got {type(limit).__name__}"
)
# Handle select parameter
if select is None:
self._select = Select()
elif isinstance(select, Select):
self._select = select
elif isinstance(select, dict):
self._select = Select.from_dict(select)
elif isinstance(select, (list, set)):
# Convert list/set of strings to Select object
self._select = Select.from_dict({"keys": list(select)})
else:
raise TypeError(
f"select must be a Select object, dict, list, set, or None, got {type(select).__name__}"
)
def to_dict(self) -> Dict[str, Any]:
"""Convert the Search to a dictionary for JSON serialization"""
return {
"filter": self._where.to_dict() if self._where is not None else None,
"rank": self._rank.to_dict() if self._rank is not None else None,
"limit": self._limit.to_dict(),
"select": self._select.to_dict(),
}
# Builder methods for chaining
def select_all(self) -> "Search":
"""Select all predefined keys (document, embedding, metadata, score)"""
new_select = Select(keys={Key.DOCUMENT, Key.EMBEDDING, Key.METADATA, Key.SCORE})
return Search(
where=self._where, rank=self._rank, limit=self._limit, select=new_select
)
def select(self, *keys: Union[Key, str]) -> "Search":
"""Select specific keys
Args:
*keys: Variable number of Key objects or string key names
Returns:
New Search object with updated select configuration
"""
new_select = Select(keys=set(keys))
return Search(
where=self._where, rank=self._rank, limit=self._limit, select=new_select
)
def where(self, where: Optional[Union[Where, Dict[str, Any]]]) -> "Search":
"""Set the where clause for filtering
Args:
where: A Where expression, dict, or None for filtering
Dicts will be converted using Where.from_dict()
Example:
search.where((Key("status") == "active") & (Key("score") > 0.5))
search.where({"status": "active"})
search.where({"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]})
"""
# Convert dict to Where if needed
if where is None:
converted_where = None
elif isinstance(where, Where):
converted_where = where
elif isinstance(where, dict):
converted_where = Where.from_dict(where)
else:
raise TypeError(
f"where must be a Where object, dict, or None, got {type(where).__name__}"
)
return Search(
where=converted_where, rank=self._rank, limit=self._limit, select=self._select
)
def rank(self, rank_expr: Optional[Union[Rank, Dict[str, Any]]]) -> "Search":
"""Set the ranking expression
Args:
rank_expr: A Rank expression, dict, or None for scoring
Dicts will be converted using Rank.from_dict()
Note: Primitive numbers are not accepted - use {"$val": number} for constant ranks
Example:
search.rank(Knn(query=[0.1, 0.2]) * 0.8 + Val(0.5) * 0.2)
search.rank({"$knn": {"query": [0.1, 0.2]}})
search.rank({"$sum": [{"$knn": {"query": [0.1, 0.2]}}, {"$val": 0.5}]})
"""
# Convert dict to Rank if needed
if rank_expr is None:
converted_rank = None
elif isinstance(rank_expr, Rank):
converted_rank = rank_expr
elif isinstance(rank_expr, dict):
converted_rank = Rank.from_dict(rank_expr)
else:
raise TypeError(
f"rank_expr must be a Rank object, dict, or None, got {type(rank_expr).__name__}"
)
return Search(
where=self._where, rank=converted_rank, limit=self._limit, select=self._select
)
def limit(self, limit: int, offset: int = 0) -> "Search":
"""Set the limit and offset for pagination
Args:
limit: Maximum number of results to return
offset: Number of results to skip (default: 0)
Example:
search.limit(20, offset=10)
"""
new_limit = Limit(offset=offset, limit=limit)
return Search(
where=self._where, rank=self._rank, limit=new_limit, select=self._select
)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,121 @@
from abc import abstractmethod
from typing import Callable, Optional, Sequence
from chromadb.types import (
OperationRecord,
LogRecord,
SeqId,
Vector,
ScalarEncoding,
)
from chromadb.config import Component
from uuid import UUID
import numpy as np
def encode_vector(vector: Vector, encoding: ScalarEncoding) -> bytes:
"""Encode a vector into a byte array."""
if encoding == ScalarEncoding.FLOAT32:
return np.array(vector, dtype=np.float32).tobytes()
elif encoding == ScalarEncoding.INT32:
return np.array(vector, dtype=np.int32).tobytes()
else:
raise ValueError(f"Unsupported encoding: {encoding.value}")
def decode_vector(vector: bytes, encoding: ScalarEncoding) -> Vector:
"""Decode a byte array into a vector"""
if encoding == ScalarEncoding.FLOAT32:
return np.frombuffer(vector, dtype=np.float32)
elif encoding == ScalarEncoding.INT32:
return np.frombuffer(vector, dtype=np.float32)
else:
raise ValueError(f"Unsupported encoding: {encoding.value}")
class Producer(Component):
"""Interface for writing embeddings to an ingest stream"""
@abstractmethod
def delete_log(self, collection_id: UUID) -> None:
pass
@abstractmethod
def purge_log(self, collection_id: UUID) -> None:
"""Truncates the log for the given collection, removing all seen records."""
pass
@abstractmethod
def submit_embedding(
self, collection_id: UUID, embedding: OperationRecord
) -> SeqId:
"""Add an embedding record to the given collections log. Returns the SeqID of the record."""
pass
@abstractmethod
def submit_embeddings(
self, collection_id: UUID, embeddings: Sequence[OperationRecord]
) -> Sequence[SeqId]:
"""Add a batch of embedding records to the given collections log. Returns the SeqIDs of
the records. The returned SeqIDs will be in the same order as the given
SubmitEmbeddingRecords. However, it is not guaranteed that the SeqIDs will be
processed in the same order as the given SubmitEmbeddingRecords. If the number
of records exceeds the maximum batch size, an exception will be thrown."""
pass
@property
@abstractmethod
def max_batch_size(self) -> int:
"""Return the maximum number of records that can be submitted in a single call
to submit_embeddings."""
pass
ConsumerCallbackFn = Callable[[Sequence[LogRecord]], None]
class Consumer(Component):
"""Interface for reading embeddings off an ingest stream"""
@abstractmethod
def subscribe(
self,
collection_id: UUID,
consume_fn: ConsumerCallbackFn,
start: Optional[SeqId] = None,
end: Optional[SeqId] = None,
id: Optional[UUID] = None,
) -> UUID:
"""Register a function that will be called to receive embeddings for a given
collections log stream. The given function may be called any number of times, with any number of
records, and may be called concurrently.
Only records between start (exclusive) and end (inclusive) SeqIDs will be
returned. If start is None, the first record returned will be the next record
generated, not including those generated before creating the subscription. If
end is None, the consumer will consume indefinitely, otherwise it will
automatically be unsubscribed when the end SeqID is reached.
If the function throws an exception, the function may be called again with the
same or different records.
Takes an optional UUID as a unique subscription ID. If no ID is provided, a new
ID will be generated and returned."""
pass
@abstractmethod
def unsubscribe(self, subscription_id: UUID) -> None:
"""Unregister a subscription. The consume function will no longer be invoked,
and resources associated with the subscription will be released."""
pass
@abstractmethod
def min_seqid(self) -> SeqId:
"""Return the minimum possible SeqID in this implementation."""
pass
@abstractmethod
def max_seqid(self) -> SeqId:
"""Return the maximum possible SeqID in this implementation."""
pass

View File

@@ -0,0 +1,49 @@
import re
from typing import Tuple
from uuid import UUID
from chromadb.db.base import SqlDB
from chromadb.segment import SegmentManager, VectorReader
topic_regex = r"persistent:\/\/(?P<tenant>.+)\/(?P<namespace>.+)\/(?P<topic>.+)"
def parse_topic_name(topic_name: str) -> Tuple[str, str, str]:
"""Parse the topic name into the tenant, namespace and topic name"""
match = re.match(topic_regex, topic_name)
if not match:
raise ValueError(f"Invalid topic name: {topic_name}")
return match.group("tenant"), match.group("namespace"), match.group("topic")
def create_topic_name(tenant: str, namespace: str, collection_id: UUID) -> str:
return f"persistent://{tenant}/{namespace}/{str(collection_id)}"
def trigger_vector_segments_max_seq_id_migration(
db: SqlDB, segment_manager: SegmentManager
) -> None:
"""
Trigger the migration of vector segments' max_seq_id from the pickled metadata file to SQLite.
Vector segments migrate this field automatically on init—so this should be used when we know segments are likely unmigrated and unloaded.
This is a no-op if all vector segments have already migrated their max_seq_id.
"""
with db.tx() as cur:
cur.execute(
"""
SELECT collection
FROM "segments"
WHERE "id" NOT IN (SELECT "segment_id" FROM "max_seq_id") AND
"type" = 'urn:chroma:segment/vector/hnsw-local-persisted'
"""
)
collection_ids_with_unmigrated_segments = [row[0] for row in cur.fetchall()]
if len(collection_ids_with_unmigrated_segments) == 0:
return
for collection_id in collection_ids_with_unmigrated_segments:
# Loading the segment triggers the migration on init
segment_manager.get_segment(UUID(collection_id), VectorReader)

View File

@@ -0,0 +1,37 @@
version: 1
disable_existing_loggers: False
formatters:
default:
"()": uvicorn.logging.DefaultFormatter
format: '%(levelprefix)s [%(asctime)s] %(message)s'
use_colors: null
datefmt: '%d-%m-%Y %H:%M:%S'
access:
"()": uvicorn.logging.AccessFormatter
format: '%(levelprefix)s [%(asctime)s] %(client_addr)s - "%(request_line)s" %(status_code)s'
datefmt: '%d-%m-%Y %H:%M:%S'
handlers:
default:
formatter: default
class: logging.StreamHandler
stream: ext://sys.stderr
access:
formatter: access
class: logging.StreamHandler
stream: ext://sys.stdout
console:
class: logging.StreamHandler
stream: ext://sys.stdout
formatter: default
file:
class : logging.handlers.RotatingFileHandler
filename: chroma.log
formatter: default
loggers:
root:
level: WARN
handlers: [console, file]
chromadb:
level: DEBUG
uvicorn:
level: INFO

View File

@@ -0,0 +1,181 @@
import sys
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
import grpc
import time
from chromadb.ingest import (
Producer,
Consumer,
ConsumerCallbackFn,
)
from chromadb.proto.convert import to_proto_submit
from chromadb.proto.logservice_pb2 import PushLogsRequest, PullLogsRequest, LogRecord
from chromadb.proto.logservice_pb2_grpc import LogServiceStub
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from chromadb.types import (
OperationRecord,
SeqId,
)
from chromadb.config import System
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
add_attributes_to_current_span,
trace_method,
)
from overrides import override
from typing import Sequence, Optional, cast
from uuid import UUID
import logging
logger = logging.getLogger(__name__)
class LogService(Producer, Consumer):
"""
Distributed Chroma Log Service
"""
_log_service_stub: LogServiceStub
_request_timeout_seconds: int
_channel: grpc.Channel
_log_service_url: str
_log_service_port: int
def __init__(self, system: System):
self._log_service_url = system.settings.require("chroma_logservice_host")
self._log_service_port = system.settings.require("chroma_logservice_port")
self._request_timeout_seconds = system.settings.require(
"chroma_logservice_request_timeout_seconds"
)
self._opentelemetry_client = system.require(OpenTelemetryClient)
super().__init__(system)
@trace_method("LogService.start", OpenTelemetryGranularity.ALL)
@override
def start(self) -> None:
self._channel = grpc.insecure_channel(
f"{self._log_service_url}:{self._log_service_port}",
)
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
self._channel = grpc.intercept_channel(self._channel, *interceptors)
self._log_service_stub = LogServiceStub(self._channel) # type: ignore
super().start()
@trace_method("LogService.stop", OpenTelemetryGranularity.ALL)
@override
def stop(self) -> None:
self._channel.close()
super().stop()
@trace_method("LogService.reset_state", OpenTelemetryGranularity.ALL)
@override
def reset_state(self) -> None:
super().reset_state()
@trace_method("LogService.delete_log", OpenTelemetryGranularity.ALL)
@override
def delete_log(self, collection_id: UUID) -> None:
raise NotImplementedError("Not implemented")
@trace_method("LogService.purge_log", OpenTelemetryGranularity.ALL)
@override
def purge_log(self, collection_id: UUID) -> None:
raise NotImplementedError("Not implemented")
@trace_method("LogService.submit_embedding", OpenTelemetryGranularity.ALL)
@override
def submit_embedding(
self, collection_id: UUID, embedding: OperationRecord
) -> SeqId:
if not self._running:
raise RuntimeError("Component not running")
return self.submit_embeddings(collection_id, [embedding])[0]
@trace_method("LogService.submit_embeddings", OpenTelemetryGranularity.ALL)
@override
def submit_embeddings(
self, collection_id: UUID, embeddings: Sequence[OperationRecord]
) -> Sequence[SeqId]:
logger.info(
f"Submitting {len(embeddings)} embeddings to log for collection {collection_id}"
)
add_attributes_to_current_span(
{
"records_count": len(embeddings),
}
)
if not self._running:
raise RuntimeError("Component not running")
if len(embeddings) == 0:
return []
# push records to the log service
counts = []
protos_to_submit = [to_proto_submit(record) for record in embeddings]
counts.append(
self.push_logs(
collection_id,
cast(Sequence[OperationRecord], protos_to_submit),
)
)
# This returns counts, which is completely incorrect
# TODO: Fix this
return counts
@trace_method("LogService.subscribe", OpenTelemetryGranularity.ALL)
@override
def subscribe(
self,
collection_id: UUID,
consume_fn: ConsumerCallbackFn,
start: Optional[SeqId] = None,
end: Optional[SeqId] = None,
id: Optional[UUID] = None,
) -> UUID:
logger.info(f"Subscribing to log for {collection_id}, noop for logservice")
return UUID(int=0)
@trace_method("LogService.unsubscribe", OpenTelemetryGranularity.ALL)
@override
def unsubscribe(self, subscription_id: UUID) -> None:
logger.info(f"Unsubscribing from {subscription_id}, noop for logservice")
@override
def min_seqid(self) -> SeqId:
return 0
@override
def max_seqid(self) -> SeqId:
return sys.maxsize
@property
@override
def max_batch_size(self) -> int:
return 100
def push_logs(self, collection_id: UUID, records: Sequence[OperationRecord]) -> int:
request = PushLogsRequest(collection_id=str(collection_id), records=records)
response = self._log_service_stub.PushLogs(
request, timeout=self._request_timeout_seconds
)
return response.record_count # type: ignore
def pull_logs(
self, collection_id: UUID, start_offset: int, batch_size: int
) -> Sequence[LogRecord]:
request = PullLogsRequest(
collection_id=str(collection_id),
start_from_offset=start_offset,
batch_size=batch_size,
end_timestamp=time.time_ns(),
)
response = self._log_service_stub.PullLogs(
request, timeout=self._request_timeout_seconds
)
return response.records # type: ignore

View File

@@ -0,0 +1,10 @@
CREATE TABLE embeddings_queue (
seq_id INTEGER PRIMARY KEY,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
operation INTEGER NOT NULL,
topic TEXT NOT NULL,
id TEXT NOT NULL,
vector BLOB,
encoding TEXT,
metadata TEXT
);

View File

@@ -0,0 +1,4 @@
CREATE TABLE embeddings_queue_config (
id INTEGER PRIMARY KEY,
config_json_str TEXT
);

View File

@@ -0,0 +1,24 @@
CREATE TABLE embeddings (
id INTEGER PRIMARY KEY,
segment_id TEXT NOT NULL,
embedding_id TEXT NOT NULL,
seq_id BLOB NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE (segment_id, embedding_id)
);
CREATE TABLE embedding_metadata (
id INTEGER REFERENCES embeddings(id),
key TEXT NOT NULL,
string_value TEXT,
int_value INTEGER,
float_value REAL,
PRIMARY KEY (id, key)
);
CREATE TABLE max_seq_id (
segment_id TEXT PRIMARY KEY,
seq_id BLOB NOT NULL
);
CREATE VIRTUAL TABLE embedding_fulltext USING fts5(id, string_value);

View File

@@ -0,0 +1,5 @@
-- SQLite does not support adding check with alter table, as a result, adding a check
-- involve creating a new table and copying the data over. It is over kill with adding
-- a boolean type column. The application write to the table needs to ensure the data
-- integrity.
ALTER TABLE embedding_metadata ADD COLUMN bool_value INTEGER

View File

@@ -0,0 +1,3 @@
CREATE VIRTUAL TABLE embedding_fulltext_search USING fts5(string_value, tokenize='trigram');
INSERT INTO embedding_fulltext_search (rowid, string_value) SELECT rowid, string_value FROM embedding_metadata;
DROP TABLE embedding_fulltext;

View File

@@ -0,0 +1,3 @@
CREATE INDEX IF NOT EXISTS embedding_metadata_int_value ON embedding_metadata (key, int_value) WHERE int_value IS NOT NULL;
CREATE INDEX IF NOT EXISTS embedding_metadata_float_value ON embedding_metadata (key, float_value) WHERE float_value IS NOT NULL;
CREATE INDEX IF NOT EXISTS embedding_metadata_string_value ON embedding_metadata (key, string_value) WHERE string_value IS NOT NULL;

View File

@@ -0,0 +1,27 @@
ALTER TABLE max_seq_id ADD COLUMN int_seq_id INTEGER;
-- Convert 8 byte wide big-endian integer as blob to native 64 bit integer.
-- Adapted from https://stackoverflow.com/a/70296198.
UPDATE max_seq_id SET int_seq_id = (
SELECT (
(instr('123456789ABCDEF', substr(hex(seq_id), -1 , 1)) << 0)
| (instr('123456789ABCDEF', substr(hex(seq_id), -2 , 1)) << 4)
| (instr('123456789ABCDEF', substr(hex(seq_id), -3 , 1)) << 8)
| (instr('123456789ABCDEF', substr(hex(seq_id), -4 , 1)) << 12)
| (instr('123456789ABCDEF', substr(hex(seq_id), -5 , 1)) << 16)
| (instr('123456789ABCDEF', substr(hex(seq_id), -6 , 1)) << 20)
| (instr('123456789ABCDEF', substr(hex(seq_id), -7 , 1)) << 24)
| (instr('123456789ABCDEF', substr(hex(seq_id), -8 , 1)) << 28)
| (instr('123456789ABCDEF', substr(hex(seq_id), -9 , 1)) << 32)
| (instr('123456789ABCDEF', substr(hex(seq_id), -10, 1)) << 36)
| (instr('123456789ABCDEF', substr(hex(seq_id), -11, 1)) << 40)
| (instr('123456789ABCDEF', substr(hex(seq_id), -12, 1)) << 44)
| (instr('123456789ABCDEF', substr(hex(seq_id), -13, 1)) << 48)
| (instr('123456789ABCDEF', substr(hex(seq_id), -14, 1)) << 52)
| (instr('123456789ABCDEF', substr(hex(seq_id), -15, 1)) << 56)
| (instr('123456789ABCDEF', substr(hex(seq_id), -16, 1)) << 60)
)
);
ALTER TABLE max_seq_id DROP COLUMN seq_id;
ALTER TABLE max_seq_id RENAME COLUMN int_seq_id TO seq_id;

View File

@@ -0,0 +1,15 @@
CREATE TABLE collections (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
topic TEXT NOT NULL,
UNIQUE (name)
);
CREATE TABLE collection_metadata (
collection_id TEXT REFERENCES collections(id) ON DELETE CASCADE,
key TEXT NOT NULL,
str_value TEXT,
int_value INTEGER,
float_value REAL,
PRIMARY KEY (collection_id, key)
);

View File

@@ -0,0 +1,16 @@
CREATE TABLE segments (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
scope TEXT NOT NULL,
topic TEXT,
collection TEXT REFERENCES collection(id)
);
CREATE TABLE segment_metadata (
segment_id TEXT REFERENCES segments(id) ON DELETE CASCADE,
key TEXT NOT NULL,
str_value TEXT,
int_value INTEGER,
float_value REAL,
PRIMARY KEY (segment_id, key)
);

View File

@@ -0,0 +1 @@
ALTER TABLE collections ADD COLUMN dimension INTEGER;

View File

@@ -0,0 +1,29 @@
CREATE TABLE IF NOT EXISTS tenants (
id TEXT PRIMARY KEY,
UNIQUE (id)
);
CREATE TABLE IF NOT EXISTS databases (
id TEXT PRIMARY KEY, -- unique globally
name TEXT NOT NULL, -- unique per tenant
tenant_id TEXT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
UNIQUE (tenant_id, name) -- Ensure that a tenant has only one database with a given name
);
CREATE TABLE IF NOT EXISTS collections_tmp (
id TEXT PRIMARY KEY, -- unique globally
name TEXT NOT NULL, -- unique per database
topic TEXT NOT NULL,
dimension INTEGER,
database_id TEXT NOT NULL REFERENCES databases(id) ON DELETE CASCADE,
UNIQUE (name, database_id)
);
-- Create default tenant and database
INSERT OR REPLACE INTO tenants (id) VALUES ('default_tenant'); -- The default tenant id is 'default_tenant' others are UUIDs
INSERT OR REPLACE INTO databases (id, name, tenant_id) VALUES ('00000000-0000-0000-0000-000000000000', 'default_database', 'default_tenant');
INSERT OR REPLACE INTO collections_tmp (id, name, topic, dimension, database_id)
SELECT id, name, topic, dimension, '00000000-0000-0000-0000-000000000000' FROM collections;
DROP TABLE collections;
ALTER TABLE collections_tmp RENAME TO collections;

View File

@@ -0,0 +1,4 @@
-- Remove the topic column from the Collections and Segments tables
ALTER TABLE collections DROP COLUMN topic;
ALTER TABLE segments DROP COLUMN topic;

View File

@@ -0,0 +1,6 @@
-- SQLite does not support adding check with alter table, as a result, adding a check
-- involve creating a new table and copying the data over. It is over kill with adding
-- a boolean type column. The application write to the table needs to ensure the data
-- integrity.
ALTER TABLE collection_metadata ADD COLUMN bool_value INTEGER;
ALTER TABLE segment_metadata ADD COLUMN bool_value INTEGER;

View File

@@ -0,0 +1,2 @@
-- Stores collection configuration dictionaries.
ALTER TABLE collections ADD COLUMN config_json_str TEXT;

View File

@@ -0,0 +1,7 @@
-- Records when database maintenance operations are performed.
-- At time of creation, this table is only used to record vacuum operations.
CREATE TABLE maintenance_log (
id INT PRIMARY KEY,
timestamp INT NOT NULL,
operation TEXT NOT NULL
);

View File

@@ -0,0 +1,11 @@
-- This makes segments.collection non-nullable.
CREATE TABLE segments_temp (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
scope TEXT NOT NULL,
collection TEXT REFERENCES collection(id) NOT NULL
);
INSERT INTO segments_temp SELECT * FROM segments;
DROP TABLE segments;
ALTER TABLE segments_temp RENAME TO segments;

View File

@@ -0,0 +1,2 @@
*_pb2.py*
*_pb2_grpc.py

View File

@@ -0,0 +1,688 @@
from typing import Dict, Optional, Sequence, Tuple, TypedDict, Union, cast
from uuid import UUID
import json
import numpy as np
from numpy.typing import NDArray
import chromadb.proto.chroma_pb2 as chroma_pb
import chromadb.proto.query_executor_pb2 as query_pb
from chromadb.api.collection_configuration import (
collection_configuration_to_json_str,
)
from chromadb.api.types import Embedding, Where, WhereDocument
from chromadb.execution.expression.operator import (
KNN,
Filter,
Limit,
Projection,
Scan,
)
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.types import (
Collection,
LogRecord,
Metadata,
Operation,
OperationRecord,
RequestVersionContext,
ScalarEncoding,
Segment,
SegmentScope,
SeqId,
UpdateMetadata,
Vector,
VectorEmbeddingRecord,
VectorQueryResult,
)
class ProjectionRecord(TypedDict):
id: str
document: Optional[str]
embedding: Optional[Vector]
metadata: Optional[Metadata]
class KNNProjectionRecord(TypedDict):
record: ProjectionRecord
distance: Optional[float]
# TODO: Unit tests for this file, handling optional states etc
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> chroma_pb.Vector:
if encoding == ScalarEncoding.FLOAT32:
as_bytes = np.array(vector, dtype=np.float32).tobytes()
proto_encoding = chroma_pb.ScalarEncoding.FLOAT32
elif encoding == ScalarEncoding.INT32:
as_bytes = np.array(vector, dtype=np.int32).tobytes()
proto_encoding = chroma_pb.ScalarEncoding.INT32
else:
raise ValueError(
f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \
or {ScalarEncoding.INT32}"
)
return chroma_pb.Vector(
dimension=vector.size, vector=as_bytes, encoding=proto_encoding
)
def from_proto_vector(vector: chroma_pb.Vector) -> Tuple[Embedding, ScalarEncoding]:
encoding = vector.encoding
as_array: Union[NDArray[np.int32], NDArray[np.float32]]
if encoding == chroma_pb.ScalarEncoding.FLOAT32:
as_array = np.frombuffer(vector.vector, dtype=np.float32)
out_encoding = ScalarEncoding.FLOAT32
elif encoding == chroma_pb.ScalarEncoding.INT32:
as_array = np.frombuffer(vector.vector, dtype=np.int32)
out_encoding = ScalarEncoding.INT32
else:
raise ValueError(
f"Unknown encoding {encoding}, expected one of \
{chroma_pb.ScalarEncoding.FLOAT32} or {chroma_pb.ScalarEncoding.INT32}"
)
return (as_array, out_encoding)
def from_proto_operation(operation: chroma_pb.Operation) -> Operation:
if operation == chroma_pb.Operation.ADD:
return Operation.ADD
elif operation == chroma_pb.Operation.UPDATE:
return Operation.UPDATE
elif operation == chroma_pb.Operation.UPSERT:
return Operation.UPSERT
elif operation == chroma_pb.Operation.DELETE:
return Operation.DELETE
else:
# TODO: full error
raise RuntimeError(f"Unknown operation {operation}")
def from_proto_metadata(metadata: chroma_pb.UpdateMetadata) -> Optional[Metadata]:
return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False))
def from_proto_update_metadata(
metadata: chroma_pb.UpdateMetadata,
) -> Optional[UpdateMetadata]:
return cast(
Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True)
)
def _from_proto_metadata_handle_none(
metadata: chroma_pb.UpdateMetadata, is_update: bool
) -> Optional[Union[UpdateMetadata, Metadata]]:
if not metadata.metadata:
return None
out_metadata: Dict[str, Union[str, int, float, bool, None]] = {}
for key, value in metadata.metadata.items():
if value.HasField("bool_value"):
out_metadata[key] = value.bool_value
elif value.HasField("string_value"):
out_metadata[key] = value.string_value
elif value.HasField("int_value"):
out_metadata[key] = value.int_value
elif value.HasField("float_value"):
out_metadata[key] = value.float_value
elif is_update:
out_metadata[key] = None
else:
raise ValueError(f"Metadata key {key} value cannot be None")
return out_metadata
def to_proto_update_metadata(metadata: UpdateMetadata) -> chroma_pb.UpdateMetadata:
return chroma_pb.UpdateMetadata(
metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()}
)
def from_proto_submit(
operation_record: chroma_pb.OperationRecord, seq_id: SeqId
) -> LogRecord:
embedding, encoding = from_proto_vector(operation_record.vector)
record = LogRecord(
log_offset=seq_id,
record=OperationRecord(
id=operation_record.id,
embedding=embedding,
encoding=encoding,
metadata=from_proto_update_metadata(operation_record.metadata),
operation=from_proto_operation(operation_record.operation),
),
)
return record
def from_proto_segment(segment: chroma_pb.Segment) -> Segment:
return Segment(
id=UUID(hex=segment.id),
type=segment.type,
scope=from_proto_segment_scope(segment.scope),
collection=UUID(hex=segment.collection),
metadata=from_proto_metadata(segment.metadata)
if segment.HasField("metadata")
else None,
file_paths={
name: [path for path in paths.paths]
for name, paths in segment.file_paths.items()
},
)
def to_proto_segment(segment: Segment) -> chroma_pb.Segment:
return chroma_pb.Segment(
id=segment["id"].hex,
type=segment["type"],
scope=to_proto_segment_scope(segment["scope"]),
collection=segment["collection"].hex,
metadata=None
if segment["metadata"] is None
else to_proto_update_metadata(segment["metadata"]),
file_paths={
name: chroma_pb.FilePaths(paths=paths)
for name, paths in segment["file_paths"].items()
},
)
def from_proto_segment_scope(segment_scope: chroma_pb.SegmentScope) -> SegmentScope:
if segment_scope == chroma_pb.SegmentScope.VECTOR:
return SegmentScope.VECTOR
elif segment_scope == chroma_pb.SegmentScope.METADATA:
return SegmentScope.METADATA
elif segment_scope == chroma_pb.SegmentScope.RECORD:
return SegmentScope.RECORD
else:
raise RuntimeError(f"Unknown segment scope {segment_scope}")
def to_proto_segment_scope(segment_scope: SegmentScope) -> chroma_pb.SegmentScope:
if segment_scope == SegmentScope.VECTOR:
return chroma_pb.SegmentScope.VECTOR
elif segment_scope == SegmentScope.METADATA:
return chroma_pb.SegmentScope.METADATA
elif segment_scope == SegmentScope.RECORD:
return chroma_pb.SegmentScope.RECORD
else:
raise RuntimeError(f"Unknown segment scope {segment_scope}")
def to_proto_metadata_update_value(
value: Union[str, int, float, bool, None]
) -> chroma_pb.UpdateMetadataValue:
# Be careful with the order here. Since bools are a subtype of int in python,
# isinstance(value, bool) and isinstance(value, int) both return true
# for a value of bool type.
if isinstance(value, bool):
return chroma_pb.UpdateMetadataValue(bool_value=value)
elif isinstance(value, str):
return chroma_pb.UpdateMetadataValue(string_value=value)
elif isinstance(value, int):
return chroma_pb.UpdateMetadataValue(int_value=value)
elif isinstance(value, float):
return chroma_pb.UpdateMetadataValue(float_value=value)
# None is used to delete the metadata key.
elif value is None:
return chroma_pb.UpdateMetadataValue()
else:
raise ValueError(
f"Unknown metadata value type {type(value)}, expected one of str, int, \
float, or None"
)
def from_proto_collection(collection: chroma_pb.Collection) -> Collection:
return Collection(
id=UUID(hex=collection.id),
name=collection.name,
configuration_json=json.loads(collection.configuration_json_str),
metadata=from_proto_metadata(collection.metadata)
if collection.HasField("metadata")
else None,
dimension=collection.dimension
if collection.HasField("dimension") and collection.dimension
else None,
database=collection.database,
tenant=collection.tenant,
version=collection.version,
log_position=collection.log_position,
)
def to_proto_collection(collection: Collection) -> chroma_pb.Collection:
return chroma_pb.Collection(
id=collection["id"].hex,
name=collection["name"],
configuration_json_str=collection_configuration_to_json_str(
collection.get_configuration()
),
metadata=None
if collection["metadata"] is None
else to_proto_update_metadata(collection["metadata"]),
dimension=collection["dimension"],
tenant=collection["tenant"],
database=collection["database"],
log_position=collection["log_position"],
version=collection["version"],
)
def to_proto_operation(operation: Operation) -> chroma_pb.Operation:
if operation == Operation.ADD:
return chroma_pb.Operation.ADD
elif operation == Operation.UPDATE:
return chroma_pb.Operation.UPDATE
elif operation == Operation.UPSERT:
return chroma_pb.Operation.UPSERT
elif operation == Operation.DELETE:
return chroma_pb.Operation.DELETE
else:
raise ValueError(
f"Unknown operation {operation}, expected one of {Operation.ADD}, \
{Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}"
)
def to_proto_submit(
submit_record: OperationRecord,
) -> chroma_pb.OperationRecord:
vector = None
if submit_record["embedding"] is not None and submit_record["encoding"] is not None:
vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"])
metadata = None
if submit_record["metadata"] is not None:
metadata = to_proto_update_metadata(submit_record["metadata"])
return chroma_pb.OperationRecord(
id=submit_record["id"],
vector=vector,
metadata=metadata,
operation=to_proto_operation(submit_record["operation"]),
)
def from_proto_vector_embedding_record(
embedding_record: chroma_pb.VectorEmbeddingRecord,
) -> VectorEmbeddingRecord:
return VectorEmbeddingRecord(
id=embedding_record.id,
embedding=from_proto_vector(embedding_record.vector)[0],
)
def to_proto_vector_embedding_record(
embedding_record: VectorEmbeddingRecord,
encoding: ScalarEncoding,
) -> chroma_pb.VectorEmbeddingRecord:
return chroma_pb.VectorEmbeddingRecord(
id=embedding_record["id"],
vector=to_proto_vector(embedding_record["embedding"], encoding),
)
def from_proto_vector_query_result(
vector_query_result: chroma_pb.VectorQueryResult,
) -> VectorQueryResult:
return VectorQueryResult(
id=vector_query_result.id,
distance=vector_query_result.distance,
embedding=from_proto_vector(vector_query_result.vector)[0],
)
def from_proto_request_version_context(
request_version_context: chroma_pb.RequestVersionContext,
) -> RequestVersionContext:
return RequestVersionContext(
collection_version=request_version_context.collection_version,
log_position=request_version_context.log_position,
)
def to_proto_request_version_context(
request_version_context: RequestVersionContext,
) -> chroma_pb.RequestVersionContext:
return chroma_pb.RequestVersionContext(
collection_version=request_version_context["collection_version"],
log_position=request_version_context["log_position"],
)
def to_proto_where(where: Where) -> chroma_pb.Where:
response = chroma_pb.Where()
if len(where) != 1:
raise ValueError(f"Expected where to have exactly one operator, got {where}")
for key, value in where.items():
if not isinstance(key, str):
raise ValueError(f"Expected where key to be a str, got {key}")
if key == "$and" or key == "$or":
if not isinstance(value, list):
raise ValueError(
f"Expected where value for $and or $or to be a list of where expressions, got {value}"
)
children: chroma_pb.WhereChildren = chroma_pb.WhereChildren(
children=[to_proto_where(w) for w in value]
)
if key == "$and":
children.operator = chroma_pb.BooleanOperator.AND
else:
children.operator = chroma_pb.BooleanOperator.OR
response.children.CopyFrom(children)
return response
# At this point we know we're at a direct comparison. It can either
# be of the form {"key": "value"} or {"key": {"$operator": "value"}}.
dc = chroma_pb.DirectComparison()
dc.key = key
if not isinstance(value, dict):
# {'key': 'value'} case
if type(value) is str:
ssc = chroma_pb.SingleStringComparison()
ssc.value = value
ssc.comparator = chroma_pb.GenericComparator.EQ
dc.single_string_operand.CopyFrom(ssc)
elif type(value) is bool:
sbc = chroma_pb.SingleBoolComparison()
sbc.value = value
sbc.comparator = chroma_pb.GenericComparator.EQ
dc.single_bool_operand.CopyFrom(sbc)
elif type(value) is int:
sic = chroma_pb.SingleIntComparison()
sic.value = value
sic.generic_comparator = chroma_pb.GenericComparator.EQ
dc.single_int_operand.CopyFrom(sic)
elif type(value) is float:
sdc = chroma_pb.SingleDoubleComparison()
sdc.value = value
sdc.generic_comparator = chroma_pb.GenericComparator.EQ
dc.single_double_operand.CopyFrom(sdc)
else:
raise ValueError(
f"Expected where value to be a string, int, or float, got {value}"
)
else:
for operator, operand in value.items():
if operator in ["$in", "$nin"]:
if not isinstance(operand, list):
raise ValueError(
f"Expected where value for $in or $nin to be a list of values, got {value}"
)
if len(operand) == 0 or not all(
isinstance(x, type(operand[0])) for x in operand
):
raise ValueError(
f"Expected where operand value to be a non-empty list, and all values to be of the same type "
f"got {operand}"
)
list_operator = None
if operator == "$in":
list_operator = chroma_pb.ListOperator.IN
else:
list_operator = chroma_pb.ListOperator.NIN
if type(operand[0]) is str:
slo = chroma_pb.StringListComparison()
for x in operand:
slo.values.extend([x]) # type: ignore
slo.list_operator = list_operator
dc.string_list_operand.CopyFrom(slo)
elif type(operand[0]) is bool:
blo = chroma_pb.BoolListComparison()
for x in operand:
blo.values.extend([x]) # type: ignore
blo.list_operator = list_operator
dc.bool_list_operand.CopyFrom(blo)
elif type(operand[0]) is int:
ilo = chroma_pb.IntListComparison()
for x in operand:
ilo.values.extend([x]) # type: ignore
ilo.list_operator = list_operator
dc.int_list_operand.CopyFrom(ilo)
elif type(operand[0]) is float:
dlo = chroma_pb.DoubleListComparison()
for x in operand:
dlo.values.extend([x]) # type: ignore
dlo.list_operator = list_operator
dc.double_list_operand.CopyFrom(dlo)
else:
raise ValueError(
f"Expected where operand value to be a list of strings, ints, or floats, got {operand}"
)
elif operator in ["$eq", "$ne", "$gt", "$lt", "$gte", "$lte"]:
# Direct comparison to a single value.
if type(operand) is str:
ssc = chroma_pb.SingleStringComparison()
ssc.value = operand
if operator == "$eq":
ssc.comparator = chroma_pb.GenericComparator.EQ
elif operator == "$ne":
ssc.comparator = chroma_pb.GenericComparator.NE
else:
raise ValueError(
f"Expected where operator to be $eq or $ne, got {operator}"
)
dc.single_string_operand.CopyFrom(ssc)
elif type(operand) is bool:
sbc = chroma_pb.SingleBoolComparison()
sbc.value = operand
if operator == "$eq":
sbc.comparator = chroma_pb.GenericComparator.EQ
elif operator == "$ne":
sbc.comparator = chroma_pb.GenericComparator.NE
else:
raise ValueError(
f"Expected where operator to be $eq or $ne, got {operator}"
)
dc.single_bool_operand.CopyFrom(sbc)
elif type(operand) is int:
sic = chroma_pb.SingleIntComparison()
sic.value = operand
if operator == "$eq":
sic.generic_comparator = chroma_pb.GenericComparator.EQ
elif operator == "$ne":
sic.generic_comparator = chroma_pb.GenericComparator.NE
elif operator == "$gt":
sic.number_comparator = chroma_pb.NumberComparator.GT
elif operator == "$lt":
sic.number_comparator = chroma_pb.NumberComparator.LT
elif operator == "$gte":
sic.number_comparator = chroma_pb.NumberComparator.GTE
elif operator == "$lte":
sic.number_comparator = chroma_pb.NumberComparator.LTE
else:
raise ValueError(
f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}"
)
dc.single_int_operand.CopyFrom(sic)
elif type(operand) is float:
sfc = chroma_pb.SingleDoubleComparison()
sfc.value = operand
if operator == "$eq":
sfc.generic_comparator = chroma_pb.GenericComparator.EQ
elif operator == "$ne":
sfc.generic_comparator = chroma_pb.GenericComparator.NE
elif operator == "$gt":
sfc.number_comparator = chroma_pb.NumberComparator.GT
elif operator == "$lt":
sfc.number_comparator = chroma_pb.NumberComparator.LT
elif operator == "$gte":
sfc.number_comparator = chroma_pb.NumberComparator.GTE
elif operator == "$lte":
sfc.number_comparator = chroma_pb.NumberComparator.LTE
else:
raise ValueError(
f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}"
)
dc.single_double_operand.CopyFrom(sfc)
else:
raise ValueError(
f"Expected where operand value to be a string, int, or float, got {operand}"
)
else:
# This case should never happen, as we've already
# handled the case for direct comparisons.
pass
response.direct_comparison.CopyFrom(dc)
return response
def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDocument:
response = chroma_pb.WhereDocument()
if len(where_document) != 1:
raise ValueError(
f"Expected where_document to have exactly one operator, got {where_document}"
)
for operator, operand in where_document.items():
if operator == "$and" or operator == "$or":
# Nested "$and" or "$or" expression.
if not isinstance(operand, list):
raise ValueError(
f"Expected where_document value for $and or $or to be a list of where_document expressions, got {operand}"
)
children: chroma_pb.WhereDocumentChildren = chroma_pb.WhereDocumentChildren(
children=[to_proto_where_document(w) for w in operand]
)
if operator == "$and":
children.operator = chroma_pb.BooleanOperator.AND
else:
children.operator = chroma_pb.BooleanOperator.OR
response.children.CopyFrom(children)
else:
# Direct "$contains" or "$not_contains" comparison to a single
# value.
if not isinstance(operand, str):
raise ValueError(
f"Expected where_document operand to be a string, got {operand}"
)
dwd = chroma_pb.DirectWhereDocument()
dwd.document = operand
if operator == "$contains":
dwd.operator = chroma_pb.WhereDocumentOperator.CONTAINS
elif operator == "$not_contains":
dwd.operator = chroma_pb.WhereDocumentOperator.NOT_CONTAINS
else:
raise ValueError(
f"Expected where_document operator to be one of $contains, $not_contains, got {operator}"
)
response.direct.CopyFrom(dwd)
return response
def to_proto_scan(scan: Scan) -> query_pb.ScanOperator:
return query_pb.ScanOperator(
collection=to_proto_collection(scan.collection),
knn=to_proto_segment(scan.knn),
metadata=to_proto_segment(scan.metadata),
record=to_proto_segment(scan.record),
)
def to_proto_filter(filter: Filter) -> query_pb.FilterOperator:
return query_pb.FilterOperator(
ids=chroma_pb.UserIds(ids=filter.user_ids)
if filter.user_ids is not None
else None,
where=to_proto_where(filter.where) if filter.where else None,
where_document=to_proto_where_document(filter.where_document)
if filter.where_document
else None,
)
def to_proto_knn(knn: KNN) -> query_pb.KNNOperator:
return query_pb.KNNOperator(
embeddings=[
to_proto_vector(vector=embedding, encoding=ScalarEncoding.FLOAT32)
for embedding in knn.embeddings
],
fetch=knn.fetch,
)
def to_proto_limit(limit: Limit) -> query_pb.LimitOperator:
return query_pb.LimitOperator(offset=limit.offset, limit=limit.limit)
def to_proto_projection(projection: Projection) -> query_pb.ProjectionOperator:
return query_pb.ProjectionOperator(
document=projection.document,
embedding=projection.embedding,
metadata=projection.metadata,
)
def to_proto_knn_projection(projection: Projection) -> query_pb.KNNProjectionOperator:
return query_pb.KNNProjectionOperator(
projection=to_proto_projection(projection), distance=projection.rank
)
def to_proto_count_plan(count: CountPlan) -> query_pb.CountPlan:
return query_pb.CountPlan(scan=to_proto_scan(count.scan))
def from_proto_count_result(result: query_pb.CountResult) -> int:
return result.count
def to_proto_get_plan(get: GetPlan) -> query_pb.GetPlan:
return query_pb.GetPlan(
scan=to_proto_scan(get.scan),
filter=to_proto_filter(get.filter),
limit=to_proto_limit(get.limit),
projection=to_proto_projection(get.projection),
)
def from_proto_projection_record(record: query_pb.ProjectionRecord) -> ProjectionRecord:
return ProjectionRecord(
id=record.id,
document=record.document if record.document else None,
embedding=from_proto_vector(record.embedding)[0]
if record.embedding is not None
else None,
metadata=from_proto_metadata(record.metadata),
)
def from_proto_get_result(result: query_pb.GetResult) -> Sequence[ProjectionRecord]:
return [from_proto_projection_record(record) for record in result.records]
def to_proto_knn_plan(knn: KNNPlan) -> query_pb.KNNPlan:
return query_pb.KNNPlan(
scan=to_proto_scan(knn.scan),
filter=to_proto_filter(knn.filter),
knn=to_proto_knn(knn.knn),
projection=to_proto_knn_projection(knn.projection),
)
def from_proto_knn_projection_record(
record: query_pb.KNNProjectionRecord,
) -> KNNProjectionRecord:
return KNNProjectionRecord(
record=from_proto_projection_record(record.record), distance=record.distance
)
def from_proto_knn_batch_result(
results: query_pb.KNNBatchResult,
) -> Sequence[Sequence[KNNProjectionRecord]]:
return [
[from_proto_knn_projection_record(record) for record in result.records]
for result in results.results
]

Some files were not shown because too many files have changed in this diff Show More