chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,439 @@
|
||||
from typing import Dict, Optional, Union
|
||||
import logging
|
||||
from chromadb.api.client import Client as ClientCreator
|
||||
from chromadb.api.client import (
|
||||
AdminClient as AdminClientCreator,
|
||||
)
|
||||
from chromadb.api.async_client import AsyncClient as AsyncClientCreator
|
||||
from chromadb.auth.token_authn import TokenTransportHeader
|
||||
import chromadb.config
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings
|
||||
from chromadb.api import AdminAPI, AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
UpdateMetadata,
|
||||
Documents,
|
||||
EmbeddingFunction,
|
||||
Embeddings,
|
||||
URI,
|
||||
URIs,
|
||||
IDs,
|
||||
Include,
|
||||
Metadata,
|
||||
Metadatas,
|
||||
Where,
|
||||
QueryResult,
|
||||
GetResult,
|
||||
WhereDocument,
|
||||
UpdateCollectionMetadata,
|
||||
SparseVector,
|
||||
SparseVectors,
|
||||
SparseEmbeddingFunction,
|
||||
Schema,
|
||||
VectorIndexConfig,
|
||||
HnswIndexConfig,
|
||||
SpannIndexConfig,
|
||||
FtsIndexConfig,
|
||||
SparseVectorIndexConfig,
|
||||
StringInvertedIndexConfig,
|
||||
IntInvertedIndexConfig,
|
||||
FloatInvertedIndexConfig,
|
||||
BoolInvertedIndexConfig,
|
||||
)
|
||||
|
||||
# Import Search API components
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression.operator import (
|
||||
# Key builder for where conditions and field selection
|
||||
Key,
|
||||
K, # Alias for Key
|
||||
# KNN-based ranking for hybrid search
|
||||
Knn,
|
||||
# Reciprocal Rank Fusion for combining rankings
|
||||
Rrf,
|
||||
)
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# Re-export types from chromadb.types
|
||||
__all__ = [
|
||||
"Collection",
|
||||
"Metadata",
|
||||
"Metadatas",
|
||||
"Where",
|
||||
"WhereDocument",
|
||||
"Documents",
|
||||
"IDs",
|
||||
"URI",
|
||||
"URIs",
|
||||
"Embeddings",
|
||||
"EmbeddingFunction",
|
||||
"Include",
|
||||
"CollectionMetadata",
|
||||
"UpdateMetadata",
|
||||
"UpdateCollectionMetadata",
|
||||
"QueryResult",
|
||||
"GetResult",
|
||||
"TokenTransportHeader",
|
||||
# Search API components
|
||||
"Search",
|
||||
"Key",
|
||||
"K",
|
||||
"Knn",
|
||||
"Rrf",
|
||||
# Sparse Vector Types
|
||||
"SparseVector",
|
||||
"SparseVectors",
|
||||
"SparseEmbeddingFunction",
|
||||
# Schema and Index Configuration
|
||||
"Schema",
|
||||
"VectorIndexConfig",
|
||||
"HnswIndexConfig",
|
||||
"SpannIndexConfig",
|
||||
"FtsIndexConfig",
|
||||
"SparseVectorIndexConfig",
|
||||
"StringInvertedIndexConfig",
|
||||
"IntInvertedIndexConfig",
|
||||
"FloatInvertedIndexConfig",
|
||||
"BoolInvertedIndexConfig",
|
||||
]
|
||||
|
||||
from chromadb.types import CloudClientArg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__settings = Settings()
|
||||
|
||||
__version__ = "1.3.4"
|
||||
|
||||
|
||||
# Workaround to deal with Colab's old sqlite3 version
|
||||
def is_in_colab() -> bool:
|
||||
try:
|
||||
import google.colab # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
IN_COLAB = is_in_colab()
|
||||
|
||||
is_client = False
|
||||
try:
|
||||
from chromadb.is_thin_client import is_thin_client
|
||||
|
||||
is_client = is_thin_client
|
||||
except ImportError:
|
||||
is_client = False
|
||||
|
||||
if not is_client:
|
||||
import sqlite3
|
||||
|
||||
if sqlite3.sqlite_version_info < (3, 35, 0):
|
||||
if IN_COLAB:
|
||||
# In Colab, hotswap to pysqlite-binary if it's too old
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "pysqlite3-binary"]
|
||||
)
|
||||
__import__("pysqlite3")
|
||||
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"\033[91mYour system has an unsupported version of sqlite3. Chroma \
|
||||
requires sqlite3 >= 3.35.0.\033[0m\n"
|
||||
"\033[94mPlease visit \
|
||||
https://docs.trychroma.com/troubleshooting#sqlite to learn how \
|
||||
to upgrade.\033[0m"
|
||||
)
|
||||
|
||||
|
||||
def configure(**kwargs) -> None: # type: ignore
|
||||
"""Override Chroma's default settings, environment variables or .env files"""
|
||||
global __settings
|
||||
__settings = chromadb.config.Settings(**kwargs)
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
return __settings
|
||||
|
||||
|
||||
def EphemeralClient(
|
||||
settings: Optional[Settings] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> ClientAPI:
|
||||
"""
|
||||
Creates an in-memory instance of Chroma. This is useful for testing and
|
||||
development, but not recommended for production use.
|
||||
|
||||
Args:
|
||||
tenant: The tenant to use for this client. Defaults to the default tenant.
|
||||
database: The database to use for this client. Defaults to the default database.
|
||||
"""
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
settings.is_persistent = False
|
||||
|
||||
# Make sure paramaters are the correct types -- users can pass anything.
|
||||
tenant = str(tenant)
|
||||
database = str(database)
|
||||
|
||||
return ClientCreator(settings=settings, tenant=tenant, database=database)
|
||||
|
||||
|
||||
def PersistentClient(
|
||||
path: Union[str, Path] = "./chroma",
|
||||
settings: Optional[Settings] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> ClientAPI:
|
||||
"""
|
||||
Creates a persistent instance of Chroma that saves to disk. This is useful for
|
||||
testing and development, but not recommended for production use.
|
||||
|
||||
Args:
|
||||
path: The directory to save Chroma's data to. Defaults to "./chroma".
|
||||
tenant: The tenant to use for this client. Defaults to the default tenant.
|
||||
database: The database to use for this client. Defaults to the default database.
|
||||
"""
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
settings.persist_directory = str(path)
|
||||
settings.is_persistent = True
|
||||
|
||||
# Make sure paramaters are the correct types -- users can pass anything.
|
||||
tenant = str(tenant)
|
||||
database = str(database)
|
||||
|
||||
return ClientCreator(tenant=tenant, database=database, settings=settings)
|
||||
|
||||
|
||||
def RustClient(
|
||||
path: Optional[str] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> ClientAPI:
|
||||
"""
|
||||
Creates an ephemeral or persistance instance of Chroma that saves to disk.
|
||||
This is useful for testing and development, but not recommended for production use.
|
||||
|
||||
Args:
|
||||
path: An optional directory to save Chroma's data to. The client is ephemeral if a None value is provided. Defaults to None.
|
||||
tenant: The tenant to use for this client. Defaults to the default tenant.
|
||||
database: The database to use for this client. Defaults to the default database.
|
||||
"""
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
|
||||
settings.chroma_api_impl = "chromadb.api.rust.RustBindingsAPI"
|
||||
settings.is_persistent = path is not None
|
||||
settings.persist_directory = path or ""
|
||||
|
||||
# Make sure paramaters are the correct types -- users can pass anything.
|
||||
tenant = str(tenant)
|
||||
database = str(database)
|
||||
|
||||
return ClientCreator(tenant=tenant, database=database, settings=settings)
|
||||
|
||||
|
||||
def HttpClient(
|
||||
host: str = "localhost",
|
||||
port: int = 8000,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> ClientAPI:
|
||||
"""
|
||||
Creates a client that connects to a remote Chroma server. This supports
|
||||
many clients connecting to the same server, and is the recommended way to
|
||||
use Chroma in production.
|
||||
|
||||
Args:
|
||||
host: The hostname of the Chroma server. Defaults to "localhost".
|
||||
port: The port of the Chroma server. Defaults to 8000.
|
||||
ssl: Whether to use SSL to connect to the Chroma server. Defaults to False.
|
||||
headers: A dictionary of headers to send to the Chroma server. Defaults to {}.
|
||||
settings: A dictionary of settings to communicate with the chroma server.
|
||||
tenant: The tenant to use for this client. Defaults to the default tenant.
|
||||
database: The database to use for this client. Defaults to the default database.
|
||||
"""
|
||||
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
|
||||
# Make sure parameters are the correct types -- users can pass anything.
|
||||
host = str(host)
|
||||
port = int(port)
|
||||
ssl = bool(ssl)
|
||||
tenant = str(tenant)
|
||||
database = str(database)
|
||||
|
||||
settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
||||
if settings.chroma_server_host and settings.chroma_server_host != host:
|
||||
raise ValueError(
|
||||
f"Chroma server host provided in settings[{settings.chroma_server_host}] is different to the one provided in HttpClient: [{host}]"
|
||||
)
|
||||
settings.chroma_server_host = host
|
||||
if settings.chroma_server_http_port and settings.chroma_server_http_port != port:
|
||||
raise ValueError(
|
||||
f"Chroma server http port provided in settings[{settings.chroma_server_http_port}] is different to the one provided in HttpClient: [{port}]"
|
||||
)
|
||||
settings.chroma_server_http_port = port
|
||||
settings.chroma_server_ssl_enabled = ssl
|
||||
settings.chroma_server_headers = headers
|
||||
|
||||
return ClientCreator(tenant=tenant, database=database, settings=settings)
|
||||
|
||||
|
||||
async def AsyncHttpClient(
|
||||
host: str = "localhost",
|
||||
port: int = 8000,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> AsyncClientAPI:
|
||||
"""
|
||||
Creates an async client that connects to a remote Chroma server. This supports
|
||||
many clients connecting to the same server, and is the recommended way to
|
||||
use Chroma in production.
|
||||
|
||||
Args:
|
||||
host: The hostname of the Chroma server. Defaults to "localhost".
|
||||
port: The port of the Chroma server. Defaults to 8000.
|
||||
ssl: Whether to use SSL to connect to the Chroma server. Defaults to False.
|
||||
headers: A dictionary of headers to send to the Chroma server. Defaults to {}.
|
||||
settings: A dictionary of settings to communicate with the chroma server.
|
||||
tenant: The tenant to use for this client. Defaults to the default tenant.
|
||||
database: The database to use for this client. Defaults to the default database.
|
||||
"""
|
||||
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
|
||||
# Make sure parameters are the correct types -- users can pass anything.
|
||||
host = str(host)
|
||||
port = int(port)
|
||||
ssl = bool(ssl)
|
||||
tenant = str(tenant)
|
||||
database = str(database)
|
||||
|
||||
settings.chroma_api_impl = "chromadb.api.async_fastapi.AsyncFastAPI"
|
||||
if settings.chroma_server_host and settings.chroma_server_host != host:
|
||||
raise ValueError(
|
||||
f"Chroma server host provided in settings[{settings.chroma_server_host}] is different to the one provided in HttpClient: [{host}]"
|
||||
)
|
||||
settings.chroma_server_host = host
|
||||
if settings.chroma_server_http_port and settings.chroma_server_http_port != port:
|
||||
raise ValueError(
|
||||
f"Chroma server http port provided in settings[{settings.chroma_server_http_port}] is different to the one provided in HttpClient: [{port}]"
|
||||
)
|
||||
settings.chroma_server_http_port = port
|
||||
settings.chroma_server_ssl_enabled = ssl
|
||||
settings.chroma_server_headers = headers
|
||||
|
||||
return await AsyncClientCreator.create(
|
||||
tenant=tenant, database=database, settings=settings
|
||||
)
|
||||
|
||||
|
||||
def CloudClient(
|
||||
tenant: Optional[str] = None,
|
||||
database: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
*, # Following arguments are keyword-only, intended for testing only.
|
||||
cloud_host: str = "api.trychroma.com",
|
||||
cloud_port: int = 443,
|
||||
enable_ssl: bool = True,
|
||||
) -> ClientAPI:
|
||||
"""
|
||||
Creates a client to connect to a tenant and database on Chroma cloud.
|
||||
|
||||
Args:
|
||||
tenant: The tenant to use for this client. Optional. If not provided, it will be inferred from the API key if the key is scoped to a single tenant. If provided, it will be validated against the API key's scope.
|
||||
database: The database to use for this client. Optional. If not provided, it will be inferred from the API key if the key is scoped to a single database. If provided, it will be validated against the API key's scope.
|
||||
api_key: The api key to use for this client.
|
||||
"""
|
||||
|
||||
required_args = [
|
||||
CloudClientArg(name="api_key", env_var="CHROMA_API_KEY", value=api_key),
|
||||
]
|
||||
|
||||
# If api_key is not provided, try to load it from the environment variable
|
||||
if not all([arg.value for arg in required_args]):
|
||||
for arg in required_args:
|
||||
arg.value = arg.value or os.environ.get(arg.env_var)
|
||||
|
||||
missing_args = [arg for arg in required_args if arg.value is None]
|
||||
if missing_args:
|
||||
raise ValueError(
|
||||
f"Missing required arguments: {', '.join([arg.name for arg in missing_args])}. "
|
||||
f"Please provide them or set the environment variables: {', '.join([arg.env_var for arg in missing_args])}"
|
||||
)
|
||||
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
|
||||
# Make sure paramaters are the correct types -- users can pass anything.
|
||||
tenant = tenant or os.environ.get("CHROMA_TENANT")
|
||||
if tenant is not None:
|
||||
tenant = str(tenant)
|
||||
database = database or os.environ.get("CHROMA_DATABASE")
|
||||
if database is not None:
|
||||
database = str(database)
|
||||
api_key = str(api_key)
|
||||
cloud_host = str(cloud_host)
|
||||
cloud_port = int(cloud_port)
|
||||
enable_ssl = bool(enable_ssl)
|
||||
|
||||
settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
||||
settings.chroma_server_host = cloud_host
|
||||
settings.chroma_server_http_port = cloud_port
|
||||
settings.chroma_server_ssl_enabled = enable_ssl
|
||||
|
||||
settings.chroma_client_auth_provider = (
|
||||
"chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
settings.chroma_client_auth_credentials = api_key
|
||||
settings.chroma_auth_token_transport_header = TokenTransportHeader.X_CHROMA_TOKEN
|
||||
settings.chroma_overwrite_singleton_tenant_database_access_from_auth = True
|
||||
|
||||
return ClientCreator(tenant=tenant, database=database, settings=settings)
|
||||
|
||||
|
||||
def Client(
|
||||
settings: Settings = __settings,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> ClientAPI:
|
||||
"""
|
||||
Return a running chroma.API instance
|
||||
|
||||
tenant: The tenant to use for this client. Defaults to the default tenant.
|
||||
database: The database to use for this client. Defaults to the default database.
|
||||
"""
|
||||
|
||||
# Make sure paramaters are the correct types -- users can pass anything.
|
||||
tenant = str(tenant)
|
||||
database = str(database)
|
||||
|
||||
return ClientCreator(tenant=tenant, database=database, settings=settings)
|
||||
|
||||
|
||||
def AdminClient(settings: Settings = Settings()) -> AdminAPI:
|
||||
"""
|
||||
|
||||
Creates an admin client that can be used to create tenants and databases.
|
||||
|
||||
"""
|
||||
return AdminClientCreator(settings=settings)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,863 @@
|
||||
from chromadb.api.types import * # noqa: F401, F403
|
||||
from chromadb.execution.expression import ( # noqa: F401, F403
|
||||
Search,
|
||||
Key,
|
||||
K,
|
||||
SearchWhere,
|
||||
And,
|
||||
Or,
|
||||
Eq,
|
||||
Ne,
|
||||
Gt,
|
||||
Gte,
|
||||
Lt,
|
||||
Lte,
|
||||
In,
|
||||
Nin,
|
||||
Regex,
|
||||
NotRegex,
|
||||
Contains,
|
||||
NotContains,
|
||||
Limit,
|
||||
Select,
|
||||
Rank,
|
||||
Abs,
|
||||
Div,
|
||||
Exp,
|
||||
Log,
|
||||
Max,
|
||||
Min,
|
||||
Mul,
|
||||
Knn,
|
||||
Rrf,
|
||||
Sub,
|
||||
Sum,
|
||||
Val,
|
||||
)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence, Optional, List, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from overrides import override
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
)
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
Documents,
|
||||
Embeddable,
|
||||
EmbeddingFunction,
|
||||
DataLoader,
|
||||
Embeddings,
|
||||
IDs,
|
||||
Include,
|
||||
IncludeMetadataDocumentsDistances,
|
||||
IncludeMetadataDocuments,
|
||||
Loadable,
|
||||
Metadatas,
|
||||
Schema,
|
||||
URIs,
|
||||
Where,
|
||||
QueryResult,
|
||||
GetResult,
|
||||
WhereDocument,
|
||||
SearchResult,
|
||||
DefaultEmbeddingFunction,
|
||||
)
|
||||
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.config import Component, Settings
|
||||
from chromadb.types import Database, Tenant, Collection as CollectionModel
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.models.AttachedFunction import AttachedFunction
|
||||
|
||||
# Re-export the async version
|
||||
from chromadb.api.async_api import ( # noqa: F401
|
||||
AsyncBaseAPI as AsyncBaseAPI,
|
||||
AsyncClientAPI as AsyncClientAPI,
|
||||
AsyncAdminAPI as AsyncAdminAPI,
|
||||
AsyncServerAPI as AsyncServerAPI,
|
||||
)
|
||||
|
||||
|
||||
class BaseAPI(ABC):
|
||||
@abstractmethod
|
||||
def heartbeat(self) -> int:
|
||||
"""Get the current time in nanoseconds since epoch.
|
||||
Used to check if the server is alive.
|
||||
|
||||
Returns:
|
||||
int: The current time in nanoseconds since epoch
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
#
|
||||
# COLLECTION METHODS
|
||||
#
|
||||
@abstractmethod
|
||||
def count_collections(self) -> int:
|
||||
"""Count the number of collections.
|
||||
|
||||
Returns:
|
||||
int: The number of collections.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
client.count_collections()
|
||||
# 1
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
) -> None:
|
||||
"""[Internal] Modify a collection by UUID. Can update the name and/or metadata.
|
||||
|
||||
Args:
|
||||
id: The internal UUID of the collection to modify.
|
||||
new_name: The new name of the collection.
|
||||
If None, the existing name will remain. Defaults to None.
|
||||
new_metadata: The new metadata to associate with the collection.
|
||||
Defaults to None.
|
||||
new_configuration: The new configuration to associate with the collection.
|
||||
Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete a collection with the given name.
|
||||
Args:
|
||||
name: The name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection does not exist.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
client.delete_collection("my_collection")
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
#
|
||||
# ITEM METHODS
|
||||
#
|
||||
|
||||
@abstractmethod
|
||||
def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
"""[Internal] Add embeddings to a collection specified by UUID.
|
||||
If (some) ids already exist, only the new embeddings will be added.
|
||||
|
||||
Args:
|
||||
ids: The ids to associate with the embeddings.
|
||||
collection_id: The UUID of the collection to add the embeddings to.
|
||||
embedding: The sequence of embeddings to add.
|
||||
metadata: The metadata to associate with the embeddings. Defaults to None.
|
||||
documents: The documents to associate with the embeddings. Defaults to None.
|
||||
uris: URIs of data sources for each embedding. Defaults to None.
|
||||
|
||||
Returns:
|
||||
True if the embeddings were added successfully.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
"""[Internal] Update entries in a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to update the embeddings in.
|
||||
ids: The IDs of the entries to update.
|
||||
embeddings: The sequence of embeddings to update. Defaults to None.
|
||||
metadatas: The metadata to associate with the embeddings. Defaults to None.
|
||||
documents: The documents to associate with the embeddings. Defaults to None.
|
||||
uris: URIs of data sources for each embedding. Defaults to None.
|
||||
Returns:
|
||||
True if the embeddings were updated successfully.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
"""[Internal] Add or update entries in the a collection specified by UUID.
|
||||
If an entry with the same id already exists, it will be updated,
|
||||
otherwise it will be added.
|
||||
|
||||
Args:
|
||||
collection_id: The collection to add the embeddings to
|
||||
ids: The ids to associate with the embeddings. Defaults to None.
|
||||
embeddings: The sequence of embeddings to add
|
||||
metadatas: The metadata to associate with the embeddings. Defaults to None.
|
||||
documents: The documents to associate with the embeddings. Defaults to None.
|
||||
uris: URIs of data sources for each embedding. Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _count(self, collection_id: UUID) -> int:
|
||||
"""[Internal] Returns the number of entries in a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to count the embeddings in.
|
||||
|
||||
Returns:
|
||||
int: The number of embeddings in the collection
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
|
||||
"""[Internal] Returns the first n entries in a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to peek into.
|
||||
n: The number of entries to peek. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
GetResult: The first n entries in the collection.
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
) -> GetResult:
|
||||
"""[Internal] Returns entries from a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
ids: The IDs of the entries to get. Defaults to None.
|
||||
where: Conditional filtering on metadata. Defaults to None.
|
||||
limit: The maximum number of entries to return. Defaults to None.
|
||||
offset: The number of entries to skip before returning. Defaults to None.
|
||||
where_document: Conditional filtering on documents. Defaults to None.
|
||||
include: The fields to include in the response.
|
||||
Defaults to ["metadatas", "documents"].
|
||||
Returns:
|
||||
GetResult: The entries in the collection that match the query.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs],
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
) -> None:
|
||||
"""[Internal] Deletes entries from a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to delete the entries from.
|
||||
ids: The IDs of the entries to delete. Defaults to None.
|
||||
where: Conditional filtering on metadata. Defaults to None.
|
||||
where_document: Conditional filtering on documents. Defaults to None.
|
||||
|
||||
Returns:
|
||||
IDs: The list of IDs of the entries that were deleted.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
) -> QueryResult:
|
||||
"""[Internal] Performs a nearest neighbors query on a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to query.
|
||||
query_embeddings: The embeddings to use as the query.
|
||||
ids: The IDs to filter by during the query. Defaults to None.
|
||||
n_results: The number of results to return. Defaults to 10.
|
||||
where: Conditional filtering on metadata. Defaults to None.
|
||||
where_document: Conditional filtering on documents. Defaults to None.
|
||||
include: The fields to include in the response.
|
||||
Defaults to ["metadatas", "documents", "distances"].
|
||||
|
||||
Returns:
|
||||
QueryResult: The results of the query.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> bool:
|
||||
"""Resets the database. This will delete all collections and entries.
|
||||
|
||||
Returns:
|
||||
bool: True if the database was reset successfully.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_version(self) -> str:
|
||||
"""Get the version of Chroma.
|
||||
|
||||
Returns:
|
||||
str: The version of Chroma
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_settings(self) -> Settings:
|
||||
"""Get the settings used to initialize.
|
||||
|
||||
Returns:
|
||||
Settings: The settings used to initialize.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_max_batch_size(self) -> int:
|
||||
"""Return the maximum number of records that can be created or mutated in a single call."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_identity(self) -> UserIdentity:
|
||||
"""Resolve the tenant and databases for the client. Returns the default
|
||||
values if can't be resolved.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ClientAPI(BaseAPI, ABC):
|
||||
tenant: str
|
||||
database: str
|
||||
|
||||
@abstractmethod
|
||||
def list_collections(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> Sequence[Collection]:
|
||||
"""List all collections.
|
||||
Args:
|
||||
limit: The maximum number of entries to return. Defaults to None.
|
||||
offset: The number of entries to skip before returning. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Sequence[Collection]: A list of collections
|
||||
|
||||
Examples:
|
||||
```python
|
||||
client.list_collections()
|
||||
# [collection(name="my_collection", metadata={})]
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
get_or_create: bool = False,
|
||||
) -> Collection:
|
||||
"""Create a new collection with the given name and metadata.
|
||||
Args:
|
||||
name: The name of the collection to create.
|
||||
metadata: Optional metadata to associate with the collection.
|
||||
embedding_function: Optional function to use to embed documents.
|
||||
Uses the default embedding function if not provided.
|
||||
get_or_create: If True, return the existing collection if it exists.
|
||||
data_loader: Optional function to use to load records (documents, images, etc.)
|
||||
|
||||
Returns:
|
||||
Collection: The newly created collection.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection already exists and get_or_create is False.
|
||||
ValueError: If the collection name is invalid.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
client.create_collection("my_collection")
|
||||
# collection(name="my_collection", metadata={})
|
||||
|
||||
client.create_collection("my_collection", metadata={"foo": "bar"})
|
||||
# collection(name="my_collection", metadata={"foo": "bar"})
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
) -> Collection:
|
||||
"""Get a collection with the given name.
|
||||
Args:
|
||||
name: The name of the collection to get
|
||||
embedding_function: Optional function to use to embed documents.
|
||||
Uses the default embedding function if not provided.
|
||||
data_loader: Optional function to use to load records (documents, images, etc.)
|
||||
|
||||
Returns:
|
||||
Collection: The collection
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection does not exist
|
||||
|
||||
Examples:
|
||||
```python
|
||||
client.get_collection("my_collection")
|
||||
# collection(name="my_collection", metadata={})
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
) -> Collection:
|
||||
"""Get or create a collection with the given name and metadata.
|
||||
Args:
|
||||
name: The name of the collection to get or create
|
||||
metadata: Optional metadata to associate with the collection. If
|
||||
the collection already exists, the metadata provided is ignored.
|
||||
If the collection does not exist, the new collection will be created
|
||||
with the provided metadata.
|
||||
embedding_function: Optional function to use to embed documents
|
||||
data_loader: Optional function to use to load records (documents, images, etc.)
|
||||
|
||||
Returns:
|
||||
The collection
|
||||
|
||||
Examples:
|
||||
```python
|
||||
client.get_or_create_collection("my_collection")
|
||||
# collection(name="my_collection", metadata={})
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
|
||||
"""Set the tenant and database for the client. Raises an error if the tenant or
|
||||
database does not exist.
|
||||
|
||||
Args:
|
||||
tenant: The tenant to set.
|
||||
database: The database to set.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_database(self, database: str) -> None:
|
||||
"""Set the database for the client. Raises an error if the database does not exist.
|
||||
|
||||
Args:
|
||||
database: The database to set.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def clear_system_cache() -> None:
|
||||
"""Clear the system cache so that new systems can be created for an existing path.
|
||||
This should only be used for testing purposes."""
|
||||
pass
|
||||
|
||||
|
||||
class AdminAPI(ABC):
|
||||
@abstractmethod
|
||||
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
"""Create a new database. Raises an error if the database already exists.
|
||||
|
||||
Args:
|
||||
database: The name of the database to create.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
||||
"""Get a database. Raises an error if the database does not exist.
|
||||
|
||||
Args:
|
||||
database: The name of the database to get.
|
||||
tenant: The tenant of the database to get.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
"""Delete a database. Raises an error if the database does not exist.
|
||||
|
||||
Args:
|
||||
database: The name of the database to delete.
|
||||
tenant: The tenant of the database to delete.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
"""List all databases for a tenant. Raises an error if the tenant does not exist.
|
||||
|
||||
Args:
|
||||
tenant: The tenant to list databases for.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_tenant(self, name: str) -> None:
|
||||
"""Create a new tenant. Raises an error if the tenant already exists.
|
||||
|
||||
Args:
|
||||
tenant: The name of the tenant to create.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tenant(self, name: str) -> Tenant:
|
||||
"""Get a tenant. Raises an error if the tenant does not exist.
|
||||
|
||||
Args:
|
||||
tenant: The name of the tenant to get.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ServerAPI(BaseAPI, AdminAPI, Component):
|
||||
"""An API instance that extends the relevant Base API methods by passing
|
||||
in a tenant and database. This is the root component of the Chroma System"""
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def count_collections(
|
||||
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
||||
) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_collections(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Sequence[CollectionModel]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _fork(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
new_name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _search(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
searches: List[Search],
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> SearchResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def _count(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def _peek(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
n: int = 10,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> QueryResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def attach_function(
|
||||
self,
|
||||
function_id: str,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Attach a function to a collection.
|
||||
|
||||
Args:
|
||||
function_id: Built-in function identifier
|
||||
name: Unique name for this attached function
|
||||
input_collection_id: Source collection that triggers the function
|
||||
output_collection: Target collection where function output is stored
|
||||
params: Optional dictionary with function-specific parameters
|
||||
tenant: The tenant name
|
||||
database: The database name
|
||||
|
||||
Returns:
|
||||
AttachedFunction: Object representing the attached function
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""Detach a function and prevent any further runs.
|
||||
|
||||
Args:
|
||||
attached_function_id: ID of the attached function to remove
|
||||
delete_output: Whether to also delete the output collection
|
||||
tenant: The tenant name
|
||||
database: The database name
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
pass
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,770 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence, Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from overrides import override
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
)
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
Documents,
|
||||
Embeddable,
|
||||
EmbeddingFunction,
|
||||
DataLoader,
|
||||
Embeddings,
|
||||
IDs,
|
||||
Include,
|
||||
Loadable,
|
||||
Metadatas,
|
||||
Schema,
|
||||
URIs,
|
||||
Where,
|
||||
QueryResult,
|
||||
GetResult,
|
||||
WhereDocument,
|
||||
IncludeMetadataDocuments,
|
||||
IncludeMetadataDocumentsDistances,
|
||||
SearchResult,
|
||||
DefaultEmbeddingFunction,
|
||||
)
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.config import Component, Settings
|
||||
from chromadb.types import Database, Tenant, Collection as CollectionModel
|
||||
|
||||
|
||||
class AsyncBaseAPI(ABC):
|
||||
@abstractmethod
|
||||
async def heartbeat(self) -> int:
|
||||
"""Get the current time in nanoseconds since epoch.
|
||||
Used to check if the server is alive.
|
||||
|
||||
Returns:
|
||||
int: The current time in nanoseconds since epoch
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
#
|
||||
# COLLECTION METHODS
|
||||
#
|
||||
|
||||
@abstractmethod
|
||||
async def count_collections(self) -> int:
|
||||
"""Count the number of collections.
|
||||
|
||||
Returns:
|
||||
int: The number of collections.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
await client.count_collections()
|
||||
# 1
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
) -> None:
|
||||
"""[Internal] Modify a collection by UUID. Can update the name and/or metadata.
|
||||
|
||||
Args:
|
||||
id: The internal UUID of the collection to modify.
|
||||
new_name: The new name of the collection.
|
||||
If None, the existing name will remain. Defaults to None.
|
||||
new_metadata: The new metadata to associate with the collection.
|
||||
Defaults to None.
|
||||
new_configuration: The new configuration to associate with the collection.
|
||||
Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete a collection with the given name.
|
||||
Args:
|
||||
name: The name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection does not exist.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
await client.delete_collection("my_collection")
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
#
|
||||
# ITEM METHODS
|
||||
#
|
||||
|
||||
@abstractmethod
|
||||
async def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
"""[Internal] Add embeddings to a collection specified by UUID.
|
||||
If (some) ids already exist, only the new embeddings will be added.
|
||||
|
||||
Args:
|
||||
ids: The ids to associate with the embeddings.
|
||||
collection_id: The UUID of the collection to add the embeddings to.
|
||||
embedding: The sequence of embeddings to add.
|
||||
metadata: The metadata to associate with the embeddings. Defaults to None.
|
||||
documents: The documents to associate with the embeddings. Defaults to None.
|
||||
uris: URIs of data sources for each embedding. Defaults to None.
|
||||
|
||||
Returns:
|
||||
True if the embeddings were added successfully.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
"""[Internal] Update entries in a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to update the embeddings in.
|
||||
ids: The IDs of the entries to update.
|
||||
embeddings: The sequence of embeddings to update. Defaults to None.
|
||||
metadatas: The metadata to associate with the embeddings. Defaults to None.
|
||||
documents: The documents to associate with the embeddings. Defaults to None.
|
||||
uris: URIs of data sources for each embedding. Defaults to None.
|
||||
Returns:
|
||||
True if the embeddings were updated successfully.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
"""[Internal] Add or update entries in the a collection specified by UUID.
|
||||
If an entry with the same id already exists, it will be updated,
|
||||
otherwise it will be added.
|
||||
|
||||
Args:
|
||||
collection_id: The collection to add the embeddings to
|
||||
ids: The ids to associate with the embeddings. Defaults to None.
|
||||
embeddings: The sequence of embeddings to add
|
||||
metadatas: The metadata to associate with the embeddings. Defaults to None.
|
||||
documents: The documents to associate with the embeddings. Defaults to None.
|
||||
uris: URIs of data sources for each embedding. Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _count(self, collection_id: UUID) -> int:
|
||||
"""[Internal] Returns the number of entries in a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to count the embeddings in.
|
||||
|
||||
Returns:
|
||||
int: The number of embeddings in the collection
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
|
||||
"""[Internal] Returns the first n entries in a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to peek into.
|
||||
n: The number of entries to peek. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
GetResult: The first n entries in the collection.
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
) -> GetResult:
|
||||
"""[Internal] Returns entries from a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
ids: The IDs of the entries to get. Defaults to None.
|
||||
where: Conditional filtering on metadata. Defaults to None.
|
||||
limit: The maximum number of entries to return. Defaults to None.
|
||||
offset: The number of entries to skip before returning. Defaults to None.
|
||||
where_document: Conditional filtering on documents. Defaults to None.
|
||||
include: The fields to include in the response.
|
||||
Defaults to ["embeddings", "metadatas", "documents"].
|
||||
Returns:
|
||||
GetResult: The entries in the collection that match the query.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs],
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
) -> None:
|
||||
"""[Internal] Deletes entries from a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to delete the entries from.
|
||||
ids: The IDs of the entries to delete. Defaults to None.
|
||||
where: Conditional filtering on metadata. Defaults to None.
|
||||
where_document: Conditional filtering on documents. Defaults to None.
|
||||
|
||||
Returns:
|
||||
IDs: The list of IDs of the entries that were deleted.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
) -> QueryResult:
|
||||
"""[Internal] Performs a nearest neighbors query on a collection specified by UUID.
|
||||
|
||||
Args:
|
||||
collection_id: The UUID of the collection to query.
|
||||
query_embeddings: The embeddings to use as the query.
|
||||
n_results: The number of results to return. Defaults to 10.
|
||||
where: Conditional filtering on metadata. Defaults to None.
|
||||
where_document: Conditional filtering on documents. Defaults to None.
|
||||
include: The fields to include in the response.
|
||||
Defaults to ["embeddings", "metadatas", "documents", "distances"].
|
||||
|
||||
Returns:
|
||||
QueryResult: The results of the query.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> bool:
|
||||
"""Resets the database. This will delete all collections and entries.
|
||||
|
||||
Returns:
|
||||
bool: True if the database was reset successfully.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_version(self) -> str:
|
||||
"""Get the version of Chroma.
|
||||
|
||||
Returns:
|
||||
str: The version of Chroma
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_settings(self) -> Settings:
|
||||
"""Get the settings used to initialize.
|
||||
|
||||
Returns:
|
||||
Settings: The settings used to initialize.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_max_batch_size(self) -> int:
|
||||
"""Return the maximum number of records that can be created or mutated in a single call."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_identity(self) -> UserIdentity:
|
||||
"""Resolve the tenant and databases for the client. Returns the default
|
||||
values if can't be resolved.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncClientAPI(AsyncBaseAPI, ABC):
|
||||
tenant: str
|
||||
database: str
|
||||
|
||||
@abstractmethod
|
||||
async def list_collections(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> Sequence[AsyncCollection]:
|
||||
"""List all collections.
|
||||
Args:
|
||||
limit: The maximum number of entries to return. Defaults to None.
|
||||
offset: The number of entries to skip before returning. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Sequence[AsyncCollection]: A list of collections.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
await client.list_collections()
|
||||
# [collection(name="my_collection", metadata={})]
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
get_or_create: bool = False,
|
||||
) -> AsyncCollection:
|
||||
"""Create a new collection with the given name and metadata.
|
||||
Args:
|
||||
name: The name of the collection to create.
|
||||
metadata: Optional metadata to associate with the collection.
|
||||
embedding_function: Optional function to use to embed documents.
|
||||
Uses the default embedding function if not provided.
|
||||
get_or_create: If True, return the existing collection if it exists.
|
||||
data_loader: Optional function to use to load records (documents, images, etc.)
|
||||
|
||||
Returns:
|
||||
Collection: The newly created collection.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection already exists and get_or_create is False.
|
||||
ValueError: If the collection name is invalid.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
await client.create_collection("my_collection")
|
||||
# collection(name="my_collection", metadata={})
|
||||
|
||||
await client.create_collection("my_collection", metadata={"foo": "bar"})
|
||||
# collection(name="my_collection", metadata={"foo": "bar"})
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
) -> AsyncCollection:
|
||||
"""Get a collection with the given name.
|
||||
Args:
|
||||
name: The name of the collection to get
|
||||
embedding_function: Optional function to use to embed documents.
|
||||
Uses the default embedding function if not provided.
|
||||
data_loader: Optional function to use to load records (documents, images, etc.)
|
||||
|
||||
Returns:
|
||||
Collection: The collection
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection does not exist
|
||||
|
||||
Examples:
|
||||
```python
|
||||
await client.get_collection("my_collection")
|
||||
# collection(name="my_collection", metadata={})
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
) -> AsyncCollection:
|
||||
"""Get or create a collection with the given name and metadata.
|
||||
Args:
|
||||
name: The name of the collection to get or create
|
||||
metadata: Optional metadata to associate with the collection. If
|
||||
the collection already exists, the metadata provided is ignored.
|
||||
If the collection does not exist, the new collection will be created
|
||||
with the provided metadata.
|
||||
embedding_function: Optional function to use to embed documents
|
||||
data_loader: Optional function to use to load records (documents, images, etc.)
|
||||
|
||||
Returns:
|
||||
The collection
|
||||
|
||||
Examples:
|
||||
```python
|
||||
await client.get_or_create_collection("my_collection")
|
||||
# collection(name="my_collection", metadata={})
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
|
||||
"""Set the tenant and database for the client. Raises an error if the tenant or
|
||||
database does not exist.
|
||||
|
||||
Args:
|
||||
tenant: The tenant to set.
|
||||
database: The database to set.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_database(self, database: str) -> None:
|
||||
"""Set the database for the client. Raises an error if the database does not exist.
|
||||
|
||||
Args:
|
||||
database: The database to set.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def clear_system_cache() -> None:
|
||||
"""Clear the system cache so that new systems can be created for an existing path.
|
||||
This should only be used for testing purposes."""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncAdminAPI(ABC):
|
||||
@abstractmethod
|
||||
async def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
"""Create a new database. Raises an error if the database already exists.
|
||||
|
||||
Args:
|
||||
database: The name of the database to create.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
||||
"""Get a database. Raises an error if the database does not exist.
|
||||
|
||||
Args:
|
||||
database: The name of the database to get.
|
||||
tenant: The tenant of the database to get.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
"""Delete a database. Raises an error if the database does not exist.
|
||||
|
||||
Args:
|
||||
database: The name of the database to delete.
|
||||
tenant: The tenant of the database to delete.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
"""List all databases for a tenant. Raises an error if the tenant does not exist.
|
||||
|
||||
Args:
|
||||
tenant: The tenant to list databases for.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_tenant(self, name: str) -> None:
|
||||
"""Create a new tenant. Raises an error if the tenant already exists.
|
||||
|
||||
Args:
|
||||
tenant: The name of the tenant to create.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_tenant(self, name: str) -> Tenant:
|
||||
"""Get a tenant. Raises an error if the tenant does not exist.
|
||||
|
||||
Args:
|
||||
tenant: The name of the tenant to get.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncServerAPI(AsyncBaseAPI, AsyncAdminAPI, Component):
|
||||
"""An API instance that extends the relevant Base API methods by passing
|
||||
in a tenant and database. This is the root component of the Chroma System"""
|
||||
|
||||
@abstractmethod
|
||||
async def list_collections(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Sequence[CollectionModel]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def count_collections(
|
||||
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
||||
) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _fork(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
new_name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _search(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
searches: List[Search],
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> SearchResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def _count(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def _peek(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
n: int = 10,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> QueryResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@override
|
||||
async def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,526 @@
|
||||
import httpx
|
||||
from typing import Optional, Sequence
|
||||
from uuid import UUID
|
||||
from overrides import override
|
||||
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.auth.utils import maybe_set_tenant_and_database
|
||||
from chromadb.api import AsyncAdminAPI, AsyncClientAPI, AsyncServerAPI
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
validate_embedding_function_conflict_on_create,
|
||||
validate_embedding_function_conflict_on_get,
|
||||
)
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.shared_system_client import SharedSystemClient
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
DataLoader,
|
||||
Documents,
|
||||
Embeddable,
|
||||
EmbeddingFunction,
|
||||
Embeddings,
|
||||
GetResult,
|
||||
IDs,
|
||||
Include,
|
||||
IncludeMetadataDocuments,
|
||||
IncludeMetadataDocumentsDistances,
|
||||
Loadable,
|
||||
Metadatas,
|
||||
QueryResult,
|
||||
Schema,
|
||||
URIs,
|
||||
DefaultEmbeddingFunction,
|
||||
)
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
|
||||
from chromadb.errors import ChromaError
|
||||
from chromadb.types import Database, Tenant, Where, WhereDocument
|
||||
|
||||
|
||||
class AsyncClient(SharedSystemClient, AsyncClientAPI):
|
||||
"""A client for Chroma. This is the main entrypoint for interacting with Chroma.
|
||||
A client internally stores its tenant and database and proxies calls to a
|
||||
Server API instance of Chroma. It treats the Server API and corresponding System
|
||||
as a singleton, so multiple clients connecting to the same resource will share the
|
||||
same API instance.
|
||||
|
||||
Client implementations should be implement their own API-caching strategies.
|
||||
"""
|
||||
|
||||
# An internal admin client for verifying that databases and tenants exist
|
||||
_admin_client: AsyncAdminAPI
|
||||
|
||||
tenant: str = DEFAULT_TENANT
|
||||
database: str = DEFAULT_DATABASE
|
||||
|
||||
_server: AsyncServerAPI
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
settings: Settings = Settings(),
|
||||
) -> "AsyncClient":
|
||||
# Create an admin client for verifying that databases and tenants exist
|
||||
self = cls(settings=settings)
|
||||
SharedSystemClient._populate_data_from_system(self._system)
|
||||
|
||||
self.tenant = tenant
|
||||
self.database = database
|
||||
|
||||
# Get the root system component we want to interact with
|
||||
self._server = self._system.instance(AsyncServerAPI)
|
||||
|
||||
user_identity = await self.get_user_identity()
|
||||
|
||||
maybe_tenant, maybe_database = maybe_set_tenant_and_database(
|
||||
user_identity,
|
||||
overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
|
||||
user_provided_tenant=tenant,
|
||||
user_provided_database=database,
|
||||
)
|
||||
if maybe_tenant:
|
||||
self.tenant = maybe_tenant
|
||||
if maybe_database:
|
||||
self.database = maybe_database
|
||||
|
||||
self._admin_client = AsyncAdminClient.from_system(self._system)
|
||||
await self._validate_tenant_database(tenant=self.tenant, database=self.database)
|
||||
|
||||
self._submit_client_start_event()
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
# (we can't override and use from_system() because it's synchronous)
|
||||
async def from_system_async(
|
||||
cls,
|
||||
system: System,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AsyncClient":
|
||||
"""Create a client from an existing system. This is useful for testing and debugging."""
|
||||
return await AsyncClient.create(tenant, database, system.settings)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_system(
|
||||
cls,
|
||||
system: System,
|
||||
) -> "SharedSystemClient":
|
||||
"""AsyncClient cannot be created synchronously. Use .from_system_async() instead."""
|
||||
raise NotImplementedError(
|
||||
"AsyncClient cannot be created synchronously. Use .from_system_async() instead."
|
||||
)
|
||||
|
||||
@override
|
||||
async def get_user_identity(self) -> UserIdentity:
|
||||
return await self._server.get_user_identity()
|
||||
|
||||
@override
|
||||
async def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
|
||||
await self._validate_tenant_database(tenant=tenant, database=database)
|
||||
self.tenant = tenant
|
||||
self.database = database
|
||||
|
||||
@override
|
||||
async def set_database(self, database: str) -> None:
|
||||
await self._validate_tenant_database(tenant=self.tenant, database=database)
|
||||
self.database = database
|
||||
|
||||
async def _validate_tenant_database(self, tenant: str, database: str) -> None:
|
||||
try:
|
||||
await self._admin_client.get_tenant(name=tenant)
|
||||
except httpx.ConnectError:
|
||||
raise ValueError(
|
||||
"Could not connect to a Chroma server. Are you sure it is running?"
|
||||
)
|
||||
# Propagate ChromaErrors
|
||||
except ChromaError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Could not connect to tenant {tenant}. Are you sure it exists?"
|
||||
)
|
||||
|
||||
try:
|
||||
await self._admin_client.get_database(name=database, tenant=tenant)
|
||||
except httpx.ConnectError:
|
||||
raise ValueError(
|
||||
"Could not connect to a Chroma server. Are you sure it is running?"
|
||||
)
|
||||
|
||||
# region BaseAPI Methods
|
||||
# Note - we could do this in less verbose ways, but they break type checking
|
||||
@override
|
||||
async def heartbeat(self) -> int:
|
||||
return await self._server.heartbeat()
|
||||
|
||||
@override
|
||||
async def list_collections(
|
||||
self, limit: Optional[int] = None, offset: Optional[int] = None
|
||||
) -> Sequence[AsyncCollection]:
|
||||
models = await self._server.list_collections(
|
||||
limit, offset, tenant=self.tenant, database=self.database
|
||||
)
|
||||
return [AsyncCollection(client=self._server, model=model) for model in models]
|
||||
|
||||
@override
|
||||
async def count_collections(self) -> int:
|
||||
return await self._server.count_collections(
|
||||
tenant=self.tenant, database=self.database
|
||||
)
|
||||
|
||||
@override
|
||||
async def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
get_or_create: bool = False,
|
||||
) -> AsyncCollection:
|
||||
if configuration is None:
|
||||
configuration = {}
|
||||
|
||||
configuration_ef = configuration.get("embedding_function")
|
||||
|
||||
validate_embedding_function_conflict_on_create(
|
||||
embedding_function, configuration_ef
|
||||
)
|
||||
|
||||
# If ef provided in function params and collection config ef is None,
|
||||
# set the collection config ef to the function params
|
||||
if embedding_function is not None and configuration_ef is None:
|
||||
configuration["embedding_function"] = embedding_function
|
||||
|
||||
model = await self._server.create_collection(
|
||||
name=name,
|
||||
schema=schema,
|
||||
configuration=configuration,
|
||||
metadata=metadata,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
get_or_create=get_or_create,
|
||||
)
|
||||
return AsyncCollection(
|
||||
client=self._server,
|
||||
model=model,
|
||||
embedding_function=embedding_function,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
@override
|
||||
async def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
) -> AsyncCollection:
|
||||
model = await self._server.get_collection(
|
||||
name=name,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
persisted_ef_config = model.configuration_json.get("embedding_function")
|
||||
|
||||
validate_embedding_function_conflict_on_get(
|
||||
embedding_function, persisted_ef_config
|
||||
)
|
||||
|
||||
return AsyncCollection(
|
||||
client=self._server,
|
||||
model=model,
|
||||
embedding_function=embedding_function,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
@override
|
||||
async def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
) -> AsyncCollection:
|
||||
if configuration is None:
|
||||
configuration = {}
|
||||
|
||||
configuration_ef = configuration.get("embedding_function")
|
||||
|
||||
validate_embedding_function_conflict_on_create(
|
||||
embedding_function, configuration_ef
|
||||
)
|
||||
|
||||
if embedding_function is not None and configuration_ef is None:
|
||||
configuration["embedding_function"] = embedding_function
|
||||
model = await self._server.get_or_create_collection(
|
||||
name=name,
|
||||
schema=schema,
|
||||
configuration=configuration,
|
||||
metadata=metadata,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
persisted_ef_config = model.configuration_json.get("embedding_function")
|
||||
|
||||
validate_embedding_function_conflict_on_get(
|
||||
embedding_function, persisted_ef_config
|
||||
)
|
||||
|
||||
return AsyncCollection(
|
||||
client=self._server,
|
||||
model=model,
|
||||
embedding_function=embedding_function,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
) -> None:
|
||||
return await self._server._modify(
|
||||
id=id,
|
||||
new_name=new_name,
|
||||
new_metadata=new_metadata,
|
||||
new_configuration=new_configuration,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@override
|
||||
async def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
) -> None:
|
||||
return await self._server.delete_collection(
|
||||
name=name,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
#
|
||||
# ITEM METHODS
|
||||
#
|
||||
|
||||
@override
|
||||
async def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
return await self._server._add(
|
||||
ids=ids,
|
||||
collection_id=collection_id,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
uris=uris,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
return await self._server._update(
|
||||
collection_id=collection_id,
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
uris=uris,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
return await self._server._upsert(
|
||||
collection_id=collection_id,
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
uris=uris,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _count(self, collection_id: UUID) -> int:
|
||||
return await self._server._count(
|
||||
collection_id=collection_id,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
|
||||
return await self._server._peek(
|
||||
collection_id=collection_id,
|
||||
n=n,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
) -> GetResult:
|
||||
return await self._server._get(
|
||||
collection_id=collection_id,
|
||||
ids=ids,
|
||||
where=where,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
async def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs],
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
) -> None:
|
||||
await self._server._delete(
|
||||
collection_id=collection_id,
|
||||
ids=ids,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
) -> QueryResult:
|
||||
return await self._server._query(
|
||||
collection_id=collection_id,
|
||||
query_embeddings=query_embeddings,
|
||||
ids=ids,
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@override
|
||||
async def reset(self) -> bool:
|
||||
return await self._server.reset()
|
||||
|
||||
@override
|
||||
async def get_version(self) -> str:
|
||||
return await self._server.get_version()
|
||||
|
||||
@override
|
||||
def get_settings(self) -> Settings:
|
||||
return self._server.get_settings()
|
||||
|
||||
@override
|
||||
async def get_max_batch_size(self) -> int:
|
||||
return await self._server.get_max_batch_size()
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
class AsyncAdminClient(SharedSystemClient, AsyncAdminAPI):
|
||||
_server: AsyncServerAPI
|
||||
|
||||
def __init__(self, settings: Settings = Settings()) -> None:
|
||||
super().__init__(settings)
|
||||
self._server = self._system.instance(AsyncServerAPI)
|
||||
|
||||
@override
|
||||
async def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
return await self._server.create_database(name=name, tenant=tenant)
|
||||
|
||||
@override
|
||||
async def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
||||
return await self._server.get_database(name=name, tenant=tenant)
|
||||
|
||||
@override
|
||||
async def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
return await self._server.delete_database(name=name, tenant=tenant)
|
||||
|
||||
@override
|
||||
async def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
return await self._server.list_databases(
|
||||
limit=limit, offset=offset, tenant=tenant
|
||||
)
|
||||
|
||||
@override
|
||||
async def create_tenant(self, name: str) -> None:
|
||||
return await self._server.create_tenant(name=name)
|
||||
|
||||
@override
|
||||
async def get_tenant(self, name: str) -> Tenant:
|
||||
return await self._server.get_tenant(name=name)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_system(
|
||||
cls,
|
||||
system: System,
|
||||
) -> "AsyncAdminClient":
|
||||
SharedSystemClient._populate_data_from_system(system)
|
||||
instance = cls(settings=system.settings)
|
||||
return instance
|
||||
@@ -0,0 +1,773 @@
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
import urllib.parse
|
||||
import orjson
|
||||
from typing import Any, Optional, cast, Tuple, Sequence, Dict, List
|
||||
import logging
|
||||
import httpx
|
||||
from overrides import override
|
||||
from chromadb import __version__
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.api.async_api import AsyncServerAPI
|
||||
from chromadb.api.base_http_client import BaseHTTPClient
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
create_collection_configuration_to_json,
|
||||
update_collection_configuration_to_json,
|
||||
)
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, Settings
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.telemetry.product import ProductTelemetryClient
|
||||
from chromadb.utils.async_to_sync import async_to_sync
|
||||
from chromadb.types import Database, Tenant, Collection as CollectionModel
|
||||
from chromadb.execution.expression.plan import Search
|
||||
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Embeddings,
|
||||
IDs,
|
||||
Include,
|
||||
Schema,
|
||||
Metadatas,
|
||||
URIs,
|
||||
Where,
|
||||
WhereDocument,
|
||||
GetResult,
|
||||
QueryResult,
|
||||
SearchResult,
|
||||
CollectionMetadata,
|
||||
optional_embeddings_to_base64_strings,
|
||||
validate_batch,
|
||||
convert_np_embeddings_to_list,
|
||||
IncludeMetadataDocuments,
|
||||
IncludeMetadataDocumentsDistances,
|
||||
)
|
||||
|
||||
from chromadb.api.types import (
|
||||
IncludeMetadataDocumentsEmbeddings,
|
||||
serialize_metadata,
|
||||
deserialize_metadata,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncFastAPI(BaseHTTPClient, AsyncServerAPI):
|
||||
# We make one client per event loop to avoid unexpected issues if a client
|
||||
# is shared between event loops.
|
||||
# For example, if a client is constructed in the main thread, then passed
|
||||
# (or a returned Collection is passed) to a new thread, the client would
|
||||
# normally throw an obscure asyncio error.
|
||||
# Mixing asyncio and threading in this manner usually discouraged, but
|
||||
# this gives a better user experience with practically no downsides.
|
||||
# https://github.com/encode/httpx/issues/2058
|
||||
_clients: Dict[int, httpx.AsyncClient] = {}
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
|
||||
system.settings.require("chroma_server_host")
|
||||
system.settings.require("chroma_server_http_port")
|
||||
|
||||
self._opentelemetry_client = self.require(OpenTelemetryClient)
|
||||
self._product_telemetry_client = self.require(ProductTelemetryClient)
|
||||
self._settings = system.settings
|
||||
|
||||
self._api_url = AsyncFastAPI.resolve_url(
|
||||
chroma_server_host=str(system.settings.chroma_server_host),
|
||||
chroma_server_http_port=system.settings.chroma_server_http_port,
|
||||
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
|
||||
default_api_path=system.settings.chroma_server_api_default_path,
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> "AsyncFastAPI":
|
||||
self._get_client()
|
||||
return self
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
while len(self._clients) > 0:
|
||||
(_, client) = self._clients.popitem()
|
||||
await client.aclose()
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
await self._cleanup()
|
||||
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
|
||||
@async_to_sync
|
||||
async def sync_cleanup() -> None:
|
||||
await self._cleanup()
|
||||
|
||||
sync_cleanup()
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
# Ideally this would use anyio to be compatible with both
|
||||
# asyncio and trio, but anyio does not expose any way to identify
|
||||
# the current event loop.
|
||||
# We attempt to get the loop assuming the environment is asyncio, and
|
||||
# otherwise gracefully fall back to using a singleton client.
|
||||
loop_hash = None
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop_hash = loop.__hash__()
|
||||
except RuntimeError:
|
||||
loop_hash = 0
|
||||
|
||||
if loop_hash not in self._clients:
|
||||
headers = (self._settings.chroma_server_headers or {}).copy()
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["User-Agent"] = (
|
||||
"Chroma Python Client v"
|
||||
+ __version__
|
||||
+ " (https://github.com/chroma-core/chroma)"
|
||||
)
|
||||
|
||||
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
|
||||
self._clients[loop_hash] = httpx.AsyncClient(
|
||||
timeout=None,
|
||||
headers=headers,
|
||||
verify=self._settings.chroma_server_ssl_verify or False,
|
||||
limits=limits,
|
||||
)
|
||||
|
||||
return self._clients[loop_hash]
|
||||
|
||||
async def _make_request(
|
||||
self, method: str, path: str, **kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
# If the request has json in kwargs, use orjson to serialize it,
|
||||
# remove it from kwargs, and add it to the content parameter
|
||||
# This is because httpx uses a slower json serializer
|
||||
if "json" in kwargs:
|
||||
data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY)
|
||||
kwargs["content"] = data
|
||||
|
||||
# Unlike requests, httpx does not automatically escape the path
|
||||
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
|
||||
url = self._api_url + escaped_path
|
||||
|
||||
response = await self._get_client().request(method, url, **cast(Any, kwargs))
|
||||
BaseHTTPClient._raise_chroma_error(response)
|
||||
return orjson.loads(response.text)
|
||||
|
||||
@trace_method("AsyncFastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def heartbeat(self) -> int:
|
||||
response = await self._make_request("get", "")
|
||||
return int(response["nanosecond heartbeat"])
|
||||
|
||||
@trace_method("AsyncFastAPI.create_database", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def create_database(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> None:
|
||||
await self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases",
|
||||
json={"name": name},
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI.get_database", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def get_database(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Database:
|
||||
response = await self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{name}",
|
||||
params={"tenant": tenant},
|
||||
)
|
||||
|
||||
return Database(
|
||||
id=response["id"], name=response["name"], tenant=response["tenant"]
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI.delete_database", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def delete_database(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> None:
|
||||
await self._make_request(
|
||||
"delete",
|
||||
f"/tenants/{tenant}/databases/{name}",
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
response = await self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases",
|
||||
params=BaseHTTPClient._clean_params(
|
||||
{
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
return [
|
||||
Database(id=db["id"], name=db["name"], tenant=db["tenant"])
|
||||
for db in response
|
||||
]
|
||||
|
||||
@trace_method("AsyncFastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def create_tenant(self, name: str) -> None:
|
||||
await self._make_request(
|
||||
"post",
|
||||
"/tenants",
|
||||
json={"name": name},
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def get_tenant(self, name: str) -> Tenant:
|
||||
resp_json = await self._make_request(
|
||||
"get",
|
||||
"/tenants/" + name,
|
||||
)
|
||||
|
||||
return Tenant(name=resp_json["name"])
|
||||
|
||||
@trace_method("AsyncFastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def get_user_identity(self) -> UserIdentity:
|
||||
return UserIdentity(**(await self._make_request("get", "/auth/identity")))
|
||||
|
||||
@trace_method("AsyncFastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def list_collections(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Sequence[CollectionModel]:
|
||||
resp_json = await self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections",
|
||||
params=BaseHTTPClient._clean_params(
|
||||
{
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
models = [
|
||||
CollectionModel.from_json(json_collection) for json_collection in resp_json
|
||||
]
|
||||
return models
|
||||
|
||||
@trace_method("AsyncFastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def count_collections(
|
||||
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
||||
) -> int:
|
||||
resp_json = await self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections_count",
|
||||
)
|
||||
|
||||
return cast(int, resp_json)
|
||||
|
||||
@trace_method("AsyncFastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
"""Creates a collection"""
|
||||
config_json = (
|
||||
create_collection_configuration_to_json(configuration, metadata)
|
||||
if configuration
|
||||
else None
|
||||
)
|
||||
serialized_schema = schema.serialize_to_json() if schema else None
|
||||
resp_json = await self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections",
|
||||
json={
|
||||
"name": name,
|
||||
"metadata": metadata,
|
||||
"configuration": config_json,
|
||||
"schema": serialized_schema,
|
||||
"get_or_create": get_or_create,
|
||||
},
|
||||
)
|
||||
model = CollectionModel.from_json(resp_json)
|
||||
|
||||
return model
|
||||
|
||||
@trace_method("AsyncFastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
resp_json = await self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{name}",
|
||||
)
|
||||
|
||||
model = CollectionModel.from_json(resp_json)
|
||||
|
||||
return model
|
||||
|
||||
@trace_method(
|
||||
"AsyncFastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
|
||||
)
|
||||
@override
|
||||
async def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
return await self.create_collection(
|
||||
name=name,
|
||||
schema=schema,
|
||||
configuration=configuration,
|
||||
metadata=metadata,
|
||||
get_or_create=True,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI._modify", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
await self._make_request(
|
||||
"put",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{id}",
|
||||
json={
|
||||
"new_metadata": new_metadata,
|
||||
"new_name": new_name,
|
||||
"new_configuration": update_collection_configuration_to_json(
|
||||
new_configuration
|
||||
)
|
||||
if new_configuration
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI._fork", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def _fork(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
new_name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
resp_json = await self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
|
||||
json={"new_name": new_name},
|
||||
)
|
||||
model = CollectionModel.from_json(resp_json)
|
||||
return model
|
||||
|
||||
@trace_method("AsyncFastAPI._search", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def _search(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
searches: List[Search],
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> SearchResult:
|
||||
"""Performs hybrid search on a collection"""
|
||||
payload = {"searches": [s.to_dict() for s in searches]}
|
||||
|
||||
resp_json = await self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/search",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
metadata_batches = resp_json.get("metadatas", None)
|
||||
if metadata_batches is not None:
|
||||
resp_json["metadatas"] = [
|
||||
[
|
||||
deserialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in metadatas
|
||||
]
|
||||
if metadatas is not None
|
||||
else None
|
||||
for metadatas in metadata_batches
|
||||
]
|
||||
|
||||
return SearchResult(resp_json)
|
||||
|
||||
@trace_method("AsyncFastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
await self._make_request(
|
||||
"delete",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{name}",
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI._count", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def _count(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> int:
|
||||
"""Returns the number of embeddings in the database"""
|
||||
resp_json = await self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/count",
|
||||
)
|
||||
|
||||
return cast(int, resp_json)
|
||||
|
||||
@trace_method("AsyncFastAPI._peek", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def _peek(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
n: int = 10,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
resp = await self._get(
|
||||
collection_id,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
limit=n,
|
||||
include=IncludeMetadataDocumentsEmbeddings,
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
@trace_method("AsyncFastAPI._get", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
# Servers do not support the "data" include, as that is hydrated on the client side
|
||||
filtered_include = [i for i in include if i != "data"]
|
||||
|
||||
resp_json = await self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/get",
|
||||
json={
|
||||
"ids": ids,
|
||||
"where": where,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"where_document": where_document,
|
||||
"include": filtered_include,
|
||||
},
|
||||
)
|
||||
|
||||
metadatas = resp_json.get("metadatas", None)
|
||||
if metadatas is not None:
|
||||
metadatas = [
|
||||
deserialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in metadatas
|
||||
]
|
||||
|
||||
return GetResult(
|
||||
ids=resp_json["ids"],
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadatas, # type: ignore
|
||||
documents=resp_json.get("documents", None),
|
||||
data=None,
|
||||
uris=resp_json.get("uris", None),
|
||||
included=include,
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI._delete", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
await self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete",
|
||||
json={"where": where, "ids": ids, "where_document": where_document},
|
||||
)
|
||||
return None
|
||||
|
||||
@trace_method("AsyncFastAPI._submit_batch", OpenTelemetryGranularity.ALL)
|
||||
async def _submit_batch(
|
||||
self,
|
||||
batch: Tuple[
|
||||
IDs,
|
||||
Optional[Embeddings],
|
||||
Optional[Metadatas],
|
||||
Optional[Documents],
|
||||
Optional[URIs],
|
||||
],
|
||||
url: str,
|
||||
) -> Any:
|
||||
"""
|
||||
Submits a batch of embeddings to the database
|
||||
"""
|
||||
supports_base64_encoding = await self.supports_base64_encoding()
|
||||
|
||||
serialized_metadatas = None
|
||||
if batch[2] is not None:
|
||||
serialized_metadatas = [
|
||||
serialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in batch[2]
|
||||
]
|
||||
|
||||
data = {
|
||||
"ids": batch[0],
|
||||
"embeddings": optional_embeddings_to_base64_strings(batch[1])
|
||||
if supports_base64_encoding
|
||||
else batch[1],
|
||||
"metadatas": serialized_metadatas,
|
||||
"documents": batch[3],
|
||||
"uris": batch[4],
|
||||
}
|
||||
|
||||
return await self._make_request(
|
||||
"post",
|
||||
url,
|
||||
json=data,
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI._add", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
async def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
batch = (
|
||||
ids,
|
||||
embeddings,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
)
|
||||
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
|
||||
await self._submit_batch(
|
||||
batch,
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/add",
|
||||
)
|
||||
return True
|
||||
|
||||
@trace_method("AsyncFastAPI._update", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
async def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
batch = (
|
||||
ids,
|
||||
embeddings if embeddings is not None else None,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
)
|
||||
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
|
||||
|
||||
await self._submit_batch(
|
||||
batch,
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/update",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@trace_method("AsyncFastAPI._upsert", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
async def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
batch = (
|
||||
ids,
|
||||
embeddings,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
)
|
||||
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
|
||||
await self._submit_batch(
|
||||
batch,
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/upsert",
|
||||
)
|
||||
return True
|
||||
|
||||
@trace_method("AsyncFastAPI._query", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
async def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> QueryResult:
|
||||
# Servers do not support the "data" include, as that is hydrated on the client side
|
||||
filtered_include = [i for i in include if i != "data"]
|
||||
|
||||
resp_json = await self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
|
||||
json={
|
||||
"ids": ids,
|
||||
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
|
||||
if query_embeddings is not None
|
||||
else None,
|
||||
"n_results": n_results,
|
||||
"where": where,
|
||||
"where_document": where_document,
|
||||
"include": filtered_include,
|
||||
},
|
||||
)
|
||||
|
||||
metadata_batches = resp_json.get("metadatas", None)
|
||||
if metadata_batches is not None:
|
||||
metadata_batches = [
|
||||
[
|
||||
deserialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in metadatas
|
||||
]
|
||||
if metadatas is not None
|
||||
else None
|
||||
for metadatas in metadata_batches
|
||||
]
|
||||
|
||||
return QueryResult(
|
||||
ids=resp_json["ids"],
|
||||
distances=resp_json.get("distances", None),
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadata_batches, # type: ignore
|
||||
documents=resp_json.get("documents", None),
|
||||
uris=resp_json.get("uris", None),
|
||||
data=None,
|
||||
included=include,
|
||||
)
|
||||
|
||||
@trace_method("AsyncFastAPI.reset", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
async def reset(self) -> bool:
|
||||
resp_json = await self._make_request("post", "/reset")
|
||||
return cast(bool, resp_json)
|
||||
|
||||
@trace_method("AsyncFastAPI.get_version", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def get_version(self) -> str:
|
||||
resp_json = await self._make_request("get", "/version")
|
||||
return cast(str, resp_json)
|
||||
|
||||
@override
|
||||
def get_settings(self) -> Settings:
|
||||
return self._settings
|
||||
|
||||
@trace_method(
|
||||
"AsyncFastAPI.get_pre_flight_checks", OpenTelemetryGranularity.OPERATION
|
||||
)
|
||||
async def get_pre_flight_checks(self) -> Any:
|
||||
if self.pre_flight_checks is None:
|
||||
resp_json = await self._make_request("get", "/pre-flight-checks")
|
||||
self.pre_flight_checks = resp_json
|
||||
return self.pre_flight_checks
|
||||
|
||||
@trace_method(
|
||||
"AsyncFastAPI.supports_base64_encoding", OpenTelemetryGranularity.OPERATION
|
||||
)
|
||||
async def supports_base64_encoding(self) -> bool:
|
||||
pre_flight_checks = await self.get_pre_flight_checks()
|
||||
b64_encoding_enabled = cast(
|
||||
bool, pre_flight_checks.get("supports_base64_encoding", False)
|
||||
)
|
||||
return b64_encoding_enabled
|
||||
|
||||
@trace_method("AsyncFastAPI.get_max_batch_size", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
async def get_max_batch_size(self) -> int:
|
||||
pre_flight_checks = await self.get_pre_flight_checks()
|
||||
max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1))
|
||||
return max_batch_size
|
||||
@@ -0,0 +1,105 @@
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
from urllib.parse import quote, urlparse, urlunparse
|
||||
import logging
|
||||
import orjson as json
|
||||
import httpx
|
||||
|
||||
import chromadb.errors as errors
|
||||
from chromadb.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseHTTPClient:
|
||||
_settings: Settings
|
||||
pre_flight_checks: Any = None
|
||||
keepalive_secs: int = 40
|
||||
|
||||
@staticmethod
|
||||
def _validate_host(host: str) -> None:
|
||||
parsed = urlparse(host)
|
||||
if "/" in host and parsed.scheme not in {"http", "https"}:
|
||||
raise ValueError(
|
||||
"Invalid URL. " f"Unrecognized protocol - {parsed.scheme}."
|
||||
)
|
||||
if "/" in host and (not host.startswith("http")):
|
||||
raise ValueError(
|
||||
"Invalid URL. "
|
||||
"Seems that you are trying to pass URL as a host but without \
|
||||
specifying the protocol. "
|
||||
"Please add http:// or https:// to the host."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def resolve_url(
|
||||
chroma_server_host: str,
|
||||
chroma_server_ssl_enabled: Optional[bool] = False,
|
||||
default_api_path: Optional[str] = "",
|
||||
chroma_server_http_port: Optional[int] = 8000,
|
||||
) -> str:
|
||||
_skip_port = False
|
||||
_chroma_server_host = chroma_server_host
|
||||
BaseHTTPClient._validate_host(_chroma_server_host)
|
||||
if _chroma_server_host.startswith("http"):
|
||||
logger.debug("Skipping port as the user is passing a full URL")
|
||||
_skip_port = True
|
||||
parsed = urlparse(_chroma_server_host)
|
||||
|
||||
scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
|
||||
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
|
||||
port = (
|
||||
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
|
||||
)
|
||||
path = parsed.path or default_api_path
|
||||
|
||||
if not path or path == net_loc:
|
||||
path = default_api_path if default_api_path else ""
|
||||
if not path.endswith(default_api_path or ""):
|
||||
path = path + default_api_path if default_api_path else ""
|
||||
full_url = urlunparse(
|
||||
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
|
||||
)
|
||||
|
||||
return full_url
|
||||
|
||||
# requests removes None values from the built query string, but httpx includes it as an empty value
|
||||
T = TypeVar("T", bound=Dict[Any, Any])
|
||||
|
||||
@staticmethod
|
||||
def _clean_params(params: T) -> T:
|
||||
"""Remove None values from provided dict."""
|
||||
return {k: v for k, v in params.items() if v is not None} # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _raise_chroma_error(resp: httpx.Response) -> None:
|
||||
"""Raises an error if the response is not ok, using a ChromaError if possible."""
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
return
|
||||
except httpx.HTTPStatusError:
|
||||
pass
|
||||
|
||||
chroma_error = None
|
||||
try:
|
||||
body = json.loads(resp.text)
|
||||
if "error" in body:
|
||||
if body["error"] in errors.error_types:
|
||||
chroma_error = errors.error_types[body["error"]](body["message"])
|
||||
|
||||
trace_id = resp.headers.get("chroma-trace-id")
|
||||
if trace_id:
|
||||
chroma_error.trace_id = trace_id
|
||||
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
if chroma_error:
|
||||
raise chroma_error
|
||||
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
trace_id = resp.headers.get("chroma-trace-id")
|
||||
if trace_id:
|
||||
raise Exception(f"{resp.text} (trace ID: {trace_id})")
|
||||
raise (Exception(resp.text))
|
||||
@@ -0,0 +1,545 @@
|
||||
from typing import Optional, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from overrides import override
|
||||
import httpx
|
||||
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
validate_embedding_function_conflict_on_create,
|
||||
validate_embedding_function_conflict_on_get,
|
||||
)
|
||||
from chromadb.api.shared_system_client import SharedSystemClient
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
DataLoader,
|
||||
Documents,
|
||||
Embeddable,
|
||||
EmbeddingFunction,
|
||||
Embeddings,
|
||||
GetResult,
|
||||
IDs,
|
||||
Include,
|
||||
Loadable,
|
||||
Metadatas,
|
||||
QueryResult,
|
||||
Schema,
|
||||
URIs,
|
||||
IncludeMetadataDocuments,
|
||||
IncludeMetadataDocumentsDistances,
|
||||
DefaultEmbeddingFunction,
|
||||
)
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.auth.utils import maybe_set_tenant_and_database
|
||||
from chromadb.config import Settings, System
|
||||
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.errors import ChromaAuthError, ChromaError
|
||||
from chromadb.types import Database, Tenant, Where, WhereDocument
|
||||
|
||||
|
||||
class Client(SharedSystemClient, ClientAPI):
|
||||
"""A client for Chroma. This is the main entrypoint for interacting with Chroma.
|
||||
A client internally stores its tenant and database and proxies calls to a
|
||||
Server API instance of Chroma. It treats the Server API and corresponding System
|
||||
as a singleton, so multiple clients connecting to the same resource will share the
|
||||
same API instance.
|
||||
|
||||
Client implementations should be implement their own API-caching strategies.
|
||||
"""
|
||||
|
||||
tenant: str = DEFAULT_TENANT
|
||||
database: str = DEFAULT_DATABASE
|
||||
|
||||
_server: ServerAPI
|
||||
# An internal admin client for verifying that databases and tenants exist
|
||||
_admin_client: AdminAPI
|
||||
|
||||
# region Initialization
|
||||
def __init__(
|
||||
self,
|
||||
tenant: Optional[str] = DEFAULT_TENANT,
|
||||
database: Optional[str] = DEFAULT_DATABASE,
|
||||
settings: Settings = Settings(),
|
||||
) -> None:
|
||||
super().__init__(settings=settings)
|
||||
if tenant is not None:
|
||||
self.tenant = tenant
|
||||
if database is not None:
|
||||
self.database = database
|
||||
|
||||
# Get the root system component we want to interact with
|
||||
self._server = self._system.instance(ServerAPI)
|
||||
|
||||
user_identity = self.get_user_identity()
|
||||
|
||||
maybe_tenant, maybe_database = maybe_set_tenant_and_database(
|
||||
user_identity,
|
||||
overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
|
||||
user_provided_tenant=tenant,
|
||||
user_provided_database=database,
|
||||
)
|
||||
|
||||
# this should not happen unless types are invalidated
|
||||
if maybe_tenant is None and tenant is None:
|
||||
raise ChromaAuthError(
|
||||
"Could not determine a tenant from the current authentication method. Please provide a tenant."
|
||||
)
|
||||
if maybe_database is None and database is None:
|
||||
raise ChromaAuthError(
|
||||
"Could not determine a database name from the current authentication method. Please provide a database name."
|
||||
)
|
||||
|
||||
if maybe_tenant:
|
||||
self.tenant = maybe_tenant
|
||||
if maybe_database:
|
||||
self.database = maybe_database
|
||||
|
||||
# Create an admin client for verifying that databases and tenants exist
|
||||
self._admin_client = AdminClient.from_system(self._system)
|
||||
self._validate_tenant_database(tenant=self.tenant, database=self.database)
|
||||
|
||||
self._submit_client_start_event()
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_system(
|
||||
cls,
|
||||
system: System,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "Client":
|
||||
SharedSystemClient._populate_data_from_system(system)
|
||||
instance = cls(tenant=tenant, database=database, settings=system.settings)
|
||||
return instance
|
||||
|
||||
# endregion
|
||||
|
||||
@override
|
||||
def get_user_identity(self) -> UserIdentity:
|
||||
try:
|
||||
return self._server.get_user_identity()
|
||||
except httpx.ConnectError:
|
||||
raise ValueError(
|
||||
"Could not connect to a Chroma server. Are you sure it is running?"
|
||||
)
|
||||
# Propagate ChromaErrors
|
||||
except ChromaError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
# region BaseAPI Methods
|
||||
# Note - we could do this in less verbose ways, but they break type checking
|
||||
@override
|
||||
def heartbeat(self) -> int:
|
||||
return self._server.heartbeat()
|
||||
|
||||
@override
|
||||
def list_collections(
|
||||
self, limit: Optional[int] = None, offset: Optional[int] = None
|
||||
) -> Sequence[Collection]:
|
||||
return [
|
||||
Collection(client=self._server, model=model)
|
||||
for model in self._server.list_collections(
|
||||
limit, offset, tenant=self.tenant, database=self.database
|
||||
)
|
||||
]
|
||||
|
||||
@override
|
||||
def count_collections(self) -> int:
|
||||
return self._server.count_collections(
|
||||
tenant=self.tenant, database=self.database
|
||||
)
|
||||
|
||||
@override
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
get_or_create: bool = False,
|
||||
) -> Collection:
|
||||
if configuration is None:
|
||||
configuration = {}
|
||||
|
||||
configuration_ef = configuration.get("embedding_function")
|
||||
|
||||
validate_embedding_function_conflict_on_create(
|
||||
embedding_function, configuration_ef
|
||||
)
|
||||
|
||||
# If ef provided in function params and collection config ef is None,
|
||||
# set the collection config ef to the function params
|
||||
if embedding_function is not None and configuration_ef is None:
|
||||
configuration["embedding_function"] = embedding_function
|
||||
|
||||
model = self._server.create_collection(
|
||||
name=name,
|
||||
schema=schema,
|
||||
metadata=metadata,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
get_or_create=get_or_create,
|
||||
configuration=configuration,
|
||||
)
|
||||
return Collection(
|
||||
client=self._server,
|
||||
model=model,
|
||||
embedding_function=embedding_function,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
) -> Collection:
|
||||
model = self._server.get_collection(
|
||||
name=name,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
persisted_ef_config = model.configuration_json.get("embedding_function")
|
||||
|
||||
validate_embedding_function_conflict_on_get(
|
||||
embedding_function, persisted_ef_config
|
||||
)
|
||||
|
||||
return Collection(
|
||||
client=self._server,
|
||||
model=model,
|
||||
embedding_function=embedding_function,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
embedding_function: Optional[
|
||||
EmbeddingFunction[Embeddable]
|
||||
] = DefaultEmbeddingFunction(), # type: ignore
|
||||
data_loader: Optional[DataLoader[Loadable]] = None,
|
||||
) -> Collection:
|
||||
if configuration is None:
|
||||
configuration = {}
|
||||
|
||||
configuration_ef = configuration.get("embedding_function")
|
||||
|
||||
validate_embedding_function_conflict_on_create(
|
||||
embedding_function, configuration_ef
|
||||
)
|
||||
|
||||
if embedding_function is not None and configuration_ef is None:
|
||||
configuration["embedding_function"] = embedding_function
|
||||
model = self._server.get_or_create_collection(
|
||||
name=name,
|
||||
schema=schema,
|
||||
metadata=metadata,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
persisted_ef_config = model.configuration_json.get("embedding_function")
|
||||
|
||||
validate_embedding_function_conflict_on_get(
|
||||
embedding_function, persisted_ef_config
|
||||
)
|
||||
|
||||
return Collection(
|
||||
client=self._server,
|
||||
model=model,
|
||||
embedding_function=embedding_function,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
@override
|
||||
def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
) -> None:
|
||||
return self._server._modify(
|
||||
id=id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
new_name=new_name,
|
||||
new_metadata=new_metadata,
|
||||
new_configuration=new_configuration,
|
||||
)
|
||||
|
||||
@override
|
||||
def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
) -> None:
|
||||
return self._server.delete_collection(
|
||||
name=name,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
#
|
||||
# ITEM METHODS
|
||||
#
|
||||
|
||||
@override
|
||||
def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
return self._server._add(
|
||||
ids=ids,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
collection_id=collection_id,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
uris=uris,
|
||||
)
|
||||
|
||||
@override
|
||||
def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
return self._server._update(
|
||||
collection_id=collection_id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
uris=uris,
|
||||
)
|
||||
|
||||
@override
|
||||
def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
) -> bool:
|
||||
return self._server._upsert(
|
||||
collection_id=collection_id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
uris=uris,
|
||||
)
|
||||
|
||||
@override
|
||||
def _count(self, collection_id: UUID) -> int:
|
||||
return self._server._count(
|
||||
collection_id=collection_id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@override
|
||||
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
|
||||
return self._server._peek(
|
||||
collection_id=collection_id,
|
||||
n=n,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@override
|
||||
def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
) -> GetResult:
|
||||
return self._server._get(
|
||||
collection_id=collection_id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
ids=ids,
|
||||
where=where,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
)
|
||||
|
||||
def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs],
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
) -> None:
|
||||
self._server._delete(
|
||||
collection_id=collection_id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
ids=ids,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
)
|
||||
|
||||
@override
|
||||
def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
) -> QueryResult:
|
||||
return self._server._query(
|
||||
collection_id=collection_id,
|
||||
ids=ids,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
)
|
||||
|
||||
@override
|
||||
def reset(self) -> bool:
|
||||
return self._server.reset()
|
||||
|
||||
@override
|
||||
def get_version(self) -> str:
|
||||
return self._server.get_version()
|
||||
|
||||
@override
|
||||
def get_settings(self) -> Settings:
|
||||
return self._server.get_settings()
|
||||
|
||||
@override
|
||||
def get_max_batch_size(self) -> int:
|
||||
return self._server.get_max_batch_size()
|
||||
|
||||
# endregion
|
||||
|
||||
# region ClientAPI Methods
|
||||
|
||||
@override
|
||||
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
|
||||
self._validate_tenant_database(tenant=tenant, database=database)
|
||||
self.tenant = tenant
|
||||
self.database = database
|
||||
|
||||
@override
|
||||
def set_database(self, database: str) -> None:
|
||||
self._validate_tenant_database(tenant=self.tenant, database=database)
|
||||
self.database = database
|
||||
|
||||
def _validate_tenant_database(self, tenant: str, database: str) -> None:
|
||||
try:
|
||||
self._admin_client.get_tenant(name=tenant)
|
||||
except httpx.ConnectError:
|
||||
raise ValueError(
|
||||
"Could not connect to a Chroma server. Are you sure it is running?"
|
||||
)
|
||||
# Propagate ChromaErrors
|
||||
except ChromaError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Could not connect to tenant {tenant}. Are you sure it exists?"
|
||||
)
|
||||
|
||||
try:
|
||||
self._admin_client.get_database(name=database, tenant=tenant)
|
||||
except httpx.ConnectError:
|
||||
raise ValueError(
|
||||
"Could not connect to a Chroma server. Are you sure it is running?"
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
class AdminClient(SharedSystemClient, AdminAPI):
|
||||
_server: ServerAPI
|
||||
|
||||
def __init__(self, settings: Settings = Settings()) -> None:
|
||||
super().__init__(settings)
|
||||
self._server = self._system.instance(ServerAPI)
|
||||
|
||||
@override
|
||||
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
return self._server.create_database(name=name, tenant=tenant)
|
||||
|
||||
@override
|
||||
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
||||
return self._server.get_database(name=name, tenant=tenant)
|
||||
|
||||
@override
|
||||
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
return self._server.delete_database(name=name, tenant=tenant)
|
||||
|
||||
@override
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
return self._server.list_databases(limit, offset, tenant=tenant)
|
||||
|
||||
@override
|
||||
def create_tenant(self, name: str) -> None:
|
||||
return self._server.create_tenant(name=name)
|
||||
|
||||
@override
|
||||
def get_tenant(self, name: str) -> Tenant:
|
||||
return self._server.get_tenant(name=name)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_system(
|
||||
cls,
|
||||
system: System,
|
||||
) -> "AdminClient":
|
||||
SharedSystemClient._populate_data_from_system(system)
|
||||
instance = cls(settings=system.settings)
|
||||
return instance
|
||||
@@ -0,0 +1,882 @@
|
||||
from typing import TypedDict, Dict, Any, Optional, cast, get_args
|
||||
import json
|
||||
from chromadb.api.types import (
|
||||
Space,
|
||||
CollectionMetadata,
|
||||
UpdateMetadata,
|
||||
EmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions import (
|
||||
known_embedding_functions,
|
||||
register_embedding_function,
|
||||
)
|
||||
from multiprocessing import cpu_count
|
||||
import warnings
|
||||
|
||||
from chromadb.api.types import Schema
|
||||
|
||||
|
||||
class HNSWConfiguration(TypedDict, total=False):
|
||||
space: Space
|
||||
ef_construction: int
|
||||
max_neighbors: int
|
||||
ef_search: int
|
||||
num_threads: int
|
||||
batch_size: int
|
||||
sync_threshold: int
|
||||
resize_factor: float
|
||||
|
||||
|
||||
class SpannConfiguration(TypedDict, total=False):
|
||||
search_nprobe: int
|
||||
write_nprobe: int
|
||||
space: Space
|
||||
ef_construction: int
|
||||
ef_search: int
|
||||
max_neighbors: int
|
||||
reassign_neighbor_count: int
|
||||
split_threshold: int
|
||||
merge_threshold: int
|
||||
|
||||
|
||||
class CollectionConfiguration(TypedDict, total=True):
|
||||
hnsw: Optional[HNSWConfiguration]
|
||||
spann: Optional[SpannConfiguration]
|
||||
embedding_function: Optional[EmbeddingFunction] # type: ignore
|
||||
|
||||
|
||||
def load_collection_configuration_from_json_str(
|
||||
config_json_str: str,
|
||||
) -> CollectionConfiguration:
|
||||
config_json_map = json.loads(config_json_str)
|
||||
return load_collection_configuration_from_json(config_json_map)
|
||||
|
||||
|
||||
# TODO: make warnings prettier and add link to migration docs
|
||||
def load_collection_configuration_from_json(
|
||||
config_json_map: Dict[str, Any]
|
||||
) -> CollectionConfiguration:
|
||||
if (
|
||||
config_json_map.get("spann") is not None
|
||||
and config_json_map.get("hnsw") is not None
|
||||
):
|
||||
raise ValueError("hnsw and spann cannot both be provided")
|
||||
|
||||
hnsw_config = None
|
||||
spann_config = None
|
||||
ef_config = None
|
||||
|
||||
# Process vector index configuration (HNSW or SPANN)
|
||||
if config_json_map.get("hnsw") is not None:
|
||||
hnsw_config = cast(HNSWConfiguration, config_json_map["hnsw"])
|
||||
if config_json_map.get("spann") is not None:
|
||||
spann_config = cast(SpannConfiguration, config_json_map["spann"])
|
||||
|
||||
# Process embedding function configuration
|
||||
if config_json_map.get("embedding_function") is not None:
|
||||
ef_config = config_json_map["embedding_function"]
|
||||
if ef_config["type"] == "legacy":
|
||||
warnings.warn(
|
||||
"legacy embedding function config",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
ef = None
|
||||
else:
|
||||
try:
|
||||
ef_name = ef_config["name"]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Embedding function name not found in config: {ef_config}"
|
||||
)
|
||||
try:
|
||||
ef = known_embedding_functions[ef_name]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Embedding function {ef_name} not found. Add @register_embedding_function decorator to the class definition."
|
||||
)
|
||||
try:
|
||||
ef = ef.build_from_config(ef_config["config"]) # type: ignore
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
|
||||
)
|
||||
else:
|
||||
ef = None
|
||||
|
||||
return CollectionConfiguration(
|
||||
hnsw=hnsw_config,
|
||||
spann=spann_config,
|
||||
embedding_function=ef, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def collection_configuration_to_json_str(config: CollectionConfiguration) -> str:
|
||||
return json.dumps(collection_configuration_to_json(config))
|
||||
|
||||
|
||||
def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[str, Any]:
|
||||
if isinstance(config, dict):
|
||||
hnsw_config = config.get("hnsw")
|
||||
spann_config = config.get("spann")
|
||||
ef = config.get("embedding_function")
|
||||
else:
|
||||
try:
|
||||
hnsw_config = config.get_parameter("hnsw").value
|
||||
except ValueError:
|
||||
hnsw_config = None
|
||||
try:
|
||||
spann_config = config.get_parameter("spann").value
|
||||
except ValueError:
|
||||
spann_config = None
|
||||
try:
|
||||
ef = config.get_parameter("embedding_function").value
|
||||
except ValueError:
|
||||
ef = None
|
||||
|
||||
ef_config: Dict[str, Any] | None = None
|
||||
if hnsw_config is not None:
|
||||
try:
|
||||
hnsw_config = cast(HNSWConfiguration, hnsw_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"not a valid hnsw config: {e}")
|
||||
if spann_config is not None:
|
||||
try:
|
||||
spann_config = cast(SpannConfiguration, spann_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"not a valid spann config: {e}")
|
||||
|
||||
if ef is None:
|
||||
ef = None
|
||||
ef_config = {"type": "legacy"}
|
||||
|
||||
if ef is not None:
|
||||
try:
|
||||
if ef.is_legacy():
|
||||
ef_config = {"type": "legacy"}
|
||||
else:
|
||||
ef_config = {
|
||||
"name": ef.name(),
|
||||
"type": "known",
|
||||
"config": ef.get_config(),
|
||||
}
|
||||
register_embedding_function(type(ef)) # type: ignore
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"legacy embedding function config: {e}",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
ef = None
|
||||
ef_config = {"type": "legacy"}
|
||||
|
||||
return {
|
||||
"hnsw": hnsw_config,
|
||||
"spann": spann_config,
|
||||
"embedding_function": ef_config,
|
||||
}
|
||||
|
||||
|
||||
class CreateHNSWConfiguration(TypedDict, total=False):
|
||||
space: Space
|
||||
ef_construction: int
|
||||
max_neighbors: int
|
||||
ef_search: int
|
||||
num_threads: int
|
||||
batch_size: int
|
||||
sync_threshold: int
|
||||
resize_factor: float
|
||||
|
||||
|
||||
def json_to_create_hnsw_configuration(
|
||||
json_map: Dict[str, Any]
|
||||
) -> CreateHNSWConfiguration:
|
||||
config: CreateHNSWConfiguration = {}
|
||||
if "space" in json_map:
|
||||
space_value = json_map["space"]
|
||||
if space_value in get_args(Space):
|
||||
config["space"] = space_value
|
||||
else:
|
||||
raise ValueError(f"not a valid space: {space_value}")
|
||||
if "ef_construction" in json_map:
|
||||
config["ef_construction"] = json_map["ef_construction"]
|
||||
if "max_neighbors" in json_map:
|
||||
config["max_neighbors"] = json_map["max_neighbors"]
|
||||
if "ef_search" in json_map:
|
||||
config["ef_search"] = json_map["ef_search"]
|
||||
if "num_threads" in json_map:
|
||||
config["num_threads"] = json_map["num_threads"]
|
||||
if "batch_size" in json_map:
|
||||
config["batch_size"] = json_map["batch_size"]
|
||||
if "sync_threshold" in json_map:
|
||||
config["sync_threshold"] = json_map["sync_threshold"]
|
||||
if "resize_factor" in json_map:
|
||||
config["resize_factor"] = json_map["resize_factor"]
|
||||
return config
|
||||
|
||||
|
||||
class CreateSpannConfiguration(TypedDict, total=False):
|
||||
search_nprobe: int
|
||||
write_nprobe: int
|
||||
space: Space
|
||||
ef_construction: int
|
||||
ef_search: int
|
||||
max_neighbors: int
|
||||
reassign_neighbor_count: int
|
||||
split_threshold: int
|
||||
merge_threshold: int
|
||||
|
||||
|
||||
def json_to_create_spann_configuration(
|
||||
json_map: Dict[str, Any]
|
||||
) -> CreateSpannConfiguration:
|
||||
config: CreateSpannConfiguration = {}
|
||||
if "search_nprobe" in json_map:
|
||||
config["search_nprobe"] = json_map["search_nprobe"]
|
||||
if "write_nprobe" in json_map:
|
||||
config["write_nprobe"] = json_map["write_nprobe"]
|
||||
if "space" in json_map:
|
||||
space_value = json_map["space"]
|
||||
if space_value in get_args(Space):
|
||||
config["space"] = space_value
|
||||
else:
|
||||
raise ValueError(f"not a valid space: {space_value}")
|
||||
if "ef_construction" in json_map:
|
||||
config["ef_construction"] = json_map["ef_construction"]
|
||||
if "ef_search" in json_map:
|
||||
config["ef_search"] = json_map["ef_search"]
|
||||
if "max_neighbors" in json_map:
|
||||
config["max_neighbors"] = json_map["max_neighbors"]
|
||||
return config
|
||||
|
||||
|
||||
class CreateCollectionConfiguration(TypedDict, total=False):
|
||||
hnsw: Optional[CreateHNSWConfiguration]
|
||||
spann: Optional[CreateSpannConfiguration]
|
||||
embedding_function: Optional[EmbeddingFunction] # type: ignore
|
||||
|
||||
|
||||
def create_collection_configuration_from_legacy_collection_metadata(
|
||||
metadata: CollectionMetadata,
|
||||
) -> CreateCollectionConfiguration:
|
||||
"""Create a CreateCollectionConfiguration from legacy collection metadata"""
|
||||
return create_collection_configuration_from_legacy_metadata_dict(metadata)
|
||||
|
||||
|
||||
def create_collection_configuration_from_legacy_metadata_dict(
|
||||
metadata: Dict[str, Any],
|
||||
) -> CreateCollectionConfiguration:
|
||||
"""Create a CreateCollectionConfiguration from legacy collection metadata"""
|
||||
old_to_new = {
|
||||
"hnsw:space": "space",
|
||||
"hnsw:construction_ef": "ef_construction",
|
||||
"hnsw:M": "max_neighbors",
|
||||
"hnsw:search_ef": "ef_search",
|
||||
"hnsw:num_threads": "num_threads",
|
||||
"hnsw:batch_size": "batch_size",
|
||||
"hnsw:sync_threshold": "sync_threshold",
|
||||
"hnsw:resize_factor": "resize_factor",
|
||||
}
|
||||
json_map = {}
|
||||
for name, value in metadata.items():
|
||||
if name in old_to_new:
|
||||
json_map[old_to_new[name]] = value
|
||||
hnsw_config = json_to_create_hnsw_configuration(json_map)
|
||||
hnsw_config = populate_create_hnsw_defaults(hnsw_config)
|
||||
|
||||
return CreateCollectionConfiguration(hnsw=hnsw_config)
|
||||
|
||||
|
||||
# TODO: make warnings prettier and add link to migration docs
|
||||
def load_create_collection_configuration_from_json(
|
||||
json_map: Dict[str, Any]
|
||||
) -> CreateCollectionConfiguration:
|
||||
if json_map.get("hnsw") is not None and json_map.get("spann") is not None:
|
||||
raise ValueError("hnsw and spann cannot both be provided")
|
||||
|
||||
result = CreateCollectionConfiguration()
|
||||
|
||||
# Handle vector index configuration
|
||||
if json_map.get("hnsw") is not None:
|
||||
result["hnsw"] = json_to_create_hnsw_configuration(json_map["hnsw"])
|
||||
|
||||
if json_map.get("spann") is not None:
|
||||
result["spann"] = json_to_create_spann_configuration(json_map["spann"])
|
||||
|
||||
# Handle embedding function configuration
|
||||
if json_map.get("embedding_function") is not None:
|
||||
ef_config = json_map["embedding_function"]
|
||||
if ef_config["type"] == "legacy":
|
||||
warnings.warn(
|
||||
"legacy embedding function config",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
ef = known_embedding_functions[ef_config["name"]]
|
||||
result["embedding_function"] = ef.build_from_config(ef_config["config"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def create_collection_configuration_to_json_str(
|
||||
config: CreateCollectionConfiguration,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
) -> str:
|
||||
"""Convert a CreateCollection configuration to a JSON-serializable string"""
|
||||
return json.dumps(create_collection_configuration_to_json(config, metadata))
|
||||
|
||||
|
||||
# TODO: make warnings prettier and add link to migration docs
|
||||
def create_collection_configuration_to_json(
|
||||
config: CreateCollectionConfiguration,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a CreateCollection configuration to a JSON-serializable dict"""
|
||||
ef_config: Dict[str, Any] | None = None
|
||||
hnsw_config = config.get("hnsw")
|
||||
spann_config = config.get("spann")
|
||||
if hnsw_config is not None:
|
||||
try:
|
||||
hnsw_config = cast(CreateHNSWConfiguration, hnsw_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"not a valid hnsw config: {e}")
|
||||
if spann_config is not None:
|
||||
try:
|
||||
spann_config = cast(CreateSpannConfiguration, spann_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"not a valid spann config: {e}")
|
||||
|
||||
if hnsw_config is not None and spann_config is not None:
|
||||
raise ValueError("hnsw and spann cannot both be provided")
|
||||
|
||||
if config.get("embedding_function") is None:
|
||||
ef = None
|
||||
ef_config = {"type": "legacy"}
|
||||
return {
|
||||
"hnsw": hnsw_config,
|
||||
"spann": spann_config,
|
||||
"embedding_function": ef_config,
|
||||
}
|
||||
|
||||
try:
|
||||
ef = cast(EmbeddingFunction, config.get("embedding_function")) # type: ignore
|
||||
if ef.is_legacy():
|
||||
ef_config = {"type": "legacy"}
|
||||
else:
|
||||
# default space logic: if neither hnsw nor spann config is provided and metadata doesn't have space,
|
||||
# then populate space from ef
|
||||
# otherwise dont use default space from ef
|
||||
|
||||
# then validate the space afterwards based on the supported spaces of the embedding function,
|
||||
# warn if space is not supported
|
||||
|
||||
if hnsw_config is None and spann_config is None:
|
||||
if metadata is None or metadata.get("hnsw:space") is None:
|
||||
# this populates space from ef if not provided in either config
|
||||
hnsw_config = CreateHNSWConfiguration(space=ef.default_space())
|
||||
|
||||
# if hnsw config or spann config exists but space is not provided, populate it from ef
|
||||
if hnsw_config is not None and hnsw_config.get("space") is None:
|
||||
hnsw_config["space"] = ef.default_space()
|
||||
if spann_config is not None and spann_config.get("space") is None:
|
||||
spann_config["space"] = ef.default_space()
|
||||
|
||||
# Validate space compatibility with embedding function
|
||||
if hnsw_config is not None:
|
||||
if hnsw_config.get("space") not in ef.supported_spaces():
|
||||
warnings.warn(
|
||||
f"space {hnsw_config.get('space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if spann_config is not None:
|
||||
if spann_config.get("space") not in ef.supported_spaces():
|
||||
warnings.warn(
|
||||
f"space {spann_config.get('space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# only validate space from metadata if config is not provided
|
||||
if (
|
||||
hnsw_config is None
|
||||
and spann_config is None
|
||||
and metadata is not None
|
||||
and metadata.get("hnsw:space") is not None
|
||||
):
|
||||
if metadata.get("hnsw:space") not in ef.supported_spaces():
|
||||
warnings.warn(
|
||||
f"space {metadata.get('hnsw:space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
ef_config = {
|
||||
"name": ef.name(),
|
||||
"type": "known",
|
||||
"config": ef.get_config(),
|
||||
}
|
||||
register_embedding_function(type(ef)) # type: ignore
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"legacy embedding function config: {e}",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
ef = None
|
||||
ef_config = {"type": "legacy"}
|
||||
|
||||
return {
|
||||
"hnsw": hnsw_config,
|
||||
"spann": spann_config,
|
||||
"embedding_function": ef_config,
|
||||
}
|
||||
|
||||
|
||||
def populate_create_hnsw_defaults(
|
||||
config: CreateHNSWConfiguration, ef: Optional[EmbeddingFunction] = None # type: ignore
|
||||
) -> CreateHNSWConfiguration:
|
||||
"""Populate a CreateHNSW configuration with default values"""
|
||||
if config.get("space") is None:
|
||||
config["space"] = ef.default_space() if ef else "l2"
|
||||
if config.get("ef_construction") is None:
|
||||
config["ef_construction"] = 100
|
||||
if config.get("max_neighbors") is None:
|
||||
config["max_neighbors"] = 16
|
||||
if config.get("ef_search") is None:
|
||||
config["ef_search"] = 100
|
||||
if config.get("num_threads") is None:
|
||||
config["num_threads"] = cpu_count()
|
||||
if config.get("batch_size") is None:
|
||||
config["batch_size"] = 100
|
||||
if config.get("sync_threshold") is None:
|
||||
config["sync_threshold"] = 1000
|
||||
if config.get("resize_factor") is None:
|
||||
config["resize_factor"] = 1.2
|
||||
return config
|
||||
|
||||
|
||||
class UpdateHNSWConfiguration(TypedDict, total=False):
|
||||
ef_search: int
|
||||
num_threads: int
|
||||
batch_size: int
|
||||
sync_threshold: int
|
||||
resize_factor: float
|
||||
|
||||
|
||||
def json_to_update_hnsw_configuration(
|
||||
json_map: Dict[str, Any]
|
||||
) -> UpdateHNSWConfiguration:
|
||||
config: UpdateHNSWConfiguration = {}
|
||||
if "ef_search" in json_map:
|
||||
config["ef_search"] = json_map["ef_search"]
|
||||
if "num_threads" in json_map:
|
||||
config["num_threads"] = json_map["num_threads"]
|
||||
if "batch_size" in json_map:
|
||||
config["batch_size"] = json_map["batch_size"]
|
||||
if "sync_threshold" in json_map:
|
||||
config["sync_threshold"] = json_map["sync_threshold"]
|
||||
if "resize_factor" in json_map:
|
||||
config["resize_factor"] = json_map["resize_factor"]
|
||||
return config
|
||||
|
||||
|
||||
class UpdateSpannConfiguration(TypedDict, total=False):
|
||||
search_nprobe: int
|
||||
ef_search: int
|
||||
|
||||
|
||||
def json_to_update_spann_configuration(
|
||||
json_map: Dict[str, Any]
|
||||
) -> UpdateSpannConfiguration:
|
||||
config: UpdateSpannConfiguration = {}
|
||||
if "search_nprobe" in json_map:
|
||||
config["search_nprobe"] = json_map["search_nprobe"]
|
||||
if "ef_search" in json_map:
|
||||
config["ef_search"] = json_map["ef_search"]
|
||||
return config
|
||||
|
||||
|
||||
class UpdateCollectionConfiguration(TypedDict, total=False):
|
||||
hnsw: Optional[UpdateHNSWConfiguration]
|
||||
spann: Optional[UpdateSpannConfiguration]
|
||||
embedding_function: Optional[EmbeddingFunction] # type: ignore
|
||||
|
||||
|
||||
def update_collection_configuration_from_legacy_collection_metadata(
|
||||
metadata: CollectionMetadata,
|
||||
) -> UpdateCollectionConfiguration:
|
||||
"""Create an UpdateCollectionConfiguration from legacy collection metadata"""
|
||||
old_to_new = {
|
||||
"hnsw:search_ef": "ef_search",
|
||||
"hnsw:num_threads": "num_threads",
|
||||
"hnsw:batch_size": "batch_size",
|
||||
"hnsw:sync_threshold": "sync_threshold",
|
||||
"hnsw:resize_factor": "resize_factor",
|
||||
}
|
||||
json_map = {}
|
||||
for name, value in metadata.items():
|
||||
if name in old_to_new:
|
||||
json_map[old_to_new[name]] = value
|
||||
hnsw_config = json_to_update_hnsw_configuration(json_map)
|
||||
return UpdateCollectionConfiguration(hnsw=hnsw_config)
|
||||
|
||||
|
||||
def update_collection_configuration_from_legacy_update_metadata(
|
||||
metadata: UpdateMetadata,
|
||||
) -> UpdateCollectionConfiguration:
|
||||
"""Create an UpdateCollectionConfiguration from legacy update metadata"""
|
||||
old_to_new = {
|
||||
"hnsw:search_ef": "ef_search",
|
||||
"hnsw:num_threads": "num_threads",
|
||||
"hnsw:batch_size": "batch_size",
|
||||
"hnsw:sync_threshold": "sync_threshold",
|
||||
"hnsw:resize_factor": "resize_factor",
|
||||
}
|
||||
json_map = {}
|
||||
for name, value in metadata.items():
|
||||
if name in old_to_new:
|
||||
json_map[old_to_new[name]] = value
|
||||
hnsw_config = json_to_update_hnsw_configuration(json_map)
|
||||
return UpdateCollectionConfiguration(hnsw=hnsw_config)
|
||||
|
||||
|
||||
def update_collection_configuration_to_json_str(
|
||||
config: UpdateCollectionConfiguration,
|
||||
) -> str:
|
||||
"""Convert an UpdateCollectionConfiguration to a JSON-serializable string"""
|
||||
json_dict = update_collection_configuration_to_json(config)
|
||||
return json.dumps(json_dict)
|
||||
|
||||
|
||||
def update_collection_configuration_to_json(
|
||||
config: UpdateCollectionConfiguration,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert an UpdateCollectionConfiguration to a JSON-serializable dict"""
|
||||
hnsw_config = config.get("hnsw")
|
||||
spann_config = config.get("spann")
|
||||
ef = config.get("embedding_function")
|
||||
if hnsw_config is None and spann_config is None and ef is None:
|
||||
return {}
|
||||
|
||||
if hnsw_config is not None:
|
||||
try:
|
||||
hnsw_config = cast(UpdateHNSWConfiguration, hnsw_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"not a valid hnsw config: {e}")
|
||||
|
||||
if spann_config is not None:
|
||||
try:
|
||||
spann_config = cast(UpdateSpannConfiguration, spann_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"not a valid spann config: {e}")
|
||||
|
||||
ef_config: Dict[str, Any] | None = None
|
||||
if ef is not None:
|
||||
if ef.is_legacy():
|
||||
ef_config = {"type": "legacy"}
|
||||
else:
|
||||
ef.validate_config(ef.get_config())
|
||||
ef_config = {
|
||||
"name": ef.name(),
|
||||
"type": "known",
|
||||
"config": ef.get_config(),
|
||||
}
|
||||
register_embedding_function(type(ef)) # type: ignore
|
||||
else:
|
||||
ef_config = None
|
||||
|
||||
return {
|
||||
"hnsw": hnsw_config,
|
||||
"spann": spann_config,
|
||||
"embedding_function": ef_config,
|
||||
}
|
||||
|
||||
|
||||
def load_update_collection_configuration_from_json_str(
|
||||
json_str: str,
|
||||
) -> UpdateCollectionConfiguration:
|
||||
json_map = json.loads(json_str)
|
||||
return load_update_collection_configuration_from_json(json_map)
|
||||
|
||||
|
||||
# TODO: make warnings prettier and add link to migration docs
|
||||
def load_update_collection_configuration_from_json(
|
||||
json_map: Dict[str, Any]
|
||||
) -> UpdateCollectionConfiguration:
|
||||
"""Convert a JSON dict to an UpdateCollectionConfiguration"""
|
||||
if json_map.get("hnsw") is not None and json_map.get("spann") is not None:
|
||||
raise ValueError("hnsw and spann cannot both be provided")
|
||||
|
||||
result = UpdateCollectionConfiguration()
|
||||
|
||||
# Handle vector index configurations
|
||||
if json_map.get("hnsw") is not None:
|
||||
result["hnsw"] = json_to_update_hnsw_configuration(json_map["hnsw"])
|
||||
|
||||
if json_map.get("spann") is not None:
|
||||
result["spann"] = json_to_update_spann_configuration(json_map["spann"])
|
||||
|
||||
# Handle embedding function
|
||||
if json_map.get("embedding_function") is not None:
|
||||
if json_map["embedding_function"]["type"] == "legacy":
|
||||
warnings.warn(
|
||||
"legacy embedding function config",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
ef = known_embedding_functions[json_map["embedding_function"]["name"]]
|
||||
result["embedding_function"] = ef.build_from_config(
|
||||
json_map["embedding_function"]["config"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def overwrite_hnsw_configuration(
|
||||
existing_hnsw_config: HNSWConfiguration, update_hnsw_config: UpdateHNSWConfiguration
|
||||
) -> HNSWConfiguration:
|
||||
"""Overwrite a HNSWConfiguration with a new configuration"""
|
||||
# Create a copy of the existing config and update with new values
|
||||
result = dict(existing_hnsw_config)
|
||||
update_fields = [
|
||||
"ef_search",
|
||||
"num_threads",
|
||||
"batch_size",
|
||||
"sync_threshold",
|
||||
"resize_factor",
|
||||
]
|
||||
|
||||
for field in update_fields:
|
||||
if field in update_hnsw_config:
|
||||
result[field] = update_hnsw_config[field] # type: ignore
|
||||
|
||||
return cast(HNSWConfiguration, result)
|
||||
|
||||
|
||||
def overwrite_spann_configuration(
|
||||
existing_spann_config: SpannConfiguration,
|
||||
update_spann_config: UpdateSpannConfiguration,
|
||||
) -> SpannConfiguration:
|
||||
"""Overwrite a SpannConfiguration with a new configuration"""
|
||||
result = dict(existing_spann_config)
|
||||
update_fields = [
|
||||
"search_nprobe",
|
||||
"ef_search",
|
||||
]
|
||||
|
||||
for field in update_fields:
|
||||
if field in update_spann_config:
|
||||
result[field] = update_spann_config[field] # type: ignore
|
||||
|
||||
return cast(SpannConfiguration, result)
|
||||
|
||||
|
||||
# TODO: make warnings prettier and add link to migration docs
|
||||
def overwrite_embedding_function(
|
||||
existing_embedding_function: EmbeddingFunction, # type: ignore
|
||||
update_embedding_function: EmbeddingFunction, # type: ignore
|
||||
) -> EmbeddingFunction: # type: ignore
|
||||
"""Overwrite an EmbeddingFunction with a new configuration"""
|
||||
# Check for legacy embedding functions
|
||||
if existing_embedding_function.is_legacy() or update_embedding_function.is_legacy():
|
||||
warnings.warn(
|
||||
"cannot update legacy embedding function config",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return existing_embedding_function
|
||||
|
||||
# Validate function compatibility
|
||||
if existing_embedding_function.name() != update_embedding_function.name():
|
||||
raise ValueError(
|
||||
f"Cannot update embedding function: incompatible types "
|
||||
f"({existing_embedding_function.name()} vs {update_embedding_function.name()})"
|
||||
)
|
||||
|
||||
# Validate and apply the configuration update
|
||||
update_embedding_function.validate_config_update(
|
||||
existing_embedding_function.get_config(), update_embedding_function.get_config()
|
||||
)
|
||||
return update_embedding_function
|
||||
|
||||
|
||||
def overwrite_collection_configuration(
|
||||
existing_config: CollectionConfiguration,
|
||||
update_config: UpdateCollectionConfiguration,
|
||||
) -> CollectionConfiguration:
|
||||
"""Overwrite a CollectionConfiguration with a new configuration"""
|
||||
update_spann = update_config.get("spann")
|
||||
update_hnsw = update_config.get("hnsw")
|
||||
if update_spann is not None and update_hnsw is not None:
|
||||
raise ValueError("hnsw and spann cannot both be provided")
|
||||
|
||||
# Handle HNSW configuration update
|
||||
|
||||
updated_hnsw_config = existing_config.get("hnsw")
|
||||
if updated_hnsw_config is not None and update_hnsw is not None:
|
||||
updated_hnsw_config = overwrite_hnsw_configuration(
|
||||
updated_hnsw_config, update_hnsw
|
||||
)
|
||||
|
||||
# Handle SPANN configuration update
|
||||
updated_spann_config = existing_config.get("spann")
|
||||
if updated_spann_config is not None and update_spann is not None:
|
||||
updated_spann_config = overwrite_spann_configuration(
|
||||
updated_spann_config, update_spann
|
||||
)
|
||||
|
||||
# Handle embedding function update
|
||||
updated_embedding_function = existing_config.get("embedding_function")
|
||||
update_ef = update_config.get("embedding_function")
|
||||
if update_ef is not None:
|
||||
if updated_embedding_function is not None:
|
||||
updated_embedding_function = overwrite_embedding_function(
|
||||
updated_embedding_function, update_ef
|
||||
)
|
||||
else:
|
||||
updated_embedding_function = update_ef
|
||||
|
||||
return CollectionConfiguration(
|
||||
hnsw=updated_hnsw_config,
|
||||
spann=updated_spann_config,
|
||||
embedding_function=updated_embedding_function,
|
||||
)
|
||||
|
||||
|
||||
def validate_embedding_function_conflict_on_create(
|
||||
embedding_function: Optional[EmbeddingFunction], # type: ignore
|
||||
configuration_ef: Optional[EmbeddingFunction], # type: ignore
|
||||
) -> None:
|
||||
"""
|
||||
Validates that there are no conflicting embedding functions between function parameter
|
||||
and collection configuration.
|
||||
|
||||
Args:
|
||||
embedding_function: The embedding function provided as a parameter
|
||||
configuration_ef: The embedding function from collection configuration
|
||||
|
||||
Returns:
|
||||
bool: True if there is a conflict, False otherwise
|
||||
"""
|
||||
# If ef provided in function params and collection config, check if they are the same
|
||||
# If not, there's a conflict
|
||||
# ef is by default "default" if not provided, so ignore that case.
|
||||
if embedding_function is not None and configuration_ef is not None:
|
||||
if (
|
||||
embedding_function.name() != "default"
|
||||
and embedding_function.name() != configuration_ef.name()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Multiple embedding functions provided. Please provide only one. Embedding function conflict: {embedding_function.name()} vs {configuration_ef.name()}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# The reason to use the config on get, rather than build the ef is because
|
||||
# if there is an issue with deserializing the config, an error shouldn't be raised
|
||||
# at get time. CollectionCommon.py will raise an error at _embed time if there is an issue deserializing.
|
||||
def validate_embedding_function_conflict_on_get(
|
||||
embedding_function: Optional[EmbeddingFunction], # type: ignore
|
||||
persisted_ef_config: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Validates that there are no conflicting embedding functions between function parameter
|
||||
and collection configuration.
|
||||
"""
|
||||
if persisted_ef_config is not None and embedding_function is not None:
|
||||
if (
|
||||
embedding_function.name() != "default"
|
||||
and persisted_ef_config.get("name") is not None
|
||||
and persisted_ef_config.get("name") != embedding_function.name()
|
||||
):
|
||||
raise ValueError(
|
||||
f"An embedding function already exists in the collection configuration, and a new one is provided. If this is intentional, please embed documents separately. Embedding function conflict: new: {embedding_function.name()} vs persisted: {persisted_ef_config.get('name')}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def update_schema_from_collection_configuration(
|
||||
schema: "Schema", configuration: "UpdateCollectionConfiguration"
|
||||
) -> "Schema":
|
||||
"""
|
||||
Updates a schema with configuration changes.
|
||||
Only updates fields that are present in the configuration update.
|
||||
|
||||
Args:
|
||||
schema: The existing Schema object
|
||||
configuration: The configuration updates to apply
|
||||
|
||||
Returns:
|
||||
Updated Schema object
|
||||
"""
|
||||
|
||||
# Get the vector index from defaults and #embedding key
|
||||
if (
|
||||
schema.defaults.float_list is None
|
||||
or schema.defaults.float_list.vector_index is None
|
||||
):
|
||||
raise ValueError("Schema is missing defaults.float_list.vector_index")
|
||||
|
||||
embedding_key = "#embedding"
|
||||
if embedding_key not in schema.keys:
|
||||
raise ValueError(f"Schema is missing keys[{embedding_key}]")
|
||||
|
||||
embedding_value_types = schema.keys[embedding_key]
|
||||
if (
|
||||
embedding_value_types.float_list is None
|
||||
or embedding_value_types.float_list.vector_index is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Schema is missing keys[{embedding_key}].float_list.vector_index"
|
||||
)
|
||||
|
||||
# Update vector index config in both locations
|
||||
for vector_index in [
|
||||
schema.defaults.float_list.vector_index,
|
||||
embedding_value_types.float_list.vector_index,
|
||||
]:
|
||||
if "hnsw" in configuration and configuration["hnsw"] is not None:
|
||||
# Update HNSW config
|
||||
if vector_index.config.hnsw is None:
|
||||
raise ValueError("Trying to update HNSW config but schema has SPANN")
|
||||
|
||||
hnsw_config = vector_index.config.hnsw
|
||||
update_hnsw = configuration["hnsw"]
|
||||
|
||||
# Only update fields that are present in the update
|
||||
if "ef_search" in update_hnsw:
|
||||
hnsw_config.ef_search = update_hnsw["ef_search"]
|
||||
if "num_threads" in update_hnsw:
|
||||
hnsw_config.num_threads = update_hnsw["num_threads"]
|
||||
if "batch_size" in update_hnsw:
|
||||
hnsw_config.batch_size = update_hnsw["batch_size"]
|
||||
if "sync_threshold" in update_hnsw:
|
||||
hnsw_config.sync_threshold = update_hnsw["sync_threshold"]
|
||||
if "resize_factor" in update_hnsw:
|
||||
hnsw_config.resize_factor = update_hnsw["resize_factor"]
|
||||
|
||||
elif "spann" in configuration and configuration["spann"] is not None:
|
||||
# Update SPANN config
|
||||
if vector_index.config.spann is None:
|
||||
raise ValueError("Trying to update SPANN config but schema has HNSW")
|
||||
|
||||
spann_config = vector_index.config.spann
|
||||
update_spann = configuration["spann"]
|
||||
|
||||
# Only update fields that are present in the update
|
||||
if "search_nprobe" in update_spann:
|
||||
spann_config.search_nprobe = update_spann["search_nprobe"]
|
||||
if "ef_search" in update_spann:
|
||||
spann_config.ef_search = update_spann["ef_search"]
|
||||
|
||||
# Update embedding function if present
|
||||
if (
|
||||
"embedding_function" in configuration
|
||||
and configuration["embedding_function"] is not None
|
||||
):
|
||||
vector_index.config.embedding_function = configuration["embedding_function"]
|
||||
|
||||
return schema
|
||||
@@ -0,0 +1,410 @@
|
||||
from abc import abstractmethod
|
||||
import json
|
||||
from overrides import override
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from chromadb.serde import JSONSerializable
|
||||
|
||||
# TODO: move out of API
|
||||
|
||||
|
||||
class StaticParameterError(Exception):
|
||||
"""Represents an error that occurs when a static parameter is set."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidConfigurationError(ValueError):
|
||||
"""Represents an error that occurs when a configuration is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
ParameterValue = Union[str, int, float, bool, "ConfigurationInternal"]
|
||||
|
||||
|
||||
class ParameterValidator(Protocol):
|
||||
"""Represents an abstract parameter validator."""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, value: ParameterValue) -> bool:
|
||||
"""Returns whether the given value is valid."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ConfigurationDefinition:
|
||||
"""Represents the definition of a configuration."""
|
||||
|
||||
name: str
|
||||
validator: ParameterValidator
|
||||
is_static: bool
|
||||
default_value: ParameterValue
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
validator: ParameterValidator,
|
||||
is_static: bool,
|
||||
default_value: ParameterValue,
|
||||
):
|
||||
self.name = name
|
||||
self.validator = validator
|
||||
self.is_static = is_static
|
||||
self.default_value = default_value
|
||||
|
||||
|
||||
class ConfigurationParameter:
|
||||
"""Represents a parameter of a configuration."""
|
||||
|
||||
name: str
|
||||
value: ParameterValue
|
||||
|
||||
def __init__(self, name: str, value: ParameterValue):
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ConfigurationParameter({self.name}, {self.value})"
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not isinstance(__value, ConfigurationParameter):
|
||||
return NotImplemented
|
||||
return self.name == __value.name and self.value == __value.value
|
||||
|
||||
|
||||
T = TypeVar("T", bound="ConfigurationInternal")
|
||||
|
||||
|
||||
class ConfigurationInternal(JSONSerializable["ConfigurationInternal"]):
|
||||
"""Represents an abstract configuration, used internally by Chroma."""
|
||||
|
||||
# The internal data structure used to store the parameters
|
||||
# All expected parameters must be present with defaults or None values at initialization
|
||||
parameter_map: Dict[str, ConfigurationParameter]
|
||||
definitions: ClassVar[Dict[str, ConfigurationDefinition]]
|
||||
|
||||
def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None):
|
||||
"""Initializes a new instance of the Configuration class. Respecting defaults and
|
||||
validators."""
|
||||
self.parameter_map = {}
|
||||
if parameters is not None:
|
||||
for parameter in parameters:
|
||||
if parameter.name not in self.definitions:
|
||||
raise ValueError(f"Invalid parameter name: {parameter.name}")
|
||||
|
||||
definition = self.definitions[parameter.name]
|
||||
# Handle the case where we have a recursive configuration definition
|
||||
if isinstance(parameter.value, dict):
|
||||
child_type = globals().get(parameter.value.get("_type", None))
|
||||
if child_type is None:
|
||||
raise ValueError(
|
||||
f"Invalid configuration type: {parameter.value}"
|
||||
)
|
||||
parameter.value = child_type.from_json(parameter.value)
|
||||
if not isinstance(parameter.value, type(definition.default_value)):
|
||||
raise ValueError(f"Invalid parameter value: {parameter.value}")
|
||||
|
||||
parameter_validator = definition.validator
|
||||
if not parameter_validator(parameter.value):
|
||||
raise ValueError(f"Invalid parameter value: {parameter.value}")
|
||||
self.parameter_map[parameter.name] = parameter
|
||||
# Apply the defaults for any missing parameters
|
||||
for name, definition in self.definitions.items():
|
||||
if name not in self.parameter_map:
|
||||
self.parameter_map[name] = ConfigurationParameter(
|
||||
name=name, value=definition.default_value
|
||||
)
|
||||
|
||||
self.configuration_validator()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Configuration({self.parameter_map.values()})"
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not isinstance(__value, ConfigurationInternal):
|
||||
return NotImplemented
|
||||
return self.parameter_map == __value.parameter_map
|
||||
|
||||
@abstractmethod
|
||||
def configuration_validator(self) -> None:
|
||||
"""Perform custom validation when parameters are dependent on each other.
|
||||
|
||||
Raises an InvalidConfigurationError if the configuration is invalid.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_parameters(self) -> List[ConfigurationParameter]:
|
||||
"""Returns the parameters of the configuration."""
|
||||
return list(self.parameter_map.values())
|
||||
|
||||
def get_parameter(self, name: str) -> ConfigurationParameter:
|
||||
"""Returns the parameter with the given name, or except if it doesn't exist."""
|
||||
if name not in self.parameter_map:
|
||||
raise ValueError(
|
||||
f"Invalid parameter name: {name} for configuration {self.__class__.__name__}"
|
||||
)
|
||||
param_value = cast(ConfigurationParameter, self.parameter_map.get(name))
|
||||
return param_value
|
||||
|
||||
def set_parameter(self, name: str, value: Union[str, int, float, bool]) -> None:
|
||||
"""Sets the parameter with the given name to the given value."""
|
||||
if name not in self.definitions:
|
||||
raise ValueError(f"Invalid parameter name: {name}")
|
||||
definition = self.definitions[name]
|
||||
parameter = self.parameter_map[name]
|
||||
if definition.is_static:
|
||||
raise StaticParameterError(f"Cannot set static parameter: {name}")
|
||||
if not definition.validator(value):
|
||||
raise ValueError(f"Invalid value for parameter {name}: {value}")
|
||||
parameter.value = value
|
||||
|
||||
@override
|
||||
def to_json_str(self) -> str:
|
||||
"""Returns the JSON representation of the configuration."""
|
||||
return json.dumps(self.to_json())
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_json_str(cls, json_str: str) -> Self:
|
||||
"""Returns a configuration from the given JSON string."""
|
||||
try:
|
||||
config_json = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
f"Unable to decode configuration from JSON string: {json_str}"
|
||||
)
|
||||
return cls.from_json(config_json) if config_json else cls()
|
||||
|
||||
@override
|
||||
def to_json(self) -> Dict[str, Any]:
|
||||
"""Returns the JSON compatible dictionary representation of the configuration."""
|
||||
json_dict = {
|
||||
name: parameter.value.to_json()
|
||||
if isinstance(parameter.value, ConfigurationInternal)
|
||||
else parameter.value
|
||||
for name, parameter in self.parameter_map.items()
|
||||
}
|
||||
# What kind of configuration is this?
|
||||
json_dict["_type"] = self.__class__.__name__
|
||||
return json_dict
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_json(cls, json_map: Dict[str, Any]) -> Self:
|
||||
"""Returns a configuration from the given JSON string."""
|
||||
if cls.__name__ != json_map.get("_type", None):
|
||||
raise ValueError(
|
||||
f"Trying to instantiate configuration of type {cls.__name__} from JSON with type {json_map['_type']}"
|
||||
)
|
||||
parameters = []
|
||||
for name, value in json_map.items():
|
||||
# Type value is only for storage
|
||||
if name == "_type":
|
||||
continue
|
||||
parameters.append(ConfigurationParameter(name=name, value=value))
|
||||
return cls(parameters=parameters)
|
||||
|
||||
|
||||
class HNSWConfigurationInternal(ConfigurationInternal):
|
||||
"""Internal representation of the HNSW configuration.
|
||||
Used for validation, defaults, serialization and deserialization."""
|
||||
|
||||
definitions = {
|
||||
"space": ConfigurationDefinition(
|
||||
name="space",
|
||||
validator=lambda value: isinstance(value, str)
|
||||
and value in ["l2", "ip", "cosine"],
|
||||
is_static=True,
|
||||
default_value="l2",
|
||||
),
|
||||
"ef_construction": ConfigurationDefinition(
|
||||
name="ef_construction",
|
||||
validator=lambda value: isinstance(value, int) and value >= 1,
|
||||
is_static=True,
|
||||
default_value=100,
|
||||
),
|
||||
"ef_search": ConfigurationDefinition(
|
||||
name="ef_search",
|
||||
validator=lambda value: isinstance(value, int) and value >= 1,
|
||||
is_static=False,
|
||||
default_value=100,
|
||||
),
|
||||
"num_threads": ConfigurationDefinition(
|
||||
name="num_threads",
|
||||
validator=lambda value: isinstance(value, int) and value >= 1,
|
||||
is_static=False,
|
||||
default_value=cpu_count(), # By default use all cores available
|
||||
),
|
||||
"M": ConfigurationDefinition(
|
||||
name="M",
|
||||
validator=lambda value: isinstance(value, int) and value >= 1,
|
||||
is_static=True,
|
||||
default_value=16,
|
||||
),
|
||||
"resize_factor": ConfigurationDefinition(
|
||||
name="resize_factor",
|
||||
validator=lambda value: isinstance(value, float) and value >= 1,
|
||||
is_static=True,
|
||||
default_value=1.2,
|
||||
),
|
||||
"batch_size": ConfigurationDefinition(
|
||||
name="batch_size",
|
||||
validator=lambda value: isinstance(value, int) and value >= 1,
|
||||
is_static=True,
|
||||
default_value=100,
|
||||
),
|
||||
"sync_threshold": ConfigurationDefinition(
|
||||
name="sync_threshold",
|
||||
validator=lambda value: isinstance(value, int) and value >= 1,
|
||||
is_static=True,
|
||||
default_value=1000,
|
||||
),
|
||||
}
|
||||
|
||||
@override
|
||||
def configuration_validator(self) -> None:
|
||||
batch_size = self.parameter_map.get("batch_size")
|
||||
sync_threshold = self.parameter_map.get("sync_threshold")
|
||||
|
||||
if (
|
||||
batch_size
|
||||
and sync_threshold
|
||||
and cast(int, batch_size.value) > cast(int, sync_threshold.value)
|
||||
):
|
||||
raise InvalidConfigurationError(
|
||||
"batch_size must be less than or equal to sync_threshold"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_legacy_params(cls, params: Dict[str, Any]) -> Self:
|
||||
"""Returns an HNSWConfiguration from a metadata dict containing legacy HNSW parameters. Used for migration."""
|
||||
|
||||
# We maintain this map to avoid a circular import with HnswParams, and
|
||||
# because then names won't change since we intend to deprecate HNSWParams
|
||||
# in favor of this type of configuration.
|
||||
old_to_new = {
|
||||
"hnsw:space": "space",
|
||||
"hnsw:construction_ef": "ef_construction",
|
||||
"hnsw:search_ef": "ef_search",
|
||||
"hnsw:M": "M",
|
||||
"hnsw:num_threads": "num_threads",
|
||||
"hnsw:resize_factor": "resize_factor",
|
||||
"hnsw:batch_size": "batch_size",
|
||||
"hnsw:sync_threshold": "sync_threshold",
|
||||
}
|
||||
|
||||
parameters = []
|
||||
for name, value in params.items():
|
||||
if name not in old_to_new:
|
||||
raise ValueError(f"Invalid legacy HNSW parameter name: {name}")
|
||||
parameters.append(
|
||||
ConfigurationParameter(name=old_to_new[name], value=value)
|
||||
)
|
||||
return cls(parameters)
|
||||
|
||||
|
||||
# This is the user-facing interface for HNSW index configuration parameters.
|
||||
# Internally, we pass around HNSWConfigurationInternal objects, which perform
|
||||
# validation, serialization and deserialization. Users don't need to know
|
||||
# about that and instead get a clean constructor with default arguments.
|
||||
class HNSWConfigurationInterface(HNSWConfigurationInternal):
|
||||
"""HNSW index configuration parameters.
|
||||
See https://docs.trychroma.com/guides#changing-the-distance-function for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
space: str = "l2",
|
||||
ef_construction: int = 100,
|
||||
ef_search: int = 100,
|
||||
num_threads: int = cpu_count(),
|
||||
M: int = 16,
|
||||
resize_factor: float = 1.2,
|
||||
batch_size: int = 100,
|
||||
sync_threshold: int = 1000,
|
||||
):
|
||||
parameters = [
|
||||
ConfigurationParameter(name="space", value=space),
|
||||
ConfigurationParameter(name="ef_construction", value=ef_construction),
|
||||
ConfigurationParameter(name="ef_search", value=ef_search),
|
||||
ConfigurationParameter(name="num_threads", value=num_threads),
|
||||
ConfigurationParameter(name="M", value=M),
|
||||
ConfigurationParameter(name="resize_factor", value=resize_factor),
|
||||
ConfigurationParameter(name="batch_size", value=batch_size),
|
||||
ConfigurationParameter(name="sync_threshold", value=sync_threshold),
|
||||
]
|
||||
|
||||
super().__init__(parameters=parameters)
|
||||
|
||||
|
||||
# Alias for user convenience - the user doesn't need to know this is an 'Interface'
|
||||
HNSWConfiguration = HNSWConfigurationInterface
|
||||
|
||||
|
||||
class CollectionConfigurationInternal(ConfigurationInternal):
|
||||
"""Internal representation of the collection configuration.
|
||||
Used for validation, defaults, and serialization / deserialization."""
|
||||
|
||||
definitions = {
|
||||
"hnsw_configuration": ConfigurationDefinition(
|
||||
name="hnsw_configuration",
|
||||
validator=lambda value: isinstance(value, HNSWConfigurationInternal),
|
||||
is_static=True,
|
||||
default_value=HNSWConfigurationInternal(),
|
||||
),
|
||||
}
|
||||
|
||||
@override
|
||||
def configuration_validator(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# This is the user-facing interface for HNSW index configuration parameters.
|
||||
# Internally, we pass around HNSWConfigurationInternal objects, which perform
|
||||
# validation, serialization and deserialization. Users don't need to know
|
||||
# about that and instead get a clean constructor with default arguments.
|
||||
class CollectionConfigurationInterface(CollectionConfigurationInternal):
|
||||
"""Configuration parameters for creating a collection."""
|
||||
|
||||
def __init__(self, hnsw_configuration: Optional[HNSWConfigurationInternal]):
|
||||
"""Initializes a new instance of the CollectionConfiguration class.
|
||||
Args:
|
||||
hnsw_configuration: The HNSW configuration to use for the collection.
|
||||
"""
|
||||
if hnsw_configuration is None:
|
||||
hnsw_configuration = HNSWConfigurationInternal()
|
||||
parameters = [
|
||||
ConfigurationParameter(name="hnsw_configuration", value=hnsw_configuration)
|
||||
]
|
||||
super().__init__(parameters=parameters)
|
||||
|
||||
|
||||
# Alias for user convenience - the user doesn't need to know this is an 'Interface'.
|
||||
CollectionConfiguration = CollectionConfigurationInterface
|
||||
|
||||
|
||||
class EmbeddingsQueueConfigurationInternal(ConfigurationInternal):
|
||||
definitions = {
|
||||
"automatically_purge": ConfigurationDefinition(
|
||||
name="automatically_purge",
|
||||
validator=lambda value: isinstance(value, bool),
|
||||
is_static=False,
|
||||
default_value=True,
|
||||
),
|
||||
}
|
||||
|
||||
@override
|
||||
def configuration_validator(self) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,806 @@
|
||||
import orjson
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, cast, Tuple, List
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
import httpx
|
||||
import urllib.parse
|
||||
from overrides import override
|
||||
|
||||
from chromadb.api.models.AttachedFunction import AttachedFunction
|
||||
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
update_collection_configuration_to_json,
|
||||
create_collection_configuration_to_json,
|
||||
)
|
||||
from chromadb import __version__
|
||||
from chromadb.api.base_http_client import BaseHTTPClient
|
||||
from chromadb.types import Database, Tenant, Collection as CollectionModel
|
||||
from chromadb.api import ServerAPI
|
||||
from chromadb.execution.expression.plan import Search
|
||||
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Embeddings,
|
||||
IDs,
|
||||
Include,
|
||||
Schema,
|
||||
Metadatas,
|
||||
URIs,
|
||||
Where,
|
||||
WhereDocument,
|
||||
GetResult,
|
||||
QueryResult,
|
||||
SearchResult,
|
||||
CollectionMetadata,
|
||||
validate_batch,
|
||||
convert_np_embeddings_to_list,
|
||||
IncludeMetadataDocuments,
|
||||
IncludeMetadataDocumentsDistances,
|
||||
)
|
||||
|
||||
from chromadb.api.types import (
|
||||
IncludeMetadataDocumentsEmbeddings,
|
||||
optional_embeddings_to_base64_strings,
|
||||
serialize_metadata,
|
||||
deserialize_metadata,
|
||||
)
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.auth import (
|
||||
ClientAuthProvider,
|
||||
)
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.telemetry.product import ProductTelemetryClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
system.settings.require("chroma_server_host")
|
||||
system.settings.require("chroma_server_http_port")
|
||||
|
||||
self._opentelemetry_client = self.require(OpenTelemetryClient)
|
||||
self._product_telemetry_client = self.require(ProductTelemetryClient)
|
||||
self._settings = system.settings
|
||||
|
||||
self._api_url = FastAPI.resolve_url(
|
||||
chroma_server_host=str(system.settings.chroma_server_host),
|
||||
chroma_server_http_port=system.settings.chroma_server_http_port,
|
||||
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
|
||||
default_api_path=system.settings.chroma_server_api_default_path,
|
||||
)
|
||||
|
||||
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
|
||||
self._session = httpx.Client(timeout=None, limits=limits)
|
||||
|
||||
self._header = system.settings.chroma_server_headers or {}
|
||||
self._header["Content-Type"] = "application/json"
|
||||
self._header["User-Agent"] = (
|
||||
"Chroma Python Client v"
|
||||
+ __version__
|
||||
+ " (https://github.com/chroma-core/chroma)"
|
||||
)
|
||||
|
||||
if self._settings.chroma_server_ssl_verify is not None:
|
||||
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
|
||||
if self._header is not None:
|
||||
self._session.headers.update(self._header)
|
||||
|
||||
if system.settings.chroma_client_auth_provider:
|
||||
self._auth_provider = self.require(ClientAuthProvider)
|
||||
_headers = self._auth_provider.authenticate()
|
||||
for header, value in _headers.items():
|
||||
self._session.headers[header] = value.get_secret_value()
|
||||
|
||||
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
|
||||
# If the request has json in kwargs, use orjson to serialize it,
|
||||
# remove it from kwargs, and add it to the content parameter
|
||||
# This is because httpx uses a slower json serializer
|
||||
if "json" in kwargs:
|
||||
data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY)
|
||||
kwargs["content"] = data
|
||||
|
||||
# Unlike requests, httpx does not automatically escape the path
|
||||
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
|
||||
url = self._api_url + escaped_path
|
||||
|
||||
response = self._session.request(method, url, **cast(Any, kwargs))
|
||||
BaseHTTPClient._raise_chroma_error(response)
|
||||
return orjson.loads(response.text)
|
||||
|
||||
@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def heartbeat(self) -> int:
|
||||
"""Returns the current server time in nanoseconds to check if the server is alive"""
|
||||
resp_json = self._make_request("get", "/heartbeat")
|
||||
return int(resp_json["nanosecond heartbeat"])
|
||||
|
||||
# Migrated to rust in distributed.
|
||||
@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def create_database(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> None:
|
||||
"""Creates a database"""
|
||||
self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases",
|
||||
json={"name": name},
|
||||
)
|
||||
|
||||
# Migrated to rust in distributed.
|
||||
@trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_database(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Database:
|
||||
"""Returns a database"""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{name}",
|
||||
)
|
||||
return Database(
|
||||
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.delete_database", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def delete_database(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> None:
|
||||
"""Deletes a database"""
|
||||
self._make_request(
|
||||
"delete",
|
||||
f"/tenants/{tenant}/databases/{name}",
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
"""Returns a list of all databases"""
|
||||
json_databases = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases",
|
||||
params=BaseHTTPClient._clean_params(
|
||||
{
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
),
|
||||
)
|
||||
databases = [
|
||||
Database(id=db["id"], name=db["name"], tenant=db["tenant"])
|
||||
for db in json_databases
|
||||
]
|
||||
return databases
|
||||
|
||||
@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def create_tenant(self, name: str) -> None:
|
||||
self._make_request("post", "/tenants", json={"name": name})
|
||||
|
||||
@trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_tenant(self, name: str) -> Tenant:
|
||||
resp_json = self._make_request("get", "/tenants/" + name)
|
||||
return Tenant(name=resp_json["name"])
|
||||
|
||||
@trace_method("FastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_user_identity(self) -> UserIdentity:
|
||||
return UserIdentity(**self._make_request("get", "/auth/identity"))
|
||||
|
||||
@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def list_collections(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Sequence[CollectionModel]:
|
||||
"""Returns a list of all collections"""
|
||||
json_collections = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections",
|
||||
params=BaseHTTPClient._clean_params(
|
||||
{
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
),
|
||||
)
|
||||
collection_models = [
|
||||
CollectionModel.from_json(json_collection)
|
||||
for json_collection in json_collections
|
||||
]
|
||||
|
||||
return collection_models
|
||||
|
||||
@trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def count_collections(
|
||||
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
||||
) -> int:
|
||||
"""Returns a count of collections"""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections_count",
|
||||
)
|
||||
return cast(int, resp_json)
|
||||
|
||||
@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
"""Creates a collection"""
|
||||
config_json = (
|
||||
create_collection_configuration_to_json(configuration, metadata)
|
||||
if configuration
|
||||
else None
|
||||
)
|
||||
serialized_schema = schema.serialize_to_json() if schema else None
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections",
|
||||
json={
|
||||
"name": name,
|
||||
"metadata": metadata,
|
||||
"configuration": config_json,
|
||||
"schema": serialized_schema,
|
||||
"get_or_create": get_or_create,
|
||||
},
|
||||
)
|
||||
model = CollectionModel.from_json(resp_json)
|
||||
|
||||
return model
|
||||
|
||||
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
"""Returns a collection"""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{name}",
|
||||
)
|
||||
|
||||
model = CollectionModel.from_json(resp_json)
|
||||
return model
|
||||
|
||||
@trace_method(
|
||||
"FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
|
||||
)
|
||||
@override
|
||||
def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
return self.create_collection(
|
||||
name=name,
|
||||
metadata=metadata,
|
||||
configuration=configuration,
|
||||
schema=schema,
|
||||
get_or_create=True,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
"""Updates a collection"""
|
||||
self._make_request(
|
||||
"put",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{id}",
|
||||
json={
|
||||
"new_metadata": new_metadata,
|
||||
"new_name": new_name,
|
||||
"new_configuration": update_collection_configuration_to_json(
|
||||
new_configuration
|
||||
)
|
||||
if new_configuration
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._fork", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _fork(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
new_name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
"""Forks a collection"""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
|
||||
json={"new_name": new_name},
|
||||
)
|
||||
model = CollectionModel.from_json(resp_json)
|
||||
return model
|
||||
|
||||
@trace_method("FastAPI._search", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _search(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
searches: List[Search],
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> SearchResult:
|
||||
"""Performs hybrid search on a collection"""
|
||||
# Convert Search objects to dictionaries
|
||||
payload = {"searches": [s.to_dict() for s in searches]}
|
||||
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/search",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
# Deserialize metadatas: convert transport format to SparseVector instances
|
||||
metadata_batches = resp_json.get("metadatas", None)
|
||||
if metadata_batches is not None:
|
||||
# SearchResult has nested structure: List[Optional[List[Optional[Metadata]]]]
|
||||
resp_json["metadatas"] = [
|
||||
[
|
||||
deserialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in metadatas
|
||||
]
|
||||
if metadatas is not None
|
||||
else None
|
||||
for metadatas in metadata_batches
|
||||
]
|
||||
|
||||
return SearchResult(resp_json)
|
||||
|
||||
@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
"""Deletes a collection"""
|
||||
self._make_request(
|
||||
"delete",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{name}",
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _count(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> int:
|
||||
"""Returns the number of embeddings in the database"""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/count",
|
||||
)
|
||||
return cast(int, resp_json)
|
||||
|
||||
@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _peek(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
n: int = 10,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
return cast(
|
||||
GetResult,
|
||||
self._get(
|
||||
collection_id,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
limit=n,
|
||||
include=IncludeMetadataDocumentsEmbeddings,
|
||||
),
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
# Servers do not support receiving "data", as that is hydrated by the client as a loadable
|
||||
filtered_include = [i for i in include if i != "data"]
|
||||
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/get",
|
||||
json={
|
||||
"ids": ids,
|
||||
"where": where,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"where_document": where_document,
|
||||
"include": filtered_include,
|
||||
},
|
||||
)
|
||||
|
||||
# Deserialize metadatas: convert transport format to SparseVector instances
|
||||
metadatas = resp_json.get("metadatas", None)
|
||||
if metadatas is not None:
|
||||
metadatas = [
|
||||
deserialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in metadatas
|
||||
]
|
||||
|
||||
return GetResult(
|
||||
ids=resp_json["ids"],
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadatas, # type: ignore
|
||||
documents=resp_json.get("documents", None),
|
||||
data=None,
|
||||
uris=resp_json.get("uris", None),
|
||||
included=include,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
"""Deletes embeddings from the database"""
|
||||
self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete",
|
||||
json={
|
||||
"ids": ids,
|
||||
"where": where,
|
||||
"where_document": where_document,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
|
||||
def _submit_batch(
|
||||
self,
|
||||
batch: Tuple[
|
||||
IDs,
|
||||
Optional[Embeddings],
|
||||
Optional[Metadatas],
|
||||
Optional[Documents],
|
||||
Optional[URIs],
|
||||
],
|
||||
url: str,
|
||||
) -> None:
|
||||
"""
|
||||
Submits a batch of embeddings to the database
|
||||
"""
|
||||
# Serialize metadatas: convert SparseVector instances to transport format
|
||||
serialized_metadatas = None
|
||||
if batch[2] is not None:
|
||||
serialized_metadatas = [
|
||||
serialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in batch[2]
|
||||
]
|
||||
|
||||
data = {
|
||||
"ids": batch[0],
|
||||
"embeddings": optional_embeddings_to_base64_strings(batch[1])
|
||||
if self.supports_base64_encoding()
|
||||
else batch[1],
|
||||
"metadatas": serialized_metadatas,
|
||||
"documents": batch[3],
|
||||
"uris": batch[4],
|
||||
}
|
||||
|
||||
self._make_request("post", url, json=data)
|
||||
|
||||
@trace_method("FastAPI._add", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""
|
||||
Adds a batch of embeddings to the database
|
||||
- pass in column oriented data lists
|
||||
"""
|
||||
batch = (
|
||||
ids,
|
||||
embeddings,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
)
|
||||
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
|
||||
self._submit_batch(
|
||||
batch,
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/add",
|
||||
)
|
||||
return True
|
||||
|
||||
@trace_method("FastAPI._update", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""
|
||||
Updates a batch of embeddings in the database
|
||||
- pass in column oriented data lists
|
||||
"""
|
||||
batch = (
|
||||
ids,
|
||||
embeddings if embeddings is not None else None,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
)
|
||||
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
|
||||
self._submit_batch(
|
||||
batch,
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/update",
|
||||
)
|
||||
return True
|
||||
|
||||
@trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""
|
||||
Upserts a batch of embeddings in the database
|
||||
- pass in column oriented data lists
|
||||
"""
|
||||
batch = (
|
||||
ids,
|
||||
embeddings,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
)
|
||||
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
|
||||
self._submit_batch(
|
||||
batch,
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/upsert",
|
||||
)
|
||||
return True
|
||||
|
||||
@trace_method("FastAPI._query", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> QueryResult:
|
||||
# Clients do not support receiving "data", as that is hydrated by the client as a loadable
|
||||
filtered_include = [i for i in include if i != "data"]
|
||||
|
||||
"""Gets the nearest neighbors of a single embedding"""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
|
||||
json={
|
||||
"ids": ids,
|
||||
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
|
||||
if query_embeddings is not None
|
||||
else None,
|
||||
"n_results": n_results,
|
||||
"where": where,
|
||||
"where_document": where_document,
|
||||
"include": filtered_include,
|
||||
},
|
||||
)
|
||||
|
||||
# Deserialize metadatas: convert transport format to SparseVector instances
|
||||
metadata_batches = resp_json.get("metadatas", None)
|
||||
if metadata_batches is not None:
|
||||
metadata_batches = [
|
||||
[
|
||||
deserialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in metadatas
|
||||
]
|
||||
if metadatas is not None
|
||||
else None
|
||||
for metadatas in metadata_batches
|
||||
]
|
||||
|
||||
return QueryResult(
|
||||
ids=resp_json["ids"],
|
||||
distances=resp_json.get("distances", None),
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadata_batches, # type: ignore
|
||||
documents=resp_json.get("documents", None),
|
||||
uris=resp_json.get("uris", None),
|
||||
data=None,
|
||||
included=include,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def reset(self) -> bool:
|
||||
"""Resets the database"""
|
||||
resp_json = self._make_request("post", "/reset")
|
||||
return cast(bool, resp_json)
|
||||
|
||||
@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_version(self) -> str:
|
||||
"""Returns the version of the server"""
|
||||
resp_json = self._make_request("get", "/version")
|
||||
return cast(str, resp_json)
|
||||
|
||||
@override
|
||||
def get_settings(self) -> Settings:
|
||||
"""Returns the settings of the client"""
|
||||
return self._settings
|
||||
|
||||
@trace_method("FastAPI.get_pre_flight_checks", OpenTelemetryGranularity.OPERATION)
|
||||
def get_pre_flight_checks(self) -> Any:
|
||||
if self.pre_flight_checks is None:
|
||||
resp_json = self._make_request("get", "/pre-flight-checks")
|
||||
self.pre_flight_checks = resp_json
|
||||
return self.pre_flight_checks
|
||||
|
||||
@trace_method(
|
||||
"FastAPI.supports_base64_encoding", OpenTelemetryGranularity.OPERATION
|
||||
)
|
||||
def supports_base64_encoding(self) -> bool:
|
||||
pre_flight_checks = self.get_pre_flight_checks()
|
||||
b64_encoding_enabled = cast(
|
||||
bool, pre_flight_checks.get("supports_base64_encoding", False)
|
||||
)
|
||||
return b64_encoding_enabled
|
||||
|
||||
@trace_method("FastAPI.get_max_batch_size", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_max_batch_size(self) -> int:
|
||||
pre_flight_checks = self.get_pre_flight_checks()
|
||||
max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1))
|
||||
return max_batch_size
|
||||
|
||||
@trace_method("FastAPI.attach_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def attach_function(
|
||||
self,
|
||||
function_id: str,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Attach a function to a collection."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/attach",
|
||||
json={
|
||||
"name": name,
|
||||
"function_id": function_id,
|
||||
"output_collection": output_collection,
|
||||
"params": params,
|
||||
},
|
||||
)
|
||||
|
||||
return AttachedFunction(
|
||||
client=self,
|
||||
id=UUID(resp_json["attached_function"]["id"]),
|
||||
name=resp_json["attached_function"]["name"],
|
||||
function_id=resp_json["attached_function"]["function_id"],
|
||||
input_collection_id=input_collection_id,
|
||||
output_collection=output_collection,
|
||||
params=params,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""Detach a function and prevent any further runs."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
|
||||
json={
|
||||
"delete_output": delete_output,
|
||||
},
|
||||
)
|
||||
return cast(bool, resp_json["success"])
|
||||
@@ -0,0 +1,492 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union, List, cast
|
||||
|
||||
from chromadb.api.types import (
|
||||
URI,
|
||||
CollectionMetadata,
|
||||
Embedding,
|
||||
PyEmbedding,
|
||||
Include,
|
||||
Metadata,
|
||||
Document,
|
||||
Image,
|
||||
Where,
|
||||
IDs,
|
||||
GetResult,
|
||||
QueryResult,
|
||||
ID,
|
||||
OneOrMany,
|
||||
WhereDocument,
|
||||
SearchResult,
|
||||
maybe_cast_one_to_many,
|
||||
)
|
||||
|
||||
from chromadb.api.models.CollectionCommon import CollectionCommon
|
||||
from chromadb.api.collection_configuration import UpdateCollectionConfiguration
|
||||
from chromadb.execution.expression.plan import Search
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api import AsyncServerAPI # noqa: F401
|
||||
|
||||
|
||||
class AsyncCollection(CollectionCommon["AsyncServerAPI"]):
|
||||
async def add(
|
||||
self,
|
||||
ids: OneOrMany[ID],
|
||||
embeddings: Optional[
|
||||
Union[
|
||||
OneOrMany[Embedding],
|
||||
OneOrMany[PyEmbedding],
|
||||
]
|
||||
] = None,
|
||||
metadatas: Optional[OneOrMany[Metadata]] = None,
|
||||
documents: Optional[OneOrMany[Document]] = None,
|
||||
images: Optional[OneOrMany[Image]] = None,
|
||||
uris: Optional[OneOrMany[URI]] = None,
|
||||
) -> None:
|
||||
"""Add embeddings to the data store.
|
||||
Args:
|
||||
ids: The ids of the embeddings you wish to add
|
||||
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
|
||||
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
|
||||
documents: The documents to associate with the embeddings. Optional.
|
||||
images: The images to associate with the embeddings. Optional.
|
||||
uris: The uris of the images to associate with the embeddings. Optional.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ValueError: If you don't provide either embeddings or documents
|
||||
ValueError: If the length of ids, embeddings, metadatas, or documents don't match
|
||||
ValueError: If you don't provide an embedding function and don't provide embeddings
|
||||
ValueError: If you provide both embeddings and documents
|
||||
ValueError: If you provide an id that already exists
|
||||
|
||||
"""
|
||||
add_request = self._validate_and_prepare_add_request(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
images=images,
|
||||
uris=uris,
|
||||
)
|
||||
|
||||
await self._client._add(
|
||||
collection_id=self.id,
|
||||
ids=add_request["ids"],
|
||||
embeddings=add_request["embeddings"],
|
||||
metadatas=add_request["metadatas"],
|
||||
documents=add_request["documents"],
|
||||
uris=add_request["uris"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
async def count(self) -> int:
|
||||
"""The total number of embeddings added to the database
|
||||
|
||||
Returns:
|
||||
int: The total number of embeddings added to the database
|
||||
|
||||
"""
|
||||
return await self._client._count(
|
||||
collection_id=self.id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
ids: Optional[OneOrMany[ID]] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = ["metadatas", "documents"],
|
||||
) -> GetResult:
|
||||
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
|
||||
all embeddings up to limit starting at offset.
|
||||
|
||||
Args:
|
||||
ids: The ids of the embeddings to get. Optional.
|
||||
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
|
||||
limit: The number of documents to return. Optional.
|
||||
offset: The offset to start returning results from. Useful for paging results with limit. Optional.
|
||||
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
|
||||
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`. Ids are always included. Defaults to `["metadatas", "documents"]`. Optional.
|
||||
|
||||
Returns:
|
||||
GetResult: A GetResult object containing the results.
|
||||
|
||||
"""
|
||||
get_request = self._validate_and_prepare_get_request(
|
||||
ids=ids,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
)
|
||||
|
||||
get_results = await self._client._get(
|
||||
collection_id=self.id,
|
||||
ids=get_request["ids"],
|
||||
where=get_request["where"],
|
||||
where_document=get_request["where_document"],
|
||||
include=get_request["include"],
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
return self._transform_get_response(
|
||||
response=get_results, include=get_request["include"]
|
||||
)
|
||||
|
||||
async def peek(self, limit: int = 10) -> GetResult:
|
||||
"""Get the first few results in the database up to limit
|
||||
|
||||
Args:
|
||||
limit: The number of results to return.
|
||||
|
||||
Returns:
|
||||
GetResult: A GetResult object containing the results.
|
||||
"""
|
||||
return self._transform_peek_response(
|
||||
await self._client._peek(
|
||||
collection_id=self.id,
|
||||
n=limit,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query_embeddings: Optional[
|
||||
Union[
|
||||
OneOrMany[Embedding],
|
||||
OneOrMany[PyEmbedding],
|
||||
]
|
||||
] = None,
|
||||
query_texts: Optional[OneOrMany[Document]] = None,
|
||||
query_images: Optional[OneOrMany[Image]] = None,
|
||||
query_uris: Optional[OneOrMany[URI]] = None,
|
||||
ids: Optional[OneOrMany[ID]] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = [
|
||||
"metadatas",
|
||||
"documents",
|
||||
"distances",
|
||||
],
|
||||
) -> QueryResult:
|
||||
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
|
||||
|
||||
Args:
|
||||
query_embeddings: The embeddings to get the closes neighbors of. Optional.
|
||||
query_texts: The document texts to get the closes neighbors of. Optional.
|
||||
query_images: The images to get the closes neighbors of. Optional.
|
||||
ids: A subset of ids to search within. Optional.
|
||||
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
|
||||
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
|
||||
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
|
||||
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`, `"distances"`. Ids are always included. Defaults to `["metadatas", "documents", "distances"]`. Optional.
|
||||
|
||||
Returns:
|
||||
QueryResult: A QueryResult object containing the results.
|
||||
|
||||
Raises:
|
||||
ValueError: If you don't provide either query_embeddings, query_texts, or query_images
|
||||
ValueError: If you provide both query_embeddings and query_texts
|
||||
ValueError: If you provide both query_embeddings and query_images
|
||||
ValueError: If you provide both query_texts and query_images
|
||||
|
||||
"""
|
||||
|
||||
query_request = self._validate_and_prepare_query_request(
|
||||
query_embeddings=query_embeddings,
|
||||
query_texts=query_texts,
|
||||
query_images=query_images,
|
||||
query_uris=query_uris,
|
||||
ids=ids,
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
)
|
||||
|
||||
query_results = await self._client._query(
|
||||
collection_id=self.id,
|
||||
ids=query_request["ids"],
|
||||
query_embeddings=query_request["embeddings"],
|
||||
n_results=query_request["n_results"],
|
||||
where=query_request["where"],
|
||||
where_document=query_request["where_document"],
|
||||
include=query_request["include"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
return self._transform_query_response(
|
||||
response=query_results, include=query_request["include"]
|
||||
)
|
||||
|
||||
async def modify(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
) -> None:
|
||||
"""Modify the collection name or metadata
|
||||
|
||||
Args:
|
||||
name: The updated name for the collection. Optional.
|
||||
metadata: The updated metadata for the collection. Optional.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
self._validate_modify_request(metadata)
|
||||
|
||||
# Note there is a race condition here where the metadata can be updated
|
||||
# but another thread sees the cached local metadata.
|
||||
# TODO: fixme
|
||||
await self._client._modify(
|
||||
id=self.id,
|
||||
new_name=name,
|
||||
new_metadata=metadata,
|
||||
new_configuration=configuration,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
self._update_model_after_modify_success(name, metadata, configuration)
|
||||
|
||||
async def fork(
|
||||
self,
|
||||
new_name: str,
|
||||
) -> "AsyncCollection":
|
||||
"""Fork the current collection under a new name. The returning collection should contain identical data to the current collection.
|
||||
This is an experimental API that only works for Hosted Chroma for now.
|
||||
|
||||
Args:
|
||||
new_name: The name of the new collection.
|
||||
|
||||
Returns:
|
||||
Collection: A new collection with the specified name and containing identical data to the current collection.
|
||||
"""
|
||||
model = await self._client._fork(
|
||||
collection_id=self.id,
|
||||
new_name=new_name,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
return AsyncCollection(
|
||||
client=self._client,
|
||||
model=model,
|
||||
embedding_function=self._embedding_function,
|
||||
data_loader=self._data_loader,
|
||||
)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
searches: OneOrMany[Search],
|
||||
) -> SearchResult:
|
||||
"""Perform hybrid search on the collection.
|
||||
This is an experimental API that only works for Hosted Chroma for now.
|
||||
|
||||
Args:
|
||||
searches: A single Search object or a list of Search objects, each containing:
|
||||
- where: Where expression for filtering
|
||||
- rank: Ranking expression for hybrid search (defaults to Val(0.0))
|
||||
- limit: Limit configuration for pagination (defaults to no limit)
|
||||
- select: Select configuration for keys to return (defaults to empty)
|
||||
|
||||
Returns:
|
||||
SearchResult: Column-major format response with:
|
||||
- ids: List of result IDs for each search payload
|
||||
- documents: Optional documents for each payload
|
||||
- embeddings: Optional embeddings for each payload
|
||||
- metadatas: Optional metadata for each payload
|
||||
- scores: Optional scores for each payload
|
||||
- select: List of selected keys for each payload
|
||||
|
||||
Raises:
|
||||
NotImplementedError: For local/segment API implementations
|
||||
|
||||
Examples:
|
||||
# Using builder pattern with Key constants
|
||||
from chromadb.execution.expression import (
|
||||
Search, Key, K, Knn, Val
|
||||
)
|
||||
|
||||
# Note: K is an alias for Key, so K.DOCUMENT == Key.DOCUMENT
|
||||
search = (Search()
|
||||
.where((K("category") == "science") & (K("score") > 0.5))
|
||||
.rank(Knn(query=[0.1, 0.2, 0.3]) * 0.8 + Val(0.5) * 0.2)
|
||||
.limit(10, offset=0)
|
||||
.select(K.DOCUMENT, K.SCORE, "title"))
|
||||
|
||||
# Direct construction
|
||||
from chromadb.execution.expression import (
|
||||
Search, Eq, And, Gt, Knn, Limit, Select, Key
|
||||
)
|
||||
|
||||
search = Search(
|
||||
where=And([Eq("category", "science"), Gt("score", 0.5)]),
|
||||
rank=Knn(query=[0.1, 0.2, 0.3]),
|
||||
limit=Limit(offset=0, limit=10),
|
||||
select=Select(keys={Key.DOCUMENT, Key.SCORE, "title"})
|
||||
)
|
||||
|
||||
# Single search
|
||||
result = await collection.search(search)
|
||||
|
||||
# Multiple searches at once
|
||||
searches = [
|
||||
Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])),
|
||||
Search().where(K("type") == "paper").rank(Knn(query=[0.3, 0.4]))
|
||||
]
|
||||
results = await collection.search(searches)
|
||||
"""
|
||||
# Convert single search to list for consistent handling
|
||||
searches_list = maybe_cast_one_to_many(searches)
|
||||
if searches_list is None:
|
||||
searches_list = []
|
||||
|
||||
# Embed any string queries in Knn objects
|
||||
embedded_searches = [
|
||||
self._embed_search_string_queries(search) for search in searches_list
|
||||
]
|
||||
|
||||
return await self._client._search(
|
||||
collection_id=self.id,
|
||||
searches=cast(List[Search], embedded_searches),
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
async def update(
|
||||
self,
|
||||
ids: OneOrMany[ID],
|
||||
embeddings: Optional[
|
||||
Union[
|
||||
OneOrMany[Embedding],
|
||||
OneOrMany[PyEmbedding],
|
||||
]
|
||||
] = None,
|
||||
metadatas: Optional[OneOrMany[Metadata]] = None,
|
||||
documents: Optional[OneOrMany[Document]] = None,
|
||||
images: Optional[OneOrMany[Image]] = None,
|
||||
uris: Optional[OneOrMany[URI]] = None,
|
||||
) -> None:
|
||||
"""Update the embeddings, metadatas or documents for provided ids.
|
||||
|
||||
Args:
|
||||
ids: The ids of the embeddings to update
|
||||
embeddings: The embeddings to update. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
|
||||
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
|
||||
documents: The documents to associate with the embeddings. Optional.
|
||||
images: The images to associate with the embeddings. Optional.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
update_request = self._validate_and_prepare_update_request(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
images=images,
|
||||
uris=uris,
|
||||
)
|
||||
|
||||
await self._client._update(
|
||||
collection_id=self.id,
|
||||
ids=update_request["ids"],
|
||||
embeddings=update_request["embeddings"],
|
||||
metadatas=update_request["metadatas"],
|
||||
documents=update_request["documents"],
|
||||
uris=update_request["uris"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
ids: OneOrMany[ID],
|
||||
embeddings: Optional[
|
||||
Union[
|
||||
OneOrMany[Embedding],
|
||||
OneOrMany[PyEmbedding],
|
||||
]
|
||||
] = None,
|
||||
metadatas: Optional[OneOrMany[Metadata]] = None,
|
||||
documents: Optional[OneOrMany[Document]] = None,
|
||||
images: Optional[OneOrMany[Image]] = None,
|
||||
uris: Optional[OneOrMany[URI]] = None,
|
||||
) -> None:
|
||||
"""Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.
|
||||
|
||||
Args:
|
||||
ids: The ids of the embeddings to update
|
||||
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional.
|
||||
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
|
||||
documents: The documents to associate with the embeddings. Optional.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
upsert_request = self._validate_and_prepare_upsert_request(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
images=images,
|
||||
uris=uris,
|
||||
)
|
||||
|
||||
await self._client._upsert(
|
||||
collection_id=self.id,
|
||||
ids=upsert_request["ids"],
|
||||
embeddings=upsert_request["embeddings"],
|
||||
metadatas=upsert_request["metadatas"],
|
||||
documents=upsert_request["documents"],
|
||||
uris=upsert_request["uris"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
) -> None:
|
||||
"""Delete the embeddings based on ids and/or a where filter
|
||||
|
||||
Args:
|
||||
ids: The ids of the embeddings to delete
|
||||
where: A Where type dict used to filter the delection by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
|
||||
where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. `{"$contains": "hello"}`. Optional.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ValueError: If you don't provide either ids, where, or where_document
|
||||
"""
|
||||
delete_request = self._validate_and_prepare_delete_request(
|
||||
ids, where, where_document
|
||||
)
|
||||
|
||||
await self._client._delete(
|
||||
collection_id=self.id,
|
||||
ids=delete_request["ids"],
|
||||
where=delete_request["where"],
|
||||
where_document=delete_request["where_document"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
from typing import TYPE_CHECKING, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api import ServerAPI # noqa: F401
|
||||
|
||||
|
||||
class AttachedFunction:
|
||||
"""Represents a function attached to a collection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: "ServerAPI",
|
||||
id: UUID,
|
||||
name: str,
|
||||
function_id: str,
|
||||
input_collection_id: UUID,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]],
|
||||
tenant: str,
|
||||
database: str,
|
||||
):
|
||||
"""Initialize an AttachedFunction.
|
||||
|
||||
Args:
|
||||
client: The API client
|
||||
id: Unique identifier for this attached function
|
||||
name: Name of this attached function instance
|
||||
function_id: The function identifier (e.g., "record_counter")
|
||||
input_collection_id: ID of the input collection
|
||||
output_collection: Name of the output collection
|
||||
params: Function-specific parameters
|
||||
tenant: The tenant name
|
||||
database: The database name
|
||||
"""
|
||||
self._client = client
|
||||
self._id = id
|
||||
self._name = name
|
||||
self._function_id = function_id
|
||||
self._input_collection_id = input_collection_id
|
||||
self._output_collection = output_collection
|
||||
self._params = params
|
||||
self._tenant = tenant
|
||||
self._database = database
|
||||
|
||||
@property
|
||||
def id(self) -> UUID:
|
||||
"""The unique identifier of this attached function."""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of this attached function instance."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def function_id(self) -> str:
|
||||
"""The function identifier."""
|
||||
return self._function_id
|
||||
|
||||
@property
|
||||
def input_collection_id(self) -> UUID:
|
||||
"""The ID of the input collection."""
|
||||
return self._input_collection_id
|
||||
|
||||
@property
|
||||
def output_collection(self) -> str:
|
||||
"""The name of the output collection."""
|
||||
return self._output_collection
|
||||
|
||||
@property
|
||||
def params(self) -> Optional[Dict[str, Any]]:
|
||||
"""The function parameters."""
|
||||
return self._params
|
||||
|
||||
def detach(self, delete_output_collection: bool = False) -> bool:
|
||||
"""Detach this function and prevent any further runs.
|
||||
|
||||
Args:
|
||||
delete_output_collection: Whether to also delete the output collection. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
|
||||
Example:
|
||||
>>> success = attached_fn.detach(delete_output_collection=True)
|
||||
"""
|
||||
return self._client.detach_function(
|
||||
attached_function_id=self._id,
|
||||
delete_output=delete_output_collection,
|
||||
tenant=self._tenant,
|
||||
database=self._database,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AttachedFunction(id={self._id}, name='{self._name}', "
|
||||
f"function_id='{self._function_id}', "
|
||||
f"input_collection_id={self._input_collection_id}, "
|
||||
f"output_collection='{self._output_collection}')"
|
||||
)
|
||||
@@ -0,0 +1,535 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any
|
||||
|
||||
from chromadb.api.models.CollectionCommon import CollectionCommon
|
||||
from chromadb.api.types import (
|
||||
URI,
|
||||
CollectionMetadata,
|
||||
Embedding,
|
||||
PyEmbedding,
|
||||
Include,
|
||||
Metadata,
|
||||
Document,
|
||||
Image,
|
||||
Where,
|
||||
IDs,
|
||||
GetResult,
|
||||
QueryResult,
|
||||
ID,
|
||||
OneOrMany,
|
||||
WhereDocument,
|
||||
SearchResult,
|
||||
maybe_cast_one_to_many,
|
||||
)
|
||||
from chromadb.api.collection_configuration import UpdateCollectionConfiguration
|
||||
from chromadb.execution.expression.plan import Search
|
||||
|
||||
import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api.models.AttachedFunction import AttachedFunction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api import ServerAPI # noqa: F401
|
||||
|
||||
|
||||
class Collection(CollectionCommon["ServerAPI"]):
|
||||
def count(self) -> int:
|
||||
"""The total number of embeddings added to the database
|
||||
|
||||
Returns:
|
||||
int: The total number of embeddings added to the database
|
||||
|
||||
"""
|
||||
return self._client._count(
|
||||
collection_id=self.id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
ids: OneOrMany[ID],
|
||||
embeddings: Optional[
|
||||
Union[
|
||||
OneOrMany[Embedding],
|
||||
OneOrMany[PyEmbedding],
|
||||
]
|
||||
] = None,
|
||||
metadatas: Optional[OneOrMany[Metadata]] = None,
|
||||
documents: Optional[OneOrMany[Document]] = None,
|
||||
images: Optional[OneOrMany[Image]] = None,
|
||||
uris: Optional[OneOrMany[URI]] = None,
|
||||
) -> None:
|
||||
"""Add embeddings to the data store.
|
||||
Args:
|
||||
ids: The ids of the embeddings you wish to add
|
||||
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
|
||||
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
|
||||
documents: The documents to associate with the embeddings. Optional.
|
||||
images: The images to associate with the embeddings. Optional.
|
||||
uris: The uris of the images to associate with the embeddings. Optional.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ValueError: If you don't provide either embeddings or documents
|
||||
ValueError: If the length of ids, embeddings, metadatas, or documents don't match
|
||||
ValueError: If you don't provide an embedding function and don't provide embeddings
|
||||
ValueError: If you provide both embeddings and documents
|
||||
ValueError: If you provide an id that already exists
|
||||
|
||||
"""
|
||||
|
||||
add_request = self._validate_and_prepare_add_request(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
images=images,
|
||||
uris=uris,
|
||||
)
|
||||
|
||||
self._client._add(
|
||||
collection_id=self.id,
|
||||
ids=add_request["ids"],
|
||||
embeddings=add_request["embeddings"],
|
||||
metadatas=add_request["metadatas"],
|
||||
documents=add_request["documents"],
|
||||
uris=add_request["uris"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
ids: Optional[OneOrMany[ID]] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = ["metadatas", "documents"],
|
||||
) -> GetResult:
|
||||
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
|
||||
all embeddings up to limit starting at offset.
|
||||
|
||||
Args:
|
||||
ids: The ids of the embeddings to get. Optional.
|
||||
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
|
||||
limit: The number of documents to return. Optional.
|
||||
offset: The offset to start returning results from. Useful for paging results with limit. Optional.
|
||||
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
|
||||
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`. Ids are always included. Defaults to `["metadatas", "documents"]`. Optional.
|
||||
|
||||
Returns:
|
||||
GetResult: A GetResult object containing the results.
|
||||
|
||||
"""
|
||||
get_request = self._validate_and_prepare_get_request(
|
||||
ids=ids,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
)
|
||||
|
||||
get_results = self._client._get(
|
||||
collection_id=self.id,
|
||||
ids=get_request["ids"],
|
||||
where=get_request["where"],
|
||||
where_document=get_request["where_document"],
|
||||
include=get_request["include"],
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
return self._transform_get_response(
|
||||
response=get_results, include=get_request["include"]
|
||||
)
|
||||
|
||||
def peek(self, limit: int = 10) -> GetResult:
|
||||
"""Get the first few results in the database up to limit
|
||||
|
||||
Args:
|
||||
limit: The number of results to return.
|
||||
|
||||
Returns:
|
||||
GetResult: A GetResult object containing the results.
|
||||
"""
|
||||
return self._transform_peek_response(
|
||||
self._client._peek(
|
||||
collection_id=self.id,
|
||||
n=limit,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
query_embeddings: Optional[
|
||||
Union[
|
||||
OneOrMany[Embedding],
|
||||
OneOrMany[PyEmbedding],
|
||||
]
|
||||
] = None,
|
||||
query_texts: Optional[OneOrMany[Document]] = None,
|
||||
query_images: Optional[OneOrMany[Image]] = None,
|
||||
query_uris: Optional[OneOrMany[URI]] = None,
|
||||
ids: Optional[OneOrMany[ID]] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = [
|
||||
"metadatas",
|
||||
"documents",
|
||||
"distances",
|
||||
],
|
||||
) -> QueryResult:
|
||||
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
|
||||
|
||||
Args:
|
||||
query_embeddings: The embeddings to get the closes neighbors of. Optional.
|
||||
query_texts: The document texts to get the closes neighbors of. Optional.
|
||||
query_images: The images to get the closes neighbors of. Optional.
|
||||
query_uris: The URIs to be used with data loader. Optional.
|
||||
ids: A subset of ids to search within. Optional.
|
||||
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
|
||||
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
|
||||
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
|
||||
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`, `"distances"`. Ids are always included. Defaults to `["metadatas", "documents", "distances"]`. Optional.
|
||||
|
||||
Returns:
|
||||
QueryResult: A QueryResult object containing the results.
|
||||
|
||||
Raises:
|
||||
ValueError: If you don't provide either query_embeddings, query_texts, or query_images
|
||||
ValueError: If you provide both query_embeddings and query_texts
|
||||
ValueError: If you provide both query_embeddings and query_images
|
||||
ValueError: If you provide both query_texts and query_images
|
||||
|
||||
"""
|
||||
|
||||
query_request = self._validate_and_prepare_query_request(
|
||||
query_embeddings=query_embeddings,
|
||||
query_texts=query_texts,
|
||||
query_images=query_images,
|
||||
query_uris=query_uris,
|
||||
ids=ids,
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
)
|
||||
|
||||
query_results = self._client._query(
|
||||
collection_id=self.id,
|
||||
ids=query_request["ids"],
|
||||
query_embeddings=query_request["embeddings"],
|
||||
n_results=query_request["n_results"],
|
||||
where=query_request["where"],
|
||||
where_document=query_request["where_document"],
|
||||
include=query_request["include"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
return self._transform_query_response(
|
||||
response=query_results, include=query_request["include"]
|
||||
)
|
||||
|
||||
def modify(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
) -> None:
|
||||
"""Modify the collection name or metadata
|
||||
|
||||
Args:
|
||||
name: The updated name for the collection. Optional.
|
||||
metadata: The updated metadata for the collection. Optional.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
self._validate_modify_request(metadata)
|
||||
|
||||
# Note there is a race condition here where the metadata can be updated
|
||||
# but another thread sees the cached local metadata.
|
||||
# TODO: fixme
|
||||
self._client._modify(
|
||||
id=self.id,
|
||||
new_name=name,
|
||||
new_metadata=metadata,
|
||||
new_configuration=configuration,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
self._update_model_after_modify_success(name, metadata, configuration)
|
||||
|
||||
def fork(
|
||||
self,
|
||||
new_name: str,
|
||||
) -> "Collection":
|
||||
"""Fork the current collection under a new name. The returning collection should contain identical data to the current collection.
|
||||
This is an experimental API that only works for Hosted Chroma for now.
|
||||
|
||||
Args:
|
||||
new_name: The name of the new collection.
|
||||
|
||||
Returns:
|
||||
Collection: A new collection with the specified name and containing identical data to the current collection.
|
||||
"""
|
||||
model = self._client._fork(
|
||||
collection_id=self.id,
|
||||
new_name=new_name,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
return Collection(
|
||||
client=self._client,
|
||||
model=model,
|
||||
embedding_function=self._embedding_function,
|
||||
data_loader=self._data_loader,
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
searches: OneOrMany[Search],
|
||||
) -> SearchResult:
|
||||
"""Perform hybrid search on the collection.
|
||||
This is an experimental API that only works for Hosted Chroma for now.
|
||||
|
||||
Args:
|
||||
searches: A single Search object or a list of Search objects, each containing:
|
||||
- where: Where expression for filtering
|
||||
- rank: Ranking expression for hybrid search (defaults to Val(0.0))
|
||||
- limit: Limit configuration for pagination (defaults to no limit)
|
||||
- select: Select configuration for keys to return (defaults to empty)
|
||||
|
||||
Returns:
|
||||
SearchResult: Column-major format response with:
|
||||
- ids: List of result IDs for each search payload
|
||||
- documents: Optional documents for each payload
|
||||
- embeddings: Optional embeddings for each payload
|
||||
- metadatas: Optional metadata for each payload
|
||||
- scores: Optional scores for each payload
|
||||
- select: List of selected keys for each payload
|
||||
|
||||
Raises:
|
||||
NotImplementedError: For local/segment API implementations
|
||||
|
||||
Examples:
|
||||
# Using builder pattern with Key constants
|
||||
from chromadb.execution.expression import (
|
||||
Search, Key, K, Knn, Val
|
||||
)
|
||||
|
||||
# Note: K is an alias for Key, so K.DOCUMENT == Key.DOCUMENT
|
||||
search = (Search()
|
||||
.where((K("category") == "science") & (K("score") > 0.5))
|
||||
.rank(Knn(query=[0.1, 0.2, 0.3]) * 0.8 + Val(0.5) * 0.2)
|
||||
.limit(10, offset=0)
|
||||
.select(K.DOCUMENT, K.SCORE, "title"))
|
||||
|
||||
# Direct construction
|
||||
from chromadb.execution.expression import (
|
||||
Search, Eq, And, Gt, Knn, Limit, Select, Key
|
||||
)
|
||||
|
||||
search = Search(
|
||||
where=And([Eq("category", "science"), Gt("score", 0.5)]),
|
||||
rank=Knn(query=[0.1, 0.2, 0.3]),
|
||||
limit=Limit(offset=0, limit=10),
|
||||
select=Select(keys={Key.DOCUMENT, Key.SCORE, "title"})
|
||||
)
|
||||
|
||||
# Single search
|
||||
result = collection.search(search)
|
||||
|
||||
# Multiple searches at once
|
||||
searches = [
|
||||
Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])),
|
||||
Search().where(K("type") == "paper").rank(Knn(query=[0.3, 0.4]))
|
||||
]
|
||||
results = collection.search(searches)
|
||||
"""
|
||||
# Convert single search to list for consistent handling
|
||||
searches_list = maybe_cast_one_to_many(searches)
|
||||
if searches_list is None:
|
||||
searches_list = []
|
||||
|
||||
# Embed any string queries in Knn objects
|
||||
embedded_searches = [
|
||||
self._embed_search_string_queries(search) for search in searches_list
|
||||
]
|
||||
|
||||
return self._client._search(
|
||||
collection_id=self.id,
|
||||
searches=cast(List[Search], embedded_searches),
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
ids: OneOrMany[ID],
|
||||
embeddings: Optional[
|
||||
Union[
|
||||
OneOrMany[Embedding],
|
||||
OneOrMany[PyEmbedding],
|
||||
]
|
||||
] = None,
|
||||
metadatas: Optional[OneOrMany[Metadata]] = None,
|
||||
documents: Optional[OneOrMany[Document]] = None,
|
||||
images: Optional[OneOrMany[Image]] = None,
|
||||
uris: Optional[OneOrMany[URI]] = None,
|
||||
) -> None:
|
||||
"""Update the embeddings, metadatas or documents for provided ids.
|
||||
|
||||
Args:
|
||||
ids: The ids of the embeddings to update
|
||||
embeddings: The embeddings to update. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
|
||||
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
|
||||
documents: The documents to associate with the embeddings. Optional.
|
||||
images: The images to associate with the embeddings. Optional.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
update_request = self._validate_and_prepare_update_request(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
images=images,
|
||||
uris=uris,
|
||||
)
|
||||
|
||||
self._client._update(
|
||||
collection_id=self.id,
|
||||
ids=update_request["ids"],
|
||||
embeddings=update_request["embeddings"],
|
||||
metadatas=update_request["metadatas"],
|
||||
documents=update_request["documents"],
|
||||
uris=update_request["uris"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def upsert(
|
||||
self,
|
||||
ids: OneOrMany[ID],
|
||||
embeddings: Optional[
|
||||
Union[
|
||||
OneOrMany[Embedding],
|
||||
OneOrMany[PyEmbedding],
|
||||
]
|
||||
] = None,
|
||||
metadatas: Optional[OneOrMany[Metadata]] = None,
|
||||
documents: Optional[OneOrMany[Document]] = None,
|
||||
images: Optional[OneOrMany[Image]] = None,
|
||||
uris: Optional[OneOrMany[URI]] = None,
|
||||
) -> None:
|
||||
"""Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.
|
||||
|
||||
Args:
|
||||
ids: The ids of the embeddings to update
|
||||
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional.
|
||||
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
|
||||
documents: The documents to associate with the embeddings. Optional.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
upsert_request = self._validate_and_prepare_upsert_request(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
images=images,
|
||||
uris=uris,
|
||||
)
|
||||
|
||||
self._client._upsert(
|
||||
collection_id=self.id,
|
||||
ids=upsert_request["ids"],
|
||||
embeddings=upsert_request["embeddings"],
|
||||
metadatas=upsert_request["metadatas"],
|
||||
documents=upsert_request["documents"],
|
||||
uris=upsert_request["uris"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
) -> None:
|
||||
"""Delete the embeddings based on ids and/or a where filter
|
||||
|
||||
Args:
|
||||
ids: The ids of the embeddings to delete
|
||||
where: A Where type dict used to filter the delection by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}]}}`. Optional.
|
||||
where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. `{"$contains": "hello"}`. Optional.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ValueError: If you don't provide either ids, where, or where_document
|
||||
"""
|
||||
delete_request = self._validate_and_prepare_delete_request(
|
||||
ids, where, where_document
|
||||
)
|
||||
|
||||
self._client._delete(
|
||||
collection_id=self.id,
|
||||
ids=delete_request["ids"],
|
||||
where=delete_request["where"],
|
||||
where_document=delete_request["where_document"],
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def attach_function(
|
||||
self,
|
||||
function_id: str,
|
||||
name: str,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> "AttachedFunction":
|
||||
"""Attach a function to this collection.
|
||||
|
||||
Args:
|
||||
function_id: Built-in function identifier (e.g., "record_counter")
|
||||
name: Unique name for this attached function
|
||||
output_collection: Name of the collection where function output will be stored
|
||||
params: Optional dictionary with function-specific parameters
|
||||
|
||||
Returns:
|
||||
AttachedFunction: Object representing the attached function
|
||||
|
||||
Example:
|
||||
>>> attached_fn = collection.attach_function(
|
||||
... function_id="record_counter",
|
||||
... name="mycoll_stats_fn",
|
||||
... output_collection="mycoll_stats",
|
||||
... params={"threshold": 100}
|
||||
... )
|
||||
"""
|
||||
return self._client.attach_function(
|
||||
function_id=function_id,
|
||||
name=name,
|
||||
input_collection_id=self.id,
|
||||
output_collection=output_collection,
|
||||
params=params,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,644 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from chromadb import (
|
||||
CollectionMetadata,
|
||||
Embeddings,
|
||||
GetResult,
|
||||
IDs,
|
||||
Where,
|
||||
WhereDocument,
|
||||
Include,
|
||||
Documents,
|
||||
Metadatas,
|
||||
QueryResult,
|
||||
URIs,
|
||||
)
|
||||
from chromadb.api import ServerAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api.models.AttachedFunction import AttachedFunction
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
create_collection_configuration_to_json_str,
|
||||
update_collection_configuration_to_json_str,
|
||||
)
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
|
||||
from chromadb.telemetry.product import ProductTelemetryClient
|
||||
from chromadb.telemetry.product.events import (
|
||||
CollectionAddEvent,
|
||||
CollectionDeleteEvent,
|
||||
CollectionGetEvent,
|
||||
CollectionUpdateEvent,
|
||||
CollectionQueryEvent,
|
||||
ClientCreateCollectionEvent,
|
||||
)
|
||||
|
||||
from chromadb.api.types import (
|
||||
IncludeMetadataDocuments,
|
||||
IncludeMetadataDocumentsDistances,
|
||||
IncludeMetadataDocumentsEmbeddings,
|
||||
Schema,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
# TODO(hammadb): Unify imports across types vs root __init__.py
|
||||
from chromadb.types import Database, Tenant, Collection as CollectionModel
|
||||
from chromadb.execution.expression.plan import Search
|
||||
import chromadb_rust_bindings
|
||||
|
||||
|
||||
from typing import Optional, Sequence, List, Dict, Any
|
||||
from overrides import override
|
||||
from uuid import UUID
|
||||
import json
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import resource
|
||||
elif platform.system() == "Windows":
|
||||
import ctypes
|
||||
|
||||
|
||||
# RustBindingsAPI is an implementation of ServerAPI which shims
|
||||
# the Rust bindings to the Python API, providing a full implementation
|
||||
# of the API. It could be that bindings was a direct implementation of
|
||||
# ServerAPI, but in order to prevent propagating the bindings types
|
||||
# into the Python API, we have to shim it here so we can convert into
|
||||
# the legacy Python types.
|
||||
# TODO(hammadb): Propagate the types from the bindings into the Python API
|
||||
# and remove the python-level types entirely.
|
||||
class RustBindingsAPI(ServerAPI):
|
||||
bindings: chromadb_rust_bindings.Bindings
|
||||
hnsw_cache_size: int
|
||||
product_telemetry_client: ProductTelemetryClient
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self.product_telemetry_client = self.require(ProductTelemetryClient)
|
||||
|
||||
if platform.system() != "Windows":
|
||||
max_file_handles = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
|
||||
else:
|
||||
max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore
|
||||
self.hnsw_cache_size = (
|
||||
max_file_handles
|
||||
# This is integer division in Python 3, and not a comment.
|
||||
# Each HNSW index has 4 data files and 1 metadata file
|
||||
// 5
|
||||
)
|
||||
|
||||
@override
|
||||
def start(self) -> None:
|
||||
# Construct the SqliteConfig
|
||||
# TOOD: We should add a "config converter"
|
||||
if self._system.settings.require("is_persistent"):
|
||||
persist_path = self._system.settings.require("persist_directory")
|
||||
sqlite_persist_path = persist_path + "/chroma.sqlite3"
|
||||
else:
|
||||
persist_path = None
|
||||
sqlite_persist_path = None
|
||||
hash_type = self._system.settings.require("migrations_hash_algorithm")
|
||||
hash_type_bindings = (
|
||||
chromadb_rust_bindings.MigrationHash.MD5
|
||||
if hash_type == "md5"
|
||||
else chromadb_rust_bindings.MigrationHash.SHA256
|
||||
)
|
||||
migration_mode = self._system.settings.require("migrations")
|
||||
migration_mode_bindings = (
|
||||
chromadb_rust_bindings.MigrationMode.Apply
|
||||
if migration_mode == "apply"
|
||||
else chromadb_rust_bindings.MigrationMode.Validate
|
||||
)
|
||||
sqlite_config = chromadb_rust_bindings.SqliteDBConfig(
|
||||
hash_type=hash_type_bindings,
|
||||
migration_mode=migration_mode_bindings,
|
||||
url=sqlite_persist_path,
|
||||
)
|
||||
|
||||
self.bindings = chromadb_rust_bindings.Bindings(
|
||||
allow_reset=self._system.settings.require("allow_reset"),
|
||||
sqlite_db_config=sqlite_config,
|
||||
persist_path=persist_path,
|
||||
hnsw_cache_size=self.hnsw_cache_size,
|
||||
)
|
||||
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
del self.bindings
|
||||
|
||||
# ////////////////////////////// Admin API //////////////////////////////
|
||||
|
||||
@override
|
||||
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
return self.bindings.create_database(name, tenant)
|
||||
|
||||
@override
|
||||
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
||||
database = self.bindings.get_database(name, tenant)
|
||||
return {
|
||||
"id": database.id,
|
||||
"name": database.name,
|
||||
"tenant": database.tenant,
|
||||
}
|
||||
|
||||
@override
|
||||
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
return self.bindings.delete_database(name, tenant)
|
||||
|
||||
@override
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
databases = self.bindings.list_databases(limit, offset, tenant)
|
||||
return [
|
||||
{
|
||||
"id": database.id,
|
||||
"name": database.name,
|
||||
"tenant": database.tenant,
|
||||
}
|
||||
for database in databases
|
||||
]
|
||||
|
||||
@override
|
||||
def create_tenant(self, name: str) -> None:
|
||||
return self.bindings.create_tenant(name)
|
||||
|
||||
@override
|
||||
def get_tenant(self, name: str) -> Tenant:
|
||||
tenant = self.bindings.get_tenant(name)
|
||||
return Tenant(name=tenant.name)
|
||||
|
||||
# ////////////////////////////// Base API //////////////////////////////
|
||||
|
||||
@override
|
||||
def heartbeat(self) -> int:
|
||||
return self.bindings.heartbeat()
|
||||
|
||||
@override
|
||||
def count_collections(
|
||||
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
||||
) -> int:
|
||||
return self.bindings.count_collections(tenant, database)
|
||||
|
||||
@override
|
||||
def list_collections(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Sequence[CollectionModel]:
|
||||
collections = self.bindings.list_collections(limit, offset, tenant, database)
|
||||
return [
|
||||
CollectionModel(
|
||||
id=collection.id,
|
||||
name=collection.name,
|
||||
serialized_schema=collection.schema,
|
||||
configuration_json=collection.configuration,
|
||||
metadata=collection.metadata,
|
||||
dimension=collection.dimension,
|
||||
tenant=collection.tenant,
|
||||
database=collection.database,
|
||||
)
|
||||
for collection in collections
|
||||
]
|
||||
|
||||
@override
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
# TODO: This event doesn't capture the get_or_create case appropriately
|
||||
# TODO: Re-enable embedding function tracking in create_collection
|
||||
self.product_telemetry_client.capture(
|
||||
ClientCreateCollectionEvent(
|
||||
collection_uuid=str(id),
|
||||
# embedding_function=embedding_function.__class__.__name__,
|
||||
)
|
||||
)
|
||||
if configuration:
|
||||
configuration_json_str = create_collection_configuration_to_json_str(
|
||||
configuration, metadata
|
||||
)
|
||||
else:
|
||||
configuration_json_str = None
|
||||
|
||||
if schema:
|
||||
schema_str = json.dumps(schema.serialize_to_json())
|
||||
else:
|
||||
schema_str = None
|
||||
|
||||
collection = self.bindings.create_collection(
|
||||
name,
|
||||
configuration_json_str,
|
||||
schema_str,
|
||||
metadata,
|
||||
get_or_create,
|
||||
tenant,
|
||||
database,
|
||||
)
|
||||
collection_model = CollectionModel(
|
||||
id=collection.id,
|
||||
name=collection.name,
|
||||
configuration_json=collection.configuration,
|
||||
serialized_schema=collection.schema,
|
||||
metadata=collection.metadata,
|
||||
dimension=collection.dimension,
|
||||
tenant=collection.tenant,
|
||||
database=collection.database,
|
||||
)
|
||||
return collection_model
|
||||
|
||||
@override
|
||||
def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
collection = self.bindings.get_collection(name, tenant, database)
|
||||
return CollectionModel(
|
||||
id=collection.id,
|
||||
name=collection.name,
|
||||
configuration_json=collection.configuration,
|
||||
serialized_schema=collection.schema,
|
||||
metadata=collection.metadata,
|
||||
dimension=collection.dimension,
|
||||
tenant=collection.tenant,
|
||||
database=collection.database,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
return self.create_collection(
|
||||
name, schema, configuration, metadata, True, tenant, database
|
||||
)
|
||||
|
||||
@override
|
||||
def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
self.bindings.delete_collection(name, tenant, database)
|
||||
|
||||
@override
|
||||
def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
if new_configuration:
|
||||
new_configuration_json_str = update_collection_configuration_to_json_str(
|
||||
new_configuration
|
||||
)
|
||||
else:
|
||||
new_configuration_json_str = None
|
||||
self.bindings.update_collection(
|
||||
str(id), new_name, new_metadata, new_configuration_json_str
|
||||
)
|
||||
|
||||
@override
|
||||
def _fork(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
new_name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
raise NotImplementedError(
|
||||
"Collection forking is not implemented for Local Chroma"
|
||||
)
|
||||
|
||||
@override
|
||||
def _search(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
searches: List[Search],
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> SearchResult:
|
||||
raise NotImplementedError("Search is not implemented for Local Chroma")
|
||||
|
||||
@override
|
||||
def _count(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> int:
|
||||
return self.bindings.count(str(collection_id), tenant, database)
|
||||
|
||||
@override
|
||||
def _peek(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
n: int = 10,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
return self._get(
|
||||
collection_id,
|
||||
limit=n,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
include=IncludeMetadataDocumentsEmbeddings,
|
||||
)
|
||||
|
||||
@override
|
||||
def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
ids_amount = len(ids) if ids else 0
|
||||
self.product_telemetry_client.capture(
|
||||
CollectionGetEvent(
|
||||
collection_uuid=str(collection_id),
|
||||
ids_count=ids_amount,
|
||||
limit=limit if limit else 0,
|
||||
include_metadata=ids_amount if "metadatas" in include else 0,
|
||||
include_documents=ids_amount if "documents" in include else 0,
|
||||
include_uris=ids_amount if "uris" in include else 0,
|
||||
)
|
||||
)
|
||||
|
||||
rust_response = self.bindings.get(
|
||||
str(collection_id),
|
||||
ids,
|
||||
json.dumps(where) if where else None,
|
||||
limit,
|
||||
offset or 0,
|
||||
json.dumps(where_document) if where_document else None,
|
||||
include,
|
||||
tenant,
|
||||
database,
|
||||
)
|
||||
|
||||
return GetResult(
|
||||
ids=rust_response.ids,
|
||||
embeddings=rust_response.embeddings,
|
||||
documents=rust_response.documents,
|
||||
uris=rust_response.uris,
|
||||
included=include,
|
||||
data=None,
|
||||
metadatas=rust_response.metadatas,
|
||||
)
|
||||
|
||||
@override
|
||||
def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
self.product_telemetry_client.capture(
|
||||
CollectionAddEvent(
|
||||
collection_uuid=str(collection_id),
|
||||
add_amount=len(ids),
|
||||
with_metadata=len(ids) if metadatas is not None else 0,
|
||||
with_documents=len(ids) if documents is not None else 0,
|
||||
with_uris=len(ids) if uris is not None else 0,
|
||||
)
|
||||
)
|
||||
|
||||
return self.bindings.add(
|
||||
ids,
|
||||
str(collection_id),
|
||||
embeddings,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
tenant,
|
||||
database,
|
||||
)
|
||||
|
||||
@override
|
||||
def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
self.product_telemetry_client.capture(
|
||||
CollectionUpdateEvent(
|
||||
collection_uuid=str(collection_id),
|
||||
update_amount=len(ids),
|
||||
with_embeddings=len(embeddings) if embeddings else 0,
|
||||
with_metadata=len(metadatas) if metadatas else 0,
|
||||
with_documents=len(documents) if documents else 0,
|
||||
with_uris=len(uris) if uris else 0,
|
||||
)
|
||||
)
|
||||
|
||||
return self.bindings.update(
|
||||
str(collection_id),
|
||||
ids,
|
||||
embeddings,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
tenant,
|
||||
database,
|
||||
)
|
||||
|
||||
@override
|
||||
def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
return self.bindings.upsert(
|
||||
str(collection_id),
|
||||
ids,
|
||||
embeddings,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
tenant,
|
||||
database,
|
||||
)
|
||||
|
||||
@override
|
||||
def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> QueryResult:
|
||||
query_amount = len(query_embeddings)
|
||||
filtered_ids_amount = len(ids) if ids else 0
|
||||
self.product_telemetry_client.capture(
|
||||
CollectionQueryEvent(
|
||||
collection_uuid=str(collection_id),
|
||||
query_amount=query_amount,
|
||||
filtered_ids_amount=filtered_ids_amount,
|
||||
n_results=n_results,
|
||||
with_metadata_filter=query_amount if where is not None else 0,
|
||||
with_document_filter=query_amount if where_document is not None else 0,
|
||||
include_metadatas=query_amount if "metadatas" in include else 0,
|
||||
include_documents=query_amount if "documents" in include else 0,
|
||||
include_uris=query_amount if "uris" in include else 0,
|
||||
include_distances=query_amount if "distances" in include else 0,
|
||||
)
|
||||
)
|
||||
|
||||
rust_response = self.bindings.query(
|
||||
str(collection_id),
|
||||
ids,
|
||||
query_embeddings,
|
||||
n_results,
|
||||
json.dumps(where) if where else None,
|
||||
json.dumps(where_document) if where_document else None,
|
||||
include,
|
||||
tenant,
|
||||
database,
|
||||
)
|
||||
|
||||
return QueryResult(
|
||||
ids=rust_response.ids,
|
||||
embeddings=rust_response.embeddings,
|
||||
documents=rust_response.documents,
|
||||
uris=rust_response.uris,
|
||||
included=include,
|
||||
data=None,
|
||||
metadatas=rust_response.metadatas,
|
||||
distances=rust_response.distances,
|
||||
)
|
||||
|
||||
@override
|
||||
def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
self.product_telemetry_client.capture(
|
||||
CollectionDeleteEvent(
|
||||
# NOTE: the delete amount is not observable from python
|
||||
# TODO: Fix this when posthog is pushed into Rust frontend
|
||||
collection_uuid=str(collection_id),
|
||||
delete_amount=0,
|
||||
)
|
||||
)
|
||||
|
||||
return self.bindings.delete(
|
||||
str(collection_id),
|
||||
ids,
|
||||
json.dumps(where) if where else None,
|
||||
json.dumps(where_document) if where_document else None,
|
||||
tenant,
|
||||
database,
|
||||
)
|
||||
|
||||
@override
|
||||
def reset(self) -> bool:
|
||||
return self.bindings.reset()
|
||||
|
||||
@override
|
||||
def get_version(self) -> str:
|
||||
return self.bindings.get_version()
|
||||
|
||||
@override
|
||||
def get_settings(self) -> Settings:
|
||||
return self._system.settings
|
||||
|
||||
@override
|
||||
def get_max_batch_size(self) -> int:
|
||||
return self.bindings.get_max_batch_size()
|
||||
|
||||
@override
|
||||
def attach_function(
|
||||
self,
|
||||
function_id: str,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
|
||||
"The Rust bindings (embedded mode) do not support attached function operations."
|
||||
)
|
||||
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""Attached functions are not supported in the Rust bindings (local embedded mode)."""
|
||||
raise NotImplementedError(
|
||||
"Attached functions are only supported when connecting to a Chroma server via HttpClient. "
|
||||
"The Rust bindings (embedded mode) do not support attached function operations."
|
||||
)
|
||||
|
||||
# TODO: Remove this if it's not planned to be used
|
||||
@override
|
||||
def get_user_identity(self) -> UserIdentity:
|
||||
return UserIdentity(
|
||||
user_id="",
|
||||
tenant=DEFAULT_TENANT,
|
||||
databases=[DEFAULT_DATABASE],
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,96 @@
|
||||
from typing import ClassVar, Dict
|
||||
import uuid
|
||||
|
||||
from chromadb.api import ServerAPI
|
||||
from chromadb.config import Settings, System
|
||||
from chromadb.telemetry.product import ProductTelemetryClient
|
||||
from chromadb.telemetry.product.events import ClientStartEvent
|
||||
|
||||
|
||||
class SharedSystemClient:
|
||||
_identifier_to_system: ClassVar[Dict[str, System]] = {}
|
||||
_identifier: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings = Settings(),
|
||||
) -> None:
|
||||
self._identifier = SharedSystemClient._get_identifier_from_settings(settings)
|
||||
SharedSystemClient._create_system_if_not_exists(self._identifier, settings)
|
||||
|
||||
@classmethod
|
||||
def _create_system_if_not_exists(
|
||||
cls, identifier: str, settings: Settings
|
||||
) -> System:
|
||||
if identifier not in cls._identifier_to_system:
|
||||
new_system = System(settings)
|
||||
cls._identifier_to_system[identifier] = new_system
|
||||
|
||||
new_system.instance(ProductTelemetryClient)
|
||||
new_system.instance(ServerAPI)
|
||||
|
||||
new_system.start()
|
||||
else:
|
||||
previous_system = cls._identifier_to_system[identifier]
|
||||
|
||||
# For now, the settings must match
|
||||
if previous_system.settings != settings:
|
||||
raise ValueError(
|
||||
f"An instance of Chroma already exists for {identifier} with different settings"
|
||||
)
|
||||
|
||||
return cls._identifier_to_system[identifier]
|
||||
|
||||
@staticmethod
|
||||
def _get_identifier_from_settings(settings: Settings) -> str:
|
||||
identifier = ""
|
||||
api_impl = settings.chroma_api_impl
|
||||
|
||||
if api_impl is None:
|
||||
raise ValueError("Chroma API implementation must be set in settings")
|
||||
elif api_impl in [
|
||||
"chromadb.api.segment.SegmentAPI",
|
||||
"chromadb.api.rust.RustBindingsAPI",
|
||||
]:
|
||||
if settings.is_persistent:
|
||||
identifier = settings.persist_directory
|
||||
else:
|
||||
identifier = (
|
||||
"ephemeral" # TODO: support pathing and multiple ephemeral clients
|
||||
)
|
||||
elif api_impl in [
|
||||
"chromadb.api.fastapi.FastAPI",
|
||||
"chromadb.api.async_fastapi.AsyncFastAPI",
|
||||
]:
|
||||
# FastAPI clients can all use unique system identifiers since their configurations can be independent, e.g. different auth tokens
|
||||
identifier = str(uuid.uuid4())
|
||||
else:
|
||||
raise ValueError(f"Unsupported Chroma API implementation {api_impl}")
|
||||
|
||||
return identifier
|
||||
|
||||
@staticmethod
|
||||
def _populate_data_from_system(system: System) -> str:
|
||||
identifier = SharedSystemClient._get_identifier_from_settings(system.settings)
|
||||
SharedSystemClient._identifier_to_system[identifier] = system
|
||||
return identifier
|
||||
|
||||
@classmethod
|
||||
def from_system(cls, system: System) -> "SharedSystemClient":
|
||||
"""Create a client from an existing system. This is useful for testing and debugging."""
|
||||
|
||||
SharedSystemClient._populate_data_from_system(system)
|
||||
instance = cls(system.settings)
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def clear_system_cache() -> None:
|
||||
SharedSystemClient._identifier_to_system = {}
|
||||
|
||||
@property
|
||||
def _system(self) -> System:
|
||||
return SharedSystemClient._identifier_to_system[self._identifier]
|
||||
|
||||
def _submit_client_start_event(self) -> None:
|
||||
telemetry_client = self._system.instance(ProductTelemetryClient)
|
||||
telemetry_client.capture(ClientStartEvent())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,7 @@
|
||||
import chromadb
|
||||
import chromadb.config
|
||||
from chromadb.server.fastapi import FastAPI
|
||||
|
||||
settings = chromadb.config.Settings()
|
||||
server = FastAPI(settings)
|
||||
app = server.app()
|
||||
@@ -0,0 +1,237 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Dict,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from chromadb.config import (
|
||||
Component,
|
||||
System,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
ClientAuthHeaders = Dict[str, SecretStr]
|
||||
|
||||
|
||||
class ClientAuthProvider(Component):
|
||||
"""
|
||||
ClientAuthProvider is responsible for providing authentication headers for
|
||||
client requests. Client implementations (in our case, just the FastAPI
|
||||
client) must inject these headers into their requests.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
|
||||
@abstractmethod
|
||||
def authenticate(self) -> ClientAuthHeaders:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserIdentity:
|
||||
"""
|
||||
UserIdentity represents the identity of a user. In general, not all fields
|
||||
will be populated, and the fields that are populated will depend on the
|
||||
authentication provider.
|
||||
|
||||
The idea is that the AuthenticationProvider is responsible for populating
|
||||
_all_ information known about the user, and the AuthorizationProvider is
|
||||
responsible for making decisions based on that information.
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
tenant: Optional[str] = None
|
||||
databases: Optional[List[str]] = None
|
||||
# This can be used for any additional auth context which needs to be
|
||||
# propagated from the authentication provider to the authorization
|
||||
# provider.
|
||||
attributes: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ServerAuthenticationProvider(Component):
|
||||
"""
|
||||
ServerAuthenticationProvider is responsible for authenticating requests. If
|
||||
a ServerAuthenticationProvider is configured, it will be called by the
|
||||
server to authenticate requests. If no ServerAuthenticationProvider is
|
||||
configured, all requests will be authenticated.
|
||||
|
||||
The ServerAuthenticationProvider should return a UserIdentity object if the
|
||||
request is authenticated for use by the ServerAuthorizationProvider.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._ignore_auth_paths: Dict[
|
||||
str, List[str]
|
||||
] = system.settings.chroma_server_auth_ignore_paths
|
||||
self.overwrite_singleton_tenant_database_access_from_auth = (
|
||||
system.settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
|
||||
pass
|
||||
|
||||
def ignore_operation(self, verb: str, path: str) -> bool:
|
||||
if (
|
||||
path in self._ignore_auth_paths.keys()
|
||||
and verb.upper() in self._ignore_auth_paths[path]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def read_creds_or_creds_file(self) -> List[str]:
|
||||
_creds_file = None
|
||||
_creds = None
|
||||
|
||||
if self._system.settings.chroma_server_authn_credentials_file:
|
||||
_creds_file = str(
|
||||
self._system.settings["chroma_server_authn_credentials_file"]
|
||||
)
|
||||
if self._system.settings.chroma_server_authn_credentials:
|
||||
_creds = str(self._system.settings["chroma_server_authn_credentials"])
|
||||
if not _creds_file and not _creds:
|
||||
raise ValueError(
|
||||
"No credentials file or credentials found in "
|
||||
"[chroma_server_authn_credentials]."
|
||||
)
|
||||
if _creds_file and _creds:
|
||||
raise ValueError(
|
||||
"Both credentials file and credentials found."
|
||||
"Please provide only one."
|
||||
)
|
||||
if _creds:
|
||||
return [c for c in _creds.split("\n") if c]
|
||||
elif _creds_file:
|
||||
with open(_creds_file, "r") as f:
|
||||
return f.readlines()
|
||||
raise ValueError("Should never happen")
|
||||
|
||||
def singleton_tenant_database_if_applicable(
|
||||
self, user: Optional[UserIdentity]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
is False, this function always returns (None, None).
|
||||
|
||||
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
is True, follows the following logic:
|
||||
- If the user only has access to a single tenant, this function will
|
||||
return that tenant as its first return value.
|
||||
- If the user only has access to a single database, this function will
|
||||
return that database as its second return value. If the user has
|
||||
access to multiple tenants and/or databases, including "*", this
|
||||
function will return None for the corresponding value(s).
|
||||
- If the user has access to multiple tenants and/or databases this
|
||||
function will return None for the corresponding value(s).
|
||||
"""
|
||||
if not self.overwrite_singleton_tenant_database_access_from_auth or not user:
|
||||
return None, None
|
||||
tenant = None
|
||||
database = None
|
||||
if user.tenant and user.tenant != "*":
|
||||
tenant = user.tenant
|
||||
if user.databases and len(user.databases) == 1 and user.databases[0] != "*":
|
||||
database = user.databases[0]
|
||||
return tenant, database
|
||||
|
||||
|
||||
class AuthzAction(str, Enum):
|
||||
"""
|
||||
The set of actions that can be authorized by the authorization provider.
|
||||
"""
|
||||
|
||||
RESET = "system:reset"
|
||||
CREATE_TENANT = "tenant:create_tenant"
|
||||
GET_TENANT = "tenant:get_tenant"
|
||||
CREATE_DATABASE = "db:create_database"
|
||||
GET_DATABASE = "db:get_database"
|
||||
DELETE_DATABASE = "db:delete_database"
|
||||
LIST_DATABASES = "db:list_databases"
|
||||
LIST_COLLECTIONS = "db:list_collections"
|
||||
COUNT_COLLECTIONS = "db:count_collections"
|
||||
CREATE_COLLECTION = "db:create_collection"
|
||||
GET_OR_CREATE_COLLECTION = "db:get_or_create_collection"
|
||||
GET_COLLECTION = "collection:get_collection"
|
||||
DELETE_COLLECTION = "collection:delete_collection"
|
||||
UPDATE_COLLECTION = "collection:update_collection"
|
||||
ADD = "collection:add"
|
||||
DELETE = "collection:delete"
|
||||
GET = "collection:get"
|
||||
QUERY = "collection:query"
|
||||
COUNT = "collection:count"
|
||||
UPDATE = "collection:update"
|
||||
UPSERT = "collection:upsert"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthzResource:
|
||||
"""
|
||||
The resource being accessed in an authorization request.
|
||||
"""
|
||||
|
||||
tenant: Optional[str]
|
||||
database: Optional[str]
|
||||
collection: Optional[str]
|
||||
|
||||
|
||||
class ServerAuthorizationProvider(Component):
|
||||
"""
|
||||
ServerAuthorizationProvider is responsible for authorizing requests. If a
|
||||
ServerAuthorizationProvider is configured, it will be called by the server
|
||||
to authorize requests. If no ServerAuthorizationProvider is configured, all
|
||||
requests will be authorized.
|
||||
|
||||
ServerAuthorizationProvider should raise an exception if the request is not
|
||||
authorized.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
|
||||
@abstractmethod
|
||||
def authorize_or_raise(
|
||||
self, user: UserIdentity, action: AuthzAction, resource: AuthzResource
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def read_config_or_config_file(self) -> List[str]:
|
||||
_config_file = None
|
||||
_config = None
|
||||
if self._system.settings.chroma_server_authz_config_file:
|
||||
_config_file = self._system.settings["chroma_server_authz_config_file"]
|
||||
if self._system.settings.chroma_server_authz_config:
|
||||
_config = str(self._system.settings["chroma_server_authz_config"])
|
||||
if not _config_file and not _config:
|
||||
raise ValueError(
|
||||
"No authz configuration file or authz configuration found."
|
||||
)
|
||||
if _config_file and _config:
|
||||
raise ValueError(
|
||||
"Both authz configuration file and authz configuration found."
|
||||
"Please provide only one."
|
||||
)
|
||||
if _config:
|
||||
return [c for c in _config.split("\n") if c]
|
||||
elif _config_file:
|
||||
with open(_config_file, "r") as f:
|
||||
return f.readlines()
|
||||
raise ValueError("Should never happen")
|
||||
Binary file not shown.
@@ -0,0 +1,146 @@
|
||||
import base64
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import bcrypt
|
||||
import logging
|
||||
|
||||
from overrides import override
|
||||
from pydantic import SecretStr
|
||||
|
||||
from chromadb.auth import (
|
||||
UserIdentity,
|
||||
ServerAuthenticationProvider,
|
||||
ClientAuthProvider,
|
||||
ClientAuthHeaders,
|
||||
AuthError,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from chromadb.errors import ChromaAuthError
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["BasicAuthenticationServerProvider", "BasicAuthClientProvider"]
|
||||
|
||||
AUTHORIZATION_HEADER = "Authorization"
|
||||
|
||||
|
||||
class BasicAuthClientProvider(ClientAuthProvider):
|
||||
"""
|
||||
Client auth provider for basic auth. The credentials are passed as a
|
||||
base64-encoded string in the Authorization header prepended with "Basic ".
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
system.settings.require("chroma_client_auth_credentials")
|
||||
self._creds = SecretStr(str(system.settings.chroma_client_auth_credentials))
|
||||
|
||||
@override
|
||||
def authenticate(self) -> ClientAuthHeaders:
|
||||
encoded = base64.b64encode(
|
||||
f"{self._creds.get_secret_value()}".encode("utf-8")
|
||||
).decode("utf-8")
|
||||
return {
|
||||
AUTHORIZATION_HEADER: SecretStr(f"Basic {encoded}"),
|
||||
}
|
||||
|
||||
|
||||
class BasicAuthenticationServerProvider(ServerAuthenticationProvider):
|
||||
"""
|
||||
Server auth provider for basic auth. The credentials are read from
|
||||
`chroma_server_authn_credentials_file` and each line must be in the format
|
||||
<username>:<bcrypt passwd>.
|
||||
|
||||
Expects tokens to be passed as a base64-encoded string in the Authorization
|
||||
header prepended with "Basic".
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
|
||||
self._creds: Dict[str, SecretStr] = {}
|
||||
creds = self.read_creds_or_creds_file()
|
||||
|
||||
for line in creds:
|
||||
if not line.strip():
|
||||
continue
|
||||
_raw_creds = [v for v in line.strip().split(":")]
|
||||
if (
|
||||
_raw_creds
|
||||
and _raw_creds[0]
|
||||
and len(_raw_creds) != 2
|
||||
or not all(_raw_creds)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid htpasswd credentials found: {_raw_creds}. "
|
||||
"Lines must be exactly <username>:<bcrypt passwd>."
|
||||
)
|
||||
username = _raw_creds[0]
|
||||
password = _raw_creds[1]
|
||||
if username in self._creds:
|
||||
raise ValueError(
|
||||
"Duplicate username found in "
|
||||
"[chroma_server_authn_credentials]. "
|
||||
"Usernames must be unique."
|
||||
)
|
||||
self._creds[username] = SecretStr(password)
|
||||
|
||||
@trace_method(
|
||||
"BasicAuthenticationServerProvider.authenticate", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
|
||||
try:
|
||||
if AUTHORIZATION_HEADER.lower() not in headers.keys():
|
||||
raise AuthError(AUTHORIZATION_HEADER + " header not found")
|
||||
_auth_header = headers[AUTHORIZATION_HEADER.lower()]
|
||||
_auth_header = re.sub(r"^Basic ", "", _auth_header)
|
||||
_auth_header = _auth_header.strip()
|
||||
|
||||
base64_decoded = base64.b64decode(_auth_header).decode("utf-8")
|
||||
if ":" not in base64_decoded:
|
||||
raise AuthError("Invalid Authorization header format")
|
||||
username, password = base64_decoded.split(":", 1)
|
||||
username = str(username) # convert to string to prevent header injection
|
||||
password = str(password) # convert to string to prevent header injection
|
||||
if username not in self._creds:
|
||||
raise AuthError("Invalid username or password")
|
||||
|
||||
_pwd_check = bcrypt.checkpw(
|
||||
password.encode("utf-8"),
|
||||
self._creds[username].get_secret_value().encode("utf-8"),
|
||||
)
|
||||
if not _pwd_check:
|
||||
raise AuthError("Invalid username or password")
|
||||
return UserIdentity(user_id=username)
|
||||
except AuthError as e:
|
||||
logger.error(
|
||||
f"BasicAuthenticationServerProvider.authenticate failed: {repr(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
# Get the last call stack
|
||||
last_call_stack = tb[-1]
|
||||
line_number = last_call_stack.lineno
|
||||
filename = last_call_stack.filename
|
||||
logger.error(
|
||||
"BasicAuthenticationServerProvider.authenticate failed: "
|
||||
f"Failed to authenticate {type(e).__name__} at {filename}:{line_number}"
|
||||
)
|
||||
time.sleep(
|
||||
random.uniform(0.001, 0.005)
|
||||
) # add some jitter to avoid timing attacks
|
||||
raise ChromaAuthError()
|
||||
@@ -0,0 +1,75 @@
|
||||
import logging
|
||||
from typing import Dict, Set
|
||||
from overrides import override
|
||||
import yaml
|
||||
from chromadb.auth import (
|
||||
AuthzAction,
|
||||
AuthzResource,
|
||||
UserIdentity,
|
||||
ServerAuthorizationProvider,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from fastapi import HTTPException
|
||||
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleRBACAuthorizationProvider(ServerAuthorizationProvider):
|
||||
"""
|
||||
A simple Role-Based Access Control (RBAC) authorization provider. This
|
||||
provider reads a configuration file that maps users to roles, and roles to
|
||||
actions. The provider then checks if the user has the action they are
|
||||
attempting to perform.
|
||||
|
||||
For an example of an RBAC configuration file, see
|
||||
examples/basic_functionality/authz/authz.yaml.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
self._config = yaml.safe_load("\n".join(self.read_config_or_config_file()))
|
||||
|
||||
# We favor preprocessing here to avoid having to parse the config file
|
||||
# on every request. This AuthorizationProvider does not support
|
||||
# per-resource authorization so we just map the user ID to the
|
||||
# permissions they have. We're not worried about the size of this dict
|
||||
# since users are all specified in the file -- anyone with a gigantic
|
||||
# number of users can roll their own AuthorizationProvider.
|
||||
self._permissions: Dict[str, Set[str]] = {}
|
||||
for user in self._config["users"]:
|
||||
_actions = self._config["roles_mapping"][user["role"]]["actions"]
|
||||
self._permissions[user["id"]] = set(_actions)
|
||||
logger.info(
|
||||
"Authorization Provider SimpleRBACAuthorizationProvider " "initialized"
|
||||
)
|
||||
|
||||
@trace_method(
|
||||
"SimpleRBACAuthorizationProvider.authorize",
|
||||
OpenTelemetryGranularity.ALL,
|
||||
)
|
||||
@override
|
||||
def authorize_or_raise(
|
||||
self, user: UserIdentity, action: AuthzAction, resource: AuthzResource
|
||||
) -> None:
|
||||
policy_decision = False
|
||||
if (
|
||||
user.user_id in self._permissions
|
||||
and action in self._permissions[user.user_id]
|
||||
):
|
||||
policy_decision = True
|
||||
|
||||
logger.debug(
|
||||
f"Authorization decision: Access "
|
||||
f"{'granted' if policy_decision else 'denied'} for "
|
||||
f"user [{user.user_id}] attempting to "
|
||||
f"[{action}] [{resource}]"
|
||||
)
|
||||
if not policy_decision:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
@@ -0,0 +1,235 @@
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import cast, Dict, List, Optional, TypedDict, TypeVar
|
||||
|
||||
|
||||
from overrides import override
|
||||
from pydantic import SecretStr
|
||||
import yaml
|
||||
|
||||
from chromadb.auth import (
|
||||
ServerAuthenticationProvider,
|
||||
ClientAuthProvider,
|
||||
ClientAuthHeaders,
|
||||
UserIdentity,
|
||||
AuthError,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from chromadb.errors import ChromaAuthError
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"TokenAuthenticationServerProvider",
|
||||
"TokenAuthClientProvider",
|
||||
"TokenTransportHeader",
|
||||
]
|
||||
|
||||
|
||||
class TokenTransportHeader(str, Enum):
|
||||
"""
|
||||
Accceptable token transport headers.
|
||||
"""
|
||||
|
||||
# I don't love having this enum here -- it's weird to have an enum
|
||||
# for just two values and it's weird to have users pass X_CHROMA_TOKEN
|
||||
# to configure "x-chroma-token". But I also like having a single source
|
||||
# of truth, so 🤷🏻♂️
|
||||
AUTHORIZATION = "Authorization"
|
||||
X_CHROMA_TOKEN = "X-Chroma-Token"
|
||||
|
||||
|
||||
valid_token_chars = set(string.digits + string.ascii_letters + string.punctuation)
|
||||
|
||||
|
||||
def _check_token(token: str) -> None:
|
||||
token_str = str(token)
|
||||
if not all(c in valid_token_chars for c in token_str):
|
||||
raise ValueError(
|
||||
"Invalid token. Must contain only ASCII letters, digits, and punctuation."
|
||||
)
|
||||
|
||||
|
||||
allowed_token_headers = [
|
||||
TokenTransportHeader.AUTHORIZATION.value,
|
||||
TokenTransportHeader.X_CHROMA_TOKEN.value,
|
||||
]
|
||||
|
||||
|
||||
def _check_allowed_token_headers(token_header: str) -> None:
|
||||
if token_header not in allowed_token_headers:
|
||||
raise ValueError(
|
||||
f"Invalid token transport header: {token_header}. "
|
||||
f"Must be one of {allowed_token_headers}"
|
||||
)
|
||||
|
||||
|
||||
class TokenAuthClientProvider(ClientAuthProvider):
|
||||
"""
|
||||
Client auth provider for token-based auth. Header key will be either
|
||||
"Authorization" or "X-Chroma-Token" depending on
|
||||
`chroma_auth_token_transport_header`. If the header is "Authorization",
|
||||
the token is passed as a bearer token.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
|
||||
system.settings.require("chroma_client_auth_credentials")
|
||||
self._token = SecretStr(str(system.settings.chroma_client_auth_credentials))
|
||||
_check_token(self._token.get_secret_value())
|
||||
|
||||
if system.settings.chroma_auth_token_transport_header:
|
||||
_check_allowed_token_headers(
|
||||
system.settings.chroma_auth_token_transport_header
|
||||
)
|
||||
self._token_transport_header = TokenTransportHeader(
|
||||
system.settings.chroma_auth_token_transport_header
|
||||
)
|
||||
else:
|
||||
self._token_transport_header = TokenTransportHeader.AUTHORIZATION
|
||||
|
||||
@override
|
||||
def authenticate(self) -> ClientAuthHeaders:
|
||||
val = self._token.get_secret_value()
|
||||
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION:
|
||||
val = f"Bearer {val}"
|
||||
return {
|
||||
self._token_transport_header.value: SecretStr(val),
|
||||
}
|
||||
|
||||
|
||||
class User(TypedDict):
|
||||
"""
|
||||
A simple User class for use in this module only. If you need a generic
|
||||
way to represent a User, please use UserIdentity as this class keeps
|
||||
track of sensitive tokens.
|
||||
"""
|
||||
|
||||
id: str
|
||||
role: str
|
||||
tenant: Optional[str]
|
||||
databases: Optional[List[str]]
|
||||
tokens: List[str]
|
||||
|
||||
|
||||
class TokenAuthenticationServerProvider(ServerAuthenticationProvider):
|
||||
"""
|
||||
Server authentication provider for token-based auth. The provider will
|
||||
- On initialization, read the users from the file specified in
|
||||
`chroma_server_authn_credentials_file`. This file must be a well-formed
|
||||
YAML file with a top-level array called `users`. Each user must have
|
||||
an `id` field and a `tokens` (string array) field.
|
||||
- On each request, check the token in the header specified by
|
||||
`chroma_auth_token_transport_header`. If the configured header is
|
||||
"Authorization", the token is expected to be a bearer token.
|
||||
- If the token is valid, the server will return the user identity
|
||||
associated with the token.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
if system.settings.chroma_auth_token_transport_header:
|
||||
_check_allowed_token_headers(
|
||||
system.settings.chroma_auth_token_transport_header
|
||||
)
|
||||
self._token_transport_header = TokenTransportHeader(
|
||||
system.settings.chroma_auth_token_transport_header
|
||||
)
|
||||
else:
|
||||
self._token_transport_header = TokenTransportHeader.AUTHORIZATION
|
||||
|
||||
self._token_user_mapping: Dict[str, User] = {}
|
||||
creds = self.read_creds_or_creds_file()
|
||||
|
||||
# If we only get one cred, assume it's just a valid token.
|
||||
if len(creds) == 1:
|
||||
self._token_user_mapping[creds[0]] = User(
|
||||
id="anonymous",
|
||||
tenant="*",
|
||||
databases=["*"],
|
||||
role="anonymous",
|
||||
tokens=[creds[0]],
|
||||
)
|
||||
return
|
||||
|
||||
self._users = cast(List[User], yaml.safe_load("\n".join(creds))["users"])
|
||||
for user in self._users:
|
||||
if "tokens" not in user:
|
||||
raise ValueError("User missing tokens")
|
||||
if "tenant" not in user:
|
||||
user["tenant"] = "*"
|
||||
if "databases" not in user:
|
||||
user["databases"] = ["*"]
|
||||
for token in user["tokens"]:
|
||||
_check_token(token)
|
||||
if (
|
||||
token in self._token_user_mapping
|
||||
and self._token_user_mapping[token] != user
|
||||
):
|
||||
raise ValueError(
|
||||
f"Token {token} already in use: wanted to use it for "
|
||||
f"user {user['id']} but it's already in use by "
|
||||
f"user {self._token_user_mapping[token]}"
|
||||
)
|
||||
self._token_user_mapping[token] = user
|
||||
|
||||
@trace_method(
|
||||
"TokenAuthenticationServerProvider.authenticate", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
|
||||
try:
|
||||
if self._token_transport_header.value.lower() not in headers.keys():
|
||||
raise AuthError(
|
||||
f"Authorization header '{self._token_transport_header.value}' not found"
|
||||
)
|
||||
token = headers[self._token_transport_header.value.lower()]
|
||||
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION:
|
||||
if not token.startswith("Bearer "):
|
||||
raise AuthError("Bearer not found in Authorization header")
|
||||
token = re.sub(r"^Bearer ", "", token)
|
||||
|
||||
token = token.strip()
|
||||
_check_token(token)
|
||||
|
||||
if token not in self._token_user_mapping:
|
||||
raise AuthError("Invalid credentials: Token not found}")
|
||||
|
||||
user_identity = UserIdentity(
|
||||
user_id=self._token_user_mapping[token]["id"],
|
||||
tenant=self._token_user_mapping[token]["tenant"],
|
||||
databases=self._token_user_mapping[token]["databases"],
|
||||
)
|
||||
return user_identity
|
||||
except AuthError as e:
|
||||
logger.debug(
|
||||
f"TokenAuthenticationServerProvider.authenticate failed: {repr(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
# Get the last call stack
|
||||
last_call_stack = tb[-1]
|
||||
line_number = last_call_stack.lineno
|
||||
filename = last_call_stack.filename
|
||||
logger.debug(
|
||||
"TokenAuthenticationServerProvider.authenticate failed: "
|
||||
f"Failed to authenticate {type(e).__name__} at {filename}:{line_number}"
|
||||
)
|
||||
time.sleep(
|
||||
random.uniform(0.001, 0.005)
|
||||
) # add some jitter to avoid timing attacks
|
||||
raise ChromaAuthError()
|
||||
Binary file not shown.
@@ -0,0 +1,86 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
|
||||
from chromadb.errors import ChromaAuthError
|
||||
|
||||
|
||||
def _singleton_tenant_database_if_applicable(
|
||||
user_identity: UserIdentity,
|
||||
overwrite_singleton_tenant_database_access_from_auth: bool,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
is False, this function always returns (None, None).
|
||||
|
||||
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
is True, follows the following logic:
|
||||
- If the user only has access to a single tenant, this function will
|
||||
return that tenant as its first return value.
|
||||
- If the user only has access to a single database, this function will
|
||||
return that database as its second return value. If the user has
|
||||
access to multiple tenants and/or databases, including "*", this
|
||||
function will return None for the corresponding value(s).
|
||||
- If the user has access to multiple tenants and/or databases this
|
||||
function will return None for the corresponding value(s).
|
||||
"""
|
||||
if not overwrite_singleton_tenant_database_access_from_auth:
|
||||
return None, None
|
||||
tenant = None
|
||||
database = None
|
||||
user_tenant = user_identity.tenant
|
||||
user_databases = user_identity.databases
|
||||
if user_tenant and user_tenant != "*":
|
||||
tenant = user_tenant
|
||||
if user_databases:
|
||||
user_databases_set = set(user_databases)
|
||||
if len(user_databases_set) == 1 and "*" not in user_databases_set:
|
||||
database = list(user_databases_set)[0]
|
||||
return tenant, database
|
||||
|
||||
|
||||
def maybe_set_tenant_and_database(
|
||||
user_identity: UserIdentity,
|
||||
overwrite_singleton_tenant_database_access_from_auth: bool,
|
||||
user_provided_tenant: Optional[str] = None,
|
||||
user_provided_database: Optional[str] = None,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
(
|
||||
new_tenant,
|
||||
new_database,
|
||||
) = _singleton_tenant_database_if_applicable(
|
||||
user_identity=user_identity,
|
||||
overwrite_singleton_tenant_database_access_from_auth=overwrite_singleton_tenant_database_access_from_auth,
|
||||
)
|
||||
|
||||
# The only error case is if the user provides a tenant and database that
|
||||
# don't match what we resolved from auth. This can incorrectly happen when
|
||||
# there is no auth provider set, but overwrite_singleton_tenant_database_access_from_auth
|
||||
# is set to True. In this case, we'll resolve tenant/database to the default
|
||||
# values, which might not match the provided values. Thus, it's important
|
||||
# to ensure that the flag is set to True only when there is an auth provider.
|
||||
if (
|
||||
user_provided_tenant
|
||||
and user_provided_tenant != DEFAULT_TENANT
|
||||
and new_tenant
|
||||
and new_tenant != user_provided_tenant
|
||||
):
|
||||
raise ChromaAuthError(f"Tenant {user_provided_tenant} does not match {new_tenant} from the server. Are you sure the tenant is correct?")
|
||||
if (
|
||||
user_provided_database
|
||||
and user_provided_database != DEFAULT_DATABASE
|
||||
and new_database
|
||||
and new_database != user_provided_database
|
||||
):
|
||||
raise ChromaAuthError(f"Database {user_provided_database} does not match {new_database} from the server. Are you sure the database is correct?")
|
||||
|
||||
if (
|
||||
not user_provided_tenant or user_provided_tenant == DEFAULT_TENANT
|
||||
) and new_tenant:
|
||||
user_provided_tenant = new_tenant
|
||||
if (
|
||||
not user_provided_database or user_provided_database == DEFAULT_DATABASE
|
||||
) and new_database:
|
||||
user_provided_database = new_database
|
||||
|
||||
return user_provided_tenant, user_provided_database
|
||||
Binary file not shown.
@@ -0,0 +1,122 @@
|
||||
from typing import Dict, List, Mapping, Optional, Sequence, Union, Any
|
||||
from typing_extensions import Literal, Final
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
# Type tag constants
|
||||
TYPE_KEY: Final[str] = "#type"
|
||||
SPARSE_VECTOR_TYPE_VALUE: Final[str] = "sparse_vector"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseVector:
|
||||
"""Represents a sparse vector using parallel arrays for indices and values.
|
||||
|
||||
Attributes:
|
||||
indices: List of dimension indices (must be non-negative integers, sorted in strictly ascending order)
|
||||
values: List of values corresponding to each index (floats)
|
||||
|
||||
Note:
|
||||
- Indices must be sorted in strictly ascending order (no duplicates)
|
||||
- Indices and values must have the same length
|
||||
- All validations are performed in __post_init__
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
values: List[float]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate the sparse vector structure."""
|
||||
if not isinstance(self.indices, list):
|
||||
raise ValueError(
|
||||
f"Expected SparseVector indices to be a list, got {type(self.indices).__name__}"
|
||||
)
|
||||
|
||||
if not isinstance(self.values, list):
|
||||
raise ValueError(
|
||||
f"Expected SparseVector values to be a list, got {type(self.values).__name__}"
|
||||
)
|
||||
|
||||
if len(self.indices) != len(self.values):
|
||||
raise ValueError(
|
||||
f"SparseVector indices and values must have the same length, "
|
||||
f"got {len(self.indices)} indices and {len(self.values)} values"
|
||||
)
|
||||
|
||||
for i, idx in enumerate(self.indices):
|
||||
if not isinstance(idx, int):
|
||||
raise ValueError(
|
||||
f"SparseVector indices must be integers, got {type(idx).__name__} at position {i}"
|
||||
)
|
||||
if idx < 0:
|
||||
raise ValueError(
|
||||
f"SparseVector indices must be non-negative, got {idx} at position {i}"
|
||||
)
|
||||
|
||||
for i, val in enumerate(self.values):
|
||||
if not isinstance(val, (int, float)):
|
||||
raise ValueError(
|
||||
f"SparseVector values must be numbers, got {type(val).__name__} at position {i}"
|
||||
)
|
||||
|
||||
# Validate indices are sorted in strictly ascending order
|
||||
if len(self.indices) > 1:
|
||||
for i in range(1, len(self.indices)):
|
||||
if self.indices[i] <= self.indices[i - 1]:
|
||||
raise ValueError(
|
||||
f"SparseVector indices must be sorted in strictly ascending order, "
|
||||
f"found indices[{i}]={self.indices[i]} <= indices[{i-1}]={self.indices[i-1]}"
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to transport format with type tag."""
|
||||
return {
|
||||
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
|
||||
"indices": self.indices,
|
||||
"values": self.values,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, Any]) -> "SparseVector":
|
||||
"""Deserialize from transport format (strict - requires #type field)."""
|
||||
if d.get(TYPE_KEY) != SPARSE_VECTOR_TYPE_VALUE:
|
||||
raise ValueError(
|
||||
f"Expected {TYPE_KEY}='{SPARSE_VECTOR_TYPE_VALUE}', got {d.get(TYPE_KEY)}"
|
||||
)
|
||||
return cls(indices=d["indices"], values=d["values"])
|
||||
|
||||
|
||||
Metadata = Mapping[str, Optional[Union[str, int, float, bool, SparseVector]]]
|
||||
UpdateMetadata = Mapping[str, Union[int, float, str, bool, SparseVector, None]]
|
||||
PyVector = Union[Sequence[float], Sequence[int]]
|
||||
Vector = NDArray[Union[np.int32, np.float32]] # TODO: Specify that the vector is 1D
|
||||
# Metadata Query Grammar
|
||||
LiteralValue = Union[str, int, float, bool]
|
||||
LogicalOperator = Union[Literal["$and"], Literal["$or"]]
|
||||
WhereOperator = Union[
|
||||
Literal["$gt"],
|
||||
Literal["$gte"],
|
||||
Literal["$lt"],
|
||||
Literal["$lte"],
|
||||
Literal["$ne"],
|
||||
Literal["$eq"],
|
||||
]
|
||||
InclusionExclusionOperator = Union[Literal["$in"], Literal["$nin"]]
|
||||
OperatorExpression = Union[
|
||||
Dict[Union[WhereOperator, LogicalOperator], LiteralValue],
|
||||
Dict[InclusionExclusionOperator, List[LiteralValue]],
|
||||
]
|
||||
|
||||
Where = Dict[
|
||||
Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]]
|
||||
]
|
||||
|
||||
WhereDocumentOperator = Union[
|
||||
Literal["$contains"],
|
||||
Literal["$not_contains"],
|
||||
Literal["$regex"],
|
||||
Literal["$not_regex"],
|
||||
LogicalOperator,
|
||||
]
|
||||
WhereDocument = Dict[WhereDocumentOperator, Union[str, List["WhereDocument"]]]
|
||||
@@ -0,0 +1,195 @@
|
||||
from typing import List, Optional, Sequence
|
||||
from uuid import UUID
|
||||
from chromadb import CollectionMetadata, Embeddings, IDs
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
Documents,
|
||||
Embeddings,
|
||||
IDs,
|
||||
Metadatas,
|
||||
URIs,
|
||||
Include,
|
||||
)
|
||||
from chromadb.types import Tenant, Collection as CollectionModel
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
|
||||
from enum import Enum
|
||||
|
||||
class DatabaseFromBindings:
|
||||
id: UUID
|
||||
name: str
|
||||
tenant: str
|
||||
|
||||
# Result Types
|
||||
|
||||
class GetResponse:
|
||||
ids: IDs
|
||||
embeddings: Embeddings
|
||||
documents: Documents
|
||||
uris: URIs
|
||||
metadatas: Metadatas
|
||||
include: Include
|
||||
|
||||
class QueryResponse:
|
||||
ids: List[IDs]
|
||||
embeddings: Optional[List[Embeddings]]
|
||||
documents: Optional[List[Documents]]
|
||||
uris: Optional[List[URIs]]
|
||||
metadatas: Optional[List[Metadatas]]
|
||||
distances: Optional[List[List[float]]]
|
||||
include: Include
|
||||
|
||||
class GetTenantResponse:
|
||||
name: str
|
||||
|
||||
# SqliteDBConfig types
|
||||
class MigrationMode(Enum):
|
||||
Apply = 0
|
||||
Validate = 1
|
||||
|
||||
class MigrationHash(Enum):
|
||||
SHA256 = 0
|
||||
MD5 = 1
|
||||
|
||||
class SqliteDBConfig:
|
||||
url: str
|
||||
hash_type: MigrationHash
|
||||
migration_mode: MigrationMode
|
||||
|
||||
def __init__(
|
||||
self, url: str, hash_type: MigrationHash, migration_mode: MigrationMode
|
||||
) -> None: ...
|
||||
|
||||
class Bindings:
|
||||
def __init__(
|
||||
self,
|
||||
allow_reset: bool,
|
||||
sqlite_db_config: SqliteDBConfig,
|
||||
persist_path: str,
|
||||
hnsw_cache_size: int,
|
||||
) -> None: ...
|
||||
def heartbeat(self) -> int: ...
|
||||
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: ...
|
||||
def get_database(
|
||||
self, name: str, tenant: str = DEFAULT_TENANT
|
||||
) -> DatabaseFromBindings: ...
|
||||
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: ...
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[DatabaseFromBindings]: ...
|
||||
def create_tenant(self, name: str) -> None: ...
|
||||
def get_tenant(self, name: str) -> GetTenantResponse: ...
|
||||
def count_collections(
|
||||
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
||||
) -> int: ...
|
||||
def list_collections(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Sequence[CollectionModel]: ...
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
configuration_json_str: Optional[str] = None,
|
||||
schema_str: Optional[str] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel: ...
|
||||
def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel: ...
|
||||
def update_collection(
|
||||
self,
|
||||
id: str,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration_json_str: Optional[str] = None,
|
||||
) -> None: ...
|
||||
def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None: ...
|
||||
def add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: str,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool: ...
|
||||
def update(
|
||||
self,
|
||||
collection_id: str,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool: ...
|
||||
def upsert(
|
||||
self,
|
||||
collection_id: str,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool: ...
|
||||
def delete(
|
||||
self,
|
||||
collection_id: str,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[str] = None,
|
||||
where_document: Optional[str] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None: ...
|
||||
def count(
|
||||
self,
|
||||
collection_id: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> int: ...
|
||||
def get(
|
||||
self,
|
||||
collection_id: str,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[str] = None,
|
||||
include: Include = ["metadatas", "documents"], # type: ignore[list-item]
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResponse: ...
|
||||
def query(
|
||||
self,
|
||||
collection_id: str,
|
||||
query_embeddings: Embeddings,
|
||||
n_results: int = 10,
|
||||
where: Optional[str] = None,
|
||||
where_document: Optional[str] = None,
|
||||
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> QueryResponse: ...
|
||||
def reset(self) -> bool: ...
|
||||
def get_version(self) -> str: ...
|
||||
@@ -0,0 +1,56 @@
|
||||
import re
|
||||
import sys
|
||||
|
||||
import chromadb_rust_bindings
|
||||
import requests
|
||||
from packaging.version import parse
|
||||
|
||||
import chromadb
|
||||
|
||||
|
||||
def build_cli_args(**kwargs):
|
||||
args = []
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
args.append(f"--{key}")
|
||||
elif value is not None:
|
||||
args.extend([f"--{key}", str(value)])
|
||||
return args
|
||||
|
||||
|
||||
def update():
|
||||
try:
|
||||
url = f"https://api.github.com/repos/chroma-core/chroma/releases"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
releases = response.json()
|
||||
|
||||
version_pattern = re.compile(r'^\d+\.\d+\.\d+$')
|
||||
numeric_releases = [r["tag_name"] for r in releases if version_pattern.fullmatch(r["tag_name"])]
|
||||
|
||||
if not numeric_releases:
|
||||
print("Couldn't fetch the latest Chroma version")
|
||||
return
|
||||
|
||||
latest = max(numeric_releases, key=parse)
|
||||
if latest == chromadb.__version__:
|
||||
print("Your Chroma version is up-to-date")
|
||||
return
|
||||
|
||||
print(
|
||||
f"A new version of Chroma is available!\nIf you're using pip, run 'pip install --upgrade chromadb' to upgrade to version {latest}")
|
||||
|
||||
except Exception as e:
|
||||
print("Couldn't fetch the latest Chroma version")
|
||||
|
||||
|
||||
def app():
|
||||
args = sys.argv
|
||||
if ["chroma", "update"] in args:
|
||||
update()
|
||||
return
|
||||
try:
|
||||
chromadb_rust_bindings.cli(args)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
@@ -0,0 +1,40 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import os
|
||||
import yaml
|
||||
|
||||
|
||||
def set_log_file_path(
|
||||
log_config_path: str, new_filename: str = "chroma.log"
|
||||
) -> Dict[str, Any]:
|
||||
"""This works with the standard log_config.yml file.
|
||||
It will not work with custom log configs that may use different handlers"""
|
||||
with open(f"{log_config_path}", "r") as file:
|
||||
log_config = yaml.safe_load(file)
|
||||
for handler in log_config["handlers"].values():
|
||||
if handler.get("class") == "logging.handlers.RotatingFileHandler":
|
||||
handler["filename"] = new_filename
|
||||
|
||||
return log_config
|
||||
|
||||
|
||||
def get_directory_size(directory: str) -> int:
|
||||
"""Get the size of a directory in bytes"""
|
||||
total = 0
|
||||
with os.scandir(directory) as it:
|
||||
for entry in it:
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
elif entry.is_dir():
|
||||
total += get_directory_size(entry.path)
|
||||
return total
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/1094933
|
||||
def sizeof_fmt(num: int, suffix: str = "B") -> str:
|
||||
n: float = float(num)
|
||||
for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
|
||||
if abs(n) < 1024.0:
|
||||
return f"{n:3.1f}{unit}{suffix}"
|
||||
n /= 1024.0
|
||||
return f"{n:.1f}Yi{suffix}"
|
||||
@@ -0,0 +1,503 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from graphlib import TopologicalSorter
|
||||
from typing import Optional, List, Any, Dict, Set, Iterable, Union
|
||||
from typing import Type, TypeVar, cast
|
||||
|
||||
from overrides import EnforceOverrides
|
||||
from overrides import override
|
||||
from typing_extensions import Literal
|
||||
import platform
|
||||
|
||||
in_pydantic_v2 = False
|
||||
try:
|
||||
from pydantic import BaseSettings
|
||||
except ImportError:
|
||||
in_pydantic_v2 = True
|
||||
from pydantic.v1 import BaseSettings
|
||||
from pydantic.v1 import validator
|
||||
|
||||
if not in_pydantic_v2:
|
||||
from pydantic import validator # type: ignore # noqa
|
||||
|
||||
# The thin client will have a flag to control which implementations to use
|
||||
is_thin_client = False
|
||||
try:
|
||||
from chromadb.is_thin_client import is_thin_client # type: ignore
|
||||
except ImportError:
|
||||
is_thin_client = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LEGACY_ERROR = """\033[91mYou are using a deprecated configuration of Chroma.
|
||||
|
||||
\033[94mIf you do not have data you wish to migrate, you only need to change how you construct
|
||||
your Chroma client. Please see the "New Clients" section of https://docs.trychroma.com/deployment/migration.
|
||||
________________________________________________________________________________________________
|
||||
|
||||
If you do have data you wish to migrate, we have a migration tool you can use in order to
|
||||
migrate your data to the new Chroma architecture.
|
||||
Please `pip install chroma-migrate` and run `chroma-migrate` to migrate your data and then
|
||||
change how you construct your Chroma client.
|
||||
|
||||
See https://docs.trychroma.com/deployment/migration for more information or join our discord at https://discord.gg/MMeYNTmh3x for help!\033[0m"""
|
||||
|
||||
_legacy_config_keys = {
|
||||
"chroma_db_impl",
|
||||
}
|
||||
|
||||
_legacy_config_values = {
|
||||
"duckdb",
|
||||
"duckdb+parquet",
|
||||
"clickhouse",
|
||||
"local",
|
||||
"rest",
|
||||
"chromadb.db.duckdb.DuckDB",
|
||||
"chromadb.db.duckdb.PersistentDuckDB",
|
||||
"chromadb.db.clickhouse.Clickhouse",
|
||||
"chromadb.api.local.LocalAPI",
|
||||
}
|
||||
|
||||
|
||||
# Map specific abstract types to the setting which specifies which
|
||||
# concrete implementation to use.
|
||||
# Please keep these sorted. We're civilized people.
|
||||
_abstract_type_keys: Dict[str, str] = {
|
||||
# TODO: Don't use concrete types here to avoid circular deps. Strings are
|
||||
# fine for right here!
|
||||
# NOTE: this is to support legacy api construction. Use ServerAPI instead
|
||||
"chromadb.api.API": "chroma_api_impl",
|
||||
"chromadb.api.ServerAPI": "chroma_api_impl",
|
||||
"chromadb.api.async_api.AsyncServerAPI": "chroma_api_impl",
|
||||
"chromadb.auth.ClientAuthProvider": "chroma_client_auth_provider",
|
||||
"chromadb.auth.ServerAuthenticationProvider": "chroma_server_authn_provider",
|
||||
"chromadb.auth.ServerAuthorizationProvider": "chroma_server_authz_provider",
|
||||
"chromadb.db.system.SysDB": "chroma_sysdb_impl",
|
||||
"chromadb.execution.executor.abstract.Executor": "chroma_executor_impl",
|
||||
"chromadb.ingest.Consumer": "chroma_consumer_impl",
|
||||
"chromadb.ingest.Producer": "chroma_producer_impl",
|
||||
"chromadb.quota.QuotaEnforcer": "chroma_quota_enforcer_impl",
|
||||
"chromadb.rate_limit.RateLimitEnforcer": "chroma_rate_limit_enforcer_impl",
|
||||
"chromadb.rate_limit.AsyncRateLimitEnforcer": "chroma_async_rate_limit_enforcer_impl",
|
||||
"chromadb.segment.SegmentManager": "chroma_segment_manager_impl",
|
||||
"chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl",
|
||||
"chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl",
|
||||
"chromadb.telemetry.product.ProductTelemetryClient": "chroma_product_telemetry_impl",
|
||||
}
|
||||
|
||||
DEFAULT_TENANT = "default_tenant"
|
||||
DEFAULT_DATABASE = "default_database"
|
||||
|
||||
|
||||
class APIVersion(str, Enum):
|
||||
V1 = "/api/v1"
|
||||
V2 = "/api/v2"
|
||||
|
||||
|
||||
# NOTE(hammadb) 1/13/2024 - This has to be in config.py instead of being localized to the module
|
||||
# that uses it because of a circular import issue. This is a temporary solution until we can
|
||||
# refactor the code to remove the circular import.
|
||||
class RoutingMode(Enum):
|
||||
"""
|
||||
Routing mode for the segment directory
|
||||
|
||||
node - Assign based on the node name, used in production with multi-node settings with the assumption that
|
||||
there is one query service pod per node. This is useful for when there is a disk based cache on the
|
||||
node that we want to route to.
|
||||
|
||||
id - Assign based on the member id, used in development and testing environments where the node name is not
|
||||
guaranteed to be unique. (I.e a local development kubernetes env). Or when there are multiple query service
|
||||
pods per node.
|
||||
"""
|
||||
|
||||
NODE = "node"
|
||||
ID = "id"
|
||||
|
||||
|
||||
class Settings(BaseSettings): # type: ignore
|
||||
# ==============
|
||||
# Generic config
|
||||
# ==============
|
||||
|
||||
environment: str = ""
|
||||
|
||||
# Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" or "chromadb.api.rust.RustBindingsAPI"
|
||||
chroma_api_impl: str = "chromadb.api.rust.RustBindingsAPI"
|
||||
|
||||
@validator("chroma_server_nofile", pre=True, always=True, allow_reuse=True)
|
||||
def empty_str_to_none(cls, v: str) -> Optional[str]:
|
||||
if type(v) is str and v.strip() == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
chroma_server_nofile: Optional[int] = None
|
||||
# the number of maximum threads to handle synchronous tasks in the FastAPI server
|
||||
chroma_server_thread_pool_size: int = 40
|
||||
|
||||
# ==================
|
||||
# Client-mode config
|
||||
# ==================
|
||||
|
||||
tenant_id: str = "default"
|
||||
topic_namespace: str = "default"
|
||||
|
||||
chroma_server_host: Optional[str] = None
|
||||
chroma_server_headers: Optional[Dict[str, str]] = None
|
||||
chroma_server_http_port: Optional[int] = None
|
||||
chroma_server_ssl_enabled: Optional[bool] = False
|
||||
|
||||
chroma_server_ssl_verify: Optional[Union[bool, str]] = None
|
||||
chroma_server_api_default_path: Optional[APIVersion] = APIVersion.V2
|
||||
# eg ["http://localhost:8000"]
|
||||
chroma_server_cors_allow_origins: List[str] = []
|
||||
|
||||
# ==================
|
||||
# Server config
|
||||
# ==================
|
||||
|
||||
is_persistent: bool = False
|
||||
persist_directory: str = "./chroma"
|
||||
|
||||
chroma_memory_limit_bytes: int = 0
|
||||
chroma_segment_cache_policy: Optional[str] = None
|
||||
|
||||
allow_reset: bool = False
|
||||
|
||||
# ===========================
|
||||
# {Client, Server} auth{n, z}
|
||||
# ===========================
|
||||
|
||||
# The header to use for the token. Defaults to "Authorization".
|
||||
chroma_auth_token_transport_header: Optional[str] = None
|
||||
|
||||
# ================
|
||||
# Client auth{n,z}
|
||||
# ================
|
||||
|
||||
# The provider for client auth. See chromadb/auth/__init__.py
|
||||
chroma_client_auth_provider: Optional[str] = None
|
||||
# If needed by the provider (e.g. BasicAuthClientProvider),
|
||||
# the credentials to use.
|
||||
chroma_client_auth_credentials: Optional[str] = None
|
||||
|
||||
# ================
|
||||
# Server auth{n,z}
|
||||
# ================
|
||||
|
||||
chroma_server_auth_ignore_paths: Dict[str, List[str]] = {
|
||||
f"{APIVersion.V2}": ["GET"],
|
||||
f"{APIVersion.V2}/heartbeat": ["GET"],
|
||||
f"{APIVersion.V2}/version": ["GET"],
|
||||
f"{APIVersion.V1}": ["GET"],
|
||||
f"{APIVersion.V1}/heartbeat": ["GET"],
|
||||
f"{APIVersion.V1}/version": ["GET"],
|
||||
}
|
||||
# Overwrite singleton tenant and database access from the auth provider
|
||||
# if applicable. See chromadb/auth/utils/__init__.py's
|
||||
# authenticate_and_authorize_or_raise method.
|
||||
chroma_overwrite_singleton_tenant_database_access_from_auth: bool = False
|
||||
|
||||
# ============
|
||||
# Server authn
|
||||
# ============
|
||||
|
||||
chroma_server_authn_provider: Optional[str] = None
|
||||
# Only one of the below may be specified.
|
||||
chroma_server_authn_credentials: Optional[str] = None
|
||||
chroma_server_authn_credentials_file: Optional[str] = None
|
||||
|
||||
# ============
|
||||
# Server authz
|
||||
# ============
|
||||
|
||||
chroma_server_authz_provider: Optional[str] = None
|
||||
# Only one of the below may be specified.
|
||||
chroma_server_authz_config: Optional[str] = None
|
||||
chroma_server_authz_config_file: Optional[str] = None
|
||||
|
||||
# =========
|
||||
# Telemetry
|
||||
# =========
|
||||
|
||||
chroma_product_telemetry_impl: str = "chromadb.telemetry.product.posthog.Posthog"
|
||||
# Required for backwards compatibility
|
||||
chroma_telemetry_impl: str = chroma_product_telemetry_impl
|
||||
|
||||
anonymized_telemetry: bool = True
|
||||
|
||||
chroma_otel_collection_endpoint: Optional[str] = ""
|
||||
chroma_otel_service_name: Optional[str] = "chromadb"
|
||||
chroma_otel_collection_headers: Dict[str, str] = {}
|
||||
chroma_otel_granularity: Optional[str] = None
|
||||
|
||||
# ==========
|
||||
# Migrations
|
||||
# ==========
|
||||
|
||||
migrations: Literal["none", "validate", "apply"] = "apply"
|
||||
# you cannot change the hash_algorithm after migrations have already
|
||||
# been applied once this is intended to be a first-time setup configuration
|
||||
migrations_hash_algorithm: Literal["md5", "sha256"] = "md5"
|
||||
|
||||
# ==================
|
||||
# Distributed Chroma
|
||||
# ==================
|
||||
|
||||
chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory"
|
||||
chroma_segment_directory_routing_mode: RoutingMode = RoutingMode.ID
|
||||
|
||||
chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider"
|
||||
worker_memberlist_name: str = "query-service-memberlist"
|
||||
|
||||
chroma_coordinator_host = "localhost"
|
||||
# TODO this is the sysdb port. Should probably rename it.
|
||||
chroma_server_grpc_port: Optional[int] = None
|
||||
chroma_sysdb_impl: str = "chromadb.db.impl.sqlite.SqliteDB"
|
||||
|
||||
chroma_producer_impl: str = "chromadb.db.impl.sqlite.SqliteDB"
|
||||
chroma_consumer_impl: str = "chromadb.db.impl.sqlite.SqliteDB"
|
||||
|
||||
chroma_segment_manager_impl: str = (
|
||||
"chromadb.segment.impl.manager.local.LocalSegmentManager"
|
||||
)
|
||||
chroma_executor_impl: str = "chromadb.execution.executor.local.LocalExecutor"
|
||||
chroma_query_replication_factor: int = 2
|
||||
|
||||
chroma_logservice_host = "localhost"
|
||||
chroma_logservice_port = 50052
|
||||
|
||||
chroma_quota_provider_impl: Optional[str] = None
|
||||
chroma_rate_limiting_provider_impl: Optional[str] = None
|
||||
|
||||
chroma_quota_enforcer_impl: str = (
|
||||
"chromadb.quota.simple_quota_enforcer.SimpleQuotaEnforcer"
|
||||
)
|
||||
|
||||
chroma_rate_limit_enforcer_impl: str = (
|
||||
"chromadb.rate_limit.simple_rate_limit.SimpleRateLimitEnforcer"
|
||||
)
|
||||
|
||||
chroma_async_rate_limit_enforcer_impl: str = (
|
||||
"chromadb.rate_limit.simple_rate_limit.SimpleAsyncRateLimitEnforcer"
|
||||
)
|
||||
|
||||
# ==========
|
||||
# gRPC service config
|
||||
# ==========
|
||||
chroma_logservice_request_timeout_seconds: int = 3
|
||||
chroma_sysdb_request_timeout_seconds: int = 3
|
||||
chroma_query_request_timeout_seconds: int = 60
|
||||
|
||||
# ======
|
||||
# Legacy
|
||||
# ======
|
||||
|
||||
chroma_db_impl: Optional[str] = None
|
||||
chroma_collection_assignment_policy_impl: str = (
|
||||
"chromadb.ingest.impl.simple_policy.SimpleAssignmentPolicy"
|
||||
)
|
||||
|
||||
# =======
|
||||
# Methods
|
||||
# =======
|
||||
|
||||
def require(self, key: str) -> Any:
|
||||
"""Return the value of a required config key, or raise an exception if it is not
|
||||
set"""
|
||||
val = self[key]
|
||||
if val is None:
|
||||
raise ValueError(f"Missing required config value '{key}'")
|
||||
return val
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
val = getattr(self, key)
|
||||
# Error on legacy config values
|
||||
if isinstance(val, str) and val in _legacy_config_values:
|
||||
raise ValueError(LEGACY_ERROR)
|
||||
return val
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
|
||||
T = TypeVar("T", bound="Component")
|
||||
|
||||
|
||||
class Component(ABC, EnforceOverrides):
|
||||
_dependencies: Set["Component"]
|
||||
_system: "System"
|
||||
_running: bool
|
||||
|
||||
def __init__(self, system: "System"):
|
||||
self._dependencies = set()
|
||||
self._system = system
|
||||
self._running = False
|
||||
|
||||
def require(self, type: Type[T]) -> T:
|
||||
"""Get a Component instance of the given type, and register as a dependency of
|
||||
that instance."""
|
||||
inst = self._system.instance(type)
|
||||
self._dependencies.add(inst)
|
||||
return inst
|
||||
|
||||
def dependencies(self) -> Set["Component"]:
|
||||
"""Return the full set of components this component depends on."""
|
||||
return self._dependencies
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Idempotently stop this component's execution and free all associated
|
||||
resources."""
|
||||
logger.debug(f"Stopping component {self.__class__.__name__}")
|
||||
self._running = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""Idempotently start this component's execution"""
|
||||
logger.debug(f"Starting component {self.__class__.__name__}")
|
||||
self._running = True
|
||||
|
||||
def reset_state(self) -> None:
|
||||
"""Reset this component's state to its initial blank state. Only intended to be
|
||||
called from tests."""
|
||||
logger.debug(f"Resetting component {self.__class__.__name__}")
|
||||
|
||||
|
||||
class System(Component):
|
||||
settings: Settings
|
||||
_instances: Dict[Type[Component], Component]
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
if is_thin_client:
|
||||
# The thin client is a system with only the API component
|
||||
if settings["chroma_api_impl"] not in [
|
||||
"chromadb.api.fastapi.FastAPI",
|
||||
"chromadb.api.async_fastapi.AsyncFastAPI",
|
||||
]:
|
||||
raise RuntimeError(
|
||||
"Chroma is running in http-only client mode, and can only be run with 'chromadb.api.fastapi.FastAPI' or 'chromadb.api.async_fastapi.AsyncFastAPI' as the chroma_api_impl. \
|
||||
see https://docs.trychroma.com/guides#using-the-python-http-only-client for more information."
|
||||
)
|
||||
# Validate settings don't contain any legacy config values
|
||||
for key in _legacy_config_keys:
|
||||
if settings[key] is not None:
|
||||
raise ValueError(LEGACY_ERROR)
|
||||
|
||||
if (
|
||||
settings["chroma_segment_cache_policy"] is not None
|
||||
and settings["chroma_segment_cache_policy"] != "LRU"
|
||||
):
|
||||
logger.error(
|
||||
"Failed to set chroma_segment_cache_policy: Only LRU is available."
|
||||
)
|
||||
if settings["chroma_memory_limit_bytes"] == 0:
|
||||
logger.error(
|
||||
"Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require."
|
||||
)
|
||||
|
||||
# Apply the nofile limit if set
|
||||
if settings["chroma_server_nofile"] is not None:
|
||||
if platform.system() != "Windows":
|
||||
import resource
|
||||
|
||||
curr_soft, curr_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
desired_soft = settings["chroma_server_nofile"]
|
||||
# Validate
|
||||
if desired_soft > curr_hard:
|
||||
logging.warning(
|
||||
f"chroma_server_nofile cannot be set to a value greater than the current hard limit of {curr_hard}. Keeping soft limit at {curr_soft}"
|
||||
)
|
||||
# Apply
|
||||
elif desired_soft > curr_soft:
|
||||
try:
|
||||
resource.setrlimit(
|
||||
resource.RLIMIT_NOFILE, (desired_soft, curr_hard)
|
||||
)
|
||||
logger.info(f"Set chroma_server_nofile to {desired_soft}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to set chroma_server_nofile to {desired_soft}: {e} nofile soft limit will remain at {curr_soft}"
|
||||
)
|
||||
# Don't apply if reducing the limit
|
||||
elif desired_soft < curr_soft:
|
||||
logger.warning(
|
||||
f"chroma_server_nofile is set to {desired_soft}, but this is less than current soft limit of {curr_soft}. chroma_server_nofile will not be set."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"chroma_server_nofile is not supported on Windows. chroma_server_nofile will not be set."
|
||||
)
|
||||
|
||||
self.settings = settings
|
||||
self._instances = {}
|
||||
super().__init__(self)
|
||||
|
||||
def instance(self, type: Type[T]) -> T:
|
||||
"""Return an instance of the component type specified. If the system is running,
|
||||
the component will be started as well."""
|
||||
|
||||
if inspect.isabstract(type):
|
||||
type_fqn = get_fqn(type)
|
||||
if type_fqn not in _abstract_type_keys:
|
||||
raise ValueError(f"Cannot instantiate abstract type: {type}")
|
||||
key = _abstract_type_keys[type_fqn]
|
||||
fqn = self.settings.require(key)
|
||||
type = get_class(fqn, type)
|
||||
|
||||
if type not in self._instances:
|
||||
impl = type(self)
|
||||
self._instances[type] = impl
|
||||
if self._running:
|
||||
impl.start()
|
||||
|
||||
inst = self._instances[type]
|
||||
return cast(T, inst)
|
||||
|
||||
def components(self) -> Iterable[Component]:
|
||||
"""Return the full set of all components and their dependencies in dependency
|
||||
order."""
|
||||
sorter: TopologicalSorter[Component] = TopologicalSorter()
|
||||
for component in self._instances.values():
|
||||
sorter.add(component, *component.dependencies())
|
||||
|
||||
return sorter.static_order()
|
||||
|
||||
@override
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
for component in self.components():
|
||||
component.start()
|
||||
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
for component in reversed(list(self.components())):
|
||||
component.stop()
|
||||
|
||||
@override
|
||||
def reset_state(self) -> None:
|
||||
"""Reset the state of this system and all constituents in reverse dependency order"""
|
||||
if not self.settings.allow_reset:
|
||||
raise ValueError(
|
||||
"Resetting is not allowed by this configuration (to enable it, set `allow_reset` to `True` in your Settings() or include `ALLOW_RESET=TRUE` in your environment variables)"
|
||||
)
|
||||
for component in reversed(list(self.components())):
|
||||
component.reset_state()
|
||||
|
||||
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
def get_class(fqn: str, type: Type[C]) -> Type[C]:
|
||||
"""Given a fully qualifed class name, import the module and return the class"""
|
||||
module_name, class_name = fqn.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
return cast(Type[C], cls)
|
||||
|
||||
|
||||
def get_fqn(cls: Type[object]) -> str:
|
||||
"""Given a class, return its fully qualified name"""
|
||||
return f"{cls.__module__}.{cls.__name__}"
|
||||
@@ -0,0 +1,122 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Sequence, Optional, Tuple
|
||||
from uuid import UUID
|
||||
from chromadb.api.types import (
|
||||
Embeddings,
|
||||
Documents,
|
||||
IDs,
|
||||
Metadatas,
|
||||
Metadata,
|
||||
Where,
|
||||
WhereDocument,
|
||||
)
|
||||
from chromadb.config import Component
|
||||
|
||||
|
||||
class DB(Component):
|
||||
@abstractmethod
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
metadata: Optional[Metadata] = None,
|
||||
get_or_create: bool = False,
|
||||
) -> Sequence: # type: ignore
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection(self, name: str) -> Sequence: # type: ignore
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_collections(
|
||||
self, limit: Optional[int] = None, offset: Optional[int] = None
|
||||
) -> Sequence: # type: ignore
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count_collections(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[Metadata] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection_uuid_from_name(self, collection_name: str) -> UUID:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
collection_uuid: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas],
|
||||
documents: Optional[Documents],
|
||||
ids: List[str],
|
||||
) -> List[UUID]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
where: Optional[Where] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
collection_uuid: Optional[UUID] = None,
|
||||
ids: Optional[IDs] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
) -> Sequence: # type: ignore
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
collection_uuid: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count(self, collection_id: UUID) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
where: Optional[Where] = None,
|
||||
collection_uuid: Optional[UUID] = None,
|
||||
ids: Optional[IDs] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_nearest_neighbors(
|
||||
self,
|
||||
collection_uuid: UUID,
|
||||
where: Optional[Where] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
n_results: int = 10,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
) -> Tuple[List[List[UUID]], List[List[float]]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_by_ids(
|
||||
self, uuids: List[UUID], columns: Optional[List[str]] = None
|
||||
) -> Sequence: # type: ignore
|
||||
pass
|
||||
@@ -0,0 +1,180 @@
|
||||
from typing import Any, Optional, Sequence, Tuple, Type
|
||||
from types import TracebackType
|
||||
from typing_extensions import Protocol, Self, Literal
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import local
|
||||
from overrides import override, EnforceOverrides
|
||||
import pypika
|
||||
import pypika.queries
|
||||
from chromadb.config import System, Component
|
||||
from uuid import UUID
|
||||
from itertools import islice, count
|
||||
|
||||
|
||||
class Cursor(Protocol):
|
||||
"""Reifies methods we use from a DBAPI2 Cursor since DBAPI2 is not typed."""
|
||||
|
||||
def execute(self, sql: str, params: Optional[Tuple[Any, ...]] = None) -> Self:
|
||||
...
|
||||
|
||||
def executescript(self, script: str) -> Self:
|
||||
...
|
||||
|
||||
def executemany(
|
||||
self, sql: str, params: Optional[Sequence[Tuple[Any, ...]]] = None
|
||||
) -> Self:
|
||||
...
|
||||
|
||||
def fetchone(self) -> Tuple[Any, ...]:
|
||||
...
|
||||
|
||||
def fetchall(self) -> Sequence[Tuple[Any, ...]]:
|
||||
...
|
||||
|
||||
|
||||
class TxWrapper(ABC, EnforceOverrides):
|
||||
"""Wrapper class for DBAPI 2.0 Connection objects, with which clients can implement transactions.
|
||||
Makes two guarantees that basic DBAPI 2.0 connections do not:
|
||||
|
||||
- __enter__ returns a Cursor object consistently (instead of a Connection like some do)
|
||||
- Always re-raises an exception if one was thrown from the body
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __enter__(self) -> Cursor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
pass
|
||||
|
||||
|
||||
class SqlDB(Component):
|
||||
"""DBAPI 2.0 interface wrapper to ensure consistent behavior between implementations"""
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
|
||||
@abstractmethod
|
||||
def tx(self) -> TxWrapper:
|
||||
"""Return a transaction wrapper"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def querybuilder() -> Type[pypika.Query]:
|
||||
"""Return a PyPika Query builder of an appropriate subtype for this database
|
||||
implementation (see
|
||||
https://pypika.readthedocs.io/en/latest/3_advanced.html#handling-different-database-platforms)
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def parameter_format() -> str:
|
||||
"""Return the appropriate parameter format for this database implementation.
|
||||
Will be called with str.format(i) where i is the numeric index of the parameter.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]:
|
||||
"""Convert a UUID to a value that can be passed to the DB driver"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]:
|
||||
"""Convert a value from the DB driver to a UUID"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def unique_constraint_error() -> Type[BaseException]:
|
||||
"""Return the exception type that the DB raises when a unique constraint is
|
||||
violated"""
|
||||
pass
|
||||
|
||||
def param(self, idx: int) -> pypika.Parameter:
|
||||
"""Return a PyPika Parameter object for the given index"""
|
||||
return pypika.Parameter(self.parameter_format().format(idx))
|
||||
|
||||
|
||||
_context = local()
|
||||
|
||||
|
||||
class ParameterValue(pypika.Parameter): # type: ignore
|
||||
"""
|
||||
Wrapper class for PyPika paramters that allows the values for Parameters
|
||||
to be expressed inline while building a query. See get_sql() for
|
||||
detailed usage information.
|
||||
"""
|
||||
|
||||
def __init__(self, value: Any):
|
||||
self.value = value
|
||||
|
||||
@override
|
||||
def get_sql(self, **kwargs: Any) -> str:
|
||||
if isinstance(self.value, (list, tuple)):
|
||||
_context.values.extend(self.value)
|
||||
indexes = islice(_context.generator, len(self.value))
|
||||
placeholders = ", ".join(_context.formatstr.format(i) for i in indexes)
|
||||
val = f"({placeholders})"
|
||||
else:
|
||||
_context.values.append(self.value)
|
||||
val = _context.formatstr.format(next(_context.generator))
|
||||
|
||||
return str(val)
|
||||
|
||||
|
||||
def get_sql(
|
||||
query: pypika.queries.QueryBuilder, formatstr: str = "?"
|
||||
) -> Tuple[str, Tuple[Any, ...]]:
|
||||
"""
|
||||
Wrapper for pypika's get_sql method that allows the values for Parameters
|
||||
to be expressed inline while building a query, and that returns a tuple of the
|
||||
SQL string and parameters. This makes it easier to construct complex queries
|
||||
programmatically and automatically matches up the generated SQL with the required
|
||||
parameter vector.
|
||||
|
||||
Doing so requires using the ParameterValue class defined in this module instead
|
||||
of the base pypika.Parameter class.
|
||||
|
||||
Usage Example:
|
||||
|
||||
q = (
|
||||
pypika.Query().from_("table")
|
||||
.select("col1")
|
||||
.where("col2"==ParameterValue("foo"))
|
||||
.where("col3"==ParameterValue("bar"))
|
||||
)
|
||||
|
||||
sql, params = get_sql(q)
|
||||
|
||||
cursor.execute(sql, params)
|
||||
|
||||
Note how it is not necessary to construct the parameter vector manually... it
|
||||
will always be generated with the parameter values in the same order as emitted
|
||||
SQL string.
|
||||
|
||||
The format string should match the parameter format for the database being used.
|
||||
It will be called with str.format(i) where i is the numeric index of the parameter.
|
||||
For example, Postgres requires parameters like `:1`, `:2`, etc. so the format string
|
||||
should be `":{}"`.
|
||||
|
||||
See https://pypika.readthedocs.io/en/latest/2_tutorial.html#parametrized-queries for more
|
||||
information on parameterized queries in PyPika.
|
||||
"""
|
||||
|
||||
_context.values = []
|
||||
_context.generator = count(1)
|
||||
_context.formatstr = formatstr
|
||||
sql = query.get_sql()
|
||||
params = tuple(_context.values)
|
||||
return sql, params
|
||||
@@ -0,0 +1,558 @@
|
||||
from typing import List, Optional, Sequence, Tuple, Union, cast
|
||||
from uuid import UUID
|
||||
from overrides import overrides
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
create_collection_configuration_to_json_str,
|
||||
UpdateCollectionConfiguration,
|
||||
update_collection_configuration_to_json_str,
|
||||
CollectionMetadata,
|
||||
)
|
||||
from chromadb.api.types import Schema
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, logger
|
||||
from chromadb.db.system import SysDB
|
||||
from chromadb.errors import NotFoundError, UniqueConstraintError, InternalError
|
||||
from chromadb.proto.convert import (
|
||||
from_proto_collection,
|
||||
from_proto_segment,
|
||||
to_proto_update_metadata,
|
||||
to_proto_segment,
|
||||
to_proto_segment_scope,
|
||||
)
|
||||
from chromadb.proto.coordinator_pb2 import (
|
||||
CreateCollectionRequest,
|
||||
CreateDatabaseRequest,
|
||||
CreateSegmentRequest,
|
||||
CreateTenantRequest,
|
||||
CountCollectionsRequest,
|
||||
CountCollectionsResponse,
|
||||
DeleteCollectionRequest,
|
||||
DeleteDatabaseRequest,
|
||||
DeleteSegmentRequest,
|
||||
GetCollectionsRequest,
|
||||
GetCollectionsResponse,
|
||||
GetCollectionSizeRequest,
|
||||
GetCollectionSizeResponse,
|
||||
GetCollectionWithSegmentsRequest,
|
||||
GetCollectionWithSegmentsResponse,
|
||||
GetDatabaseRequest,
|
||||
GetSegmentsRequest,
|
||||
GetTenantRequest,
|
||||
ListDatabasesRequest,
|
||||
UpdateCollectionRequest,
|
||||
UpdateSegmentRequest,
|
||||
)
|
||||
from chromadb.proto.coordinator_pb2_grpc import SysDBStub
|
||||
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
|
||||
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.types import (
|
||||
Collection,
|
||||
CollectionAndSegments,
|
||||
Database,
|
||||
Metadata,
|
||||
OptionalArgument,
|
||||
Segment,
|
||||
SegmentScope,
|
||||
Tenant,
|
||||
Unspecified,
|
||||
UpdateMetadata,
|
||||
)
|
||||
from google.protobuf.empty_pb2 import Empty
|
||||
import grpc
|
||||
|
||||
|
||||
class GrpcSysDB(SysDB):
|
||||
"""A gRPC implementation of the SysDB. In the distributed system, the SysDB is also
|
||||
called the 'Coordinator'. This implementation is used by Chroma frontend servers
|
||||
to call a remote SysDB (Coordinator) service."""
|
||||
|
||||
_sys_db_stub: SysDBStub
|
||||
_channel: grpc.Channel
|
||||
_coordinator_url: str
|
||||
_coordinator_port: int
|
||||
_request_timeout_seconds: int
|
||||
|
||||
def __init__(self, system: System):
|
||||
self._coordinator_url = system.settings.require("chroma_coordinator_host")
|
||||
# TODO: break out coordinator_port into a separate setting?
|
||||
self._coordinator_port = system.settings.require("chroma_server_grpc_port")
|
||||
self._request_timeout_seconds = system.settings.require(
|
||||
"chroma_sysdb_request_timeout_seconds"
|
||||
)
|
||||
return super().__init__(system)
|
||||
|
||||
@overrides
|
||||
def start(self) -> None:
|
||||
self._channel = grpc.insecure_channel(
|
||||
f"{self._coordinator_url}:{self._coordinator_port}",
|
||||
options=[("grpc.max_concurrent_streams", 1000)],
|
||||
)
|
||||
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
|
||||
self._channel = grpc.intercept_channel(self._channel, *interceptors)
|
||||
self._sys_db_stub = SysDBStub(self._channel) # type: ignore
|
||||
return super().start()
|
||||
|
||||
@overrides
|
||||
def stop(self) -> None:
|
||||
self._channel.close()
|
||||
return super().stop()
|
||||
|
||||
@overrides
|
||||
def reset_state(self) -> None:
|
||||
self._sys_db_stub.ResetState(Empty())
|
||||
return super().reset_state()
|
||||
|
||||
@overrides
|
||||
def create_database(
|
||||
self, id: UUID, name: str, tenant: str = DEFAULT_TENANT
|
||||
) -> None:
|
||||
try:
|
||||
request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant)
|
||||
response = self._sys_db_stub.CreateDatabase(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to create database name {name} and database id {id} for tenant {tenant} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
||||
try:
|
||||
request = GetDatabaseRequest(name=name, tenant=tenant)
|
||||
response = self._sys_db_stub.GetDatabase(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
return Database(
|
||||
id=UUID(hex=response.database.id),
|
||||
name=response.database.name,
|
||||
tenant=response.database.tenant,
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to get database {name} for tenant {tenant} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
try:
|
||||
request = DeleteDatabaseRequest(name=name, tenant=tenant)
|
||||
self._sys_db_stub.DeleteDatabase(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to delete database {name} for tenant {tenant} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError
|
||||
|
||||
@overrides
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
try:
|
||||
request = ListDatabasesRequest(limit=limit, offset=offset, tenant=tenant)
|
||||
response = self._sys_db_stub.ListDatabases(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
results: List[Database] = []
|
||||
for proto_database in response.databases:
|
||||
results.append(
|
||||
Database(
|
||||
id=UUID(hex=proto_database.id),
|
||||
name=proto_database.name,
|
||||
tenant=proto_database.tenant,
|
||||
)
|
||||
)
|
||||
return results
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to list databases for tenant {tenant} due to error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def create_tenant(self, name: str) -> None:
|
||||
try:
|
||||
request = CreateTenantRequest(name=name)
|
||||
response = self._sys_db_stub.CreateTenant(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(f"Failed to create tenant {name} due to error: {e}")
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_tenant(self, name: str) -> Tenant:
|
||||
try:
|
||||
request = GetTenantRequest(name=name)
|
||||
response = self._sys_db_stub.GetTenant(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
return Tenant(
|
||||
name=response.tenant.name,
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(f"Failed to get tenant {name} due to error: {e}")
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def create_segment(self, segment: Segment) -> None:
|
||||
try:
|
||||
proto_segment = to_proto_segment(segment)
|
||||
request = CreateSegmentRequest(
|
||||
segment=proto_segment,
|
||||
)
|
||||
response = self._sys_db_stub.CreateSegment(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(f"Failed to create segment {segment}, error: {e}")
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def delete_segment(self, collection: UUID, id: UUID) -> None:
|
||||
try:
|
||||
request = DeleteSegmentRequest(
|
||||
id=id.hex,
|
||||
collection=collection.hex,
|
||||
)
|
||||
response = self._sys_db_stub.DeleteSegment(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to delete segment with id {id} for collection {collection} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_segments(
|
||||
self,
|
||||
collection: UUID,
|
||||
id: Optional[UUID] = None,
|
||||
type: Optional[str] = None,
|
||||
scope: Optional[SegmentScope] = None,
|
||||
) -> Sequence[Segment]:
|
||||
try:
|
||||
request = GetSegmentsRequest(
|
||||
id=id.hex if id else None,
|
||||
type=type,
|
||||
scope=to_proto_segment_scope(scope) if scope else None,
|
||||
collection=collection.hex,
|
||||
)
|
||||
response = self._sys_db_stub.GetSegments(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
results: List[Segment] = []
|
||||
for proto_segment in response.segments:
|
||||
segment = from_proto_segment(proto_segment)
|
||||
results.append(segment)
|
||||
return results
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to get segment id {id}, type {type}, scope {scope} for collection {collection} due to error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def update_segment(
|
||||
self,
|
||||
collection: UUID,
|
||||
id: UUID,
|
||||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
|
||||
) -> None:
|
||||
try:
|
||||
write_metadata = None
|
||||
if metadata != Unspecified():
|
||||
write_metadata = cast(Union[UpdateMetadata, None], metadata)
|
||||
|
||||
request = UpdateSegmentRequest(
|
||||
id=id.hex,
|
||||
collection=collection.hex,
|
||||
metadata=to_proto_update_metadata(write_metadata)
|
||||
if write_metadata
|
||||
else None,
|
||||
)
|
||||
|
||||
if metadata is None:
|
||||
request.ClearField("metadata")
|
||||
request.reset_metadata = True
|
||||
|
||||
self._sys_db_stub.UpdateSegment(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to update segment with id {id} for collection {collection}, error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def create_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
name: str,
|
||||
schema: Optional[Schema],
|
||||
configuration: CreateCollectionConfiguration,
|
||||
segments: Sequence[Segment],
|
||||
metadata: Optional[Metadata] = None,
|
||||
dimension: Optional[int] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Tuple[Collection, bool]:
|
||||
try:
|
||||
request = CreateCollectionRequest(
|
||||
id=id.hex,
|
||||
name=name,
|
||||
configuration_json_str=create_collection_configuration_to_json_str(
|
||||
configuration, cast(CollectionMetadata, metadata)
|
||||
),
|
||||
metadata=to_proto_update_metadata(metadata) if metadata else None,
|
||||
dimension=dimension,
|
||||
get_or_create=get_or_create,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
segments=[to_proto_segment(segment) for segment in segments],
|
||||
)
|
||||
response = self._sys_db_stub.CreateCollection(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
collection = from_proto_collection(response.collection)
|
||||
return collection, response.created
|
||||
except grpc.RpcError as e:
|
||||
logger.error(
|
||||
f"Failed to create collection id {id}, name {name} for database {database} and tenant {tenant} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def delete_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
try:
|
||||
request = DeleteCollectionRequest(
|
||||
id=id.hex,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
response = self._sys_db_stub.DeleteCollection(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.error(
|
||||
f"Failed to delete collection id {id} for database {database} and tenant {tenant} due to error: {e}"
|
||||
)
|
||||
e = cast(grpc.Call, e)
|
||||
logger.error(
|
||||
f"Error code: {e.code()}, NotFoundError: {grpc.StatusCode.NOT_FOUND}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_collections(
|
||||
self,
|
||||
id: Optional[UUID] = None,
|
||||
name: Optional[str] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> Sequence[Collection]:
|
||||
try:
|
||||
# TODO: implement limit and offset in the gRPC service
|
||||
request = None
|
||||
if id is not None:
|
||||
request = GetCollectionsRequest(
|
||||
id=id.hex,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if name is not None:
|
||||
if tenant is None and database is None:
|
||||
raise ValueError(
|
||||
"If name is specified, tenant and database must also be specified in order to uniquely identify the collection"
|
||||
)
|
||||
request = GetCollectionsRequest(
|
||||
name=name,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if id is None and name is None:
|
||||
request = GetCollectionsRequest(
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
response: GetCollectionsResponse = self._sys_db_stub.GetCollections(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
results: List[Collection] = []
|
||||
for collection in response.collections:
|
||||
results.append(from_proto_collection(collection))
|
||||
return results
|
||||
except grpc.RpcError as e:
|
||||
logger.error(
|
||||
f"Failed to get collections with id {id}, name {name}, tenant {tenant}, database {database} due to error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def count_collections(
|
||||
self,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: Optional[str] = None,
|
||||
) -> int:
|
||||
try:
|
||||
if database is None or database == "":
|
||||
request = CountCollectionsRequest(tenant=tenant)
|
||||
response: CountCollectionsResponse = self._sys_db_stub.CountCollections(
|
||||
request
|
||||
)
|
||||
return response.count
|
||||
else:
|
||||
request = CountCollectionsRequest(
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
response: CountCollectionsResponse = self._sys_db_stub.CountCollections(
|
||||
request
|
||||
)
|
||||
return response.count
|
||||
except grpc.RpcError as e:
|
||||
logger.error(f"Failed to count collections due to error: {e}")
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_collection_size(self, id: UUID) -> int:
|
||||
try:
|
||||
request = GetCollectionSizeRequest(id=id.hex)
|
||||
response: GetCollectionSizeResponse = self._sys_db_stub.GetCollectionSize(
|
||||
request
|
||||
)
|
||||
return response.total_records_post_compaction
|
||||
except grpc.RpcError as e:
|
||||
logger.error(f"Failed to get collection {id} size due to error: {e}")
|
||||
raise InternalError()
|
||||
|
||||
@trace_method(
|
||||
"SysDB.get_collection_with_segments", OpenTelemetryGranularity.OPERATION
|
||||
)
|
||||
@overrides
|
||||
def get_collection_with_segments(
|
||||
self, collection_id: UUID
|
||||
) -> CollectionAndSegments:
|
||||
try:
|
||||
request = GetCollectionWithSegmentsRequest(id=collection_id.hex)
|
||||
response: GetCollectionWithSegmentsResponse = (
|
||||
self._sys_db_stub.GetCollectionWithSegments(request)
|
||||
)
|
||||
return CollectionAndSegments(
|
||||
collection=from_proto_collection(response.collection),
|
||||
segments=[from_proto_segment(segment) for segment in response.segments],
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
logger.error(
|
||||
f"Failed to get collection {collection_id} and its segments due to error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def update_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
name: OptionalArgument[str] = Unspecified(),
|
||||
dimension: OptionalArgument[Optional[int]] = Unspecified(),
|
||||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
|
||||
configuration: OptionalArgument[
|
||||
Optional[UpdateCollectionConfiguration]
|
||||
] = Unspecified(),
|
||||
) -> None:
|
||||
try:
|
||||
write_name = None
|
||||
if name != Unspecified():
|
||||
write_name = cast(str, name)
|
||||
|
||||
write_dimension = None
|
||||
if dimension != Unspecified():
|
||||
write_dimension = cast(Union[int, None], dimension)
|
||||
|
||||
write_metadata = None
|
||||
if metadata != Unspecified():
|
||||
write_metadata = cast(Union[UpdateMetadata, None], metadata)
|
||||
|
||||
write_configuration = None
|
||||
if configuration != Unspecified():
|
||||
write_configuration = cast(
|
||||
Union[UpdateCollectionConfiguration, None], configuration
|
||||
)
|
||||
|
||||
request = UpdateCollectionRequest(
|
||||
id=id.hex,
|
||||
name=write_name,
|
||||
dimension=write_dimension,
|
||||
metadata=to_proto_update_metadata(write_metadata)
|
||||
if write_metadata
|
||||
else None,
|
||||
configuration_json_str=update_collection_configuration_to_json_str(
|
||||
write_configuration
|
||||
)
|
||||
if write_configuration
|
||||
else None,
|
||||
)
|
||||
if metadata is None:
|
||||
request.ClearField("metadata")
|
||||
request.reset_metadata = True
|
||||
|
||||
response = self._sys_db_stub.UpdateCollection(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
e = cast(grpc.Call, e)
|
||||
logger.error(
|
||||
f"Failed to update collection id {id}, name {name} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
def reset_and_wait_for_ready(self) -> None:
|
||||
self._sys_db_stub.ResetState(Empty(), wait_for_ready=True)
|
||||
@@ -0,0 +1,497 @@
|
||||
from concurrent import futures
|
||||
from typing import Any, Dict, List, cast
|
||||
from uuid import UUID
|
||||
from overrides import overrides
|
||||
import json
|
||||
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System
|
||||
from chromadb.proto.convert import (
|
||||
from_proto_metadata,
|
||||
from_proto_update_metadata,
|
||||
from_proto_segment,
|
||||
from_proto_segment_scope,
|
||||
to_proto_collection,
|
||||
to_proto_segment,
|
||||
)
|
||||
import chromadb.proto.chroma_pb2 as proto
|
||||
from chromadb.proto.coordinator_pb2 import (
|
||||
CreateCollectionRequest,
|
||||
CreateCollectionResponse,
|
||||
CreateDatabaseRequest,
|
||||
CreateDatabaseResponse,
|
||||
CreateSegmentRequest,
|
||||
CreateSegmentResponse,
|
||||
CreateTenantRequest,
|
||||
CreateTenantResponse,
|
||||
CountCollectionsRequest,
|
||||
CountCollectionsResponse,
|
||||
DeleteCollectionRequest,
|
||||
DeleteCollectionResponse,
|
||||
DeleteSegmentRequest,
|
||||
DeleteSegmentResponse,
|
||||
GetCollectionsRequest,
|
||||
GetCollectionsResponse,
|
||||
GetCollectionSizeRequest,
|
||||
GetCollectionSizeResponse,
|
||||
GetCollectionWithSegmentsRequest,
|
||||
GetCollectionWithSegmentsResponse,
|
||||
GetDatabaseRequest,
|
||||
GetDatabaseResponse,
|
||||
GetSegmentsRequest,
|
||||
GetSegmentsResponse,
|
||||
GetTenantRequest,
|
||||
GetTenantResponse,
|
||||
ResetStateResponse,
|
||||
UpdateCollectionRequest,
|
||||
UpdateCollectionResponse,
|
||||
UpdateSegmentRequest,
|
||||
UpdateSegmentResponse,
|
||||
)
|
||||
from chromadb.proto.coordinator_pb2_grpc import (
|
||||
SysDBServicer,
|
||||
add_SysDBServicer_to_server,
|
||||
)
|
||||
import grpc
|
||||
from google.protobuf.empty_pb2 import Empty
|
||||
from chromadb.types import Collection, Metadata, Segment, SegmentScope
|
||||
|
||||
|
||||
class GrpcMockSysDB(SysDBServicer, Component):
|
||||
"""A mock sysdb implementation that can be used for testing the grpc client. It stores
|
||||
state in simple python data structures instead of a database."""
|
||||
|
||||
_server: grpc.Server
|
||||
_server_port: int
|
||||
_segments: Dict[str, Segment] = {}
|
||||
_collection_to_segments: Dict[str, List[str]] = {}
|
||||
_tenants_to_databases_to_collections: Dict[
|
||||
str, Dict[str, Dict[str, Collection]]
|
||||
] = {}
|
||||
_tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {}
|
||||
|
||||
def __init__(self, system: System):
|
||||
self._server_port = system.settings.require("chroma_server_grpc_port")
|
||||
return super().__init__(system)
|
||||
|
||||
@overrides
|
||||
def start(self) -> None:
|
||||
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
add_SysDBServicer_to_server(self, self._server) # type: ignore
|
||||
self._server.add_insecure_port(f"[::]:{self._server_port}")
|
||||
self._server.start()
|
||||
return super().start()
|
||||
|
||||
@overrides
|
||||
def stop(self) -> None:
|
||||
self._server.stop(None)
|
||||
return super().stop()
|
||||
|
||||
@overrides
|
||||
def reset_state(self) -> None:
|
||||
self._segments = {}
|
||||
self._tenants_to_databases_to_collections = {}
|
||||
# Create defaults
|
||||
self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {}
|
||||
self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {}
|
||||
self._tenants_to_database_to_id[DEFAULT_TENANT] = {}
|
||||
self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0)
|
||||
return super().reset_state()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def CreateDatabase(
|
||||
self, request: CreateDatabaseRequest, context: grpc.ServicerContext
|
||||
) -> CreateDatabaseResponse:
|
||||
tenant = request.tenant
|
||||
database = request.name
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
if database in self._tenants_to_databases_to_collections[tenant]:
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS, f"Database {database} already exists"
|
||||
)
|
||||
self._tenants_to_databases_to_collections[tenant][database] = {}
|
||||
self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id)
|
||||
return CreateDatabaseResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetDatabase(
|
||||
self, request: GetDatabaseRequest, context: grpc.ServicerContext
|
||||
) -> GetDatabaseResponse:
|
||||
tenant = request.tenant
|
||||
database = request.name
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
if database not in self._tenants_to_databases_to_collections[tenant]:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
|
||||
id = self._tenants_to_database_to_id[tenant][database]
|
||||
return GetDatabaseResponse(
|
||||
database=proto.Database(id=id.hex, name=database, tenant=tenant),
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def CreateTenant(
|
||||
self, request: CreateTenantRequest, context: grpc.ServicerContext
|
||||
) -> CreateTenantResponse:
|
||||
tenant = request.name
|
||||
if tenant in self._tenants_to_databases_to_collections:
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS, f"Tenant {tenant} already exists"
|
||||
)
|
||||
self._tenants_to_databases_to_collections[tenant] = {}
|
||||
self._tenants_to_database_to_id[tenant] = {}
|
||||
return CreateTenantResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetTenant(
|
||||
self, request: GetTenantRequest, context: grpc.ServicerContext
|
||||
) -> GetTenantResponse:
|
||||
tenant = request.name
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
return GetTenantResponse(
|
||||
tenant=proto.Tenant(name=tenant),
|
||||
)
|
||||
|
||||
# We are forced to use check_signature=False because the generated proto code
|
||||
# does not have type annotations for the request and response objects.
|
||||
# TODO: investigate generating types for the request and response objects
|
||||
@overrides(check_signature=False)
|
||||
def CreateSegment(
|
||||
self, request: CreateSegmentRequest, context: grpc.ServicerContext
|
||||
) -> CreateSegmentResponse:
|
||||
segment = from_proto_segment(request.segment)
|
||||
return self.CreateSegmentHelper(segment, context)
|
||||
|
||||
def CreateSegmentHelper(
|
||||
self, segment: Segment, context: grpc.ServicerContext
|
||||
) -> CreateSegmentResponse:
|
||||
if segment["id"].hex in self._segments:
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Segment {segment['id']} already exists",
|
||||
)
|
||||
self._segments[segment["id"].hex] = segment
|
||||
return CreateSegmentResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def DeleteSegment(
|
||||
self, request: DeleteSegmentRequest, context: grpc.ServicerContext
|
||||
) -> DeleteSegmentResponse:
|
||||
id_to_delete = request.id
|
||||
if id_to_delete in self._segments:
|
||||
del self._segments[id_to_delete]
|
||||
return DeleteSegmentResponse()
|
||||
else:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Segment {id_to_delete} not found"
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetSegments(
|
||||
self, request: GetSegmentsRequest, context: grpc.ServicerContext
|
||||
) -> GetSegmentsResponse:
|
||||
target_id = UUID(hex=request.id) if request.HasField("id") else None
|
||||
target_type = request.type if request.HasField("type") else None
|
||||
target_scope = (
|
||||
from_proto_segment_scope(request.scope)
|
||||
if request.HasField("scope")
|
||||
else None
|
||||
)
|
||||
target_collection = UUID(hex=request.collection)
|
||||
|
||||
found_segments = []
|
||||
for segment in self._segments.values():
|
||||
if target_id and segment["id"] != target_id:
|
||||
continue
|
||||
if target_type and segment["type"] != target_type:
|
||||
continue
|
||||
if target_scope and segment["scope"] != target_scope:
|
||||
continue
|
||||
if target_collection and segment["collection"] != target_collection:
|
||||
continue
|
||||
found_segments.append(segment)
|
||||
return GetSegmentsResponse(
|
||||
segments=[to_proto_segment(segment) for segment in found_segments]
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def UpdateSegment(
|
||||
self, request: UpdateSegmentRequest, context: grpc.ServicerContext
|
||||
) -> UpdateSegmentResponse:
|
||||
id_to_update = UUID(request.id)
|
||||
if id_to_update.hex not in self._segments:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Segment {id_to_update} not found"
|
||||
)
|
||||
else:
|
||||
segment = self._segments[id_to_update.hex]
|
||||
if request.HasField("metadata"):
|
||||
target = cast(Dict[str, Any], segment["metadata"])
|
||||
if segment["metadata"] is None:
|
||||
segment["metadata"] = {}
|
||||
self._merge_metadata(target, request.metadata)
|
||||
if request.HasField("reset_metadata") and request.reset_metadata:
|
||||
segment["metadata"] = {}
|
||||
return UpdateSegmentResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def CreateCollection(
|
||||
self, request: CreateCollectionRequest, context: grpc.ServicerContext
|
||||
) -> CreateCollectionResponse:
|
||||
collection_name = request.name
|
||||
tenant = request.tenant
|
||||
database = request.database
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
if database not in self._tenants_to_databases_to_collections[tenant]:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
|
||||
|
||||
# Check if the collection already exists globally by id
|
||||
for (
|
||||
search_tenant,
|
||||
databases,
|
||||
) in self._tenants_to_databases_to_collections.items():
|
||||
for search_database, search_collections in databases.items():
|
||||
if request.id in search_collections:
|
||||
if (
|
||||
search_tenant != request.tenant
|
||||
or search_database != request.database
|
||||
):
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
|
||||
)
|
||||
elif not request.get_or_create:
|
||||
# If the id exists for this tenant and database, and we are not doing a get_or_create, then
|
||||
# we should return an already exists error
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
|
||||
)
|
||||
|
||||
# Check if the collection already exists in this database by name
|
||||
collections = self._tenants_to_databases_to_collections[tenant][database]
|
||||
matches = [c for c in collections.values() if c["name"] == collection_name]
|
||||
assert len(matches) <= 1
|
||||
if len(matches) > 0:
|
||||
if request.get_or_create:
|
||||
existing_collection = matches[0]
|
||||
return CreateCollectionResponse(
|
||||
collection=to_proto_collection(existing_collection),
|
||||
created=False,
|
||||
)
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Collection {collection_name} already exists",
|
||||
)
|
||||
|
||||
configuration_json = json.loads(request.configuration_json_str)
|
||||
|
||||
id = UUID(hex=request.id)
|
||||
new_collection = Collection(
|
||||
id=id,
|
||||
name=request.name,
|
||||
configuration_json=configuration_json,
|
||||
serialized_schema=None,
|
||||
metadata=from_proto_metadata(request.metadata),
|
||||
dimension=request.dimension,
|
||||
database=database,
|
||||
tenant=tenant,
|
||||
version=0,
|
||||
)
|
||||
|
||||
# Check that segments are unique and do not already exist
|
||||
# Keep a track of the segments that are being added
|
||||
segments_added = []
|
||||
# Create segments for the collection
|
||||
for segment_proto in request.segments:
|
||||
segment = from_proto_segment(segment_proto)
|
||||
if segment["id"].hex in self._segments:
|
||||
# Remove the already added segment since we need to roll back
|
||||
for s in segments_added:
|
||||
self.DeleteSegment(DeleteSegmentRequest(id=s), context)
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Segment {segment['id']} already exists",
|
||||
)
|
||||
self.CreateSegmentHelper(segment, context)
|
||||
segments_added.append(segment["id"].hex)
|
||||
|
||||
collections[request.id] = new_collection
|
||||
collection_unique_key = f"{tenant}:{database}:{request.id}"
|
||||
self._collection_to_segments[collection_unique_key] = segments_added
|
||||
return CreateCollectionResponse(
|
||||
collection=to_proto_collection(new_collection),
|
||||
created=True,
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def DeleteCollection(
|
||||
self, request: DeleteCollectionRequest, context: grpc.ServicerContext
|
||||
) -> DeleteCollectionResponse:
|
||||
collection_id = request.id
|
||||
tenant = request.tenant
|
||||
database = request.database
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
if database not in self._tenants_to_databases_to_collections[tenant]:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
|
||||
collections = self._tenants_to_databases_to_collections[tenant][database]
|
||||
if collection_id in collections:
|
||||
del collections[collection_id]
|
||||
collection_unique_key = f"{tenant}:{database}:{collection_id}"
|
||||
segment_ids = self._collection_to_segments[collection_unique_key]
|
||||
if segment_ids: # Delete segments if provided.
|
||||
for segment_id in segment_ids:
|
||||
del self._segments[segment_id]
|
||||
return DeleteCollectionResponse()
|
||||
else:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Collection {collection_id} not found"
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetCollections(
|
||||
self, request: GetCollectionsRequest, context: grpc.ServicerContext
|
||||
) -> GetCollectionsResponse:
|
||||
target_id = UUID(hex=request.id) if request.HasField("id") else None
|
||||
target_name = request.name if request.HasField("name") else None
|
||||
|
||||
allCollections = {}
|
||||
for tenant, databases in self._tenants_to_databases_to_collections.items():
|
||||
for database, collections in databases.items():
|
||||
if request.tenant != "" and tenant != request.tenant:
|
||||
continue
|
||||
if request.database != "" and database != request.database:
|
||||
continue
|
||||
allCollections.update(collections)
|
||||
print(
|
||||
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
|
||||
)
|
||||
found_collections = []
|
||||
for collection in allCollections.values():
|
||||
if target_id and collection["id"] != target_id:
|
||||
continue
|
||||
if target_name and collection["name"] != target_name:
|
||||
continue
|
||||
found_collections.append(collection)
|
||||
return GetCollectionsResponse(
|
||||
collections=[
|
||||
to_proto_collection(collection) for collection in found_collections
|
||||
]
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def CountCollections(
|
||||
self, request: CountCollectionsRequest, context: grpc.ServicerContext
|
||||
) -> CountCollectionsResponse:
|
||||
request = GetCollectionsRequest(
|
||||
tenant=request.tenant,
|
||||
database=request.database,
|
||||
)
|
||||
collections = self.GetCollections(request, context)
|
||||
return CountCollectionsResponse(count=len(collections.collections))
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetCollectionSize(
|
||||
self, request: GetCollectionSizeRequest, context: grpc.ServicerContext
|
||||
) -> GetCollectionSizeResponse:
|
||||
return GetCollectionSizeResponse(
|
||||
total_records_post_compaction=0,
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetCollectionWithSegments(
|
||||
self, request: GetCollectionWithSegmentsRequest, context: grpc.ServicerContext
|
||||
) -> GetCollectionWithSegmentsResponse:
|
||||
allCollections = {}
|
||||
for tenant, databases in self._tenants_to_databases_to_collections.items():
|
||||
for database, collections in databases.items():
|
||||
allCollections.update(collections)
|
||||
print(
|
||||
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
|
||||
)
|
||||
collection = allCollections.get(request.id, None)
|
||||
if collection is None:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Collection with id {request.id} not found"
|
||||
)
|
||||
collection_unique_key = (
|
||||
f"{collection.tenant}:{collection.database}:{request.id}"
|
||||
)
|
||||
segments = [
|
||||
self._segments[id]
|
||||
for id in self._collection_to_segments[collection_unique_key]
|
||||
]
|
||||
if {segment["scope"] for segment in segments} != {
|
||||
SegmentScope.METADATA,
|
||||
SegmentScope.RECORD,
|
||||
SegmentScope.VECTOR,
|
||||
}:
|
||||
context.abort(
|
||||
grpc.StatusCode.INTERNAL,
|
||||
f"Incomplete segments for collection {collection}: {segments}",
|
||||
)
|
||||
|
||||
return GetCollectionWithSegmentsResponse(
|
||||
collection=to_proto_collection(collection),
|
||||
segments=[to_proto_segment(segment) for segment in segments],
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def UpdateCollection(
|
||||
self, request: UpdateCollectionRequest, context: grpc.ServicerContext
|
||||
) -> UpdateCollectionResponse:
|
||||
id_to_update = UUID(request.id)
|
||||
# Find the collection with this id
|
||||
collections = {}
|
||||
for tenant, databases in self._tenants_to_databases_to_collections.items():
|
||||
for database, maybe_collections in databases.items():
|
||||
if id_to_update.hex in maybe_collections:
|
||||
collections = maybe_collections
|
||||
|
||||
if id_to_update.hex not in collections:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Collection {id_to_update} not found"
|
||||
)
|
||||
else:
|
||||
collection = collections[id_to_update.hex]
|
||||
if request.HasField("name"):
|
||||
collection["name"] = request.name
|
||||
if request.HasField("dimension"):
|
||||
collection["dimension"] = request.dimension
|
||||
if request.HasField("metadata"):
|
||||
# TODO: IN SysDB SQlite we have technical debt where we
|
||||
# replace the entire metadata dict with the new one. We should
|
||||
# fix that by merging it. For now we just do the same thing here
|
||||
|
||||
update_metadata = from_proto_update_metadata(request.metadata)
|
||||
cleaned_metadata = None
|
||||
if update_metadata is not None:
|
||||
cleaned_metadata = {}
|
||||
for key, value in update_metadata.items():
|
||||
if value is not None:
|
||||
cleaned_metadata[key] = value
|
||||
|
||||
collection["metadata"] = cleaned_metadata
|
||||
elif request.HasField("reset_metadata"):
|
||||
if request.reset_metadata:
|
||||
collection["metadata"] = {}
|
||||
|
||||
return UpdateCollectionResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def ResetState(
|
||||
self, request: Empty, context: grpc.ServicerContext
|
||||
) -> ResetStateResponse:
|
||||
self.reset_state()
|
||||
return ResetStateResponse()
|
||||
|
||||
def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None:
|
||||
target_metadata = cast(Dict[str, Any], target)
|
||||
source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source))
|
||||
target_metadata.update(source_metadata)
|
||||
# If a key has a None value, remove it from the metadata
|
||||
for key, value in source_metadata.items():
|
||||
if value is None and key in target:
|
||||
del target_metadata[key]
|
||||
@@ -0,0 +1,273 @@
|
||||
import logging
|
||||
from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool
|
||||
from chromadb.db.migrations import MigratableDB, Migration
|
||||
from chromadb.config import System, Settings
|
||||
import chromadb.db.base as base
|
||||
from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue
|
||||
from chromadb.db.mixins.sysdb import SqlSysDB
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
import sqlite3
|
||||
from overrides import override
|
||||
import pypika
|
||||
from typing import Sequence, cast, Optional, Type, Any
|
||||
from typing_extensions import Literal
|
||||
from types import TracebackType
|
||||
import os
|
||||
from uuid import UUID
|
||||
from threading import local
|
||||
from importlib_resources import files
|
||||
from importlib_resources.abc import Traversable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TxWrapper(base.TxWrapper):
|
||||
_conn: Connection
|
||||
_pool: Pool
|
||||
|
||||
def __init__(self, conn_pool: Pool, stack: local):
|
||||
self._tx_stack = stack
|
||||
self._conn = conn_pool.connect()
|
||||
self._pool = conn_pool
|
||||
|
||||
@override
|
||||
def __enter__(self) -> base.Cursor:
|
||||
if len(self._tx_stack.stack) == 0:
|
||||
self._conn.execute("PRAGMA case_sensitive_like = ON")
|
||||
self._conn.execute("BEGIN;")
|
||||
self._tx_stack.stack.append(self)
|
||||
return self._conn.cursor() # type: ignore
|
||||
|
||||
@override
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
self._tx_stack.stack.pop()
|
||||
if len(self._tx_stack.stack) == 0:
|
||||
if exc_type is None:
|
||||
self._conn.commit()
|
||||
else:
|
||||
self._conn.rollback()
|
||||
self._conn.cursor().close()
|
||||
self._pool.return_to_pool(self._conn)
|
||||
return False
|
||||
|
||||
|
||||
class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB):
|
||||
_conn_pool: Pool
|
||||
_settings: Settings
|
||||
_migration_imports: Sequence[Traversable]
|
||||
_db_file: str
|
||||
_tx_stack: local
|
||||
_is_persistent: bool
|
||||
|
||||
def __init__(self, system: System):
|
||||
self._settings = system.settings
|
||||
self._migration_imports = [
|
||||
files("chromadb.migrations.embeddings_queue"),
|
||||
files("chromadb.migrations.sysdb"),
|
||||
files("chromadb.migrations.metadb"),
|
||||
]
|
||||
self._is_persistent = self._settings.require("is_persistent")
|
||||
self._opentelemetry_client = system.require(OpenTelemetryClient)
|
||||
if not self._is_persistent:
|
||||
# In order to allow sqlite to be shared between multiple threads, we need to use a
|
||||
# URI connection string with shared cache.
|
||||
# See https://www.sqlite.org/sharedcache.html
|
||||
# https://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa
|
||||
self._db_file = "file::memory:?cache=shared"
|
||||
self._conn_pool = LockPool(self._db_file, is_uri=True)
|
||||
else:
|
||||
self._db_file = (
|
||||
self._settings.require("persist_directory") + "/chroma.sqlite3"
|
||||
)
|
||||
if not os.path.exists(self._db_file):
|
||||
os.makedirs(os.path.dirname(self._db_file), exist_ok=True)
|
||||
self._conn_pool = PerThreadPool(self._db_file)
|
||||
self._tx_stack = local()
|
||||
super().__init__(system)
|
||||
|
||||
@trace_method("SqliteDB.start", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
with self.tx() as cur:
|
||||
cur.execute("PRAGMA foreign_keys = ON")
|
||||
cur.execute("PRAGMA case_sensitive_like = ON")
|
||||
self.initialize_migrations()
|
||||
|
||||
if (
|
||||
# (don't attempt to access .config if migrations haven't been run)
|
||||
self._settings.require("migrations") == "apply"
|
||||
and self.config.get_parameter("automatically_purge").value is False
|
||||
):
|
||||
logger.warning(
|
||||
"⚠️ It looks like you upgraded from a version below 0.5.6 and could benefit from vacuuming your database. Run chromadb utils vacuum --help for more information."
|
||||
)
|
||||
|
||||
@trace_method("SqliteDB.stop", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
self._conn_pool.close()
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def querybuilder() -> Type[pypika.Query]:
|
||||
return pypika.Query # type: ignore
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def parameter_format() -> str:
|
||||
return "?"
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def migration_scope() -> str:
|
||||
return "sqlite"
|
||||
|
||||
@override
|
||||
def migration_dirs(self) -> Sequence[Traversable]:
|
||||
return self._migration_imports
|
||||
|
||||
@override
|
||||
def tx(self) -> TxWrapper:
|
||||
if not hasattr(self._tx_stack, "stack"):
|
||||
self._tx_stack.stack = []
|
||||
return TxWrapper(self._conn_pool, stack=self._tx_stack)
|
||||
|
||||
@trace_method("SqliteDB.reset_state", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def reset_state(self) -> None:
|
||||
if not self._settings.require("allow_reset"):
|
||||
raise ValueError(
|
||||
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
|
||||
)
|
||||
with self.tx() as cur:
|
||||
# Drop all tables
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table'
|
||||
"""
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
cur.execute(f"DROP TABLE IF EXISTS {row[0]}")
|
||||
self._conn_pool.close()
|
||||
self.start()
|
||||
super().reset_state()
|
||||
|
||||
@trace_method("SqliteDB.setup_migrations", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def setup_migrations(self) -> None:
|
||||
with self.tx() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
dir TEXT NOT NULL,
|
||||
version INTEGER NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
sql TEXT NOT NULL,
|
||||
hash TEXT NOT NULL,
|
||||
PRIMARY KEY (dir, version)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
@trace_method("SqliteDB.migrations_initialized", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def migrations_initialized(self) -> bool:
|
||||
with self.tx() as cur:
|
||||
cur.execute(
|
||||
"""SELECT count(*) FROM sqlite_master
|
||||
WHERE type='table' AND name='migrations'"""
|
||||
)
|
||||
|
||||
if cur.fetchone()[0] == 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@trace_method("SqliteDB.db_migrations", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def db_migrations(self, dir: Traversable) -> Sequence[Migration]:
|
||||
with self.tx() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT dir, version, filename, sql, hash
|
||||
FROM migrations
|
||||
WHERE dir = ?
|
||||
ORDER BY version ASC
|
||||
""",
|
||||
(dir.name,),
|
||||
)
|
||||
|
||||
migrations = []
|
||||
for row in cur.fetchall():
|
||||
found_dir = cast(str, row[0])
|
||||
found_version = cast(int, row[1])
|
||||
found_filename = cast(str, row[2])
|
||||
found_sql = cast(str, row[3])
|
||||
found_hash = cast(str, row[4])
|
||||
migrations.append(
|
||||
Migration(
|
||||
dir=found_dir,
|
||||
version=found_version,
|
||||
filename=found_filename,
|
||||
sql=found_sql,
|
||||
hash=found_hash,
|
||||
scope=self.migration_scope(),
|
||||
)
|
||||
)
|
||||
return migrations
|
||||
|
||||
@override
|
||||
def apply_migration(self, cur: base.Cursor, migration: Migration) -> None:
|
||||
cur.executescript(migration["sql"])
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO migrations (dir, version, filename, sql, hash)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
migration["dir"],
|
||||
migration["version"],
|
||||
migration["filename"],
|
||||
migration["sql"],
|
||||
migration["hash"],
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]:
|
||||
return UUID(value) if value is not None else None
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]:
|
||||
return str(uuid) if uuid is not None else None
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def unique_constraint_error() -> Type[BaseException]:
|
||||
return sqlite3.IntegrityError
|
||||
|
||||
def vacuum(self, timeout: int = 5) -> None:
|
||||
"""Runs VACUUM on the database. `timeout` is the maximum time to wait for an exclusive lock in seconds."""
|
||||
conn = self._conn_pool.connect()
|
||||
conn.execute(f"PRAGMA busy_timeout = {int(timeout) * 1000}")
|
||||
conn.execute("VACUUM")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO maintenance_log (operation, timestamp)
|
||||
VALUES ('vacuum', CURRENT_TIMESTAMP)
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,163 @@
|
||||
import sqlite3
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Set
|
||||
import threading
|
||||
from overrides import override
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class Connection:
|
||||
"""A threadpool connection that returns itself to the pool on close()"""
|
||||
|
||||
_pool: "Pool"
|
||||
_db_file: str
|
||||
_conn: sqlite3.Connection
|
||||
|
||||
def __init__(
|
||||
self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any
|
||||
):
|
||||
self._pool = pool
|
||||
self._db_file = db_file
|
||||
self._conn = sqlite3.connect(
|
||||
db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs
|
||||
) # type: ignore
|
||||
self._conn.isolation_level = None # Handle commits explicitly
|
||||
|
||||
def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: # type: ignore
|
||||
if parameters is ...:
|
||||
return self._conn.execute(sql)
|
||||
return self._conn.execute(sql, parameters)
|
||||
|
||||
def commit(self) -> None:
|
||||
self._conn.commit()
|
||||
|
||||
def rollback(self) -> None:
|
||||
self._conn.rollback()
|
||||
|
||||
def cursor(self) -> sqlite3.Cursor:
|
||||
return self._conn.cursor()
|
||||
|
||||
def close_actual(self) -> None:
|
||||
"""Actually closes the connection to the db"""
|
||||
self._conn.close()
|
||||
|
||||
|
||||
class Pool(ABC):
|
||||
"""Abstract base class for a pool of connections to a sqlite database."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, db_file: str, is_uri: bool) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, *args: Any, **kwargs: Any) -> Connection:
|
||||
"""Return a connection from the pool."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def return_to_pool(self, conn: Connection) -> None:
|
||||
"""Return a connection to the pool."""
|
||||
pass
|
||||
|
||||
|
||||
class LockPool(Pool):
|
||||
"""A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time.
|
||||
This is used because sqlite does not support multithreaded access with connection timeouts when using the
|
||||
shared cache mode. We use the shared cache mode to allow multiple threads to share a database.
|
||||
"""
|
||||
|
||||
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
|
||||
_lock: threading.RLock
|
||||
_connection: threading.local
|
||||
_db_file: str
|
||||
_is_uri: bool
|
||||
|
||||
def __init__(self, db_file: str, is_uri: bool = False):
|
||||
self._connections = set()
|
||||
self._connection = threading.local()
|
||||
self._lock = threading.RLock()
|
||||
self._db_file = db_file
|
||||
self._is_uri = is_uri
|
||||
|
||||
@override
|
||||
def connect(self, *args: Any, **kwargs: Any) -> Connection:
|
||||
self._lock.acquire()
|
||||
if hasattr(self._connection, "conn") and self._connection.conn is not None:
|
||||
return self._connection.conn # type: ignore # cast doesn't work here for some reason
|
||||
else:
|
||||
new_connection = Connection(
|
||||
self, self._db_file, self._is_uri, *args, **kwargs
|
||||
)
|
||||
self._connection.conn = new_connection
|
||||
self._connections.add(weakref.ref(new_connection))
|
||||
return new_connection
|
||||
|
||||
@override
|
||||
def return_to_pool(self, conn: Connection) -> None:
|
||||
try:
|
||||
self._lock.release()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
for conn in self._connections:
|
||||
if conn() is not None:
|
||||
conn().close_actual() # type: ignore
|
||||
self._connections.clear()
|
||||
self._connection = threading.local()
|
||||
try:
|
||||
self._lock.release()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
class PerThreadPool(Pool):
|
||||
"""Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be
|
||||
extended to do so and block on connect() if the cap is reached.
|
||||
"""
|
||||
|
||||
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
|
||||
_lock: threading.Lock
|
||||
_connection: threading.local
|
||||
_db_file: str
|
||||
_is_uri_: bool
|
||||
|
||||
def __init__(self, db_file: str, is_uri: bool = False):
|
||||
self._connections = set()
|
||||
self._connection = threading.local()
|
||||
self._lock = threading.Lock()
|
||||
self._db_file = db_file
|
||||
self._is_uri = is_uri
|
||||
|
||||
@override
|
||||
def connect(self, *args: Any, **kwargs: Any) -> Connection:
|
||||
if hasattr(self._connection, "conn") and self._connection.conn is not None:
|
||||
return self._connection.conn # type: ignore # cast doesn't work here for some reason
|
||||
else:
|
||||
new_connection = Connection(
|
||||
self, self._db_file, self._is_uri, *args, **kwargs
|
||||
)
|
||||
self._connection.conn = new_connection
|
||||
with self._lock:
|
||||
self._connections.add(weakref.ref(new_connection))
|
||||
return new_connection
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
for conn in self._connections:
|
||||
if conn() is not None:
|
||||
conn().close_actual() # type: ignore
|
||||
self._connections.clear()
|
||||
self._connection = threading.local()
|
||||
|
||||
@override
|
||||
def return_to_pool(self, conn: Connection) -> None:
|
||||
pass # Each thread gets its own connection, so we don't need to return it to the pool
|
||||
@@ -0,0 +1,276 @@
|
||||
import sys
|
||||
from typing import Sequence
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
from importlib_resources.abc import Traversable
|
||||
import re
|
||||
import hashlib
|
||||
from chromadb.db.base import SqlDB, Cursor
|
||||
from abc import abstractmethod
|
||||
from chromadb.config import System, Settings
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
|
||||
|
||||
class MigrationFile(TypedDict):
|
||||
path: NotRequired[Traversable]
|
||||
dir: str
|
||||
filename: str
|
||||
version: int
|
||||
scope: str
|
||||
|
||||
|
||||
class Migration(MigrationFile):
|
||||
hash: str
|
||||
sql: str
|
||||
|
||||
|
||||
class UninitializedMigrationsError(Exception):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Migrations have not been initialized")
|
||||
|
||||
|
||||
class UnappliedMigrationsError(Exception):
|
||||
def __init__(self, dir: str, version: int):
|
||||
self.dir = dir
|
||||
self.version = version
|
||||
super().__init__(
|
||||
f"Unapplied migrations in {dir}, starting with version {version}"
|
||||
)
|
||||
|
||||
|
||||
class InconsistentVersionError(Exception):
|
||||
def __init__(self, dir: str, db_version: int, source_version: int):
|
||||
super().__init__(
|
||||
f"Inconsistent migration versions in {dir}:"
|
||||
+ f"db version was {db_version}, source version was {source_version}."
|
||||
+ " Has the migration sequence been modified since being applied to the DB?"
|
||||
)
|
||||
|
||||
|
||||
class InconsistentHashError(Exception):
|
||||
def __init__(self, path: str, db_hash: str, source_hash: str):
|
||||
super().__init__(
|
||||
f"Inconsistent hashes in {path}:"
|
||||
+ f"db hash was {db_hash}, source has was {source_hash}."
|
||||
+ " Was the migration file modified after being applied to the DB?"
|
||||
)
|
||||
|
||||
|
||||
class InvalidHashError(Exception):
|
||||
def __init__(self, alg: str):
|
||||
super().__init__(f"Invalid hash algorithm specified: {alg}")
|
||||
|
||||
|
||||
class InvalidMigrationFilename(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MigratableDB(SqlDB):
|
||||
"""Simple base class for databases which support basic migrations.
|
||||
|
||||
Migrations are SQL files stored as package resources and accessed via
|
||||
importlib_resources.
|
||||
|
||||
All migrations in the same directory are assumed to be dependent on previous
|
||||
migrations in the same directory, where "previous" is defined on lexographical
|
||||
ordering of filenames.
|
||||
|
||||
Migrations have a ascending numeric version number and a hash of the file contents.
|
||||
When migrations are applied, the hashes of previous migrations are checked to ensure
|
||||
that the database is consistent with the source repository. If they are not, an
|
||||
error is thrown and no migrations will be applied.
|
||||
|
||||
Migration files must follow the naming convention:
|
||||
<version>.<description>.<scope>.sql, where <version> is a 5-digit zero-padded
|
||||
integer, <description> is a short textual description, and <scope> is a short string
|
||||
identifying the database implementation.
|
||||
"""
|
||||
|
||||
_settings: Settings
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
self._settings = system.settings
|
||||
self._opentelemetry_client = system.require(OpenTelemetryClient)
|
||||
super().__init__(system)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def migration_scope() -> str:
|
||||
"""The database implementation to use for migrations (e.g, sqlite, pgsql)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def migration_dirs(self) -> Sequence[Traversable]:
|
||||
"""Directories containing the migration sequences that should be applied to this
|
||||
DB."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_migrations(self) -> None:
|
||||
"""Idempotently creates the migrations table"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def migrations_initialized(self) -> bool:
|
||||
"""Return true if the migrations table exists"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def db_migrations(self, dir: Traversable) -> Sequence[Migration]:
|
||||
"""Return a list of all migrations already applied to this database, from the
|
||||
given source directory, in ascending order."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply_migration(self, cur: Cursor, migration: Migration) -> None:
|
||||
"""Apply a single migration to the database"""
|
||||
pass
|
||||
|
||||
def initialize_migrations(self) -> None:
|
||||
"""Initialize migrations for this DB"""
|
||||
migrate = self._settings.require("migrations")
|
||||
|
||||
if migrate == "validate":
|
||||
self.validate_migrations()
|
||||
|
||||
if migrate == "apply":
|
||||
self.apply_migrations()
|
||||
|
||||
@trace_method("MigratableDB.validate_migrations", OpenTelemetryGranularity.ALL)
|
||||
def validate_migrations(self) -> None:
|
||||
"""Validate all migrations and throw an exception if there are any unapplied
|
||||
migrations in the source repo."""
|
||||
if not self.migrations_initialized():
|
||||
raise UninitializedMigrationsError()
|
||||
for dir in self.migration_dirs():
|
||||
db_migrations = self.db_migrations(dir)
|
||||
source_migrations = find_migrations(
|
||||
dir,
|
||||
self.migration_scope(),
|
||||
self._settings.require("migrations_hash_algorithm"),
|
||||
)
|
||||
unapplied_migrations = verify_migration_sequence(
|
||||
db_migrations, source_migrations
|
||||
)
|
||||
if len(unapplied_migrations) > 0:
|
||||
version = unapplied_migrations[0]["version"]
|
||||
raise UnappliedMigrationsError(dir=dir.name, version=version)
|
||||
|
||||
@trace_method("MigratableDB.apply_migrations", OpenTelemetryGranularity.ALL)
|
||||
def apply_migrations(self) -> None:
|
||||
"""Validate existing migrations, and apply all new ones."""
|
||||
self.setup_migrations()
|
||||
for dir in self.migration_dirs():
|
||||
db_migrations = self.db_migrations(dir)
|
||||
source_migrations = find_migrations(
|
||||
dir,
|
||||
self.migration_scope(),
|
||||
self._settings.require("migrations_hash_algorithm"),
|
||||
)
|
||||
unapplied_migrations = verify_migration_sequence(
|
||||
db_migrations, source_migrations
|
||||
)
|
||||
with self.tx() as cur:
|
||||
for migration in unapplied_migrations:
|
||||
self.apply_migration(cur, migration)
|
||||
|
||||
|
||||
# Format is <version>-<name>.<scope>.sql
|
||||
# e.g, 00001-users.sqlite.sql
|
||||
filename_regex = re.compile(r"(\d+)-(.+)\.(.+)\.sql")
|
||||
|
||||
|
||||
def _parse_migration_filename(
|
||||
dir: str, filename: str, path: Traversable
|
||||
) -> MigrationFile:
|
||||
"""Parse a migration filename into a MigrationFile object"""
|
||||
match = filename_regex.match(filename)
|
||||
if match is None:
|
||||
raise InvalidMigrationFilename("Invalid migration filename: " + filename)
|
||||
version, _, scope = match.groups()
|
||||
return {
|
||||
"path": path,
|
||||
"dir": dir,
|
||||
"filename": filename,
|
||||
"version": int(version),
|
||||
"scope": scope,
|
||||
}
|
||||
|
||||
|
||||
def verify_migration_sequence(
|
||||
db_migrations: Sequence[Migration],
|
||||
source_migrations: Sequence[Migration],
|
||||
) -> Sequence[Migration]:
|
||||
"""Given a list of migrations already applied to a database, and a list of
|
||||
migrations from the source code, validate that the applied migrations are correct
|
||||
and match the expected migrations.
|
||||
|
||||
Throws an exception if any migrations are missing, out of order, or if the source
|
||||
hash does not match.
|
||||
|
||||
Returns a list of all unapplied migrations, or an empty list if all migrations are
|
||||
applied and the database is up to date."""
|
||||
|
||||
for db_migration, source_migration in zip(db_migrations, source_migrations):
|
||||
if db_migration["version"] != source_migration["version"]:
|
||||
raise InconsistentVersionError(
|
||||
dir=db_migration["dir"],
|
||||
db_version=db_migration["version"],
|
||||
source_version=source_migration["version"],
|
||||
)
|
||||
|
||||
if db_migration["hash"] != source_migration["hash"]:
|
||||
raise InconsistentHashError(
|
||||
path=db_migration["dir"] + "/" + db_migration["filename"],
|
||||
db_hash=db_migration["hash"],
|
||||
source_hash=source_migration["hash"],
|
||||
)
|
||||
|
||||
return source_migrations[len(db_migrations) :]
|
||||
|
||||
|
||||
def find_migrations(
|
||||
dir: Traversable, scope: str, hash_alg: str = "md5"
|
||||
) -> Sequence[Migration]:
|
||||
"""Return a list of all migration present in the given directory, in ascending
|
||||
order. Filter by scope."""
|
||||
files = [
|
||||
_parse_migration_filename(dir.name, t.name, t)
|
||||
for t in dir.iterdir()
|
||||
if t.name.endswith(".sql")
|
||||
]
|
||||
files = list(filter(lambda f: f["scope"] == scope, files))
|
||||
files = sorted(files, key=lambda f: f["version"])
|
||||
return [_read_migration_file(f, hash_alg) for f in files]
|
||||
|
||||
|
||||
def _read_migration_file(file: MigrationFile, hash_alg: str) -> Migration:
|
||||
"""Read a migration file"""
|
||||
if "path" not in file or not file["path"].is_file():
|
||||
raise FileNotFoundError(
|
||||
f"No migration file found for dir {file['dir']} with filename {file['filename']} and scope {file['scope']} at version {file['version']}"
|
||||
)
|
||||
sql = file["path"].read_text()
|
||||
|
||||
if hash_alg == "md5":
|
||||
hash = (
|
||||
hashlib.md5(sql.encode("utf-8"), usedforsecurity=False).hexdigest()
|
||||
if sys.version_info >= (3, 9)
|
||||
else hashlib.md5(sql.encode("utf-8")).hexdigest()
|
||||
)
|
||||
elif hash_alg == "sha256":
|
||||
hash = hashlib.sha256(sql.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
raise InvalidHashError(alg=hash_alg)
|
||||
|
||||
return {
|
||||
"hash": hash,
|
||||
"sql": sql,
|
||||
"dir": file["dir"],
|
||||
"filename": file["filename"],
|
||||
"version": file["version"],
|
||||
"scope": file["scope"],
|
||||
}
|
||||
@@ -0,0 +1,507 @@
|
||||
from functools import cached_property
|
||||
import json
|
||||
from chromadb.api.configuration import (
|
||||
ConfigurationParameter,
|
||||
EmbeddingsQueueConfigurationInternal,
|
||||
)
|
||||
from chromadb.db.base import SqlDB, ParameterValue, get_sql
|
||||
from chromadb.errors import BatchSizeExceededError
|
||||
from chromadb.ingest import (
|
||||
Producer,
|
||||
Consumer,
|
||||
ConsumerCallbackFn,
|
||||
decode_vector,
|
||||
encode_vector,
|
||||
)
|
||||
from chromadb.types import (
|
||||
OperationRecord,
|
||||
LogRecord,
|
||||
ScalarEncoding,
|
||||
SeqId,
|
||||
Operation,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from overrides import override
|
||||
from collections import defaultdict
|
||||
from typing import Sequence, Optional, Dict, Set, Tuple, cast
|
||||
from uuid import UUID
|
||||
from pypika import Table, functions
|
||||
import uuid
|
||||
import logging
|
||||
from chromadb.ingest.impl.utils import create_topic_name
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_operation_codes = {
|
||||
Operation.ADD: 0,
|
||||
Operation.UPDATE: 1,
|
||||
Operation.UPSERT: 2,
|
||||
Operation.DELETE: 3,
|
||||
}
|
||||
_operation_codes_inv = {v: k for k, v in _operation_codes.items()}
|
||||
|
||||
# Set in conftest.py to rethrow errors in the "async" path during testing
|
||||
# https://doc.pytest.org/en/latest/example/simple.html#detect-if-running-from-within-a-pytest-run
|
||||
_called_from_test = False
|
||||
|
||||
|
||||
class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
||||
"""A SQL database that stores embeddings, allowing a traditional RDBMS to be used as
|
||||
the primary ingest queue and satisfying the top level Producer/Consumer interfaces.
|
||||
|
||||
Note that this class is only suitable for use cases where the producer and consumer
|
||||
are in the same process.
|
||||
|
||||
This is because notification of new embeddings happens solely in-process: this
|
||||
implementation does not actively listen to the the database for new records added by
|
||||
other processes.
|
||||
"""
|
||||
|
||||
class Subscription:
|
||||
id: UUID
|
||||
topic_name: str
|
||||
start: int
|
||||
end: int
|
||||
callback: ConsumerCallbackFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: UUID,
|
||||
topic_name: str,
|
||||
start: int,
|
||||
end: int,
|
||||
callback: ConsumerCallbackFn,
|
||||
):
|
||||
self.id = id
|
||||
self.topic_name = topic_name
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.callback = callback
|
||||
|
||||
_subscriptions: Dict[str, Set[Subscription]]
|
||||
_max_batch_size: Optional[int]
|
||||
_tenant: str
|
||||
_topic_namespace: str
|
||||
# How many variables are in the insert statement for a single record
|
||||
VARIABLES_PER_RECORD = 6
|
||||
|
||||
def __init__(self, system: System):
|
||||
self._subscriptions = defaultdict(set)
|
||||
self._max_batch_size = None
|
||||
self._opentelemetry_client = system.require(OpenTelemetryClient)
|
||||
self._tenant = system.settings.require("tenant_id")
|
||||
self._topic_namespace = system.settings.require("topic_namespace")
|
||||
super().__init__(system)
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue.reset_state", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def reset_state(self) -> None:
|
||||
super().reset_state()
|
||||
self._subscriptions = defaultdict(set)
|
||||
|
||||
# Invalidate the cached property
|
||||
try:
|
||||
del self.config
|
||||
except AttributeError:
|
||||
# Cached property hasn't been accessed yet
|
||||
pass
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue.delete_topic", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def delete_log(self, collection_id: UUID) -> None:
|
||||
topic_name = create_topic_name(
|
||||
self._tenant, self._topic_namespace, collection_id
|
||||
)
|
||||
t = Table("embeddings_queue")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(t)
|
||||
.where(t.topic == ParameterValue(topic_name))
|
||||
.delete()
|
||||
)
|
||||
with self.tx() as cur:
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue.purge_log", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def purge_log(self, collection_id: UUID) -> None:
|
||||
# (We need to purge on a per topic/collection basis, because the maximum sequence ID is tracked on a per topic/collection basis.)
|
||||
|
||||
segments_t = Table("segments")
|
||||
segment_ids_q = (
|
||||
self.querybuilder()
|
||||
.from_(segments_t)
|
||||
# This coalesce prevents a correctness bug when > 1 segments exist and:
|
||||
# - > 1 has written to the max_seq_id table
|
||||
# - > 1 has not never written to the max_seq_id table
|
||||
# In that case, we should not delete any WAL entries as we can't be sure that the all segments are caught up.
|
||||
.select(functions.Coalesce(Table("max_seq_id").seq_id, -1))
|
||||
.where(
|
||||
segments_t.collection == ParameterValue(self.uuid_to_db(collection_id))
|
||||
)
|
||||
.left_join(Table("max_seq_id"))
|
||||
.on(segments_t.id == Table("max_seq_id").segment_id)
|
||||
)
|
||||
|
||||
topic_name = create_topic_name(
|
||||
self._tenant, self._topic_namespace, collection_id
|
||||
)
|
||||
with self.tx() as cur:
|
||||
sql, params = get_sql(segment_ids_q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
results = cur.fetchall()
|
||||
if results:
|
||||
min_seq_id = min(row[0] for row in results)
|
||||
else:
|
||||
return
|
||||
|
||||
t = Table("embeddings_queue")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(t)
|
||||
.where(t.seq_id < ParameterValue(min_seq_id))
|
||||
.where(t.topic == ParameterValue(topic_name))
|
||||
.delete()
|
||||
)
|
||||
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue.submit_embedding", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def submit_embedding(
|
||||
self, collection_id: UUID, embedding: OperationRecord
|
||||
) -> SeqId:
|
||||
if not self._running:
|
||||
raise RuntimeError("Component not running")
|
||||
|
||||
return self.submit_embeddings(collection_id, [embedding])[0]
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue.submit_embeddings", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def submit_embeddings(
|
||||
self, collection_id: UUID, embeddings: Sequence[OperationRecord]
|
||||
) -> Sequence[SeqId]:
|
||||
if not self._running:
|
||||
raise RuntimeError("Component not running")
|
||||
|
||||
if len(embeddings) == 0:
|
||||
return []
|
||||
|
||||
if len(embeddings) > self.max_batch_size:
|
||||
raise BatchSizeExceededError(
|
||||
f"""
|
||||
Cannot submit more than {self.max_batch_size:,} embeddings at once.
|
||||
Please submit your embeddings in batches of size
|
||||
{self.max_batch_size:,} or less.
|
||||
"""
|
||||
)
|
||||
|
||||
# This creates the persisted configuration if it doesn't exist.
|
||||
# It should be run as soon as possible (before any WAL mutations) since the default configuration depends on the WAL size.
|
||||
# (We can't run this in __init__()/start() because the migrations have not been run at that point and the table may not be available.)
|
||||
_ = self.config
|
||||
|
||||
topic_name = create_topic_name(
|
||||
self._tenant, self._topic_namespace, collection_id
|
||||
)
|
||||
|
||||
t = Table("embeddings_queue")
|
||||
insert = (
|
||||
self.querybuilder()
|
||||
.into(t)
|
||||
.columns(t.operation, t.topic, t.id, t.vector, t.encoding, t.metadata)
|
||||
)
|
||||
id_to_idx: Dict[str, int] = {}
|
||||
for embedding in embeddings:
|
||||
(
|
||||
embedding_bytes,
|
||||
encoding,
|
||||
metadata,
|
||||
) = self._prepare_vector_encoding_metadata(embedding)
|
||||
insert = insert.insert(
|
||||
ParameterValue(_operation_codes[embedding["operation"]]),
|
||||
ParameterValue(topic_name),
|
||||
ParameterValue(embedding["id"]),
|
||||
ParameterValue(embedding_bytes),
|
||||
ParameterValue(encoding),
|
||||
ParameterValue(metadata),
|
||||
)
|
||||
id_to_idx[embedding["id"]] = len(id_to_idx)
|
||||
with self.tx() as cur:
|
||||
sql, params = get_sql(insert, self.parameter_format())
|
||||
# The returning clause does not guarantee order, so we need to do reorder
|
||||
# the results. https://www.sqlite.org/lang_returning.html
|
||||
sql = f"{sql} RETURNING seq_id, id" # Pypika doesn't support RETURNING
|
||||
results = cur.execute(sql, params).fetchall()
|
||||
# Reorder the results
|
||||
seq_ids = [cast(SeqId, None)] * len(
|
||||
results
|
||||
) # Lie to mypy: https://stackoverflow.com/questions/76694215/python-type-casting-when-preallocating-list
|
||||
embedding_records = []
|
||||
for seq_id, id in results:
|
||||
seq_ids[id_to_idx[id]] = seq_id
|
||||
submit_embedding_record = embeddings[id_to_idx[id]]
|
||||
# We allow notifying consumers out of order relative to one call to
|
||||
# submit_embeddings so we do not reorder the records before submitting them
|
||||
embedding_record = LogRecord(
|
||||
log_offset=seq_id,
|
||||
record=OperationRecord(
|
||||
id=id,
|
||||
embedding=submit_embedding_record["embedding"],
|
||||
encoding=submit_embedding_record["encoding"],
|
||||
metadata=submit_embedding_record["metadata"],
|
||||
operation=submit_embedding_record["operation"],
|
||||
),
|
||||
)
|
||||
embedding_records.append(embedding_record)
|
||||
self._notify_all(topic_name, embedding_records)
|
||||
|
||||
if self.config.get_parameter("automatically_purge").value:
|
||||
self.purge_log(collection_id)
|
||||
|
||||
return seq_ids
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue.subscribe", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def subscribe(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
consume_fn: ConsumerCallbackFn,
|
||||
start: Optional[SeqId] = None,
|
||||
end: Optional[SeqId] = None,
|
||||
id: Optional[UUID] = None,
|
||||
) -> UUID:
|
||||
if not self._running:
|
||||
raise RuntimeError("Component not running")
|
||||
|
||||
topic_name = create_topic_name(
|
||||
self._tenant, self._topic_namespace, collection_id
|
||||
)
|
||||
|
||||
subscription_id = id or uuid.uuid4()
|
||||
start, end = self._validate_range(start, end)
|
||||
|
||||
subscription = self.Subscription(
|
||||
subscription_id, topic_name, start, end, consume_fn
|
||||
)
|
||||
|
||||
# Backfill first, so if it errors we do not add the subscription
|
||||
self._backfill(subscription)
|
||||
self._subscriptions[topic_name].add(subscription)
|
||||
|
||||
return subscription_id
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue.unsubscribe", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def unsubscribe(self, subscription_id: UUID) -> None:
|
||||
for topic_name, subscriptions in self._subscriptions.items():
|
||||
for subscription in subscriptions:
|
||||
if subscription.id == subscription_id:
|
||||
subscriptions.remove(subscription)
|
||||
if len(subscriptions) == 0:
|
||||
del self._subscriptions[topic_name]
|
||||
return
|
||||
|
||||
@override
|
||||
def min_seqid(self) -> SeqId:
|
||||
return -1
|
||||
|
||||
@override
|
||||
def max_seqid(self) -> SeqId:
|
||||
return 2**63 - 1
|
||||
|
||||
@property
|
||||
@trace_method("SqlEmbeddingsQueue.max_batch_size", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def max_batch_size(self) -> int:
|
||||
if self._max_batch_size is None:
|
||||
with self.tx() as cur:
|
||||
cur.execute("PRAGMA compile_options;")
|
||||
compile_options = cur.fetchall()
|
||||
|
||||
for option in compile_options:
|
||||
if "MAX_VARIABLE_NUMBER" in option[0]:
|
||||
# The pragma returns a string like 'MAX_VARIABLE_NUMBER=999'
|
||||
self._max_batch_size = int(option[0].split("=")[1]) // (
|
||||
self.VARIABLES_PER_RECORD
|
||||
)
|
||||
|
||||
if self._max_batch_size is None:
|
||||
# This value is the default for sqlite3 versions < 3.32.0
|
||||
# It is the safest value to use if we can't find the pragma for some
|
||||
# reason
|
||||
self._max_batch_size = 999 // self.VARIABLES_PER_RECORD
|
||||
return self._max_batch_size
|
||||
|
||||
@trace_method(
|
||||
"SqlEmbeddingsQueue._prepare_vector_encoding_metadata",
|
||||
OpenTelemetryGranularity.ALL,
|
||||
)
|
||||
def _prepare_vector_encoding_metadata(
|
||||
self, embedding: OperationRecord
|
||||
) -> Tuple[Optional[bytes], Optional[str], Optional[str]]:
|
||||
if embedding["embedding"] is not None:
|
||||
encoding_type = cast(ScalarEncoding, embedding["encoding"])
|
||||
encoding = encoding_type.value
|
||||
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
|
||||
else:
|
||||
embedding_bytes = None
|
||||
encoding = None
|
||||
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None
|
||||
return embedding_bytes, encoding, metadata
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue._backfill", OpenTelemetryGranularity.ALL)
|
||||
def _backfill(self, subscription: Subscription) -> None:
|
||||
"""Backfill the given subscription with any currently matching records in the
|
||||
DB"""
|
||||
t = Table("embeddings_queue")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(t)
|
||||
.where(t.topic == ParameterValue(subscription.topic_name))
|
||||
.where(t.seq_id > ParameterValue(subscription.start))
|
||||
.where(t.seq_id <= ParameterValue(subscription.end))
|
||||
.select(t.seq_id, t.operation, t.id, t.vector, t.encoding, t.metadata)
|
||||
.orderby(t.seq_id)
|
||||
)
|
||||
with self.tx() as cur:
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
for row in rows:
|
||||
if row[3]:
|
||||
encoding = ScalarEncoding(row[4])
|
||||
vector = decode_vector(row[3], encoding)
|
||||
else:
|
||||
encoding = None
|
||||
vector = None
|
||||
self._notify_one(
|
||||
subscription,
|
||||
[
|
||||
LogRecord(
|
||||
log_offset=row[0],
|
||||
record=OperationRecord(
|
||||
operation=_operation_codes_inv[row[1]],
|
||||
id=row[2],
|
||||
embedding=vector,
|
||||
encoding=encoding,
|
||||
metadata=json.loads(row[5]) if row[5] else None,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue._validate_range", OpenTelemetryGranularity.ALL)
|
||||
def _validate_range(
|
||||
self, start: Optional[SeqId], end: Optional[SeqId]
|
||||
) -> Tuple[int, int]:
|
||||
"""Validate and normalize the start and end SeqIDs for a subscription using this
|
||||
impl."""
|
||||
start = start or self._next_seq_id()
|
||||
end = end or self.max_seqid()
|
||||
if not isinstance(start, int) or not isinstance(end, int):
|
||||
raise TypeError("SeqIDs must be integers for sql-based EmbeddingsDB")
|
||||
if start >= end:
|
||||
raise ValueError(f"Invalid SeqID range: {start} to {end}")
|
||||
return start, end
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue._next_seq_id", OpenTelemetryGranularity.ALL)
|
||||
def _next_seq_id(self) -> int:
|
||||
"""Get the next SeqID for this database."""
|
||||
t = Table("embeddings_queue")
|
||||
q = self.querybuilder().from_(t).select(functions.Max(t.seq_id))
|
||||
with self.tx() as cur:
|
||||
cur.execute(q.get_sql())
|
||||
return int(cur.fetchone()[0]) + 1
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue._notify_all", OpenTelemetryGranularity.ALL)
|
||||
def _notify_all(self, topic: str, embeddings: Sequence[LogRecord]) -> None:
|
||||
"""Send a notification to each subscriber of the given topic."""
|
||||
if self._running:
|
||||
for sub in self._subscriptions[topic]:
|
||||
self._notify_one(sub, embeddings)
|
||||
|
||||
@trace_method("SqlEmbeddingsQueue._notify_one", OpenTelemetryGranularity.ALL)
|
||||
def _notify_one(self, sub: Subscription, embeddings: Sequence[LogRecord]) -> None:
|
||||
"""Send a notification to a single subscriber."""
|
||||
# Filter out any embeddings that are not in the subscription range
|
||||
should_unsubscribe = False
|
||||
filtered_embeddings = []
|
||||
for embedding in embeddings:
|
||||
if embedding["log_offset"] <= sub.start:
|
||||
continue
|
||||
if embedding["log_offset"] > sub.end:
|
||||
should_unsubscribe = True
|
||||
break
|
||||
filtered_embeddings.append(embedding)
|
||||
|
||||
# Log errors instead of throwing them to preserve async semantics
|
||||
# for consistency between local and distributed configurations
|
||||
try:
|
||||
if len(filtered_embeddings) > 0:
|
||||
sub.callback(filtered_embeddings)
|
||||
if should_unsubscribe:
|
||||
self.unsubscribe(sub.id)
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
f"Exception occurred invoking consumer for subscription {sub.id.hex}"
|
||||
+ f"to topic {sub.topic_name} %s",
|
||||
str(e),
|
||||
)
|
||||
if _called_from_test:
|
||||
raise e
|
||||
|
||||
@cached_property
|
||||
def config(self) -> EmbeddingsQueueConfigurationInternal:
|
||||
t = Table("embeddings_queue_config")
|
||||
q = self.querybuilder().from_(t).select(t.config_json_str).limit(1)
|
||||
|
||||
with self.tx() as cur:
|
||||
cur.execute(q.get_sql())
|
||||
result = cur.fetchone()
|
||||
|
||||
if result is None:
|
||||
is_fresh_system = self._get_wal_size() == 0
|
||||
config = EmbeddingsQueueConfigurationInternal(
|
||||
[ConfigurationParameter("automatically_purge", is_fresh_system)]
|
||||
)
|
||||
self.set_config(config)
|
||||
return config
|
||||
|
||||
return EmbeddingsQueueConfigurationInternal.from_json_str(result[0])
|
||||
|
||||
def set_config(self, config: EmbeddingsQueueConfigurationInternal) -> None:
|
||||
with self.tx() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO embeddings_queue_config (id, config_json_str)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(
|
||||
1,
|
||||
config.to_json_str(),
|
||||
),
|
||||
)
|
||||
|
||||
# Invalidate the cached property
|
||||
try:
|
||||
del self.config
|
||||
except AttributeError:
|
||||
# Cached property hasn't been accessed yet
|
||||
pass
|
||||
|
||||
def _get_wal_size(self) -> int:
|
||||
t = Table("embeddings_queue")
|
||||
q = self.querybuilder().from_(t).select(functions.Count("*"))
|
||||
|
||||
with self.tx() as cur:
|
||||
cur.execute(q.get_sql())
|
||||
return int(cur.fetchone()[0])
|
||||
@@ -0,0 +1,986 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional, Sequence, Any, Tuple, cast, Dict, Union, Set
|
||||
from uuid import UUID
|
||||
from overrides import override
|
||||
from pypika import Table, Column
|
||||
from itertools import groupby
|
||||
|
||||
from chromadb.api.types import Schema
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System
|
||||
from chromadb.db.base import Cursor, SqlDB, ParameterValue, get_sql
|
||||
from chromadb.db.system import SysDB
|
||||
from chromadb.errors import (
|
||||
NotFoundError,
|
||||
UniqueConstraintError,
|
||||
)
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
add_attributes_to_current_span,
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.ingest import Producer
|
||||
from chromadb.types import (
|
||||
CollectionAndSegments,
|
||||
Database,
|
||||
OptionalArgument,
|
||||
Segment,
|
||||
Metadata,
|
||||
Collection,
|
||||
SegmentScope,
|
||||
Tenant,
|
||||
Unspecified,
|
||||
UpdateMetadata,
|
||||
)
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
create_collection_configuration_to_json_str,
|
||||
load_collection_configuration_from_json_str,
|
||||
CollectionConfiguration,
|
||||
create_collection_configuration_to_json,
|
||||
collection_configuration_to_json,
|
||||
collection_configuration_to_json_str,
|
||||
overwrite_collection_configuration,
|
||||
update_collection_configuration_from_legacy_update_metadata,
|
||||
CollectionMetadata,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SqlSysDB(SqlDB, SysDB):
|
||||
# Used only to delete log streams on collection deletion.
|
||||
# TODO: refactor to remove this dependency into a separate interface
|
||||
_producer: Producer
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self._opentelemetry_client = system.require(OpenTelemetryClient)
|
||||
|
||||
@trace_method("SqlSysDB.create_segment", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
self._producer = self._system.instance(Producer)
|
||||
|
||||
@override
|
||||
def create_database(
|
||||
self, id: UUID, name: str, tenant: str = DEFAULT_TENANT
|
||||
) -> None:
|
||||
with self.tx() as cur:
|
||||
# Get the tenant id for the tenant name and then insert the database with the id, name and tenant id
|
||||
databases = Table("databases")
|
||||
tenants = Table("tenants")
|
||||
insert_database = (
|
||||
self.querybuilder()
|
||||
.into(databases)
|
||||
.columns(databases.id, databases.name, databases.tenant_id)
|
||||
.insert(
|
||||
ParameterValue(self.uuid_to_db(id)),
|
||||
ParameterValue(name),
|
||||
self.querybuilder()
|
||||
.select(tenants.id)
|
||||
.from_(tenants)
|
||||
.where(tenants.id == ParameterValue(tenant)),
|
||||
)
|
||||
)
|
||||
sql, params = get_sql(insert_database, self.parameter_format())
|
||||
try:
|
||||
cur.execute(sql, params)
|
||||
except self.unique_constraint_error() as e:
|
||||
raise UniqueConstraintError(
|
||||
f"Database {name} already exists for tenant {tenant}"
|
||||
) from e
|
||||
|
||||
@override
|
||||
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
||||
with self.tx() as cur:
|
||||
databases = Table("databases")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(databases)
|
||||
.select(databases.id, databases.name)
|
||||
.where(databases.name == ParameterValue(name))
|
||||
.where(databases.tenant_id == ParameterValue(tenant))
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
row = cur.execute(sql, params).fetchone()
|
||||
if not row:
|
||||
raise NotFoundError(
|
||||
f"Database {name} not found for tenant {tenant}. Are you sure it exists?"
|
||||
)
|
||||
if row[0] is None:
|
||||
raise NotFoundError(
|
||||
f"Database {name} not found for tenant {tenant}. Are you sure it exists?"
|
||||
)
|
||||
id: UUID = cast(UUID, self.uuid_from_db(row[0]))
|
||||
return Database(
|
||||
id=id,
|
||||
name=row[1],
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
@override
|
||||
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
with self.tx() as cur:
|
||||
databases = Table("databases")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(databases)
|
||||
.where(databases.name == ParameterValue(name))
|
||||
.where(databases.tenant_id == ParameterValue(tenant))
|
||||
.delete()
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
sql = sql + " RETURNING id"
|
||||
result = cur.execute(sql, params).fetchone()
|
||||
if not result:
|
||||
raise NotFoundError(f"Database {name} not found for tenant {tenant}")
|
||||
|
||||
# As of 01/09/2025, cascading deletes don't work because foreign keys are not enabled.
|
||||
# See https://github.com/chroma-core/chroma/issues/3456.
|
||||
collections = Table("collections")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(collections)
|
||||
.where(collections.database_id == ParameterValue(result[0]))
|
||||
.delete()
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
|
||||
@override
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
with self.tx() as cur:
|
||||
databases = Table("databases")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(databases)
|
||||
.select(databases.id, databases.name)
|
||||
.where(databases.tenant_id == ParameterValue(tenant))
|
||||
.offset(offset)
|
||||
.limit(
|
||||
sys.maxsize if limit is None else limit
|
||||
) # SQLite requires that a limit is provided to use offset
|
||||
.orderby(databases.created_at)
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
rows = cur.execute(sql, params).fetchall()
|
||||
return [
|
||||
Database(
|
||||
id=cast(UUID, self.uuid_from_db(row[0])),
|
||||
name=row[1],
|
||||
tenant=tenant,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
@override
|
||||
def create_tenant(self, name: str) -> None:
|
||||
with self.tx() as cur:
|
||||
tenants = Table("tenants")
|
||||
insert_tenant = (
|
||||
self.querybuilder()
|
||||
.into(tenants)
|
||||
.columns(tenants.id)
|
||||
.insert(ParameterValue(name))
|
||||
)
|
||||
sql, params = get_sql(insert_tenant, self.parameter_format())
|
||||
try:
|
||||
cur.execute(sql, params)
|
||||
except self.unique_constraint_error() as e:
|
||||
raise UniqueConstraintError(f"Tenant {name} already exists") from e
|
||||
|
||||
@override
|
||||
def get_tenant(self, name: str) -> Tenant:
|
||||
with self.tx() as cur:
|
||||
tenants = Table("tenants")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(tenants)
|
||||
.select(tenants.id)
|
||||
.where(tenants.id == ParameterValue(name))
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
row = cur.execute(sql, params).fetchone()
|
||||
if not row:
|
||||
raise NotFoundError(f"Tenant {name} not found")
|
||||
return Tenant(name=name)
|
||||
|
||||
# Create a segment using the passed cursor, so that the other changes
|
||||
# can be in the same transaction.
|
||||
def create_segment_with_tx(self, cur: Cursor, segment: Segment) -> None:
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"segment_id": str(segment["id"]),
|
||||
"segment_type": segment["type"],
|
||||
"segment_scope": segment["scope"].value,
|
||||
"collection": str(segment["collection"]),
|
||||
}
|
||||
)
|
||||
|
||||
segments = Table("segments")
|
||||
insert_segment = (
|
||||
self.querybuilder()
|
||||
.into(segments)
|
||||
.columns(
|
||||
segments.id,
|
||||
segments.type,
|
||||
segments.scope,
|
||||
segments.collection,
|
||||
)
|
||||
.insert(
|
||||
ParameterValue(self.uuid_to_db(segment["id"])),
|
||||
ParameterValue(segment["type"]),
|
||||
ParameterValue(segment["scope"].value),
|
||||
ParameterValue(self.uuid_to_db(segment["collection"])),
|
||||
)
|
||||
)
|
||||
sql, params = get_sql(insert_segment, self.parameter_format())
|
||||
try:
|
||||
cur.execute(sql, params)
|
||||
except self.unique_constraint_error() as e:
|
||||
raise UniqueConstraintError(
|
||||
f"Segment {segment['id']} already exists"
|
||||
) from e
|
||||
|
||||
# Insert segment metadata if it exists
|
||||
metadata_t = Table("segment_metadata")
|
||||
if segment["metadata"]:
|
||||
try:
|
||||
self._insert_metadata(
|
||||
cur,
|
||||
metadata_t,
|
||||
metadata_t.segment_id,
|
||||
segment["id"],
|
||||
segment["metadata"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting segment metadata: {e}")
|
||||
raise
|
||||
|
||||
# TODO(rohit): Investigate and remove this method completely.
|
||||
@trace_method("SqlSysDB.create_segment", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def create_segment(self, segment: Segment) -> None:
|
||||
with self.tx() as cur:
|
||||
self.create_segment_with_tx(cur, segment)
|
||||
|
||||
@trace_method("SqlSysDB.create_collection", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def create_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
name: str,
|
||||
schema: Optional[Schema],
|
||||
configuration: CreateCollectionConfiguration,
|
||||
segments: Sequence[Segment],
|
||||
metadata: Optional[Metadata] = None,
|
||||
dimension: Optional[int] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Tuple[Collection, bool]:
|
||||
if id is None and not get_or_create:
|
||||
raise ValueError("id must be specified if get_or_create is False")
|
||||
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"collection_id": str(id),
|
||||
"collection_name": name,
|
||||
}
|
||||
)
|
||||
|
||||
existing = self.get_collections(name=name, tenant=tenant, database=database)
|
||||
if existing:
|
||||
if get_or_create:
|
||||
collection = existing[0]
|
||||
return (
|
||||
self.get_collections(
|
||||
id=collection.id, tenant=tenant, database=database
|
||||
)[0],
|
||||
False,
|
||||
)
|
||||
else:
|
||||
raise UniqueConstraintError(f"Collection {name} already exists")
|
||||
|
||||
collection = Collection(
|
||||
id=id,
|
||||
name=name,
|
||||
configuration_json=create_collection_configuration_to_json(
|
||||
configuration, cast(CollectionMetadata, metadata)
|
||||
),
|
||||
serialized_schema=None,
|
||||
metadata=metadata,
|
||||
dimension=dimension,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
version=0,
|
||||
)
|
||||
|
||||
with self.tx() as cur:
|
||||
collections = Table("collections")
|
||||
databases = Table("databases")
|
||||
|
||||
insert_collection = (
|
||||
self.querybuilder()
|
||||
.into(collections)
|
||||
.columns(
|
||||
collections.id,
|
||||
collections.name,
|
||||
collections.config_json_str,
|
||||
collections.dimension,
|
||||
collections.database_id,
|
||||
)
|
||||
.insert(
|
||||
ParameterValue(self.uuid_to_db(collection["id"])),
|
||||
ParameterValue(collection["name"]),
|
||||
ParameterValue(
|
||||
create_collection_configuration_to_json_str(
|
||||
configuration, cast(CollectionMetadata, metadata)
|
||||
)
|
||||
),
|
||||
ParameterValue(collection["dimension"]),
|
||||
# Get the database id for the database with the given name and tenant
|
||||
self.querybuilder()
|
||||
.select(databases.id)
|
||||
.from_(databases)
|
||||
.where(databases.name == ParameterValue(database))
|
||||
.where(databases.tenant_id == ParameterValue(tenant)),
|
||||
)
|
||||
)
|
||||
sql, params = get_sql(insert_collection, self.parameter_format())
|
||||
try:
|
||||
cur.execute(sql, params)
|
||||
except self.unique_constraint_error() as e:
|
||||
raise UniqueConstraintError(
|
||||
f"Collection {collection['id']} already exists"
|
||||
) from e
|
||||
metadata_t = Table("collection_metadata")
|
||||
if collection["metadata"]:
|
||||
self._insert_metadata(
|
||||
cur,
|
||||
metadata_t,
|
||||
metadata_t.collection_id,
|
||||
collection.id,
|
||||
collection["metadata"],
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
self.create_segment_with_tx(cur, segment)
|
||||
|
||||
return collection, True
|
||||
|
||||
@trace_method("SqlSysDB.get_segments", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def get_segments(
|
||||
self,
|
||||
collection: UUID,
|
||||
id: Optional[UUID] = None,
|
||||
type: Optional[str] = None,
|
||||
scope: Optional[SegmentScope] = None,
|
||||
) -> Sequence[Segment]:
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"segment_id": str(id),
|
||||
"segment_type": type if type else "",
|
||||
"segment_scope": scope.value if scope else "",
|
||||
"collection": str(collection),
|
||||
}
|
||||
)
|
||||
segments_t = Table("segments")
|
||||
metadata_t = Table("segment_metadata")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(segments_t)
|
||||
.select(
|
||||
segments_t.id,
|
||||
segments_t.type,
|
||||
segments_t.scope,
|
||||
segments_t.collection,
|
||||
metadata_t.key,
|
||||
metadata_t.str_value,
|
||||
metadata_t.int_value,
|
||||
metadata_t.float_value,
|
||||
metadata_t.bool_value,
|
||||
)
|
||||
.left_join(metadata_t)
|
||||
.on(segments_t.id == metadata_t.segment_id)
|
||||
.orderby(segments_t.id)
|
||||
)
|
||||
if id:
|
||||
q = q.where(segments_t.id == ParameterValue(self.uuid_to_db(id)))
|
||||
if type:
|
||||
q = q.where(segments_t.type == ParameterValue(type))
|
||||
if scope:
|
||||
q = q.where(segments_t.scope == ParameterValue(scope.value))
|
||||
if collection:
|
||||
q = q.where(
|
||||
segments_t.collection == ParameterValue(self.uuid_to_db(collection))
|
||||
)
|
||||
|
||||
with self.tx() as cur:
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
rows = cur.execute(sql, params).fetchall()
|
||||
by_segment = groupby(rows, lambda r: cast(object, r[0]))
|
||||
segments = []
|
||||
for segment_id, segment_rows in by_segment:
|
||||
id = self.uuid_from_db(str(segment_id))
|
||||
rows = list(segment_rows)
|
||||
type = str(rows[0][1])
|
||||
scope = SegmentScope(str(rows[0][2]))
|
||||
collection = self.uuid_from_db(rows[0][3]) # type: ignore[assignment]
|
||||
metadata = self._metadata_from_rows(rows)
|
||||
segments.append(
|
||||
Segment(
|
||||
id=cast(UUID, id),
|
||||
type=type,
|
||||
scope=scope,
|
||||
collection=collection,
|
||||
metadata=metadata,
|
||||
file_paths={},
|
||||
)
|
||||
)
|
||||
|
||||
return segments
|
||||
|
||||
@trace_method("SqlSysDB.get_collections", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def get_collections(
|
||||
self,
|
||||
id: Optional[UUID] = None,
|
||||
name: Optional[str] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> Sequence[Collection]:
|
||||
"""Get collections by name, embedding function and/or metadata"""
|
||||
|
||||
if name is not None and (tenant is None or database is None):
|
||||
raise ValueError(
|
||||
"If name is specified, tenant and database must also be specified in order to uniquely identify the collection"
|
||||
)
|
||||
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"collection_id": str(id),
|
||||
"collection_name": name if name else "",
|
||||
}
|
||||
)
|
||||
|
||||
collections_t = Table("collections")
|
||||
metadata_t = Table("collection_metadata")
|
||||
databases_t = Table("databases")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(collections_t)
|
||||
.select(
|
||||
collections_t.id,
|
||||
collections_t.name,
|
||||
collections_t.config_json_str,
|
||||
collections_t.dimension,
|
||||
databases_t.name,
|
||||
databases_t.tenant_id,
|
||||
metadata_t.key,
|
||||
metadata_t.str_value,
|
||||
metadata_t.int_value,
|
||||
metadata_t.float_value,
|
||||
metadata_t.bool_value,
|
||||
)
|
||||
.left_join(metadata_t)
|
||||
.on(collections_t.id == metadata_t.collection_id)
|
||||
.left_join(databases_t)
|
||||
.on(collections_t.database_id == databases_t.id)
|
||||
.orderby(collections_t.id)
|
||||
)
|
||||
if id:
|
||||
q = q.where(collections_t.id == ParameterValue(self.uuid_to_db(id)))
|
||||
if name:
|
||||
q = q.where(collections_t.name == ParameterValue(name))
|
||||
|
||||
# Only if we have a name, tenant and database do we need to filter databases
|
||||
# Given an id, we can uniquely identify the collection so we don't need to filter databases
|
||||
if id is None and tenant and database:
|
||||
databases_t = Table("databases")
|
||||
q = q.where(
|
||||
collections_t.database_id
|
||||
== self.querybuilder()
|
||||
.select(databases_t.id)
|
||||
.from_(databases_t)
|
||||
.where(databases_t.name == ParameterValue(database))
|
||||
.where(databases_t.tenant_id == ParameterValue(tenant))
|
||||
)
|
||||
# cant set limit and offset here because this is metadata and we havent reduced yet
|
||||
|
||||
with self.tx() as cur:
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
rows = cur.execute(sql, params).fetchall()
|
||||
by_collection = groupby(rows, lambda r: cast(object, r[0]))
|
||||
collections = []
|
||||
for collection_id, collection_rows in by_collection:
|
||||
id = self.uuid_from_db(str(collection_id))
|
||||
rows = list(collection_rows)
|
||||
name = str(rows[0][1])
|
||||
metadata = self._metadata_from_rows(rows)
|
||||
dimension = int(rows[0][3]) if rows[0][3] else None
|
||||
if rows[0][2] is not None:
|
||||
configuration = load_collection_configuration_from_json_str(
|
||||
rows[0][2]
|
||||
)
|
||||
else:
|
||||
# 07/2024: This is a legacy case where we don't have a collection
|
||||
# configuration stored in the database. This non-destructively migrates
|
||||
# the collection to have a configuration, and takes into account any
|
||||
# HNSW params that might be in the existing metadata.
|
||||
configuration = self._insert_config_from_legacy_params(
|
||||
collection_id, metadata
|
||||
)
|
||||
|
||||
collections.append(
|
||||
Collection(
|
||||
id=cast(UUID, id),
|
||||
name=name,
|
||||
configuration_json=collection_configuration_to_json(
|
||||
configuration
|
||||
),
|
||||
serialized_schema=None,
|
||||
metadata=metadata,
|
||||
dimension=dimension,
|
||||
tenant=str(rows[0][5]),
|
||||
database=str(rows[0][4]),
|
||||
version=0,
|
||||
)
|
||||
)
|
||||
|
||||
# apply limit and offset
|
||||
if limit is not None:
|
||||
if offset is None:
|
||||
offset = 0
|
||||
collections = collections[offset : offset + limit]
|
||||
else:
|
||||
collections = collections[offset:]
|
||||
|
||||
return collections
|
||||
|
||||
@override
|
||||
def get_collection_with_segments(
|
||||
self, collection_id: UUID
|
||||
) -> CollectionAndSegments:
|
||||
collections = self.get_collections(id=collection_id)
|
||||
if len(collections) == 0:
|
||||
raise NotFoundError(f"Collection {collection_id} does not exist.")
|
||||
return CollectionAndSegments(
|
||||
collection=collections[0],
|
||||
segments=self.get_segments(collection=collection_id),
|
||||
)
|
||||
|
||||
@trace_method("SqlSysDB.delete_segment", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def delete_segment(self, collection: UUID, id: UUID) -> None:
|
||||
"""Delete a segment from the SysDB"""
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"segment_id": str(id),
|
||||
}
|
||||
)
|
||||
t = Table("segments")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(t)
|
||||
.where(t.id == ParameterValue(self.uuid_to_db(id)))
|
||||
.delete()
|
||||
)
|
||||
with self.tx() as cur:
|
||||
# no need for explicit del from metadata table because of ON DELETE CASCADE
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
sql = sql + " RETURNING id"
|
||||
result = cur.execute(sql, params).fetchone()
|
||||
if not result:
|
||||
raise NotFoundError(f"Segment {id} not found")
|
||||
|
||||
# Used by delete_collection to delete all segments for a collection along with
|
||||
# the collection itself in a single transaction.
|
||||
def delete_segments_for_collection(self, cur: Cursor, collection: UUID) -> None:
|
||||
segments_t = Table("segments")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(segments_t)
|
||||
.where(segments_t.collection == ParameterValue(self.uuid_to_db(collection)))
|
||||
.delete()
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
|
||||
@trace_method("SqlSysDB.delete_collection", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def delete_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
"""Delete a collection and all associated segments from the SysDB. Deletes
|
||||
the log stream for this collection as well."""
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"collection_id": str(id),
|
||||
}
|
||||
)
|
||||
t = Table("collections")
|
||||
databases_t = Table("databases")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(t)
|
||||
.where(t.id == ParameterValue(self.uuid_to_db(id)))
|
||||
.where(
|
||||
t.database_id
|
||||
== self.querybuilder()
|
||||
.select(databases_t.id)
|
||||
.from_(databases_t)
|
||||
.where(databases_t.name == ParameterValue(database))
|
||||
.where(databases_t.tenant_id == ParameterValue(tenant))
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
with self.tx() as cur:
|
||||
# no need for explicit del from metadata table because of ON DELETE CASCADE
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
sql = sql + " RETURNING id"
|
||||
result = cur.execute(sql, params).fetchone()
|
||||
if not result:
|
||||
raise NotFoundError(f"Collection {id} not found")
|
||||
# Delete segments.
|
||||
self.delete_segments_for_collection(cur, id)
|
||||
|
||||
self._producer.delete_log(result[0])
|
||||
|
||||
@trace_method("SqlSysDB.update_segment", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def update_segment(
|
||||
self,
|
||||
collection: UUID,
|
||||
id: UUID,
|
||||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
|
||||
) -> None:
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"segment_id": str(id),
|
||||
"collection": str(collection),
|
||||
}
|
||||
)
|
||||
segments_t = Table("segments")
|
||||
metadata_t = Table("segment_metadata")
|
||||
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.update(segments_t)
|
||||
.where(segments_t.id == ParameterValue(self.uuid_to_db(id)))
|
||||
.set(segments_t.collection, ParameterValue(self.uuid_to_db(collection)))
|
||||
)
|
||||
|
||||
with self.tx() as cur:
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
if sql: # pypika emits a blank string if nothing to do
|
||||
cur.execute(sql, params)
|
||||
|
||||
if metadata is None:
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(metadata_t)
|
||||
.where(metadata_t.segment_id == ParameterValue(self.uuid_to_db(id)))
|
||||
.delete()
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
elif metadata != Unspecified():
|
||||
metadata = cast(UpdateMetadata, metadata)
|
||||
metadata = cast(UpdateMetadata, metadata)
|
||||
self._insert_metadata(
|
||||
cur,
|
||||
metadata_t,
|
||||
metadata_t.segment_id,
|
||||
id,
|
||||
metadata,
|
||||
set(metadata.keys()),
|
||||
)
|
||||
|
||||
@trace_method("SqlSysDB.update_collection", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def update_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
name: OptionalArgument[str] = Unspecified(),
|
||||
dimension: OptionalArgument[Optional[int]] = Unspecified(),
|
||||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
|
||||
configuration: OptionalArgument[
|
||||
Optional[UpdateCollectionConfiguration]
|
||||
] = Unspecified(),
|
||||
) -> None:
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"collection_id": str(id),
|
||||
}
|
||||
)
|
||||
collections_t = Table("collections")
|
||||
metadata_t = Table("collection_metadata")
|
||||
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.update(collections_t)
|
||||
.where(collections_t.id == ParameterValue(self.uuid_to_db(id)))
|
||||
)
|
||||
|
||||
if not name == Unspecified():
|
||||
q = q.set(collections_t.name, ParameterValue(name))
|
||||
|
||||
if not dimension == Unspecified():
|
||||
q = q.set(collections_t.dimension, ParameterValue(dimension))
|
||||
|
||||
with self.tx() as cur:
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
if sql: # pypika emits a blank string if nothing to do
|
||||
sql = sql + " RETURNING id"
|
||||
result = cur.execute(sql, params)
|
||||
if not result.fetchone():
|
||||
raise NotFoundError(f"Collection {id} not found")
|
||||
|
||||
# TODO: Update to use better semantics where it's possible to update
|
||||
# individual keys without wiping all the existing metadata.
|
||||
|
||||
# For now, follow current legancy semantics where metadata is fully reset
|
||||
if metadata != Unspecified():
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(metadata_t)
|
||||
.where(
|
||||
metadata_t.collection_id == ParameterValue(self.uuid_to_db(id))
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
if metadata is not None:
|
||||
metadata = cast(UpdateMetadata, metadata)
|
||||
self._insert_metadata(
|
||||
cur,
|
||||
metadata_t,
|
||||
metadata_t.collection_id,
|
||||
id,
|
||||
metadata,
|
||||
set(metadata.keys()),
|
||||
)
|
||||
|
||||
if configuration != Unspecified():
|
||||
update_configuration = cast(
|
||||
UpdateCollectionConfiguration, configuration
|
||||
)
|
||||
self._update_config_json_str(cur, update_configuration, id)
|
||||
else:
|
||||
if metadata != Unspecified():
|
||||
metadata = cast(UpdateMetadata, metadata)
|
||||
if metadata is not None:
|
||||
update_configuration = (
|
||||
update_collection_configuration_from_legacy_update_metadata(
|
||||
metadata
|
||||
)
|
||||
)
|
||||
self._update_config_json_str(cur, update_configuration, id)
|
||||
|
||||
def _update_config_json_str(
|
||||
self, cur: Cursor, update_configuration: UpdateCollectionConfiguration, id: UUID
|
||||
) -> None:
|
||||
collections_t = Table("collections")
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(collections_t)
|
||||
.select(collections_t.config_json_str)
|
||||
.where(collections_t.id == ParameterValue(self.uuid_to_db(id)))
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
row = cur.execute(sql, params).fetchone()
|
||||
if not row:
|
||||
raise NotFoundError(f"Collection {id} not found")
|
||||
config_json_str = row[0]
|
||||
existing_config = load_collection_configuration_from_json_str(config_json_str)
|
||||
new_config = overwrite_collection_configuration(
|
||||
existing_config, update_configuration
|
||||
)
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.update(collections_t)
|
||||
.set(
|
||||
collections_t.config_json_str,
|
||||
ParameterValue(collection_configuration_to_json_str(new_config)),
|
||||
)
|
||||
.where(collections_t.id == ParameterValue(self.uuid_to_db(id)))
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
|
||||
@trace_method("SqlSysDB._metadata_from_rows", OpenTelemetryGranularity.ALL)
|
||||
def _metadata_from_rows(
|
||||
self, rows: Sequence[Tuple[Any, ...]]
|
||||
) -> Optional[Metadata]:
|
||||
"""Given SQL rows, return a metadata map (assuming that the last four columns
|
||||
are the key, str_value, int_value & float_value)"""
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"num_rows": len(rows),
|
||||
}
|
||||
)
|
||||
metadata: Dict[str, Union[str, int, float, bool]] = {}
|
||||
for row in rows:
|
||||
key = str(row[-5])
|
||||
if row[-4] is not None:
|
||||
metadata[key] = str(row[-4])
|
||||
elif row[-3] is not None:
|
||||
metadata[key] = int(row[-3])
|
||||
elif row[-2] is not None:
|
||||
metadata[key] = float(row[-2])
|
||||
elif row[-1] is not None:
|
||||
metadata[key] = bool(row[-1])
|
||||
return metadata or None
|
||||
|
||||
@trace_method("SqlSysDB._insert_metadata", OpenTelemetryGranularity.ALL)
|
||||
def _insert_metadata(
|
||||
self,
|
||||
cur: Cursor,
|
||||
table: Table,
|
||||
id_col: Column,
|
||||
id: UUID,
|
||||
metadata: UpdateMetadata,
|
||||
clear_keys: Optional[Set[str]] = None,
|
||||
) -> None:
|
||||
# It would be cleaner to use something like ON CONFLICT UPDATE here But that is
|
||||
# very difficult to do in a portable way (e.g sqlite and postgres have
|
||||
# completely different sytnax)
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"num_keys": len(metadata),
|
||||
}
|
||||
)
|
||||
if clear_keys:
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.from_(table)
|
||||
.where(id_col == ParameterValue(self.uuid_to_db(id)))
|
||||
.where(table.key.isin([ParameterValue(k) for k in clear_keys]))
|
||||
.delete()
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
cur.execute(sql, params)
|
||||
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.into(table)
|
||||
.columns(
|
||||
id_col,
|
||||
table.key,
|
||||
table.str_value,
|
||||
table.int_value,
|
||||
table.float_value,
|
||||
table.bool_value,
|
||||
)
|
||||
)
|
||||
sql_id = self.uuid_to_db(id)
|
||||
for k, v in metadata.items():
|
||||
# Note: The order is important here because isinstance(v, bool)
|
||||
# and isinstance(v, int) both are true for v of bool type.
|
||||
if isinstance(v, bool):
|
||||
q = q.insert(
|
||||
ParameterValue(sql_id),
|
||||
ParameterValue(k),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
ParameterValue(int(v)),
|
||||
)
|
||||
elif isinstance(v, str):
|
||||
q = q.insert(
|
||||
ParameterValue(sql_id),
|
||||
ParameterValue(k),
|
||||
ParameterValue(v),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
elif isinstance(v, int):
|
||||
q = q.insert(
|
||||
ParameterValue(sql_id),
|
||||
ParameterValue(k),
|
||||
None,
|
||||
ParameterValue(v),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
elif isinstance(v, float):
|
||||
q = q.insert(
|
||||
ParameterValue(sql_id),
|
||||
ParameterValue(k),
|
||||
None,
|
||||
None,
|
||||
ParameterValue(v),
|
||||
None,
|
||||
)
|
||||
elif v is None:
|
||||
continue
|
||||
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
if sql:
|
||||
cur.execute(sql, params)
|
||||
|
||||
def _insert_config_from_legacy_params(
|
||||
self, collection_id: Any, metadata: Optional[Metadata]
|
||||
) -> CollectionConfiguration:
|
||||
"""Insert the configuration from legacy metadata params into the collections table, and return the configuration object."""
|
||||
|
||||
# This is a legacy case where we don't have configuration stored in the database
|
||||
# This is non-destructive, we don't delete or overwrite any keys in the metadata
|
||||
|
||||
collections_t = Table("collections")
|
||||
|
||||
create_collection_config = CreateCollectionConfiguration()
|
||||
# Write the configuration into the database
|
||||
configuration_json_str = create_collection_configuration_to_json_str(
|
||||
create_collection_config, cast(CollectionMetadata, metadata)
|
||||
)
|
||||
q = (
|
||||
self.querybuilder()
|
||||
.update(collections_t)
|
||||
.set(
|
||||
collections_t.config_json_str,
|
||||
ParameterValue(configuration_json_str),
|
||||
)
|
||||
.where(collections_t.id == ParameterValue(collection_id))
|
||||
)
|
||||
sql, params = get_sql(q, self.parameter_format())
|
||||
with self.tx() as cur:
|
||||
cur.execute(sql, params)
|
||||
return load_collection_configuration_from_json_str(configuration_json_str)
|
||||
|
||||
@override
|
||||
def get_collection_size(self, id: UUID) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def count_collections(
|
||||
self,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Gets the number of collections for the (tenant, database) combination."""
|
||||
# TODO(Sanket): Implement this efficiently using a count query.
|
||||
# Note, the underlying get_collections api always requires a database
|
||||
# to be specified. In the sysdb implementation in go code, it does not
|
||||
# filter on database if it is set to "". This is a bad API and
|
||||
# should be fixed. For now, we will replicate the behavior.
|
||||
request_database: str = "" if database is None or database == "" else database
|
||||
return len(self.get_collections(tenant=tenant, database=request_database))
|
||||
@@ -0,0 +1,189 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from uuid import UUID
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
)
|
||||
from chromadb.api.types import Schema
|
||||
from chromadb.types import (
|
||||
Collection,
|
||||
CollectionAndSegments,
|
||||
Database,
|
||||
Tenant,
|
||||
Metadata,
|
||||
Segment,
|
||||
SegmentScope,
|
||||
OptionalArgument,
|
||||
Unspecified,
|
||||
UpdateMetadata,
|
||||
)
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component
|
||||
|
||||
|
||||
class SysDB(Component):
|
||||
"""Data interface for Chroma's System database"""
|
||||
|
||||
@abstractmethod
|
||||
def create_database(
|
||||
self, id: UUID, name: str, tenant: str = DEFAULT_TENANT
|
||||
) -> None:
|
||||
"""Create a new database in the System database. Raises an Error if the Database
|
||||
already exists."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
||||
"""Get a database by name and tenant. Raises an Error if the Database does not
|
||||
exist."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
"""Delete a database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
"""List all databases for a tenant."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_tenant(self, name: str) -> None:
|
||||
"""Create a new tenant in the System database. The name must be unique.
|
||||
Raises an Error if the Tenant already exists."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tenant(self, name: str) -> Tenant:
|
||||
"""Get a tenant by name. Raises an Error if the Tenant does not exist."""
|
||||
pass
|
||||
|
||||
# TODO: Investigate and remove this method, as segment creation is done as
|
||||
# part of collection creation.
|
||||
@abstractmethod
|
||||
def create_segment(self, segment: Segment) -> None:
|
||||
"""Create a new segment in the System database. Raises an Error if the ID
|
||||
already exists."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_segment(self, collection: UUID, id: UUID) -> None:
|
||||
"""Delete a segment from the System database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_segments(
|
||||
self,
|
||||
collection: UUID,
|
||||
id: Optional[UUID] = None,
|
||||
type: Optional[str] = None,
|
||||
scope: Optional[SegmentScope] = None,
|
||||
) -> Sequence[Segment]:
|
||||
"""Find segments by id, type, scope or collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_segment(
|
||||
self,
|
||||
collection: UUID,
|
||||
id: UUID,
|
||||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
|
||||
) -> None:
|
||||
"""Update a segment. Unspecified fields will be left unchanged. For the
|
||||
metadata, keys with None values will be removed and keys not present in the
|
||||
UpdateMetadata dict will be left unchanged."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
name: str,
|
||||
schema: Optional[Schema],
|
||||
configuration: CreateCollectionConfiguration,
|
||||
segments: Sequence[Segment],
|
||||
metadata: Optional[Metadata] = None,
|
||||
dimension: Optional[int] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Tuple[Collection, bool]:
|
||||
"""Create a new collection and associated resources
|
||||
in the SysDB. If get_or_create is True, the
|
||||
collection will be created if one with the same name does not exist.
|
||||
The metadata will be updated using the same protocol as update_collection. If get_or_create
|
||||
is False and the collection already exists, an error will be raised.
|
||||
|
||||
Returns a tuple of the created collection and a boolean indicating whether the
|
||||
collection was created or not.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
"""Delete a collection, all associated segments and any associate resources (log stream)
|
||||
from the SysDB and the system at large."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collections(
|
||||
self,
|
||||
id: Optional[UUID] = None,
|
||||
name: Optional[str] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> Sequence[Collection]:
|
||||
"""Find collections by id or name. If name is provided, tenant and database must also be provided."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count_collections(
|
||||
self,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Gets the number of collections for the (tenant, database) combination."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection_with_segments(
|
||||
self, collection_id: UUID
|
||||
) -> CollectionAndSegments:
|
||||
"""Get a consistent snapshot of a collection by id. This will return a collection with segment
|
||||
information that matches the collection version and log position.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
name: OptionalArgument[str] = Unspecified(),
|
||||
dimension: OptionalArgument[Optional[int]] = Unspecified(),
|
||||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
|
||||
configuration: OptionalArgument[
|
||||
Optional[UpdateCollectionConfiguration]
|
||||
] = Unspecified(),
|
||||
) -> None:
|
||||
"""Update a collection. Unspecified fields will be left unchanged. For metadata,
|
||||
keys with None values will be removed and keys not present in the UpdateMetadata
|
||||
dict will be left unchanged."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection_size(self, id: UUID) -> int:
|
||||
"""Returns the number of records in a collection."""
|
||||
pass
|
||||
@@ -0,0 +1,194 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Optional, Type
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
|
||||
class ChromaError(Exception, EnforceOverrides):
|
||||
trace_id: Optional[str] = None
|
||||
|
||||
def code(self) -> int:
|
||||
"""Return an appropriate HTTP response code for this error"""
|
||||
return 400 # Bad Request
|
||||
|
||||
def message(self) -> str:
|
||||
return ", ".join(self.args)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def name(cls) -> str:
|
||||
"""Return the error name"""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidDimensionException(ChromaError):
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "InvalidDimension"
|
||||
|
||||
|
||||
class IDAlreadyExistsError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 409 # Conflict
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "IDAlreadyExists"
|
||||
|
||||
|
||||
class ChromaAuthError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 403
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "AuthError"
|
||||
|
||||
@overrides
|
||||
def message(self) -> str:
|
||||
return "Forbidden"
|
||||
|
||||
|
||||
class DuplicateIDError(ChromaError):
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "DuplicateID"
|
||||
|
||||
|
||||
class InvalidArgumentError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 400
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "InvalidArgument"
|
||||
|
||||
|
||||
class InvalidUUIDError(ChromaError):
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "InvalidUUID"
|
||||
|
||||
|
||||
class InvalidHTTPVersion(ChromaError):
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "InvalidHTTPVersion"
|
||||
|
||||
|
||||
class AuthorizationError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 401
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "AuthorizationError"
|
||||
|
||||
|
||||
class NotFoundError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 404
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "NotFoundError"
|
||||
|
||||
|
||||
class UniqueConstraintError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 409
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "UniqueConstraintError"
|
||||
|
||||
|
||||
class BatchSizeExceededError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 413
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "BatchSizeExceededError"
|
||||
|
||||
|
||||
class VersionMismatchError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 500
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "VersionMismatchError"
|
||||
|
||||
|
||||
class InternalError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 500
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "InternalError"
|
||||
|
||||
|
||||
class RateLimitError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 429
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "RateLimitError"
|
||||
|
||||
|
||||
class QuotaError(ChromaError):
|
||||
@overrides
|
||||
def code(self) -> int:
|
||||
return 400
|
||||
|
||||
@classmethod
|
||||
@overrides
|
||||
def name(cls) -> str:
|
||||
return "QuotaError"
|
||||
|
||||
|
||||
error_types: Dict[str, Type[ChromaError]] = {
|
||||
"InvalidDimension": InvalidDimensionException,
|
||||
"InvalidArgumentError": InvalidArgumentError,
|
||||
"IDAlreadyExists": IDAlreadyExistsError,
|
||||
"DuplicateID": DuplicateIDError,
|
||||
"InvalidUUID": InvalidUUIDError,
|
||||
"InvalidHTTPVersion": InvalidHTTPVersion,
|
||||
"AuthorizationError": AuthorizationError,
|
||||
"NotFoundError": NotFoundError,
|
||||
"BatchSizeExceededError": BatchSizeExceededError,
|
||||
"VersionMismatchError": VersionMismatchError,
|
||||
"RateLimitError": RateLimitError,
|
||||
"AuthError": ChromaAuthError,
|
||||
"UniqueConstraintError": UniqueConstraintError,
|
||||
"QuotaError": QuotaError,
|
||||
"InternalError": InternalError,
|
||||
# Catch-all for any other errors
|
||||
"ChromaError": ChromaError,
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,19 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from chromadb.api.types import GetResult, QueryResult
|
||||
from chromadb.config import Component
|
||||
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
|
||||
|
||||
|
||||
class Executor(Component):
|
||||
@abstractmethod
|
||||
def count(self, plan: CountPlan) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, plan: GetPlan) -> GetResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def knn(self, plan: KNNPlan) -> QueryResult:
|
||||
pass
|
||||
@@ -0,0 +1,242 @@
|
||||
import threading
|
||||
import random
|
||||
from typing import Callable, Dict, List, Optional, TypeVar
|
||||
import grpc
|
||||
from overrides import overrides
|
||||
from chromadb.api.types import GetResult, Metadata, QueryResult
|
||||
from chromadb.config import System
|
||||
from chromadb.execution.executor.abstract import Executor
|
||||
from chromadb.execution.expression.operator import Scan
|
||||
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
|
||||
from chromadb.proto import convert
|
||||
from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub
|
||||
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
|
||||
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
|
||||
from tenacity import (
|
||||
RetryCallState,
|
||||
Retrying,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
retry_if_exception,
|
||||
)
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
|
||||
def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
|
||||
"""Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map."""
|
||||
if not metadata:
|
||||
return None
|
||||
result = {}
|
||||
for k, v in metadata.items():
|
||||
if not k.startswith("chroma:"):
|
||||
result[k] = v
|
||||
if len(result) == 0:
|
||||
return None
|
||||
return result
|
||||
|
||||
|
||||
def _uri(metadata: Optional[Metadata]) -> Optional[str]:
|
||||
"""Retrieve the uri (if any) from a Metadata map"""
|
||||
|
||||
if metadata and "chroma:uri" in metadata:
|
||||
return str(metadata["chroma:uri"])
|
||||
return None
|
||||
|
||||
|
||||
# Type variables for input and output types of the round-robin retry function
|
||||
I = TypeVar("I") # noqa: E741
|
||||
O = TypeVar("O") # noqa: E741
|
||||
|
||||
|
||||
class DistributedExecutor(Executor):
|
||||
_mtx: threading.Lock
|
||||
_grpc_stub_pool: Dict[str, QueryExecutorStub]
|
||||
_manager: DistributedSegmentManager
|
||||
_request_timeout_seconds: int
|
||||
_query_replication_factor: int
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self._mtx = threading.Lock()
|
||||
self._grpc_stub_pool = {}
|
||||
self._manager = self.require(DistributedSegmentManager)
|
||||
self._request_timeout_seconds = system.settings.require(
|
||||
"chroma_query_request_timeout_seconds"
|
||||
)
|
||||
self._query_replication_factor = system.settings.require(
|
||||
"chroma_query_replication_factor"
|
||||
)
|
||||
|
||||
def _round_robin_retry(self, funcs: List[Callable[[I], O]], args: I) -> O:
|
||||
"""
|
||||
Retry a list of functions in a round-robin fashion until one of them succeeds.
|
||||
|
||||
funcs: List of functions to retry
|
||||
args: Arguments to pass to each function
|
||||
|
||||
"""
|
||||
attempt_count = 0
|
||||
sleep_span: Optional[Span] = None
|
||||
|
||||
def before_sleep(_: RetryCallState) -> None:
|
||||
# HACK(hammadb) 1/14/2024 - this is a hack to avoid the fact that tracer is not yet available and there are boot order issues
|
||||
# This should really use our component system to get the tracer. Since our grpc utils use this pattern
|
||||
# we are copying it here. This should be removed once we have a better way to get the tracer
|
||||
from chromadb.telemetry.opentelemetry import tracer
|
||||
|
||||
nonlocal sleep_span
|
||||
if tracer is not None:
|
||||
sleep_span = tracer.start_span("Waiting to retry RPC")
|
||||
|
||||
for attempt in Retrying(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential_jitter(0.1, jitter=0.1),
|
||||
reraise=True,
|
||||
retry=retry_if_exception(
|
||||
lambda x: isinstance(x, grpc.RpcError)
|
||||
and x.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]
|
||||
),
|
||||
before_sleep=before_sleep,
|
||||
):
|
||||
if sleep_span is not None:
|
||||
sleep_span.end()
|
||||
sleep_span = None
|
||||
|
||||
with attempt:
|
||||
return funcs[attempt_count % len(funcs)](args)
|
||||
attempt_count += 1
|
||||
|
||||
# NOTE(hammadb) because Retrying() will always either return or raise an exception, this line should never be reached
|
||||
raise Exception("Unreachable code error - should never reach here")
|
||||
|
||||
@overrides
|
||||
def count(self, plan: CountPlan) -> int:
|
||||
endpoints = self._get_grpc_endpoints(plan.scan)
|
||||
count_funcs = [self._get_stub(endpoint).Count for endpoint in endpoints]
|
||||
count_result = self._round_robin_retry(
|
||||
count_funcs, convert.to_proto_count_plan(plan)
|
||||
)
|
||||
return convert.from_proto_count_result(count_result)
|
||||
|
||||
@overrides
|
||||
def get(self, plan: GetPlan) -> GetResult:
|
||||
endpoints = self._get_grpc_endpoints(plan.scan)
|
||||
get_funcs = [self._get_stub(endpoint).Get for endpoint in endpoints]
|
||||
get_result = self._round_robin_retry(get_funcs, convert.to_proto_get_plan(plan))
|
||||
records = convert.from_proto_get_result(get_result)
|
||||
|
||||
ids = [record["id"] for record in records]
|
||||
embeddings = (
|
||||
[record["embedding"] for record in records]
|
||||
if plan.projection.embedding
|
||||
else None
|
||||
)
|
||||
documents = (
|
||||
[record["document"] for record in records]
|
||||
if plan.projection.document
|
||||
else None
|
||||
)
|
||||
uris = (
|
||||
[_uri(record["metadata"]) for record in records]
|
||||
if plan.projection.uri
|
||||
else None
|
||||
)
|
||||
metadatas = (
|
||||
[_clean_metadata(record["metadata"]) for record in records]
|
||||
if plan.projection.metadata
|
||||
else None
|
||||
)
|
||||
|
||||
# TODO: Fix typing
|
||||
return GetResult(
|
||||
ids=ids,
|
||||
embeddings=embeddings, # type: ignore[typeddict-item]
|
||||
documents=documents, # type: ignore[typeddict-item]
|
||||
uris=uris, # type: ignore[typeddict-item]
|
||||
data=None,
|
||||
metadatas=metadatas, # type: ignore[typeddict-item]
|
||||
included=plan.projection.included,
|
||||
)
|
||||
|
||||
@overrides
|
||||
def knn(self, plan: KNNPlan) -> QueryResult:
|
||||
endpoints = self._get_grpc_endpoints(plan.scan)
|
||||
knn_funcs = [self._get_stub(endpoint).KNN for endpoint in endpoints]
|
||||
knn_result = self._round_robin_retry(knn_funcs, convert.to_proto_knn_plan(plan))
|
||||
results = convert.from_proto_knn_batch_result(knn_result)
|
||||
|
||||
ids = [[record["record"]["id"] for record in records] for records in results]
|
||||
embeddings = (
|
||||
[
|
||||
[record["record"]["embedding"] for record in records]
|
||||
for records in results
|
||||
]
|
||||
if plan.projection.embedding
|
||||
else None
|
||||
)
|
||||
documents = (
|
||||
[
|
||||
[record["record"]["document"] for record in records]
|
||||
for records in results
|
||||
]
|
||||
if plan.projection.document
|
||||
else None
|
||||
)
|
||||
uris = (
|
||||
[
|
||||
[_uri(record["record"]["metadata"]) for record in records]
|
||||
for records in results
|
||||
]
|
||||
if plan.projection.uri
|
||||
else None
|
||||
)
|
||||
metadatas = (
|
||||
[
|
||||
[_clean_metadata(record["record"]["metadata"]) for record in records]
|
||||
for records in results
|
||||
]
|
||||
if plan.projection.metadata
|
||||
else None
|
||||
)
|
||||
distances = (
|
||||
[[record["distance"] for record in records] for records in results]
|
||||
if plan.projection.rank
|
||||
else None
|
||||
)
|
||||
|
||||
# TODO: Fix typing
|
||||
return QueryResult(
|
||||
ids=ids,
|
||||
embeddings=embeddings, # type: ignore[typeddict-item]
|
||||
documents=documents, # type: ignore[typeddict-item]
|
||||
uris=uris, # type: ignore[typeddict-item]
|
||||
data=None,
|
||||
metadatas=metadatas, # type: ignore[typeddict-item]
|
||||
distances=distances, # type: ignore[typeddict-item]
|
||||
included=plan.projection.included,
|
||||
)
|
||||
|
||||
def _get_grpc_endpoints(self, scan: Scan) -> List[str]:
|
||||
# Since grpc endpoint is endpoint is determined by collection uuid,
|
||||
# the endpoint should be the same for all segments of the same collection
|
||||
grpc_urls = self._manager.get_endpoints(
|
||||
scan.record, self._query_replication_factor
|
||||
)
|
||||
# Shuffle the grpc urls to distribute the load evenly
|
||||
random.shuffle(grpc_urls)
|
||||
return grpc_urls
|
||||
|
||||
def _get_stub(self, grpc_url: str) -> QueryExecutorStub:
|
||||
with self._mtx:
|
||||
if grpc_url not in self._grpc_stub_pool:
|
||||
channel = grpc.insecure_channel(
|
||||
grpc_url,
|
||||
options=[
|
||||
("grpc.max_concurrent_streams", 1000),
|
||||
("grpc.max_receive_message_length", 32000000), # 32 MB
|
||||
],
|
||||
)
|
||||
interceptors = [OtelInterceptor()]
|
||||
channel = grpc.intercept_channel(channel, *interceptors)
|
||||
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel)
|
||||
return self._grpc_stub_pool[grpc_url]
|
||||
@@ -0,0 +1,205 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from chromadb.api.types import GetResult, Metadata, QueryResult
|
||||
from chromadb.config import System
|
||||
from chromadb.execution.executor.abstract import Executor
|
||||
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
|
||||
from chromadb.segment import MetadataReader, VectorReader
|
||||
from chromadb.segment.impl.manager.local import LocalSegmentManager
|
||||
from chromadb.types import Collection, VectorQuery, VectorQueryResult
|
||||
|
||||
|
||||
def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
|
||||
"""Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map."""
|
||||
if not metadata:
|
||||
return None
|
||||
result = {}
|
||||
for k, v in metadata.items():
|
||||
if not k.startswith("chroma:"):
|
||||
result[k] = v
|
||||
if len(result) == 0:
|
||||
return None
|
||||
return result
|
||||
|
||||
|
||||
def _doc(metadata: Optional[Metadata]) -> Optional[str]:
|
||||
"""Retrieve the document (if any) from a Metadata map"""
|
||||
|
||||
if metadata and "chroma:document" in metadata:
|
||||
return str(metadata["chroma:document"])
|
||||
return None
|
||||
|
||||
|
||||
def _uri(metadata: Optional[Metadata]) -> Optional[str]:
|
||||
"""Retrieve the uri (if any) from a Metadata map"""
|
||||
|
||||
if metadata and "chroma:uri" in metadata:
|
||||
return str(metadata["chroma:uri"])
|
||||
return None
|
||||
|
||||
|
||||
class LocalExecutor(Executor):
|
||||
_manager: LocalSegmentManager
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self._manager = self.require(LocalSegmentManager)
|
||||
|
||||
@overrides
|
||||
def count(self, plan: CountPlan) -> int:
|
||||
return self._metadata_segment(plan.scan.collection).count(plan.scan.version)
|
||||
|
||||
@overrides
|
||||
def get(self, plan: GetPlan) -> GetResult:
|
||||
records = self._metadata_segment(plan.scan.collection).get_metadata(
|
||||
request_version_context=plan.scan.version,
|
||||
where=plan.filter.where,
|
||||
where_document=plan.filter.where_document,
|
||||
ids=plan.filter.user_ids,
|
||||
limit=plan.limit.limit,
|
||||
offset=plan.limit.offset,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
ids = [r["id"] for r in records]
|
||||
embeddings = None
|
||||
documents = None
|
||||
uris = None
|
||||
metadatas = None
|
||||
included = list()
|
||||
|
||||
if plan.projection.embedding:
|
||||
if len(records) > 0:
|
||||
vectors = self._vector_segment(plan.scan.collection).get_vectors(
|
||||
ids=ids, request_version_context=plan.scan.version
|
||||
)
|
||||
embeddings = [v["embedding"] for v in vectors]
|
||||
else:
|
||||
embeddings = list()
|
||||
included.append("embeddings")
|
||||
|
||||
if plan.projection.document:
|
||||
documents = [_doc(r["metadata"]) for r in records]
|
||||
included.append("documents")
|
||||
|
||||
if plan.projection.uri:
|
||||
uris = [_uri(r["metadata"]) for r in records]
|
||||
included.append("uris")
|
||||
|
||||
if plan.projection.metadata:
|
||||
metadatas = [_clean_metadata(r["metadata"]) for r in records]
|
||||
included.append("metadatas")
|
||||
|
||||
# TODO: Fix typing
|
||||
return GetResult(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=documents, # type: ignore[typeddict-item]
|
||||
uris=uris, # type: ignore[typeddict-item]
|
||||
data=None,
|
||||
metadatas=metadatas, # type: ignore[typeddict-item]
|
||||
included=included,
|
||||
)
|
||||
|
||||
@overrides
|
||||
def knn(self, plan: KNNPlan) -> QueryResult:
|
||||
prefiltered_ids = None
|
||||
if plan.filter.user_ids or plan.filter.where or plan.filter.where_document:
|
||||
records = self._metadata_segment(plan.scan.collection).get_metadata(
|
||||
request_version_context=plan.scan.version,
|
||||
where=plan.filter.where,
|
||||
where_document=plan.filter.where_document,
|
||||
ids=plan.filter.user_ids,
|
||||
limit=None,
|
||||
offset=0,
|
||||
include_metadata=False,
|
||||
)
|
||||
prefiltered_ids = [r["id"] for r in records]
|
||||
|
||||
knns: Sequence[Sequence[VectorQueryResult]] = [[]] * len(plan.knn.embeddings)
|
||||
|
||||
# Query vectors only when the user did not specify a filter or when the filter
|
||||
# yields non-empty ids. Otherwise, the user specified a filter but it yields
|
||||
# no matching ids, in which case we can return an empty result.
|
||||
if prefiltered_ids is None or len(prefiltered_ids) > 0:
|
||||
query = VectorQuery(
|
||||
vectors=plan.knn.embeddings,
|
||||
k=plan.knn.fetch,
|
||||
allowed_ids=prefiltered_ids,
|
||||
include_embeddings=plan.projection.embedding,
|
||||
options=None,
|
||||
request_version_context=plan.scan.version,
|
||||
)
|
||||
knns = self._vector_segment(plan.scan.collection).query_vectors(query)
|
||||
|
||||
ids = [[r["id"] for r in result] for result in knns]
|
||||
embeddings = None
|
||||
documents = None
|
||||
uris = None
|
||||
metadatas = None
|
||||
distances = None
|
||||
included = list()
|
||||
|
||||
if plan.projection.embedding:
|
||||
embeddings = [[r["embedding"] for r in result] for result in knns]
|
||||
included.append("embeddings")
|
||||
|
||||
if plan.projection.rank:
|
||||
distances = [[r["distance"] for r in result] for result in knns]
|
||||
included.append("distances")
|
||||
|
||||
if plan.projection.document or plan.projection.metadata or plan.projection.uri:
|
||||
merged_ids = list(set([id for result in ids for id in result]))
|
||||
hydrated_records = self._metadata_segment(
|
||||
plan.scan.collection
|
||||
).get_metadata(
|
||||
request_version_context=plan.scan.version,
|
||||
where=None,
|
||||
where_document=None,
|
||||
ids=merged_ids,
|
||||
limit=None,
|
||||
offset=0,
|
||||
include_metadata=True,
|
||||
)
|
||||
metadata_by_id = {r["id"]: r["metadata"] for r in hydrated_records}
|
||||
|
||||
if plan.projection.document:
|
||||
documents = [
|
||||
[_doc(metadata_by_id.get(id, None)) for id in result]
|
||||
for result in ids
|
||||
]
|
||||
included.append("documents")
|
||||
|
||||
if plan.projection.uri:
|
||||
uris = [
|
||||
[_uri(metadata_by_id.get(id, None)) for id in result]
|
||||
for result in ids
|
||||
]
|
||||
included.append("uris")
|
||||
|
||||
if plan.projection.metadata:
|
||||
metadatas = [
|
||||
[_clean_metadata(metadata_by_id.get(id, None)) for id in result]
|
||||
for result in ids
|
||||
]
|
||||
included.append("metadatas")
|
||||
|
||||
# TODO: Fix typing
|
||||
return QueryResult(
|
||||
ids=ids,
|
||||
embeddings=embeddings, # type: ignore[typeddict-item]
|
||||
documents=documents, # type: ignore[typeddict-item]
|
||||
uris=uris, # type: ignore[typeddict-item]
|
||||
data=None,
|
||||
metadatas=metadatas, # type: ignore[typeddict-item]
|
||||
distances=distances,
|
||||
included=included,
|
||||
)
|
||||
|
||||
def _metadata_segment(self, collection: Collection) -> MetadataReader:
|
||||
return self._manager.get_segment(collection.id, MetadataReader)
|
||||
|
||||
def _vector_segment(self, collection: Collection) -> VectorReader:
|
||||
return self._manager.get_segment(collection.id, VectorReader)
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Chromadb execution expression module for search operations.
|
||||
"""
|
||||
|
||||
from chromadb.execution.expression.operator import (
|
||||
# Field proxy for building Where conditions
|
||||
Key,
|
||||
K,
|
||||
# Where expressions
|
||||
Where,
|
||||
And,
|
||||
Or,
|
||||
Eq,
|
||||
Ne,
|
||||
Gt,
|
||||
Gte,
|
||||
Lt,
|
||||
Lte,
|
||||
In,
|
||||
Nin,
|
||||
Regex,
|
||||
NotRegex,
|
||||
Contains,
|
||||
NotContains,
|
||||
# Search configuration
|
||||
Limit,
|
||||
Select,
|
||||
# Rank expressions
|
||||
Rank,
|
||||
Abs,
|
||||
Div,
|
||||
Exp,
|
||||
Log,
|
||||
Max,
|
||||
Min,
|
||||
Mul,
|
||||
Knn,
|
||||
Rrf,
|
||||
Sub,
|
||||
Sum,
|
||||
Val,
|
||||
)
|
||||
|
||||
from chromadb.execution.expression.plan import (
|
||||
Search,
|
||||
)
|
||||
|
||||
SearchWhere = Where
|
||||
|
||||
__all__ = [
|
||||
# Main search class
|
||||
"Search",
|
||||
# Field proxy
|
||||
"Key",
|
||||
"K",
|
||||
# Where expressions
|
||||
"SearchWhere",
|
||||
"Where",
|
||||
"And",
|
||||
"Or",
|
||||
"Eq",
|
||||
"Ne",
|
||||
"Gt",
|
||||
"Gte",
|
||||
"Lt",
|
||||
"Lte",
|
||||
"In",
|
||||
"Nin",
|
||||
"Regex",
|
||||
"NotRegex",
|
||||
"Contains",
|
||||
"NotContains",
|
||||
# Search configuration
|
||||
"Limit",
|
||||
"Select",
|
||||
# Rank expressions
|
||||
"Rank",
|
||||
"Abs",
|
||||
"Div",
|
||||
"Exp",
|
||||
"Log",
|
||||
"Max",
|
||||
"Min",
|
||||
"Mul",
|
||||
"Knn",
|
||||
"Rrf",
|
||||
"Sub",
|
||||
"Sum",
|
||||
"Val",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,263 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any, Union, Set, Optional
|
||||
|
||||
from chromadb.execution.expression.operator import (
|
||||
KNN,
|
||||
Filter,
|
||||
Limit,
|
||||
Projection,
|
||||
Scan,
|
||||
Rank,
|
||||
Select,
|
||||
Val,
|
||||
Where,
|
||||
Key,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountPlan:
|
||||
scan: Scan
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetPlan:
|
||||
scan: Scan
|
||||
filter: Filter = field(default_factory=Filter)
|
||||
limit: Limit = field(default_factory=Limit)
|
||||
projection: Projection = field(default_factory=Projection)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KNNPlan:
|
||||
scan: Scan
|
||||
knn: KNN
|
||||
filter: Filter = field(default_factory=Filter)
|
||||
projection: Projection = field(default_factory=Projection)
|
||||
|
||||
|
||||
class Search:
|
||||
"""Payload for hybrid search operations.
|
||||
|
||||
Can be constructed directly or using builder pattern:
|
||||
|
||||
Direct construction with expressions:
|
||||
Search(
|
||||
where=Key("status") == "active",
|
||||
rank=Knn(query=[0.1, 0.2]),
|
||||
limit=Limit(limit=10),
|
||||
select=Select(keys={Key.DOCUMENT})
|
||||
)
|
||||
|
||||
Direct construction with dicts:
|
||||
Search(
|
||||
where={"status": "active"},
|
||||
rank={"$knn": {"query": [0.1, 0.2]}},
|
||||
limit=10, # Creates Limit(limit=10, offset=0)
|
||||
select=["#document", "#score"]
|
||||
)
|
||||
|
||||
Builder pattern:
|
||||
(Search()
|
||||
.where(Key("status") == "active")
|
||||
.rank(Knn(query=[0.1, 0.2]))
|
||||
.limit(10)
|
||||
.select(Key.DOCUMENT))
|
||||
|
||||
Builder pattern with dicts:
|
||||
(Search()
|
||||
.where({"status": "active"})
|
||||
.rank({"$knn": {"query": [0.1, 0.2]}})
|
||||
.limit(10)
|
||||
.select(Key.DOCUMENT))
|
||||
|
||||
Filter by IDs:
|
||||
Search().where(Key.ID.is_in(["id1", "id2", "id3"]))
|
||||
|
||||
Combined with metadata filtering:
|
||||
Search().where((Key.ID.is_in(["id1", "id2"])) & (Key("status") == "active"))
|
||||
|
||||
Empty Search() is valid and will use defaults:
|
||||
- where: None (no filtering)
|
||||
- rank: None (no ranking - results ordered by default order)
|
||||
- limit: No limit
|
||||
- select: Empty selection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Union[Where, Dict[str, Any]]] = None,
|
||||
rank: Optional[Union[Rank, Dict[str, Any]]] = None,
|
||||
limit: Optional[Union[Limit, Dict[str, Any], int]] = None,
|
||||
select: Optional[Union[Select, Dict[str, Any], List[str], Set[str]]] = None,
|
||||
):
|
||||
"""Initialize a Search with optional parameters.
|
||||
|
||||
Args:
|
||||
where: Where expression or dict for filtering results (defaults to None - no filtering)
|
||||
Dict will be converted using Where.from_dict()
|
||||
rank: Rank expression or dict for scoring (defaults to None - no ranking)
|
||||
Dict will be converted using Rank.from_dict()
|
||||
Note: Primitive numbers are not accepted - use {"$val": number} for constant ranks
|
||||
limit: Limit configuration for pagination (defaults to no limit)
|
||||
Can be a Limit object, a dict for Limit.from_dict(), or an int
|
||||
When passing an int, it creates Limit(limit=value, offset=0)
|
||||
select: Select configuration for keys (defaults to empty selection)
|
||||
Can be a Select object, a dict for Select.from_dict(),
|
||||
or a list/set of strings (e.g., ["#document", "#score"])
|
||||
"""
|
||||
# Handle where parameter
|
||||
if where is None:
|
||||
self._where = None
|
||||
elif isinstance(where, Where):
|
||||
self._where = where
|
||||
elif isinstance(where, dict):
|
||||
self._where = Where.from_dict(where)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"where must be a Where object, dict, or None, got {type(where).__name__}"
|
||||
)
|
||||
|
||||
# Handle rank parameter
|
||||
if rank is None:
|
||||
self._rank = None
|
||||
elif isinstance(rank, Rank):
|
||||
self._rank = rank
|
||||
elif isinstance(rank, dict):
|
||||
self._rank = Rank.from_dict(rank)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"rank must be a Rank object, dict, or None, got {type(rank).__name__}"
|
||||
)
|
||||
|
||||
# Handle limit parameter
|
||||
if limit is None:
|
||||
self._limit = Limit()
|
||||
elif isinstance(limit, Limit):
|
||||
self._limit = limit
|
||||
elif isinstance(limit, int):
|
||||
self._limit = Limit.from_dict({"limit": limit, "offset": 0})
|
||||
elif isinstance(limit, dict):
|
||||
self._limit = Limit.from_dict(limit)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"limit must be a Limit object, dict, int, or None, got {type(limit).__name__}"
|
||||
)
|
||||
|
||||
# Handle select parameter
|
||||
if select is None:
|
||||
self._select = Select()
|
||||
elif isinstance(select, Select):
|
||||
self._select = select
|
||||
elif isinstance(select, dict):
|
||||
self._select = Select.from_dict(select)
|
||||
elif isinstance(select, (list, set)):
|
||||
# Convert list/set of strings to Select object
|
||||
self._select = Select.from_dict({"keys": list(select)})
|
||||
else:
|
||||
raise TypeError(
|
||||
f"select must be a Select object, dict, list, set, or None, got {type(select).__name__}"
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the Search to a dictionary for JSON serialization"""
|
||||
return {
|
||||
"filter": self._where.to_dict() if self._where is not None else None,
|
||||
"rank": self._rank.to_dict() if self._rank is not None else None,
|
||||
"limit": self._limit.to_dict(),
|
||||
"select": self._select.to_dict(),
|
||||
}
|
||||
|
||||
# Builder methods for chaining
|
||||
def select_all(self) -> "Search":
|
||||
"""Select all predefined keys (document, embedding, metadata, score)"""
|
||||
new_select = Select(keys={Key.DOCUMENT, Key.EMBEDDING, Key.METADATA, Key.SCORE})
|
||||
return Search(
|
||||
where=self._where, rank=self._rank, limit=self._limit, select=new_select
|
||||
)
|
||||
|
||||
def select(self, *keys: Union[Key, str]) -> "Search":
|
||||
"""Select specific keys
|
||||
|
||||
Args:
|
||||
*keys: Variable number of Key objects or string key names
|
||||
|
||||
Returns:
|
||||
New Search object with updated select configuration
|
||||
"""
|
||||
new_select = Select(keys=set(keys))
|
||||
return Search(
|
||||
where=self._where, rank=self._rank, limit=self._limit, select=new_select
|
||||
)
|
||||
|
||||
def where(self, where: Optional[Union[Where, Dict[str, Any]]]) -> "Search":
|
||||
"""Set the where clause for filtering
|
||||
|
||||
Args:
|
||||
where: A Where expression, dict, or None for filtering
|
||||
Dicts will be converted using Where.from_dict()
|
||||
|
||||
Example:
|
||||
search.where((Key("status") == "active") & (Key("score") > 0.5))
|
||||
search.where({"status": "active"})
|
||||
search.where({"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]})
|
||||
"""
|
||||
# Convert dict to Where if needed
|
||||
if where is None:
|
||||
converted_where = None
|
||||
elif isinstance(where, Where):
|
||||
converted_where = where
|
||||
elif isinstance(where, dict):
|
||||
converted_where = Where.from_dict(where)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"where must be a Where object, dict, or None, got {type(where).__name__}"
|
||||
)
|
||||
|
||||
return Search(
|
||||
where=converted_where, rank=self._rank, limit=self._limit, select=self._select
|
||||
)
|
||||
|
||||
def rank(self, rank_expr: Optional[Union[Rank, Dict[str, Any]]]) -> "Search":
|
||||
"""Set the ranking expression
|
||||
|
||||
Args:
|
||||
rank_expr: A Rank expression, dict, or None for scoring
|
||||
Dicts will be converted using Rank.from_dict()
|
||||
Note: Primitive numbers are not accepted - use {"$val": number} for constant ranks
|
||||
|
||||
Example:
|
||||
search.rank(Knn(query=[0.1, 0.2]) * 0.8 + Val(0.5) * 0.2)
|
||||
search.rank({"$knn": {"query": [0.1, 0.2]}})
|
||||
search.rank({"$sum": [{"$knn": {"query": [0.1, 0.2]}}, {"$val": 0.5}]})
|
||||
"""
|
||||
# Convert dict to Rank if needed
|
||||
if rank_expr is None:
|
||||
converted_rank = None
|
||||
elif isinstance(rank_expr, Rank):
|
||||
converted_rank = rank_expr
|
||||
elif isinstance(rank_expr, dict):
|
||||
converted_rank = Rank.from_dict(rank_expr)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"rank_expr must be a Rank object, dict, or None, got {type(rank_expr).__name__}"
|
||||
)
|
||||
|
||||
return Search(
|
||||
where=self._where, rank=converted_rank, limit=self._limit, select=self._select
|
||||
)
|
||||
|
||||
def limit(self, limit: int, offset: int = 0) -> "Search":
|
||||
"""Set the limit and offset for pagination
|
||||
|
||||
Args:
|
||||
limit: Maximum number of results to return
|
||||
offset: Number of results to skip (default: 0)
|
||||
|
||||
Example:
|
||||
search.limit(20, offset=10)
|
||||
"""
|
||||
new_limit = Limit(offset=offset, limit=limit)
|
||||
return Search(
|
||||
where=self._where, rank=self._rank, limit=new_limit, select=self._select
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,121 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Optional, Sequence
|
||||
from chromadb.types import (
|
||||
OperationRecord,
|
||||
LogRecord,
|
||||
SeqId,
|
||||
Vector,
|
||||
ScalarEncoding,
|
||||
)
|
||||
from chromadb.config import Component
|
||||
from uuid import UUID
|
||||
import numpy as np
|
||||
|
||||
|
||||
def encode_vector(vector: Vector, encoding: ScalarEncoding) -> bytes:
|
||||
"""Encode a vector into a byte array."""
|
||||
|
||||
if encoding == ScalarEncoding.FLOAT32:
|
||||
return np.array(vector, dtype=np.float32).tobytes()
|
||||
elif encoding == ScalarEncoding.INT32:
|
||||
return np.array(vector, dtype=np.int32).tobytes()
|
||||
else:
|
||||
raise ValueError(f"Unsupported encoding: {encoding.value}")
|
||||
|
||||
|
||||
def decode_vector(vector: bytes, encoding: ScalarEncoding) -> Vector:
|
||||
"""Decode a byte array into a vector"""
|
||||
|
||||
if encoding == ScalarEncoding.FLOAT32:
|
||||
return np.frombuffer(vector, dtype=np.float32)
|
||||
elif encoding == ScalarEncoding.INT32:
|
||||
return np.frombuffer(vector, dtype=np.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported encoding: {encoding.value}")
|
||||
|
||||
|
||||
class Producer(Component):
|
||||
"""Interface for writing embeddings to an ingest stream"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_log(self, collection_id: UUID) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def purge_log(self, collection_id: UUID) -> None:
|
||||
"""Truncates the log for the given collection, removing all seen records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_embedding(
|
||||
self, collection_id: UUID, embedding: OperationRecord
|
||||
) -> SeqId:
|
||||
"""Add an embedding record to the given collections log. Returns the SeqID of the record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_embeddings(
|
||||
self, collection_id: UUID, embeddings: Sequence[OperationRecord]
|
||||
) -> Sequence[SeqId]:
|
||||
"""Add a batch of embedding records to the given collections log. Returns the SeqIDs of
|
||||
the records. The returned SeqIDs will be in the same order as the given
|
||||
SubmitEmbeddingRecords. However, it is not guaranteed that the SeqIDs will be
|
||||
processed in the same order as the given SubmitEmbeddingRecords. If the number
|
||||
of records exceeds the maximum batch size, an exception will be thrown."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_batch_size(self) -> int:
|
||||
"""Return the maximum number of records that can be submitted in a single call
|
||||
to submit_embeddings."""
|
||||
pass
|
||||
|
||||
|
||||
ConsumerCallbackFn = Callable[[Sequence[LogRecord]], None]
|
||||
|
||||
|
||||
class Consumer(Component):
|
||||
"""Interface for reading embeddings off an ingest stream"""
|
||||
|
||||
@abstractmethod
|
||||
def subscribe(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
consume_fn: ConsumerCallbackFn,
|
||||
start: Optional[SeqId] = None,
|
||||
end: Optional[SeqId] = None,
|
||||
id: Optional[UUID] = None,
|
||||
) -> UUID:
|
||||
"""Register a function that will be called to receive embeddings for a given
|
||||
collections log stream. The given function may be called any number of times, with any number of
|
||||
records, and may be called concurrently.
|
||||
|
||||
Only records between start (exclusive) and end (inclusive) SeqIDs will be
|
||||
returned. If start is None, the first record returned will be the next record
|
||||
generated, not including those generated before creating the subscription. If
|
||||
end is None, the consumer will consume indefinitely, otherwise it will
|
||||
automatically be unsubscribed when the end SeqID is reached.
|
||||
|
||||
If the function throws an exception, the function may be called again with the
|
||||
same or different records.
|
||||
|
||||
Takes an optional UUID as a unique subscription ID. If no ID is provided, a new
|
||||
ID will be generated and returned."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unsubscribe(self, subscription_id: UUID) -> None:
|
||||
"""Unregister a subscription. The consume function will no longer be invoked,
|
||||
and resources associated with the subscription will be released."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def min_seqid(self) -> SeqId:
|
||||
"""Return the minimum possible SeqID in this implementation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def max_seqid(self) -> SeqId:
|
||||
"""Return the maximum possible SeqID in this implementation."""
|
||||
pass
|
||||
@@ -0,0 +1,49 @@
|
||||
import re
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from chromadb.db.base import SqlDB
|
||||
from chromadb.segment import SegmentManager, VectorReader
|
||||
|
||||
topic_regex = r"persistent:\/\/(?P<tenant>.+)\/(?P<namespace>.+)\/(?P<topic>.+)"
|
||||
|
||||
|
||||
def parse_topic_name(topic_name: str) -> Tuple[str, str, str]:
|
||||
"""Parse the topic name into the tenant, namespace and topic name"""
|
||||
match = re.match(topic_regex, topic_name)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid topic name: {topic_name}")
|
||||
return match.group("tenant"), match.group("namespace"), match.group("topic")
|
||||
|
||||
|
||||
def create_topic_name(tenant: str, namespace: str, collection_id: UUID) -> str:
|
||||
return f"persistent://{tenant}/{namespace}/{str(collection_id)}"
|
||||
|
||||
|
||||
def trigger_vector_segments_max_seq_id_migration(
|
||||
db: SqlDB, segment_manager: SegmentManager
|
||||
) -> None:
|
||||
"""
|
||||
Trigger the migration of vector segments' max_seq_id from the pickled metadata file to SQLite.
|
||||
|
||||
Vector segments migrate this field automatically on init—so this should be used when we know segments are likely unmigrated and unloaded.
|
||||
|
||||
This is a no-op if all vector segments have already migrated their max_seq_id.
|
||||
"""
|
||||
with db.tx() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT collection
|
||||
FROM "segments"
|
||||
WHERE "id" NOT IN (SELECT "segment_id" FROM "max_seq_id") AND
|
||||
"type" = 'urn:chroma:segment/vector/hnsw-local-persisted'
|
||||
"""
|
||||
)
|
||||
collection_ids_with_unmigrated_segments = [row[0] for row in cur.fetchall()]
|
||||
|
||||
if len(collection_ids_with_unmigrated_segments) == 0:
|
||||
return
|
||||
|
||||
for collection_id in collection_ids_with_unmigrated_segments:
|
||||
# Loading the segment triggers the migration on init
|
||||
segment_manager.get_segment(UUID(collection_id), VectorReader)
|
||||
@@ -0,0 +1,37 @@
|
||||
version: 1
|
||||
disable_existing_loggers: False
|
||||
formatters:
|
||||
default:
|
||||
"()": uvicorn.logging.DefaultFormatter
|
||||
format: '%(levelprefix)s [%(asctime)s] %(message)s'
|
||||
use_colors: null
|
||||
datefmt: '%d-%m-%Y %H:%M:%S'
|
||||
access:
|
||||
"()": uvicorn.logging.AccessFormatter
|
||||
format: '%(levelprefix)s [%(asctime)s] %(client_addr)s - "%(request_line)s" %(status_code)s'
|
||||
datefmt: '%d-%m-%Y %H:%M:%S'
|
||||
handlers:
|
||||
default:
|
||||
formatter: default
|
||||
class: logging.StreamHandler
|
||||
stream: ext://sys.stderr
|
||||
access:
|
||||
formatter: access
|
||||
class: logging.StreamHandler
|
||||
stream: ext://sys.stdout
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
stream: ext://sys.stdout
|
||||
formatter: default
|
||||
file:
|
||||
class : logging.handlers.RotatingFileHandler
|
||||
filename: chroma.log
|
||||
formatter: default
|
||||
loggers:
|
||||
root:
|
||||
level: WARN
|
||||
handlers: [console, file]
|
||||
chromadb:
|
||||
level: DEBUG
|
||||
uvicorn:
|
||||
level: INFO
|
||||
@@ -0,0 +1,181 @@
|
||||
import sys
|
||||
|
||||
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
|
||||
import grpc
|
||||
import time
|
||||
from chromadb.ingest import (
|
||||
Producer,
|
||||
Consumer,
|
||||
ConsumerCallbackFn,
|
||||
)
|
||||
from chromadb.proto.convert import to_proto_submit
|
||||
from chromadb.proto.logservice_pb2 import PushLogsRequest, PullLogsRequest, LogRecord
|
||||
from chromadb.proto.logservice_pb2_grpc import LogServiceStub
|
||||
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
|
||||
from chromadb.types import (
|
||||
OperationRecord,
|
||||
SeqId,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
add_attributes_to_current_span,
|
||||
trace_method,
|
||||
)
|
||||
from overrides import override
|
||||
from typing import Sequence, Optional, cast
|
||||
from uuid import UUID
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogService(Producer, Consumer):
|
||||
"""
|
||||
Distributed Chroma Log Service
|
||||
"""
|
||||
|
||||
_log_service_stub: LogServiceStub
|
||||
_request_timeout_seconds: int
|
||||
_channel: grpc.Channel
|
||||
_log_service_url: str
|
||||
_log_service_port: int
|
||||
|
||||
def __init__(self, system: System):
|
||||
self._log_service_url = system.settings.require("chroma_logservice_host")
|
||||
self._log_service_port = system.settings.require("chroma_logservice_port")
|
||||
self._request_timeout_seconds = system.settings.require(
|
||||
"chroma_logservice_request_timeout_seconds"
|
||||
)
|
||||
self._opentelemetry_client = system.require(OpenTelemetryClient)
|
||||
super().__init__(system)
|
||||
|
||||
@trace_method("LogService.start", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def start(self) -> None:
|
||||
self._channel = grpc.insecure_channel(
|
||||
f"{self._log_service_url}:{self._log_service_port}",
|
||||
)
|
||||
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
|
||||
self._channel = grpc.intercept_channel(self._channel, *interceptors)
|
||||
self._log_service_stub = LogServiceStub(self._channel) # type: ignore
|
||||
super().start()
|
||||
|
||||
@trace_method("LogService.stop", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
self._channel.close()
|
||||
super().stop()
|
||||
|
||||
@trace_method("LogService.reset_state", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def reset_state(self) -> None:
|
||||
super().reset_state()
|
||||
|
||||
@trace_method("LogService.delete_log", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def delete_log(self, collection_id: UUID) -> None:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@trace_method("LogService.purge_log", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def purge_log(self, collection_id: UUID) -> None:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@trace_method("LogService.submit_embedding", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def submit_embedding(
|
||||
self, collection_id: UUID, embedding: OperationRecord
|
||||
) -> SeqId:
|
||||
if not self._running:
|
||||
raise RuntimeError("Component not running")
|
||||
|
||||
return self.submit_embeddings(collection_id, [embedding])[0]
|
||||
|
||||
@trace_method("LogService.submit_embeddings", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def submit_embeddings(
|
||||
self, collection_id: UUID, embeddings: Sequence[OperationRecord]
|
||||
) -> Sequence[SeqId]:
|
||||
logger.info(
|
||||
f"Submitting {len(embeddings)} embeddings to log for collection {collection_id}"
|
||||
)
|
||||
|
||||
add_attributes_to_current_span(
|
||||
{
|
||||
"records_count": len(embeddings),
|
||||
}
|
||||
)
|
||||
|
||||
if not self._running:
|
||||
raise RuntimeError("Component not running")
|
||||
|
||||
if len(embeddings) == 0:
|
||||
return []
|
||||
|
||||
# push records to the log service
|
||||
counts = []
|
||||
protos_to_submit = [to_proto_submit(record) for record in embeddings]
|
||||
counts.append(
|
||||
self.push_logs(
|
||||
collection_id,
|
||||
cast(Sequence[OperationRecord], protos_to_submit),
|
||||
)
|
||||
)
|
||||
|
||||
# This returns counts, which is completely incorrect
|
||||
# TODO: Fix this
|
||||
return counts
|
||||
|
||||
@trace_method("LogService.subscribe", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def subscribe(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
consume_fn: ConsumerCallbackFn,
|
||||
start: Optional[SeqId] = None,
|
||||
end: Optional[SeqId] = None,
|
||||
id: Optional[UUID] = None,
|
||||
) -> UUID:
|
||||
logger.info(f"Subscribing to log for {collection_id}, noop for logservice")
|
||||
return UUID(int=0)
|
||||
|
||||
@trace_method("LogService.unsubscribe", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def unsubscribe(self, subscription_id: UUID) -> None:
|
||||
logger.info(f"Unsubscribing from {subscription_id}, noop for logservice")
|
||||
|
||||
@override
|
||||
def min_seqid(self) -> SeqId:
|
||||
return 0
|
||||
|
||||
@override
|
||||
def max_seqid(self) -> SeqId:
|
||||
return sys.maxsize
|
||||
|
||||
@property
|
||||
@override
|
||||
def max_batch_size(self) -> int:
|
||||
return 100
|
||||
|
||||
def push_logs(self, collection_id: UUID, records: Sequence[OperationRecord]) -> int:
|
||||
request = PushLogsRequest(collection_id=str(collection_id), records=records)
|
||||
response = self._log_service_stub.PushLogs(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
return response.record_count # type: ignore
|
||||
|
||||
def pull_logs(
|
||||
self, collection_id: UUID, start_offset: int, batch_size: int
|
||||
) -> Sequence[LogRecord]:
|
||||
request = PullLogsRequest(
|
||||
collection_id=str(collection_id),
|
||||
start_from_offset=start_offset,
|
||||
batch_size=batch_size,
|
||||
end_timestamp=time.time_ns(),
|
||||
)
|
||||
response = self._log_service_stub.PullLogs(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
return response.records # type: ignore
|
||||
@@ -0,0 +1,10 @@
|
||||
CREATE TABLE embeddings_queue (
|
||||
seq_id INTEGER PRIMARY KEY,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
operation INTEGER NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
id TEXT NOT NULL,
|
||||
vector BLOB,
|
||||
encoding TEXT,
|
||||
metadata TEXT
|
||||
);
|
||||
@@ -0,0 +1,4 @@
|
||||
CREATE TABLE embeddings_queue_config (
|
||||
id INTEGER PRIMARY KEY,
|
||||
config_json_str TEXT
|
||||
);
|
||||
@@ -0,0 +1,24 @@
|
||||
CREATE TABLE embeddings (
|
||||
id INTEGER PRIMARY KEY,
|
||||
segment_id TEXT NOT NULL,
|
||||
embedding_id TEXT NOT NULL,
|
||||
seq_id BLOB NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE (segment_id, embedding_id)
|
||||
);
|
||||
|
||||
CREATE TABLE embedding_metadata (
|
||||
id INTEGER REFERENCES embeddings(id),
|
||||
key TEXT NOT NULL,
|
||||
string_value TEXT,
|
||||
int_value INTEGER,
|
||||
float_value REAL,
|
||||
PRIMARY KEY (id, key)
|
||||
);
|
||||
|
||||
CREATE TABLE max_seq_id (
|
||||
segment_id TEXT PRIMARY KEY,
|
||||
seq_id BLOB NOT NULL
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE embedding_fulltext USING fts5(id, string_value);
|
||||
@@ -0,0 +1,5 @@
|
||||
-- SQLite does not support adding check with alter table, as a result, adding a check
|
||||
-- involve creating a new table and copying the data over. It is over kill with adding
|
||||
-- a boolean type column. The application write to the table needs to ensure the data
|
||||
-- integrity.
|
||||
ALTER TABLE embedding_metadata ADD COLUMN bool_value INTEGER
|
||||
@@ -0,0 +1,3 @@
|
||||
CREATE VIRTUAL TABLE embedding_fulltext_search USING fts5(string_value, tokenize='trigram');
|
||||
INSERT INTO embedding_fulltext_search (rowid, string_value) SELECT rowid, string_value FROM embedding_metadata;
|
||||
DROP TABLE embedding_fulltext;
|
||||
@@ -0,0 +1,3 @@
|
||||
CREATE INDEX IF NOT EXISTS embedding_metadata_int_value ON embedding_metadata (key, int_value) WHERE int_value IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS embedding_metadata_float_value ON embedding_metadata (key, float_value) WHERE float_value IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS embedding_metadata_string_value ON embedding_metadata (key, string_value) WHERE string_value IS NOT NULL;
|
||||
@@ -0,0 +1,27 @@
|
||||
ALTER TABLE max_seq_id ADD COLUMN int_seq_id INTEGER;
|
||||
|
||||
-- Convert 8 byte wide big-endian integer as blob to native 64 bit integer.
|
||||
-- Adapted from https://stackoverflow.com/a/70296198.
|
||||
UPDATE max_seq_id SET int_seq_id = (
|
||||
SELECT (
|
||||
(instr('123456789ABCDEF', substr(hex(seq_id), -1 , 1)) << 0)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -2 , 1)) << 4)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -3 , 1)) << 8)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -4 , 1)) << 12)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -5 , 1)) << 16)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -6 , 1)) << 20)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -7 , 1)) << 24)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -8 , 1)) << 28)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -9 , 1)) << 32)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -10, 1)) << 36)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -11, 1)) << 40)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -12, 1)) << 44)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -13, 1)) << 48)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -14, 1)) << 52)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -15, 1)) << 56)
|
||||
| (instr('123456789ABCDEF', substr(hex(seq_id), -16, 1)) << 60)
|
||||
)
|
||||
);
|
||||
|
||||
ALTER TABLE max_seq_id DROP COLUMN seq_id;
|
||||
ALTER TABLE max_seq_id RENAME COLUMN int_seq_id TO seq_id;
|
||||
@@ -0,0 +1,15 @@
|
||||
CREATE TABLE collections (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
UNIQUE (name)
|
||||
);
|
||||
|
||||
CREATE TABLE collection_metadata (
|
||||
collection_id TEXT REFERENCES collections(id) ON DELETE CASCADE,
|
||||
key TEXT NOT NULL,
|
||||
str_value TEXT,
|
||||
int_value INTEGER,
|
||||
float_value REAL,
|
||||
PRIMARY KEY (collection_id, key)
|
||||
);
|
||||
@@ -0,0 +1,16 @@
|
||||
CREATE TABLE segments (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
scope TEXT NOT NULL,
|
||||
topic TEXT,
|
||||
collection TEXT REFERENCES collection(id)
|
||||
);
|
||||
|
||||
CREATE TABLE segment_metadata (
|
||||
segment_id TEXT REFERENCES segments(id) ON DELETE CASCADE,
|
||||
key TEXT NOT NULL,
|
||||
str_value TEXT,
|
||||
int_value INTEGER,
|
||||
float_value REAL,
|
||||
PRIMARY KEY (segment_id, key)
|
||||
);
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE collections ADD COLUMN dimension INTEGER;
|
||||
@@ -0,0 +1,29 @@
|
||||
CREATE TABLE IF NOT EXISTS tenants (
|
||||
id TEXT PRIMARY KEY,
|
||||
UNIQUE (id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS databases (
|
||||
id TEXT PRIMARY KEY, -- unique globally
|
||||
name TEXT NOT NULL, -- unique per tenant
|
||||
tenant_id TEXT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
|
||||
UNIQUE (tenant_id, name) -- Ensure that a tenant has only one database with a given name
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS collections_tmp (
|
||||
id TEXT PRIMARY KEY, -- unique globally
|
||||
name TEXT NOT NULL, -- unique per database
|
||||
topic TEXT NOT NULL,
|
||||
dimension INTEGER,
|
||||
database_id TEXT NOT NULL REFERENCES databases(id) ON DELETE CASCADE,
|
||||
UNIQUE (name, database_id)
|
||||
);
|
||||
|
||||
-- Create default tenant and database
|
||||
INSERT OR REPLACE INTO tenants (id) VALUES ('default_tenant'); -- The default tenant id is 'default_tenant' others are UUIDs
|
||||
INSERT OR REPLACE INTO databases (id, name, tenant_id) VALUES ('00000000-0000-0000-0000-000000000000', 'default_database', 'default_tenant');
|
||||
|
||||
INSERT OR REPLACE INTO collections_tmp (id, name, topic, dimension, database_id)
|
||||
SELECT id, name, topic, dimension, '00000000-0000-0000-0000-000000000000' FROM collections;
|
||||
DROP TABLE collections;
|
||||
ALTER TABLE collections_tmp RENAME TO collections;
|
||||
@@ -0,0 +1,4 @@
|
||||
-- Remove the topic column from the Collections and Segments tables
|
||||
|
||||
ALTER TABLE collections DROP COLUMN topic;
|
||||
ALTER TABLE segments DROP COLUMN topic;
|
||||
@@ -0,0 +1,6 @@
|
||||
-- SQLite does not support adding check with alter table, as a result, adding a check
|
||||
-- involve creating a new table and copying the data over. It is over kill with adding
|
||||
-- a boolean type column. The application write to the table needs to ensure the data
|
||||
-- integrity.
|
||||
ALTER TABLE collection_metadata ADD COLUMN bool_value INTEGER;
|
||||
ALTER TABLE segment_metadata ADD COLUMN bool_value INTEGER;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- Stores collection configuration dictionaries.
|
||||
ALTER TABLE collections ADD COLUMN config_json_str TEXT;
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Records when database maintenance operations are performed.
|
||||
-- At time of creation, this table is only used to record vacuum operations.
|
||||
CREATE TABLE maintenance_log (
|
||||
id INT PRIMARY KEY,
|
||||
timestamp INT NOT NULL,
|
||||
operation TEXT NOT NULL
|
||||
);
|
||||
@@ -0,0 +1,11 @@
|
||||
-- This makes segments.collection non-nullable.
|
||||
CREATE TABLE segments_temp (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
scope TEXT NOT NULL,
|
||||
collection TEXT REFERENCES collection(id) NOT NULL
|
||||
);
|
||||
|
||||
INSERT INTO segments_temp SELECT * FROM segments;
|
||||
DROP TABLE segments;
|
||||
ALTER TABLE segments_temp RENAME TO segments;
|
||||
2
backend_service/venv/lib/python3.13/site-packages/chromadb/proto/.gitignore
vendored
Normal file
2
backend_service/venv/lib/python3.13/site-packages/chromadb/proto/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*_pb2.py*
|
||||
*_pb2_grpc.py
|
||||
@@ -0,0 +1,688 @@
|
||||
from typing import Dict, Optional, Sequence, Tuple, TypedDict, Union, cast
|
||||
from uuid import UUID
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import chromadb.proto.chroma_pb2 as chroma_pb
|
||||
import chromadb.proto.query_executor_pb2 as query_pb
|
||||
from chromadb.api.collection_configuration import (
|
||||
collection_configuration_to_json_str,
|
||||
)
|
||||
from chromadb.api.types import Embedding, Where, WhereDocument
|
||||
from chromadb.execution.expression.operator import (
|
||||
KNN,
|
||||
Filter,
|
||||
Limit,
|
||||
Projection,
|
||||
Scan,
|
||||
)
|
||||
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
|
||||
from chromadb.types import (
|
||||
Collection,
|
||||
LogRecord,
|
||||
Metadata,
|
||||
Operation,
|
||||
OperationRecord,
|
||||
RequestVersionContext,
|
||||
ScalarEncoding,
|
||||
Segment,
|
||||
SegmentScope,
|
||||
SeqId,
|
||||
UpdateMetadata,
|
||||
Vector,
|
||||
VectorEmbeddingRecord,
|
||||
VectorQueryResult,
|
||||
)
|
||||
|
||||
|
||||
class ProjectionRecord(TypedDict):
|
||||
id: str
|
||||
document: Optional[str]
|
||||
embedding: Optional[Vector]
|
||||
metadata: Optional[Metadata]
|
||||
|
||||
|
||||
class KNNProjectionRecord(TypedDict):
|
||||
record: ProjectionRecord
|
||||
distance: Optional[float]
|
||||
|
||||
|
||||
# TODO: Unit tests for this file, handling optional states etc
|
||||
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> chroma_pb.Vector:
|
||||
if encoding == ScalarEncoding.FLOAT32:
|
||||
as_bytes = np.array(vector, dtype=np.float32).tobytes()
|
||||
proto_encoding = chroma_pb.ScalarEncoding.FLOAT32
|
||||
elif encoding == ScalarEncoding.INT32:
|
||||
as_bytes = np.array(vector, dtype=np.int32).tobytes()
|
||||
proto_encoding = chroma_pb.ScalarEncoding.INT32
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \
|
||||
or {ScalarEncoding.INT32}"
|
||||
)
|
||||
|
||||
return chroma_pb.Vector(
|
||||
dimension=vector.size, vector=as_bytes, encoding=proto_encoding
|
||||
)
|
||||
|
||||
|
||||
def from_proto_vector(vector: chroma_pb.Vector) -> Tuple[Embedding, ScalarEncoding]:
|
||||
encoding = vector.encoding
|
||||
as_array: Union[NDArray[np.int32], NDArray[np.float32]]
|
||||
if encoding == chroma_pb.ScalarEncoding.FLOAT32:
|
||||
as_array = np.frombuffer(vector.vector, dtype=np.float32)
|
||||
out_encoding = ScalarEncoding.FLOAT32
|
||||
elif encoding == chroma_pb.ScalarEncoding.INT32:
|
||||
as_array = np.frombuffer(vector.vector, dtype=np.int32)
|
||||
out_encoding = ScalarEncoding.INT32
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown encoding {encoding}, expected one of \
|
||||
{chroma_pb.ScalarEncoding.FLOAT32} or {chroma_pb.ScalarEncoding.INT32}"
|
||||
)
|
||||
|
||||
return (as_array, out_encoding)
|
||||
|
||||
|
||||
def from_proto_operation(operation: chroma_pb.Operation) -> Operation:
|
||||
if operation == chroma_pb.Operation.ADD:
|
||||
return Operation.ADD
|
||||
elif operation == chroma_pb.Operation.UPDATE:
|
||||
return Operation.UPDATE
|
||||
elif operation == chroma_pb.Operation.UPSERT:
|
||||
return Operation.UPSERT
|
||||
elif operation == chroma_pb.Operation.DELETE:
|
||||
return Operation.DELETE
|
||||
else:
|
||||
# TODO: full error
|
||||
raise RuntimeError(f"Unknown operation {operation}")
|
||||
|
||||
|
||||
def from_proto_metadata(metadata: chroma_pb.UpdateMetadata) -> Optional[Metadata]:
|
||||
return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False))
|
||||
|
||||
|
||||
def from_proto_update_metadata(
|
||||
metadata: chroma_pb.UpdateMetadata,
|
||||
) -> Optional[UpdateMetadata]:
|
||||
return cast(
|
||||
Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True)
|
||||
)
|
||||
|
||||
|
||||
def _from_proto_metadata_handle_none(
|
||||
metadata: chroma_pb.UpdateMetadata, is_update: bool
|
||||
) -> Optional[Union[UpdateMetadata, Metadata]]:
|
||||
if not metadata.metadata:
|
||||
return None
|
||||
out_metadata: Dict[str, Union[str, int, float, bool, None]] = {}
|
||||
for key, value in metadata.metadata.items():
|
||||
if value.HasField("bool_value"):
|
||||
out_metadata[key] = value.bool_value
|
||||
elif value.HasField("string_value"):
|
||||
out_metadata[key] = value.string_value
|
||||
elif value.HasField("int_value"):
|
||||
out_metadata[key] = value.int_value
|
||||
elif value.HasField("float_value"):
|
||||
out_metadata[key] = value.float_value
|
||||
elif is_update:
|
||||
out_metadata[key] = None
|
||||
else:
|
||||
raise ValueError(f"Metadata key {key} value cannot be None")
|
||||
return out_metadata
|
||||
|
||||
|
||||
def to_proto_update_metadata(metadata: UpdateMetadata) -> chroma_pb.UpdateMetadata:
|
||||
return chroma_pb.UpdateMetadata(
|
||||
metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()}
|
||||
)
|
||||
|
||||
|
||||
def from_proto_submit(
|
||||
operation_record: chroma_pb.OperationRecord, seq_id: SeqId
|
||||
) -> LogRecord:
|
||||
embedding, encoding = from_proto_vector(operation_record.vector)
|
||||
record = LogRecord(
|
||||
log_offset=seq_id,
|
||||
record=OperationRecord(
|
||||
id=operation_record.id,
|
||||
embedding=embedding,
|
||||
encoding=encoding,
|
||||
metadata=from_proto_update_metadata(operation_record.metadata),
|
||||
operation=from_proto_operation(operation_record.operation),
|
||||
),
|
||||
)
|
||||
return record
|
||||
|
||||
|
||||
def from_proto_segment(segment: chroma_pb.Segment) -> Segment:
|
||||
return Segment(
|
||||
id=UUID(hex=segment.id),
|
||||
type=segment.type,
|
||||
scope=from_proto_segment_scope(segment.scope),
|
||||
collection=UUID(hex=segment.collection),
|
||||
metadata=from_proto_metadata(segment.metadata)
|
||||
if segment.HasField("metadata")
|
||||
else None,
|
||||
file_paths={
|
||||
name: [path for path in paths.paths]
|
||||
for name, paths in segment.file_paths.items()
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_proto_segment(segment: Segment) -> chroma_pb.Segment:
|
||||
return chroma_pb.Segment(
|
||||
id=segment["id"].hex,
|
||||
type=segment["type"],
|
||||
scope=to_proto_segment_scope(segment["scope"]),
|
||||
collection=segment["collection"].hex,
|
||||
metadata=None
|
||||
if segment["metadata"] is None
|
||||
else to_proto_update_metadata(segment["metadata"]),
|
||||
file_paths={
|
||||
name: chroma_pb.FilePaths(paths=paths)
|
||||
for name, paths in segment["file_paths"].items()
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def from_proto_segment_scope(segment_scope: chroma_pb.SegmentScope) -> SegmentScope:
|
||||
if segment_scope == chroma_pb.SegmentScope.VECTOR:
|
||||
return SegmentScope.VECTOR
|
||||
elif segment_scope == chroma_pb.SegmentScope.METADATA:
|
||||
return SegmentScope.METADATA
|
||||
elif segment_scope == chroma_pb.SegmentScope.RECORD:
|
||||
return SegmentScope.RECORD
|
||||
else:
|
||||
raise RuntimeError(f"Unknown segment scope {segment_scope}")
|
||||
|
||||
|
||||
def to_proto_segment_scope(segment_scope: SegmentScope) -> chroma_pb.SegmentScope:
|
||||
if segment_scope == SegmentScope.VECTOR:
|
||||
return chroma_pb.SegmentScope.VECTOR
|
||||
elif segment_scope == SegmentScope.METADATA:
|
||||
return chroma_pb.SegmentScope.METADATA
|
||||
elif segment_scope == SegmentScope.RECORD:
|
||||
return chroma_pb.SegmentScope.RECORD
|
||||
else:
|
||||
raise RuntimeError(f"Unknown segment scope {segment_scope}")
|
||||
|
||||
|
||||
def to_proto_metadata_update_value(
|
||||
value: Union[str, int, float, bool, None]
|
||||
) -> chroma_pb.UpdateMetadataValue:
|
||||
# Be careful with the order here. Since bools are a subtype of int in python,
|
||||
# isinstance(value, bool) and isinstance(value, int) both return true
|
||||
# for a value of bool type.
|
||||
if isinstance(value, bool):
|
||||
return chroma_pb.UpdateMetadataValue(bool_value=value)
|
||||
elif isinstance(value, str):
|
||||
return chroma_pb.UpdateMetadataValue(string_value=value)
|
||||
elif isinstance(value, int):
|
||||
return chroma_pb.UpdateMetadataValue(int_value=value)
|
||||
elif isinstance(value, float):
|
||||
return chroma_pb.UpdateMetadataValue(float_value=value)
|
||||
# None is used to delete the metadata key.
|
||||
elif value is None:
|
||||
return chroma_pb.UpdateMetadataValue()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown metadata value type {type(value)}, expected one of str, int, \
|
||||
float, or None"
|
||||
)
|
||||
|
||||
|
||||
def from_proto_collection(collection: chroma_pb.Collection) -> Collection:
|
||||
return Collection(
|
||||
id=UUID(hex=collection.id),
|
||||
name=collection.name,
|
||||
configuration_json=json.loads(collection.configuration_json_str),
|
||||
metadata=from_proto_metadata(collection.metadata)
|
||||
if collection.HasField("metadata")
|
||||
else None,
|
||||
dimension=collection.dimension
|
||||
if collection.HasField("dimension") and collection.dimension
|
||||
else None,
|
||||
database=collection.database,
|
||||
tenant=collection.tenant,
|
||||
version=collection.version,
|
||||
log_position=collection.log_position,
|
||||
)
|
||||
|
||||
|
||||
def to_proto_collection(collection: Collection) -> chroma_pb.Collection:
|
||||
return chroma_pb.Collection(
|
||||
id=collection["id"].hex,
|
||||
name=collection["name"],
|
||||
configuration_json_str=collection_configuration_to_json_str(
|
||||
collection.get_configuration()
|
||||
),
|
||||
metadata=None
|
||||
if collection["metadata"] is None
|
||||
else to_proto_update_metadata(collection["metadata"]),
|
||||
dimension=collection["dimension"],
|
||||
tenant=collection["tenant"],
|
||||
database=collection["database"],
|
||||
log_position=collection["log_position"],
|
||||
version=collection["version"],
|
||||
)
|
||||
|
||||
|
||||
def to_proto_operation(operation: Operation) -> chroma_pb.Operation:
|
||||
if operation == Operation.ADD:
|
||||
return chroma_pb.Operation.ADD
|
||||
elif operation == Operation.UPDATE:
|
||||
return chroma_pb.Operation.UPDATE
|
||||
elif operation == Operation.UPSERT:
|
||||
return chroma_pb.Operation.UPSERT
|
||||
elif operation == Operation.DELETE:
|
||||
return chroma_pb.Operation.DELETE
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation {operation}, expected one of {Operation.ADD}, \
|
||||
{Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}"
|
||||
)
|
||||
|
||||
|
||||
def to_proto_submit(
|
||||
submit_record: OperationRecord,
|
||||
) -> chroma_pb.OperationRecord:
|
||||
vector = None
|
||||
if submit_record["embedding"] is not None and submit_record["encoding"] is not None:
|
||||
vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"])
|
||||
|
||||
metadata = None
|
||||
if submit_record["metadata"] is not None:
|
||||
metadata = to_proto_update_metadata(submit_record["metadata"])
|
||||
|
||||
return chroma_pb.OperationRecord(
|
||||
id=submit_record["id"],
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
operation=to_proto_operation(submit_record["operation"]),
|
||||
)
|
||||
|
||||
|
||||
def from_proto_vector_embedding_record(
|
||||
embedding_record: chroma_pb.VectorEmbeddingRecord,
|
||||
) -> VectorEmbeddingRecord:
|
||||
return VectorEmbeddingRecord(
|
||||
id=embedding_record.id,
|
||||
embedding=from_proto_vector(embedding_record.vector)[0],
|
||||
)
|
||||
|
||||
|
||||
def to_proto_vector_embedding_record(
|
||||
embedding_record: VectorEmbeddingRecord,
|
||||
encoding: ScalarEncoding,
|
||||
) -> chroma_pb.VectorEmbeddingRecord:
|
||||
return chroma_pb.VectorEmbeddingRecord(
|
||||
id=embedding_record["id"],
|
||||
vector=to_proto_vector(embedding_record["embedding"], encoding),
|
||||
)
|
||||
|
||||
|
||||
def from_proto_vector_query_result(
|
||||
vector_query_result: chroma_pb.VectorQueryResult,
|
||||
) -> VectorQueryResult:
|
||||
return VectorQueryResult(
|
||||
id=vector_query_result.id,
|
||||
distance=vector_query_result.distance,
|
||||
embedding=from_proto_vector(vector_query_result.vector)[0],
|
||||
)
|
||||
|
||||
|
||||
def from_proto_request_version_context(
|
||||
request_version_context: chroma_pb.RequestVersionContext,
|
||||
) -> RequestVersionContext:
|
||||
return RequestVersionContext(
|
||||
collection_version=request_version_context.collection_version,
|
||||
log_position=request_version_context.log_position,
|
||||
)
|
||||
|
||||
|
||||
def to_proto_request_version_context(
|
||||
request_version_context: RequestVersionContext,
|
||||
) -> chroma_pb.RequestVersionContext:
|
||||
return chroma_pb.RequestVersionContext(
|
||||
collection_version=request_version_context["collection_version"],
|
||||
log_position=request_version_context["log_position"],
|
||||
)
|
||||
|
||||
|
||||
def to_proto_where(where: Where) -> chroma_pb.Where:
|
||||
response = chroma_pb.Where()
|
||||
if len(where) != 1:
|
||||
raise ValueError(f"Expected where to have exactly one operator, got {where}")
|
||||
|
||||
for key, value in where.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Expected where key to be a str, got {key}")
|
||||
|
||||
if key == "$and" or key == "$or":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
f"Expected where value for $and or $or to be a list of where expressions, got {value}"
|
||||
)
|
||||
children: chroma_pb.WhereChildren = chroma_pb.WhereChildren(
|
||||
children=[to_proto_where(w) for w in value]
|
||||
)
|
||||
if key == "$and":
|
||||
children.operator = chroma_pb.BooleanOperator.AND
|
||||
else:
|
||||
children.operator = chroma_pb.BooleanOperator.OR
|
||||
|
||||
response.children.CopyFrom(children)
|
||||
return response
|
||||
|
||||
# At this point we know we're at a direct comparison. It can either
|
||||
# be of the form {"key": "value"} or {"key": {"$operator": "value"}}.
|
||||
|
||||
dc = chroma_pb.DirectComparison()
|
||||
dc.key = key
|
||||
|
||||
if not isinstance(value, dict):
|
||||
# {'key': 'value'} case
|
||||
if type(value) is str:
|
||||
ssc = chroma_pb.SingleStringComparison()
|
||||
ssc.value = value
|
||||
ssc.comparator = chroma_pb.GenericComparator.EQ
|
||||
dc.single_string_operand.CopyFrom(ssc)
|
||||
elif type(value) is bool:
|
||||
sbc = chroma_pb.SingleBoolComparison()
|
||||
sbc.value = value
|
||||
sbc.comparator = chroma_pb.GenericComparator.EQ
|
||||
dc.single_bool_operand.CopyFrom(sbc)
|
||||
elif type(value) is int:
|
||||
sic = chroma_pb.SingleIntComparison()
|
||||
sic.value = value
|
||||
sic.generic_comparator = chroma_pb.GenericComparator.EQ
|
||||
dc.single_int_operand.CopyFrom(sic)
|
||||
elif type(value) is float:
|
||||
sdc = chroma_pb.SingleDoubleComparison()
|
||||
sdc.value = value
|
||||
sdc.generic_comparator = chroma_pb.GenericComparator.EQ
|
||||
dc.single_double_operand.CopyFrom(sdc)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected where value to be a string, int, or float, got {value}"
|
||||
)
|
||||
else:
|
||||
for operator, operand in value.items():
|
||||
if operator in ["$in", "$nin"]:
|
||||
if not isinstance(operand, list):
|
||||
raise ValueError(
|
||||
f"Expected where value for $in or $nin to be a list of values, got {value}"
|
||||
)
|
||||
if len(operand) == 0 or not all(
|
||||
isinstance(x, type(operand[0])) for x in operand
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected where operand value to be a non-empty list, and all values to be of the same type "
|
||||
f"got {operand}"
|
||||
)
|
||||
list_operator = None
|
||||
if operator == "$in":
|
||||
list_operator = chroma_pb.ListOperator.IN
|
||||
else:
|
||||
list_operator = chroma_pb.ListOperator.NIN
|
||||
if type(operand[0]) is str:
|
||||
slo = chroma_pb.StringListComparison()
|
||||
for x in operand:
|
||||
slo.values.extend([x]) # type: ignore
|
||||
slo.list_operator = list_operator
|
||||
dc.string_list_operand.CopyFrom(slo)
|
||||
elif type(operand[0]) is bool:
|
||||
blo = chroma_pb.BoolListComparison()
|
||||
for x in operand:
|
||||
blo.values.extend([x]) # type: ignore
|
||||
blo.list_operator = list_operator
|
||||
dc.bool_list_operand.CopyFrom(blo)
|
||||
elif type(operand[0]) is int:
|
||||
ilo = chroma_pb.IntListComparison()
|
||||
for x in operand:
|
||||
ilo.values.extend([x]) # type: ignore
|
||||
ilo.list_operator = list_operator
|
||||
dc.int_list_operand.CopyFrom(ilo)
|
||||
elif type(operand[0]) is float:
|
||||
dlo = chroma_pb.DoubleListComparison()
|
||||
for x in operand:
|
||||
dlo.values.extend([x]) # type: ignore
|
||||
dlo.list_operator = list_operator
|
||||
dc.double_list_operand.CopyFrom(dlo)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected where operand value to be a list of strings, ints, or floats, got {operand}"
|
||||
)
|
||||
elif operator in ["$eq", "$ne", "$gt", "$lt", "$gte", "$lte"]:
|
||||
# Direct comparison to a single value.
|
||||
if type(operand) is str:
|
||||
ssc = chroma_pb.SingleStringComparison()
|
||||
ssc.value = operand
|
||||
if operator == "$eq":
|
||||
ssc.comparator = chroma_pb.GenericComparator.EQ
|
||||
elif operator == "$ne":
|
||||
ssc.comparator = chroma_pb.GenericComparator.NE
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected where operator to be $eq or $ne, got {operator}"
|
||||
)
|
||||
dc.single_string_operand.CopyFrom(ssc)
|
||||
elif type(operand) is bool:
|
||||
sbc = chroma_pb.SingleBoolComparison()
|
||||
sbc.value = operand
|
||||
if operator == "$eq":
|
||||
sbc.comparator = chroma_pb.GenericComparator.EQ
|
||||
elif operator == "$ne":
|
||||
sbc.comparator = chroma_pb.GenericComparator.NE
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected where operator to be $eq or $ne, got {operator}"
|
||||
)
|
||||
dc.single_bool_operand.CopyFrom(sbc)
|
||||
elif type(operand) is int:
|
||||
sic = chroma_pb.SingleIntComparison()
|
||||
sic.value = operand
|
||||
if operator == "$eq":
|
||||
sic.generic_comparator = chroma_pb.GenericComparator.EQ
|
||||
elif operator == "$ne":
|
||||
sic.generic_comparator = chroma_pb.GenericComparator.NE
|
||||
elif operator == "$gt":
|
||||
sic.number_comparator = chroma_pb.NumberComparator.GT
|
||||
elif operator == "$lt":
|
||||
sic.number_comparator = chroma_pb.NumberComparator.LT
|
||||
elif operator == "$gte":
|
||||
sic.number_comparator = chroma_pb.NumberComparator.GTE
|
||||
elif operator == "$lte":
|
||||
sic.number_comparator = chroma_pb.NumberComparator.LTE
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}"
|
||||
)
|
||||
dc.single_int_operand.CopyFrom(sic)
|
||||
elif type(operand) is float:
|
||||
sfc = chroma_pb.SingleDoubleComparison()
|
||||
sfc.value = operand
|
||||
if operator == "$eq":
|
||||
sfc.generic_comparator = chroma_pb.GenericComparator.EQ
|
||||
elif operator == "$ne":
|
||||
sfc.generic_comparator = chroma_pb.GenericComparator.NE
|
||||
elif operator == "$gt":
|
||||
sfc.number_comparator = chroma_pb.NumberComparator.GT
|
||||
elif operator == "$lt":
|
||||
sfc.number_comparator = chroma_pb.NumberComparator.LT
|
||||
elif operator == "$gte":
|
||||
sfc.number_comparator = chroma_pb.NumberComparator.GTE
|
||||
elif operator == "$lte":
|
||||
sfc.number_comparator = chroma_pb.NumberComparator.LTE
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}"
|
||||
)
|
||||
dc.single_double_operand.CopyFrom(sfc)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected where operand value to be a string, int, or float, got {operand}"
|
||||
)
|
||||
else:
|
||||
# This case should never happen, as we've already
|
||||
# handled the case for direct comparisons.
|
||||
pass
|
||||
|
||||
response.direct_comparison.CopyFrom(dc)
|
||||
return response
|
||||
|
||||
|
||||
def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDocument:
|
||||
response = chroma_pb.WhereDocument()
|
||||
if len(where_document) != 1:
|
||||
raise ValueError(
|
||||
f"Expected where_document to have exactly one operator, got {where_document}"
|
||||
)
|
||||
|
||||
for operator, operand in where_document.items():
|
||||
if operator == "$and" or operator == "$or":
|
||||
# Nested "$and" or "$or" expression.
|
||||
if not isinstance(operand, list):
|
||||
raise ValueError(
|
||||
f"Expected where_document value for $and or $or to be a list of where_document expressions, got {operand}"
|
||||
)
|
||||
children: chroma_pb.WhereDocumentChildren = chroma_pb.WhereDocumentChildren(
|
||||
children=[to_proto_where_document(w) for w in operand]
|
||||
)
|
||||
if operator == "$and":
|
||||
children.operator = chroma_pb.BooleanOperator.AND
|
||||
else:
|
||||
children.operator = chroma_pb.BooleanOperator.OR
|
||||
|
||||
response.children.CopyFrom(children)
|
||||
else:
|
||||
# Direct "$contains" or "$not_contains" comparison to a single
|
||||
# value.
|
||||
if not isinstance(operand, str):
|
||||
raise ValueError(
|
||||
f"Expected where_document operand to be a string, got {operand}"
|
||||
)
|
||||
dwd = chroma_pb.DirectWhereDocument()
|
||||
dwd.document = operand
|
||||
if operator == "$contains":
|
||||
dwd.operator = chroma_pb.WhereDocumentOperator.CONTAINS
|
||||
elif operator == "$not_contains":
|
||||
dwd.operator = chroma_pb.WhereDocumentOperator.NOT_CONTAINS
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected where_document operator to be one of $contains, $not_contains, got {operator}"
|
||||
)
|
||||
response.direct.CopyFrom(dwd)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def to_proto_scan(scan: Scan) -> query_pb.ScanOperator:
|
||||
return query_pb.ScanOperator(
|
||||
collection=to_proto_collection(scan.collection),
|
||||
knn=to_proto_segment(scan.knn),
|
||||
metadata=to_proto_segment(scan.metadata),
|
||||
record=to_proto_segment(scan.record),
|
||||
)
|
||||
|
||||
|
||||
def to_proto_filter(filter: Filter) -> query_pb.FilterOperator:
|
||||
return query_pb.FilterOperator(
|
||||
ids=chroma_pb.UserIds(ids=filter.user_ids)
|
||||
if filter.user_ids is not None
|
||||
else None,
|
||||
where=to_proto_where(filter.where) if filter.where else None,
|
||||
where_document=to_proto_where_document(filter.where_document)
|
||||
if filter.where_document
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
def to_proto_knn(knn: KNN) -> query_pb.KNNOperator:
|
||||
return query_pb.KNNOperator(
|
||||
embeddings=[
|
||||
to_proto_vector(vector=embedding, encoding=ScalarEncoding.FLOAT32)
|
||||
for embedding in knn.embeddings
|
||||
],
|
||||
fetch=knn.fetch,
|
||||
)
|
||||
|
||||
|
||||
def to_proto_limit(limit: Limit) -> query_pb.LimitOperator:
|
||||
return query_pb.LimitOperator(offset=limit.offset, limit=limit.limit)
|
||||
|
||||
|
||||
def to_proto_projection(projection: Projection) -> query_pb.ProjectionOperator:
|
||||
return query_pb.ProjectionOperator(
|
||||
document=projection.document,
|
||||
embedding=projection.embedding,
|
||||
metadata=projection.metadata,
|
||||
)
|
||||
|
||||
|
||||
def to_proto_knn_projection(projection: Projection) -> query_pb.KNNProjectionOperator:
|
||||
return query_pb.KNNProjectionOperator(
|
||||
projection=to_proto_projection(projection), distance=projection.rank
|
||||
)
|
||||
|
||||
|
||||
def to_proto_count_plan(count: CountPlan) -> query_pb.CountPlan:
|
||||
return query_pb.CountPlan(scan=to_proto_scan(count.scan))
|
||||
|
||||
|
||||
def from_proto_count_result(result: query_pb.CountResult) -> int:
|
||||
return result.count
|
||||
|
||||
|
||||
def to_proto_get_plan(get: GetPlan) -> query_pb.GetPlan:
|
||||
return query_pb.GetPlan(
|
||||
scan=to_proto_scan(get.scan),
|
||||
filter=to_proto_filter(get.filter),
|
||||
limit=to_proto_limit(get.limit),
|
||||
projection=to_proto_projection(get.projection),
|
||||
)
|
||||
|
||||
|
||||
def from_proto_projection_record(record: query_pb.ProjectionRecord) -> ProjectionRecord:
|
||||
return ProjectionRecord(
|
||||
id=record.id,
|
||||
document=record.document if record.document else None,
|
||||
embedding=from_proto_vector(record.embedding)[0]
|
||||
if record.embedding is not None
|
||||
else None,
|
||||
metadata=from_proto_metadata(record.metadata),
|
||||
)
|
||||
|
||||
|
||||
def from_proto_get_result(result: query_pb.GetResult) -> Sequence[ProjectionRecord]:
|
||||
return [from_proto_projection_record(record) for record in result.records]
|
||||
|
||||
|
||||
def to_proto_knn_plan(knn: KNNPlan) -> query_pb.KNNPlan:
|
||||
return query_pb.KNNPlan(
|
||||
scan=to_proto_scan(knn.scan),
|
||||
filter=to_proto_filter(knn.filter),
|
||||
knn=to_proto_knn(knn.knn),
|
||||
projection=to_proto_knn_projection(knn.projection),
|
||||
)
|
||||
|
||||
|
||||
def from_proto_knn_projection_record(
|
||||
record: query_pb.KNNProjectionRecord,
|
||||
) -> KNNProjectionRecord:
|
||||
return KNNProjectionRecord(
|
||||
record=from_proto_projection_record(record.record), distance=record.distance
|
||||
)
|
||||
|
||||
|
||||
def from_proto_knn_batch_result(
|
||||
results: query_pb.KNNBatchResult,
|
||||
) -> Sequence[Sequence[KNNProjectionRecord]]:
|
||||
return [
|
||||
[from_proto_knn_projection_record(record) for record in result.records]
|
||||
for result in results.results
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user