from typing import TYPE_CHECKING from chromadb import ( CollectionMetadata, Embeddings, GetResult, IDs, Where, WhereDocument, Include, Documents, Metadatas, QueryResult, URIs, ) from chromadb.api import ServerAPI if TYPE_CHECKING: from chromadb.api.models.AttachedFunction import AttachedFunction from chromadb.api.collection_configuration import ( CreateCollectionConfiguration, UpdateCollectionConfiguration, create_collection_configuration_to_json_str, update_collection_configuration_to_json_str, ) from chromadb.auth import UserIdentity from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.telemetry.product import ProductTelemetryClient from chromadb.telemetry.product.events import ( CollectionAddEvent, CollectionDeleteEvent, CollectionGetEvent, CollectionUpdateEvent, CollectionQueryEvent, ClientCreateCollectionEvent, ) from chromadb.api.types import ( IncludeMetadataDocuments, IncludeMetadataDocumentsDistances, IncludeMetadataDocumentsEmbeddings, Schema, SearchResult, ) # TODO(hammadb): Unify imports across types vs root __init__.py from chromadb.types import Database, Tenant, Collection as CollectionModel from chromadb.execution.expression.plan import Search import chromadb_rust_bindings from typing import Optional, Sequence, List, Dict, Any, Tuple from overrides import override from uuid import UUID import json import platform if platform.system() != "Windows": import resource elif platform.system() == "Windows": import ctypes # RustBindingsAPI is an implementation of ServerAPI which shims # the Rust bindings to the Python API, providing a full implementation # of the API. It could be that bindings was a direct implementation of # ServerAPI, but in order to prevent propagating the bindings types # into the Python API, we have to shim it here so we can convert into # the legacy Python types. # TODO(hammadb): Propagate the types from the bindings into the Python API # and remove the python-level types entirely. class RustBindingsAPI(ServerAPI): bindings: chromadb_rust_bindings.Bindings hnsw_cache_size: int product_telemetry_client: ProductTelemetryClient def __init__(self, system: System): super().__init__(system) self.product_telemetry_client = self.require(ProductTelemetryClient) if platform.system() != "Windows": max_file_handles = resource.getrlimit(resource.RLIMIT_NOFILE)[0] else: max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore self.hnsw_cache_size = ( max_file_handles # This is integer division in Python 3, and not a comment. # Each HNSW index has 4 data files and 1 metadata file // 5 ) @override def start(self) -> None: # Construct the SqliteConfig # TOOD: We should add a "config converter" if self._system.settings.require("is_persistent"): persist_path = self._system.settings.require("persist_directory") sqlite_persist_path = persist_path + "/chroma.sqlite3" else: persist_path = None sqlite_persist_path = None hash_type = self._system.settings.require("migrations_hash_algorithm") hash_type_bindings = ( chromadb_rust_bindings.MigrationHash.MD5 if hash_type == "md5" else chromadb_rust_bindings.MigrationHash.SHA256 ) migration_mode = self._system.settings.require("migrations") migration_mode_bindings = ( chromadb_rust_bindings.MigrationMode.Apply if migration_mode == "apply" else chromadb_rust_bindings.MigrationMode.Validate ) sqlite_config = chromadb_rust_bindings.SqliteDBConfig( hash_type=hash_type_bindings, migration_mode=migration_mode_bindings, url=sqlite_persist_path, ) self.bindings = chromadb_rust_bindings.Bindings( allow_reset=self._system.settings.require("allow_reset"), sqlite_db_config=sqlite_config, persist_path=persist_path, hnsw_cache_size=self.hnsw_cache_size, ) @override def stop(self) -> None: del self.bindings # ////////////////////////////// Admin API ////////////////////////////// @override def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: return self.bindings.create_database(name, tenant) @override def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: database = self.bindings.get_database(name, tenant) return { "id": database.id, "name": database.name, "tenant": database.tenant, } @override def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: return self.bindings.delete_database(name, tenant) @override def list_databases( self, limit: Optional[int] = None, offset: Optional[int] = None, tenant: str = DEFAULT_TENANT, ) -> Sequence[Database]: databases = self.bindings.list_databases(limit, offset, tenant) return [ { "id": database.id, "name": database.name, "tenant": database.tenant, } for database in databases ] @override def create_tenant(self, name: str) -> None: return self.bindings.create_tenant(name) @override def get_tenant(self, name: str) -> Tenant: tenant = self.bindings.get_tenant(name) return Tenant(name=tenant.name) # ////////////////////////////// Base API ////////////////////////////// @override def heartbeat(self) -> int: return self.bindings.heartbeat() @override def count_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> int: return self.bindings.count_collections(tenant, database) @override def list_collections( self, limit: Optional[int] = None, offset: Optional[int] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Sequence[CollectionModel]: collections = self.bindings.list_collections(limit, offset, tenant, database) return [ CollectionModel( id=collection.id, name=collection.name, serialized_schema=collection.schema, configuration_json=collection.configuration, metadata=collection.metadata, dimension=collection.dimension, tenant=collection.tenant, database=collection.database, ) for collection in collections ] @override def create_collection( self, name: str, schema: Optional[Schema] = None, configuration: Optional[CreateCollectionConfiguration] = None, metadata: Optional[CollectionMetadata] = None, get_or_create: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: # TODO: This event doesn't capture the get_or_create case appropriately # TODO: Re-enable embedding function tracking in create_collection self.product_telemetry_client.capture( ClientCreateCollectionEvent( collection_uuid=str(id), # embedding_function=embedding_function.__class__.__name__, ) ) if configuration: configuration_json_str = create_collection_configuration_to_json_str( configuration, metadata ) else: configuration_json_str = None if schema: schema_str = json.dumps(schema.serialize_to_json()) else: schema_str = None collection = self.bindings.create_collection( name, configuration_json_str, schema_str, metadata, get_or_create, tenant, database, ) collection_model = CollectionModel( id=collection.id, name=collection.name, configuration_json=collection.configuration, serialized_schema=collection.schema, metadata=collection.metadata, dimension=collection.dimension, tenant=collection.tenant, database=collection.database, ) return collection_model @override def get_collection( self, name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: collection = self.bindings.get_collection(name, tenant, database) return CollectionModel( id=collection.id, name=collection.name, configuration_json=collection.configuration, serialized_schema=collection.schema, metadata=collection.metadata, dimension=collection.dimension, tenant=collection.tenant, database=collection.database, ) @override def get_or_create_collection( self, name: str, schema: Optional[Schema] = None, configuration: Optional[CreateCollectionConfiguration] = None, metadata: Optional[CollectionMetadata] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: return self.create_collection( name, schema, configuration, metadata, True, tenant, database ) @override def delete_collection( self, name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: self.bindings.delete_collection(name, tenant, database) @override def _modify( self, id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, new_configuration: Optional[UpdateCollectionConfiguration] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: if new_configuration: new_configuration_json_str = update_collection_configuration_to_json_str( new_configuration ) else: new_configuration_json_str = None self.bindings.update_collection( str(id), new_name, new_metadata, new_configuration_json_str ) @override def _fork( self, collection_id: UUID, new_name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: raise NotImplementedError( "Collection forking is not implemented for Local Chroma" ) @override def _search( self, collection_id: UUID, searches: List[Search], tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> SearchResult: raise NotImplementedError("Search is not implemented for Local Chroma") @override def _count( self, collection_id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> int: return self.bindings.count(str(collection_id), tenant, database) @override def _peek( self, collection_id: UUID, n: int = 10, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> GetResult: return self._get( collection_id, limit=n, tenant=tenant, database=database, include=IncludeMetadataDocumentsEmbeddings, ) @override def _get( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = None, limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, include: Include = IncludeMetadataDocuments, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> GetResult: ids_amount = len(ids) if ids else 0 self.product_telemetry_client.capture( CollectionGetEvent( collection_uuid=str(collection_id), ids_count=ids_amount, limit=limit if limit else 0, include_metadata=ids_amount if "metadatas" in include else 0, include_documents=ids_amount if "documents" in include else 0, include_uris=ids_amount if "uris" in include else 0, ) ) rust_response = self.bindings.get( str(collection_id), ids, json.dumps(where) if where else None, limit, offset or 0, json.dumps(where_document) if where_document else None, include, tenant, database, ) return GetResult( ids=rust_response.ids, embeddings=rust_response.embeddings, documents=rust_response.documents, uris=rust_response.uris, included=include, data=None, metadatas=rust_response.metadatas, ) @override def _add( self, ids: IDs, collection_id: UUID, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: self.product_telemetry_client.capture( CollectionAddEvent( collection_uuid=str(collection_id), add_amount=len(ids), with_metadata=len(ids) if metadatas is not None else 0, with_documents=len(ids) if documents is not None else 0, with_uris=len(ids) if uris is not None else 0, ) ) return self.bindings.add( ids, str(collection_id), embeddings, metadatas, documents, uris, tenant, database, ) @override def _update( self, collection_id: UUID, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: self.product_telemetry_client.capture( CollectionUpdateEvent( collection_uuid=str(collection_id), update_amount=len(ids), with_embeddings=len(embeddings) if embeddings else 0, with_metadata=len(metadatas) if metadatas else 0, with_documents=len(documents) if documents else 0, with_uris=len(uris) if uris else 0, ) ) return self.bindings.update( str(collection_id), ids, embeddings, metadatas, documents, uris, tenant, database, ) @override def _upsert( self, collection_id: UUID, ids: IDs, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: return self.bindings.upsert( str(collection_id), ids, embeddings, metadatas, documents, uris, tenant, database, ) @override def _query( self, collection_id: UUID, query_embeddings: Embeddings, ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, include: Include = IncludeMetadataDocumentsDistances, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> QueryResult: query_amount = len(query_embeddings) filtered_ids_amount = len(ids) if ids else 0 self.product_telemetry_client.capture( CollectionQueryEvent( collection_uuid=str(collection_id), query_amount=query_amount, filtered_ids_amount=filtered_ids_amount, n_results=n_results, with_metadata_filter=query_amount if where is not None else 0, with_document_filter=query_amount if where_document is not None else 0, include_metadatas=query_amount if "metadatas" in include else 0, include_documents=query_amount if "documents" in include else 0, include_uris=query_amount if "uris" in include else 0, include_distances=query_amount if "distances" in include else 0, ) ) rust_response = self.bindings.query( str(collection_id), ids, query_embeddings, n_results, json.dumps(where) if where else None, json.dumps(where_document) if where_document else None, include, tenant, database, ) return QueryResult( ids=rust_response.ids, embeddings=rust_response.embeddings, documents=rust_response.documents, uris=rust_response.uris, included=include, data=None, metadatas=rust_response.metadatas, distances=rust_response.distances, ) @override def _delete( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: self.product_telemetry_client.capture( CollectionDeleteEvent( # NOTE: the delete amount is not observable from python # TODO: Fix this when posthog is pushed into Rust frontend collection_uuid=str(collection_id), delete_amount=0, ) ) return self.bindings.delete( str(collection_id), ids, json.dumps(where) if where else None, json.dumps(where_document) if where_document else None, tenant, database, ) @override def reset(self) -> bool: return self.bindings.reset() @override def get_version(self) -> str: return self.bindings.get_version() @override def get_settings(self) -> Settings: return self._system.settings @override def get_max_batch_size(self) -> int: return self.bindings.get_max_batch_size() @override def attach_function( self, function_id: str, name: str, input_collection_id: UUID, output_collection: str, params: Optional[Dict[str, Any]] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Tuple["AttachedFunction", bool]: """Attached functions are not supported in the Rust bindings (local embedded mode).""" raise NotImplementedError( "Attached functions are only supported when connecting to a Chroma server via HttpClient. " "The Rust bindings (embedded mode) do not support attached function operations." ) @override def get_attached_function( self, name: str, input_collection_id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> "AttachedFunction": """Attached functions are not supported in the Rust bindings (local embedded mode).""" raise NotImplementedError( "Attached functions are only supported when connecting to a Chroma server via HttpClient. " "The Rust bindings (embedded mode) do not support attached function operations." ) @override def detach_function( self, name: str, input_collection_id: UUID, delete_output: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: """Attached functions are not supported in the Rust bindings (local embedded mode).""" raise NotImplementedError( "Attached functions are only supported when connecting to a Chroma server via HttpClient. " "The Rust bindings (embedded mode) do not support attached function operations." ) # TODO: Remove this if it's not planned to be used @override def get_user_identity(self) -> UserIdentity: return UserIdentity( user_id="", tenant=DEFAULT_TENANT, databases=[DEFAULT_DATABASE], )