import orjson import logging from typing import Any, Dict, Mapping, Optional, cast, Tuple, List from typing import Sequence from uuid import UUID import httpx import urllib.parse from overrides import override from chromadb.api.models.AttachedFunction import AttachedFunction from chromadb.api.collection_configuration import ( CreateCollectionConfiguration, UpdateCollectionConfiguration, update_collection_configuration_to_json, create_collection_configuration_to_json, ) from chromadb import __version__ from chromadb.api.base_http_client import BaseHTTPClient from chromadb.types import Database, Tenant, Collection as CollectionModel from chromadb.api import ServerAPI from chromadb.execution.expression.plan import Search from chromadb.api.types import ( Documents, Embeddings, IDs, Include, Schema, Metadatas, URIs, Where, WhereDocument, GetResult, QueryResult, SearchResult, CollectionMetadata, validate_batch, convert_np_embeddings_to_list, IncludeMetadataDocuments, IncludeMetadataDocumentsDistances, ) from chromadb.api.types import ( IncludeMetadataDocumentsEmbeddings, optional_embeddings_to_base64_strings, serialize_metadata, deserialize_metadata, ) from chromadb.auth import UserIdentity from chromadb.auth import ( ClientAuthProvider, ) from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.telemetry.opentelemetry import ( OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) from chromadb.telemetry.product import ProductTelemetryClient logger = logging.getLogger(__name__) class FastAPI(BaseHTTPClient, ServerAPI): def __init__(self, system: System): super().__init__(system) system.settings.require("chroma_server_host") system.settings.require("chroma_server_http_port") self._opentelemetry_client = self.require(OpenTelemetryClient) self._product_telemetry_client = self.require(ProductTelemetryClient) self._settings = system.settings self._api_url = FastAPI.resolve_url( chroma_server_host=str(system.settings.chroma_server_host), chroma_server_http_port=system.settings.chroma_server_http_port, chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled, default_api_path=system.settings.chroma_server_api_default_path, ) if self._settings.chroma_server_ssl_verify is not None: self._session = httpx.Client( timeout=None, limits=self.http_limits, verify=self._settings.chroma_server_ssl_verify, ) else: self._session = httpx.Client(timeout=None, limits=self.http_limits) self._header = system.settings.chroma_server_headers or {} self._header["Content-Type"] = "application/json" self._header["User-Agent"] = ( "Chroma Python Client v" + __version__ + " (https://github.com/chroma-core/chroma)" ) if self._header is not None: self._session.headers.update(self._header) if system.settings.chroma_client_auth_provider: self._auth_provider = self.require(ClientAuthProvider) _headers = self._auth_provider.authenticate() for header, value in _headers.items(): self._session.headers[header] = value.get_secret_value() @override def get_request_headers(self) -> Mapping[str, str]: return dict(self._session.headers) @override def get_api_url(self) -> str: return self._api_url def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any: # If the request has json in kwargs, use orjson to serialize it, # remove it from kwargs, and add it to the content parameter # This is because httpx uses a slower json serializer if "json" in kwargs: data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY) kwargs["content"] = data # Unlike requests, httpx does not automatically escape the path escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) url = self._api_url + escaped_path response = self._session.request(method, url, **cast(Any, kwargs)) BaseHTTPClient._raise_chroma_error(response) return orjson.loads(response.text) @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) @override def heartbeat(self) -> int: """Returns the current server time in nanoseconds to check if the server is alive""" resp_json = self._make_request("get", "/heartbeat") return int(resp_json["nanosecond heartbeat"]) # Migrated to rust in distributed. @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION) @override def create_database( self, name: str, tenant: str = DEFAULT_TENANT, ) -> None: """Creates a database""" self._make_request( "post", f"/tenants/{tenant}/databases", json={"name": name}, ) # Migrated to rust in distributed. @trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION) @override def get_database( self, name: str, tenant: str = DEFAULT_TENANT, ) -> Database: """Returns a database""" resp_json = self._make_request( "get", f"/tenants/{tenant}/databases/{name}", ) return Database( id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"] ) @trace_method("FastAPI.delete_database", OpenTelemetryGranularity.OPERATION) @override def delete_database( self, name: str, tenant: str = DEFAULT_TENANT, ) -> None: """Deletes a database""" self._make_request( "delete", f"/tenants/{tenant}/databases/{name}", ) @trace_method("FastAPI.list_databases", OpenTelemetryGranularity.OPERATION) @override def list_databases( self, limit: Optional[int] = None, offset: Optional[int] = None, tenant: str = DEFAULT_TENANT, ) -> Sequence[Database]: """Returns a list of all databases""" json_databases = self._make_request( "get", f"/tenants/{tenant}/databases", params=BaseHTTPClient._clean_params( { "limit": limit, "offset": offset, } ), ) databases = [ Database(id=db["id"], name=db["name"], tenant=db["tenant"]) for db in json_databases ] return databases @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION) @override def create_tenant(self, name: str) -> None: self._make_request("post", "/tenants", json={"name": name}) @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION) @override def get_tenant(self, name: str) -> Tenant: resp_json = self._make_request("get", "/tenants/" + name) return Tenant(name=resp_json["name"]) @trace_method("FastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION) @override def get_user_identity(self) -> UserIdentity: return UserIdentity(**self._make_request("get", "/auth/identity")) @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) @override def list_collections( self, limit: Optional[int] = None, offset: Optional[int] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Sequence[CollectionModel]: """Returns a list of all collections""" json_collections = self._make_request( "get", f"/tenants/{tenant}/databases/{database}/collections", params=BaseHTTPClient._clean_params( { "limit": limit, "offset": offset, } ), ) collection_models = [ CollectionModel.from_json(json_collection) for json_collection in json_collections ] return collection_models @trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION) @override def count_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> int: """Returns a count of collections""" resp_json = self._make_request( "get", f"/tenants/{tenant}/databases/{database}/collections_count", ) return cast(int, resp_json) @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) @override def create_collection( self, name: str, schema: Optional[Schema] = None, configuration: Optional[CreateCollectionConfiguration] = None, metadata: Optional[CollectionMetadata] = None, get_or_create: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: """Creates a collection""" config_json = ( create_collection_configuration_to_json(configuration, metadata) if configuration else None ) serialized_schema = schema.serialize_to_json() if schema else None resp_json = self._make_request( "post", f"/tenants/{tenant}/databases/{database}/collections", json={ "name": name, "metadata": metadata, "configuration": config_json, "schema": serialized_schema, "get_or_create": get_or_create, }, ) model = CollectionModel.from_json(resp_json) return model @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION) @override def get_collection( self, name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: """Returns a collection""" resp_json = self._make_request( "get", f"/tenants/{tenant}/databases/{database}/collections/{name}", ) model = CollectionModel.from_json(resp_json) return model @trace_method( "FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION ) @override def get_or_create_collection( self, name: str, schema: Optional[Schema] = None, configuration: Optional[CreateCollectionConfiguration] = None, metadata: Optional[CollectionMetadata] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: return self.create_collection( name=name, metadata=metadata, configuration=configuration, schema=schema, get_or_create=True, tenant=tenant, database=database, ) @trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION) @override def _modify( self, id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, new_configuration: Optional[UpdateCollectionConfiguration] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: """Updates a collection""" self._make_request( "put", f"/tenants/{tenant}/databases/{database}/collections/{id}", json={ "new_metadata": new_metadata, "new_name": new_name, "new_configuration": update_collection_configuration_to_json( new_configuration ) if new_configuration else None, }, ) @trace_method("FastAPI._fork", OpenTelemetryGranularity.OPERATION) @override def _fork( self, collection_id: UUID, new_name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: """Forks a collection""" resp_json = self._make_request( "post", f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork", json={"new_name": new_name}, ) model = CollectionModel.from_json(resp_json) return model @trace_method("FastAPI._search", OpenTelemetryGranularity.OPERATION) @override def _search( self, collection_id: UUID, searches: List[Search], tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> SearchResult: """Performs hybrid search on a collection""" # Convert Search objects to dictionaries payload = {"searches": [s.to_dict() for s in searches]} resp_json = self._make_request( "post", f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/search", json=payload, ) # Deserialize metadatas: convert transport format to SparseVector instances metadata_batches = resp_json.get("metadatas", None) if metadata_batches is not None: # SearchResult has nested structure: List[Optional[List[Optional[Metadata]]]] resp_json["metadatas"] = [ [ deserialize_metadata(metadata) if metadata is not None else None for metadata in metadatas ] if metadatas is not None else None for metadatas in metadata_batches ] return SearchResult(resp_json) @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override def delete_collection( self, name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: """Deletes a collection""" self._make_request( "delete", f"/tenants/{tenant}/databases/{database}/collections/{name}", ) @trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION) @override def _count( self, collection_id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> int: """Returns the number of embeddings in the database""" resp_json = self._make_request( "get", f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/count", ) return cast(int, resp_json) @trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION) @override def _peek( self, collection_id: UUID, n: int = 10, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> GetResult: return cast( GetResult, self._get( collection_id, tenant=tenant, database=database, limit=n, include=IncludeMetadataDocumentsEmbeddings, ), ) @trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION) @override def _get( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = None, limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, include: Include = IncludeMetadataDocuments, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> GetResult: # Servers do not support receiving "data", as that is hydrated by the client as a loadable filtered_include = [i for i in include if i != "data"] resp_json = self._make_request( "post", f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/get", json={ "ids": ids, "where": where, "limit": limit, "offset": offset, "where_document": where_document, "include": filtered_include, }, ) # Deserialize metadatas: convert transport format to SparseVector instances metadatas = resp_json.get("metadatas", None) if metadatas is not None: metadatas = [ deserialize_metadata(metadata) if metadata is not None else None for metadata in metadatas ] return GetResult( ids=resp_json["ids"], embeddings=resp_json.get("embeddings", None), metadatas=metadatas, documents=resp_json.get("documents", None), data=None, uris=resp_json.get("uris", None), included=include, ) @trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION) @override def _delete( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: """Deletes embeddings from the database""" self._make_request( "post", f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete", json={ "ids": ids, "where": where, "where_document": where_document, }, ) return None @trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL) def _submit_batch( self, batch: Tuple[ IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents], Optional[URIs], ], url: str, ) -> None: """ Submits a batch of embeddings to the database """ # Serialize metadatas: convert SparseVector instances to transport format serialized_metadatas = None if batch[2] is not None: serialized_metadatas = [ serialize_metadata(metadata) if metadata is not None else None for metadata in batch[2] ] data = { "ids": batch[0], "embeddings": optional_embeddings_to_base64_strings(batch[1]) if self.supports_base64_encoding() else batch[1], "metadatas": serialized_metadatas, "documents": batch[3], "uris": batch[4], } self._make_request("post", url, json=data) @trace_method("FastAPI._add", OpenTelemetryGranularity.ALL) @override def _add( self, ids: IDs, collection_id: UUID, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: """ Adds a batch of embeddings to the database - pass in column oriented data lists """ batch = ( ids, embeddings, metadatas, documents, uris, ) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) self._submit_batch( batch, f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/add", ) return True @trace_method("FastAPI._update", OpenTelemetryGranularity.ALL) @override def _update( self, collection_id: UUID, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: """ Updates a batch of embeddings in the database - pass in column oriented data lists """ batch = ( ids, embeddings if embeddings is not None else None, metadatas, documents, uris, ) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) self._submit_batch( batch, f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/update", ) return True @trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL) @override def _upsert( self, collection_id: UUID, ids: IDs, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: """ Upserts a batch of embeddings in the database - pass in column oriented data lists """ batch = ( ids, embeddings, metadatas, documents, uris, ) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) self._submit_batch( batch, f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/upsert", ) return True @trace_method("FastAPI._query", OpenTelemetryGranularity.ALL) @override def _query( self, collection_id: UUID, query_embeddings: Embeddings, ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, include: Include = IncludeMetadataDocumentsDistances, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> QueryResult: # Clients do not support receiving "data", as that is hydrated by the client as a loadable filtered_include = [i for i in include if i != "data"] """Gets the nearest neighbors of a single embedding""" resp_json = self._make_request( "post", f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query", json={ "ids": ids, "query_embeddings": convert_np_embeddings_to_list(query_embeddings) if query_embeddings is not None else None, "n_results": n_results, "where": where, "where_document": where_document, "include": filtered_include, }, ) # Deserialize metadatas: convert transport format to SparseVector instances metadata_batches = resp_json.get("metadatas", None) if metadata_batches is not None: metadata_batches = [ [ deserialize_metadata(metadata) if metadata is not None else None for metadata in metadatas ] if metadatas is not None else None for metadatas in metadata_batches ] return QueryResult( ids=resp_json["ids"], distances=resp_json.get("distances", None), embeddings=resp_json.get("embeddings", None), metadatas=metadata_batches, documents=resp_json.get("documents", None), uris=resp_json.get("uris", None), data=None, included=include, ) @trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL) @override def reset(self) -> bool: """Resets the database""" resp_json = self._make_request("post", "/reset") return cast(bool, resp_json) @trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION) @override def get_version(self) -> str: """Returns the version of the server""" resp_json = self._make_request("get", "/version") return cast(str, resp_json) @override def get_settings(self) -> Settings: """Returns the settings of the client""" return self._settings @trace_method("FastAPI.get_pre_flight_checks", OpenTelemetryGranularity.OPERATION) def get_pre_flight_checks(self) -> Any: if self.pre_flight_checks is None: resp_json = self._make_request("get", "/pre-flight-checks") self.pre_flight_checks = resp_json return self.pre_flight_checks @trace_method( "FastAPI.supports_base64_encoding", OpenTelemetryGranularity.OPERATION ) def supports_base64_encoding(self) -> bool: pre_flight_checks = self.get_pre_flight_checks() b64_encoding_enabled = cast( bool, pre_flight_checks.get("supports_base64_encoding", False) ) return b64_encoding_enabled @trace_method("FastAPI.get_max_batch_size", OpenTelemetryGranularity.OPERATION) @override def get_max_batch_size(self) -> int: pre_flight_checks = self.get_pre_flight_checks() max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1)) return max_batch_size @trace_method("FastAPI.attach_function", OpenTelemetryGranularity.ALL) @override def attach_function( self, function_id: str, name: str, input_collection_id: UUID, output_collection: str, params: Optional[Dict[str, Any]] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Tuple["AttachedFunction", bool]: """Attach a function to a collection.""" resp_json = self._make_request( "post", f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/attach", json={ "name": name, "function_id": function_id, "output_collection": output_collection, "params": params, }, ) attached_function = AttachedFunction( client=self, id=UUID(resp_json["attached_function"]["id"]), name=resp_json["attached_function"]["name"], function_name=resp_json["attached_function"]["function_name"], input_collection_id=input_collection_id, output_collection=output_collection, params=params, tenant=tenant, database=database, ) created = resp_json.get( "created", True ) # Default to True for backwards compatibility return (attached_function, created) @trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.ALL) @override def get_attached_function( self, name: str, input_collection_id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> "AttachedFunction": """Get an attached function by name for a specific collection.""" resp_json = self._make_request( "get", f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/{name}", ) af = resp_json["attached_function"] return AttachedFunction( client=self, id=UUID(af["id"]), name=af["name"], function_name=af["function_name"], input_collection_id=input_collection_id, output_collection=af["output_collection"], params=af.get("params"), tenant=tenant, database=database, ) @trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL) @override def detach_function( self, name: str, input_collection_id: UUID, delete_output: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: """Detach a function and prevent any further runs.""" resp_json = self._make_request( "post", f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/attached_functions/{name}/detach", json={ "delete_output": delete_output, }, ) return cast(bool, resp_json["success"])