1119 lines
37 KiB
Python
1119 lines
37 KiB
Python
import multiprocessing
|
|
import os
|
|
import socket
|
|
import subprocess
|
|
import tempfile
|
|
import time
|
|
from typing import (
|
|
Any,
|
|
Generator,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Callable,
|
|
cast,
|
|
)
|
|
from uuid import UUID
|
|
|
|
import hypothesis
|
|
import pytest
|
|
import uvicorn
|
|
from httpx import ConnectError
|
|
from typing_extensions import Protocol
|
|
|
|
from chromadb.api.async_fastapi import AsyncFastAPI
|
|
from chromadb.api.fastapi import FastAPI
|
|
import chromadb.server.fastapi
|
|
from chromadb.api import ClientAPI, ServerAPI, BaseAPI
|
|
from chromadb.config import Settings, System
|
|
from chromadb.db.mixins import embeddings_queue
|
|
from chromadb.ingest import Producer
|
|
from chromadb.types import SeqId, OperationRecord
|
|
from chromadb.api.client import Client as ClientCreator, AdminClient
|
|
from chromadb.api.async_client import (
|
|
AsyncAdminClient,
|
|
AsyncClient as AsyncClientCreator,
|
|
)
|
|
from chromadb.utils.async_to_sync import async_class_to_sync
|
|
import logging
|
|
import sys
|
|
import numpy as np
|
|
from unittest.mock import MagicMock
|
|
from pytest import MonkeyPatch
|
|
from chromadb.api.types import Documents, Embeddings
|
|
import uuid
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VALID_PRESETS = ["fast", "normal", "slow"]
|
|
CURRENT_PRESET = os.getenv("PROPERTY_TESTING_PRESET", "fast")
|
|
|
|
if CURRENT_PRESET not in VALID_PRESETS:
|
|
raise ValueError(
|
|
f"Invalid property testing preset: {CURRENT_PRESET}. Must be one of {VALID_PRESETS}."
|
|
)
|
|
|
|
hypothesis.settings.register_profile(
|
|
"base",
|
|
deadline=90000,
|
|
suppress_health_check=[
|
|
hypothesis.HealthCheck.data_too_large,
|
|
hypothesis.HealthCheck.large_base_example,
|
|
hypothesis.HealthCheck.function_scoped_fixture,
|
|
],
|
|
verbosity=hypothesis.Verbosity.verbose,
|
|
)
|
|
|
|
hypothesis.settings.register_profile(
|
|
"fast", hypothesis.settings.get_profile("base"), max_examples=50
|
|
)
|
|
# Hypothesis's default max_examples is 100
|
|
hypothesis.settings.register_profile(
|
|
"normal", hypothesis.settings.get_profile("base"), max_examples=100
|
|
)
|
|
hypothesis.settings.register_profile(
|
|
"slow",
|
|
hypothesis.settings.get_profile("base"),
|
|
max_examples=500,
|
|
stateful_step_count=100,
|
|
)
|
|
|
|
hypothesis.settings.load_profile(CURRENT_PRESET)
|
|
|
|
# Check if we are running in a mode where SPANN is disabled
|
|
# (Rust bindings test OR Rust single-node integration test)
|
|
is_spann_disabled_mode = (
|
|
os.getenv("CHROMA_RUST_BINDINGS_TEST_ONLY") == "1"
|
|
or os.getenv("CHROMA_INTEGRATION_TEST_ONLY") == "1"
|
|
)
|
|
skip_reason_spann_disabled = (
|
|
"SPANN creation/modification disallowed in Rust bindings or integration test mode"
|
|
)
|
|
skip_reason_spann_enabled = (
|
|
"SPANN creation/modification allowed in Rust bindings or integration test mode"
|
|
)
|
|
|
|
|
|
def reset(api: BaseAPI) -> None:
|
|
api.reset()
|
|
|
|
|
|
def override_hypothesis_profile(
|
|
fast: Optional[hypothesis.settings] = None,
|
|
normal: Optional[hypothesis.settings] = None,
|
|
slow: Optional[hypothesis.settings] = None,
|
|
) -> Optional[hypothesis.settings]:
|
|
"""Override Hypothesis settings for specific profiles.
|
|
|
|
For example, to override max_examples only when the current profile is 'fast':
|
|
|
|
override_hypothesis_profile(
|
|
fast=hypothesis.settings(max_examples=50),
|
|
)
|
|
|
|
Settings will be merged with the default/active profile.
|
|
"""
|
|
|
|
allowable_override_keys = [
|
|
"deadline",
|
|
"max_examples",
|
|
"stateful_step_count",
|
|
"suppress_health_check",
|
|
]
|
|
|
|
override_profiles = {
|
|
"fast": fast,
|
|
"normal": normal,
|
|
"slow": slow,
|
|
}
|
|
|
|
overriding_profile = override_profiles.get(CURRENT_PRESET)
|
|
|
|
if overriding_profile is not None:
|
|
overridden_settings = {
|
|
key: value
|
|
for key, value in overriding_profile.__dict__.items()
|
|
if key in allowable_override_keys
|
|
}
|
|
|
|
return hypothesis.settings(hypothesis.settings.default, **overridden_settings)
|
|
|
|
return cast(hypothesis.settings, hypothesis.settings.default)
|
|
|
|
|
|
NOT_CLUSTER_ONLY = os.getenv("CHROMA_CLUSTER_TEST_ONLY") != "1"
|
|
COMPACTION_SLEEP = 120
|
|
|
|
|
|
def skip_if_not_cluster() -> pytest.MarkDecorator:
|
|
return pytest.mark.skipif(
|
|
NOT_CLUSTER_ONLY,
|
|
reason="Requires Kubernetes to be running with a valid config",
|
|
)
|
|
|
|
|
|
def generate_self_signed_certificate() -> None:
|
|
config_path = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)), "openssl.cnf"
|
|
)
|
|
print(f"Config path: {config_path}") # Debug print to verify path
|
|
if not os.path.exists(config_path):
|
|
raise FileNotFoundError(f"Config file not found at {config_path}")
|
|
subprocess.run(
|
|
[
|
|
"openssl",
|
|
"req",
|
|
"-x509",
|
|
"-newkey",
|
|
"rsa:4096",
|
|
"-keyout",
|
|
"serverkey.pem",
|
|
"-out",
|
|
"servercert.pem",
|
|
"-days",
|
|
"365",
|
|
"-nodes",
|
|
"-subj",
|
|
"/CN=localhost",
|
|
"-config",
|
|
config_path,
|
|
]
|
|
)
|
|
|
|
|
|
def find_free_port() -> int:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("", 0))
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
return s.getsockname()[1] # type: ignore
|
|
|
|
|
|
def _run_server(
|
|
port: int,
|
|
is_persistent: bool = False,
|
|
persist_directory: Optional[str] = None,
|
|
chroma_server_authn_provider: Optional[str] = None,
|
|
chroma_server_authn_credentials_file: Optional[str] = None,
|
|
chroma_server_authn_credentials: Optional[str] = None,
|
|
chroma_auth_token_transport_header: Optional[str] = None,
|
|
chroma_server_authz_provider: Optional[str] = None,
|
|
chroma_server_authz_config_file: Optional[str] = None,
|
|
chroma_server_ssl_certfile: Optional[str] = None,
|
|
chroma_server_ssl_keyfile: Optional[str] = None,
|
|
chroma_overwrite_singleton_tenant_database_access_from_auth: Optional[bool] = False,
|
|
) -> None:
|
|
"""Run a Chroma server locally"""
|
|
if is_persistent and persist_directory:
|
|
settings = Settings(
|
|
chroma_api_impl="chromadb.api.segment.SegmentAPI",
|
|
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
|
|
is_persistent=is_persistent,
|
|
persist_directory=persist_directory,
|
|
allow_reset=True,
|
|
chroma_server_authn_provider=chroma_server_authn_provider,
|
|
chroma_server_authn_credentials_file=chroma_server_authn_credentials_file,
|
|
chroma_server_authn_credentials=chroma_server_authn_credentials,
|
|
chroma_auth_token_transport_header=chroma_auth_token_transport_header,
|
|
chroma_server_authz_provider=chroma_server_authz_provider,
|
|
chroma_server_authz_config_file=chroma_server_authz_config_file,
|
|
chroma_overwrite_singleton_tenant_database_access_from_auth=chroma_overwrite_singleton_tenant_database_access_from_auth,
|
|
)
|
|
else:
|
|
settings = Settings(
|
|
chroma_api_impl="chromadb.api.segment.SegmentAPI",
|
|
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
|
|
is_persistent=False,
|
|
allow_reset=True,
|
|
chroma_server_authn_provider=chroma_server_authn_provider,
|
|
chroma_server_authn_credentials_file=chroma_server_authn_credentials_file,
|
|
chroma_server_authn_credentials=chroma_server_authn_credentials,
|
|
chroma_auth_token_transport_header=chroma_auth_token_transport_header,
|
|
chroma_server_authz_provider=chroma_server_authz_provider,
|
|
chroma_server_authz_config_file=chroma_server_authz_config_file,
|
|
chroma_overwrite_singleton_tenant_database_access_from_auth=chroma_overwrite_singleton_tenant_database_access_from_auth,
|
|
)
|
|
server = chromadb.server.fastapi.FastAPI(settings)
|
|
uvicorn.run(
|
|
server.app(),
|
|
host="0.0.0.0",
|
|
port=port,
|
|
log_level="error",
|
|
timeout_keep_alive=30,
|
|
ssl_keyfile=chroma_server_ssl_keyfile,
|
|
ssl_certfile=chroma_server_ssl_certfile,
|
|
)
|
|
|
|
|
|
def _await_server(api: ServerAPI, attempts: int = 0) -> None:
|
|
try:
|
|
api.heartbeat()
|
|
except ConnectError as e:
|
|
if attempts > 15:
|
|
raise e
|
|
else:
|
|
time.sleep(4)
|
|
_await_server(api, attempts + 1)
|
|
|
|
|
|
def _fastapi_fixture(
|
|
is_persistent: bool = False,
|
|
chroma_api_impl: str = "chromadb.api.fastapi.FastAPI",
|
|
chroma_server_authn_provider: Optional[str] = None,
|
|
chroma_client_auth_provider: Optional[str] = None,
|
|
chroma_server_authn_credentials_file: Optional[str] = None,
|
|
chroma_server_authn_credentials: Optional[str] = None,
|
|
chroma_client_auth_credentials: Optional[str] = None,
|
|
chroma_auth_token_transport_header: Optional[str] = None,
|
|
chroma_server_authz_provider: Optional[str] = None,
|
|
chroma_server_authz_config_file: Optional[str] = None,
|
|
chroma_server_ssl_certfile: Optional[str] = None,
|
|
chroma_server_ssl_keyfile: Optional[str] = None,
|
|
chroma_overwrite_singleton_tenant_database_access_from_auth: Optional[bool] = False,
|
|
) -> Generator[System, None, None]:
|
|
"""Fixture generator that launches a server in a separate process, and yields a
|
|
fastapi client connect to it"""
|
|
|
|
port = find_free_port()
|
|
ctx = multiprocessing.get_context("spawn")
|
|
args: Tuple[
|
|
int,
|
|
bool,
|
|
Optional[str],
|
|
Optional[str],
|
|
Optional[str],
|
|
Optional[str],
|
|
Optional[str],
|
|
Optional[str],
|
|
Optional[str],
|
|
Optional[str],
|
|
Optional[str],
|
|
Optional[bool],
|
|
] = (
|
|
port,
|
|
False,
|
|
None,
|
|
chroma_server_authn_provider,
|
|
chroma_server_authn_credentials_file,
|
|
chroma_server_authn_credentials,
|
|
chroma_auth_token_transport_header,
|
|
chroma_server_authz_provider,
|
|
chroma_server_authz_config_file,
|
|
chroma_server_ssl_certfile,
|
|
chroma_server_ssl_keyfile,
|
|
chroma_overwrite_singleton_tenant_database_access_from_auth,
|
|
)
|
|
|
|
def run(args: Any) -> Generator[System, None, None]:
|
|
proc = ctx.Process(target=_run_server, args=args, daemon=True)
|
|
proc.start()
|
|
settings = Settings(
|
|
chroma_api_impl=chroma_api_impl,
|
|
chroma_server_host="localhost",
|
|
chroma_server_http_port=port,
|
|
allow_reset=True,
|
|
chroma_client_auth_provider=chroma_client_auth_provider,
|
|
chroma_client_auth_credentials=chroma_client_auth_credentials,
|
|
chroma_auth_token_transport_header=chroma_auth_token_transport_header,
|
|
chroma_server_ssl_verify=chroma_server_ssl_certfile,
|
|
chroma_server_ssl_enabled=True if chroma_server_ssl_certfile else False,
|
|
chroma_overwrite_singleton_tenant_database_access_from_auth=chroma_overwrite_singleton_tenant_database_access_from_auth,
|
|
)
|
|
system = System(settings)
|
|
api = system.instance(ServerAPI)
|
|
system.start()
|
|
_await_server(api if isinstance(api, FastAPI) else async_class_to_sync(api))
|
|
yield system
|
|
system.stop()
|
|
proc.kill()
|
|
proc.join()
|
|
|
|
if is_persistent:
|
|
persist_directory = tempfile.TemporaryDirectory()
|
|
args = (
|
|
port,
|
|
is_persistent,
|
|
persist_directory.name,
|
|
chroma_server_authn_provider,
|
|
chroma_server_authn_credentials_file,
|
|
chroma_server_authn_credentials,
|
|
chroma_auth_token_transport_header,
|
|
chroma_server_authz_provider,
|
|
chroma_server_authz_config_file,
|
|
chroma_server_ssl_certfile,
|
|
chroma_server_ssl_keyfile,
|
|
chroma_overwrite_singleton_tenant_database_access_from_auth,
|
|
)
|
|
|
|
yield from run(args)
|
|
|
|
try:
|
|
persist_directory.cleanup()
|
|
|
|
# (Older versions of Python throw NotADirectoryError sometimes instead of PermissionError)
|
|
# (when we drop support for Python < 3.10, we should use ignore_cleanup_errors=True with the context manager instead)
|
|
except (PermissionError, NotADirectoryError) as e:
|
|
# todo: what's holding onto directory contents on Windows?
|
|
if os.name == "nt":
|
|
pass
|
|
else:
|
|
raise e
|
|
|
|
else:
|
|
yield from run(args)
|
|
|
|
|
|
def fastapi() -> Generator[System, None, None]:
|
|
return _fastapi_fixture(is_persistent=False)
|
|
|
|
|
|
def async_fastapi() -> Generator[System, None, None]:
|
|
return _fastapi_fixture(
|
|
is_persistent=False,
|
|
chroma_api_impl="chromadb.api.async_fastapi.AsyncFastAPI",
|
|
)
|
|
|
|
|
|
def fastapi_persistent() -> Generator[System, None, None]:
|
|
return _fastapi_fixture(is_persistent=True)
|
|
|
|
|
|
def fastapi_ssl() -> Generator[System, None, None]:
|
|
generate_self_signed_certificate()
|
|
return _fastapi_fixture(
|
|
is_persistent=False,
|
|
chroma_server_ssl_certfile="./servercert.pem",
|
|
chroma_server_ssl_keyfile="./serverkey.pem",
|
|
)
|
|
|
|
|
|
@pytest.fixture()
|
|
def basic_http_client() -> Generator[System, None, None]:
|
|
port = 8000
|
|
host = "localhost"
|
|
|
|
if os.getenv("CHROMA_SERVER_HOST"):
|
|
host = os.getenv("CHROMA_SERVER_HOST", "").split(":")[0]
|
|
port = int(os.getenv("CHROMA_SERVER_HOST", "").split(":")[1])
|
|
|
|
settings = Settings(
|
|
chroma_api_impl="chromadb.api.fastapi.FastAPI",
|
|
chroma_server_http_port=port,
|
|
chroma_server_host=host,
|
|
allow_reset=True,
|
|
)
|
|
system = System(settings)
|
|
api = system.instance(ServerAPI)
|
|
_await_server(api)
|
|
system.start()
|
|
yield system
|
|
system.stop()
|
|
|
|
|
|
def fastapi_server_basic_auth_valid_cred_single_user() -> Generator[System, None, None]:
|
|
# This (and similar usage below) should use the delete_on_close parameter
|
|
# instead of delete=False, but it's only available in Python 3.12 and later.
|
|
# We must explicitly close the file before spawning a subprocess to avoid
|
|
# file locking issues on Windows.
|
|
with tempfile.NamedTemporaryFile("w", suffix=".htpasswd", delete=False) as f:
|
|
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n")
|
|
f.close()
|
|
|
|
for item in _fastapi_fixture(
|
|
is_persistent=False,
|
|
chroma_server_authn_provider="chromadb.auth.basic_authn.BasicAuthenticationServerProvider",
|
|
chroma_server_authn_credentials_file=f.name,
|
|
chroma_client_auth_provider="chromadb.auth.basic_authn.BasicAuthClientProvider",
|
|
chroma_client_auth_credentials="admin:admin",
|
|
):
|
|
yield item
|
|
|
|
|
|
def fastapi_server_basic_auth_valid_cred_multiple_users() -> (
|
|
Generator[System, None, None]
|
|
):
|
|
creds = {
|
|
"user": "$2y$10$kY9hn.Wlfcj7n1Cnjmy1kuIhEFIVBsfbNWLQ5ahoKmdc2HLA4oP6i",
|
|
"user2": "$2y$10$CymQ63tic/DRj8dD82915eoM4ke3d6RaNKU4dj4IVJlHyea0yeGDS",
|
|
"admin": "$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS",
|
|
}
|
|
with tempfile.NamedTemporaryFile("w", suffix=".htpasswd", delete=False) as f:
|
|
for user, cred in creds.items():
|
|
f.write(f"{user}:{cred}\n")
|
|
f.close()
|
|
|
|
for item in _fastapi_fixture(
|
|
is_persistent=False,
|
|
chroma_server_authn_provider="chromadb.auth.basic_authn.BasicAuthenticationServerProvider",
|
|
chroma_server_authn_credentials_file=f.name,
|
|
chroma_client_auth_provider="chromadb.auth.basic_authn.BasicAuthClientProvider",
|
|
chroma_client_auth_credentials="admin:admin",
|
|
):
|
|
yield item
|
|
|
|
|
|
def fastapi_server_basic_auth_invalid_cred() -> Generator[System, None, None]:
|
|
with tempfile.NamedTemporaryFile("w", suffix=".htpasswd", delete=False) as f:
|
|
f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n")
|
|
f.close()
|
|
|
|
for item in _fastapi_fixture(
|
|
is_persistent=False,
|
|
chroma_server_authn_provider="chromadb.auth.basic_authn.BasicAuthenticationServerProvider",
|
|
chroma_server_authn_credentials_file=f.name,
|
|
chroma_client_auth_provider="chromadb.auth.basic_authn.BasicAuthClientProvider",
|
|
chroma_client_auth_credentials="admin:admin1",
|
|
):
|
|
yield item
|
|
|
|
|
|
def fastapi_server_basic_authn_rbac_authz() -> Generator[System, None, None]:
|
|
with tempfile.NamedTemporaryFile(
|
|
"w", suffix=".htpasswd", delete=False
|
|
) as server_authn_file:
|
|
server_authn_file.write(
|
|
"admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n"
|
|
)
|
|
server_authn_file.close()
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
"w", suffix=".authz", delete=False
|
|
) as server_authz_file:
|
|
server_authz_file.write(
|
|
"""
|
|
roles_mapping:
|
|
admin:
|
|
actions:
|
|
[
|
|
"system:reset",
|
|
"tenant:create_tenant",
|
|
"tenant:get_tenant",
|
|
"db:create_database",
|
|
"db:get_database",
|
|
"db:list_collections",
|
|
"db:create_collection",
|
|
"db:get_or_create_collection",
|
|
"collection:get_collection",
|
|
"collection:delete_collection",
|
|
"collection:update_collection",
|
|
"collection:add",
|
|
"collection:delete",
|
|
"collection:get",
|
|
"collection:query",
|
|
"collection:peek",
|
|
"collection:update",
|
|
"collection:upsert",
|
|
"collection:count",
|
|
]
|
|
users:
|
|
- id: admin
|
|
role: admin
|
|
"""
|
|
)
|
|
server_authz_file.close()
|
|
|
|
for item in _fastapi_fixture(
|
|
is_persistent=False,
|
|
chroma_client_auth_provider="chromadb.auth.basic_authn.BasicAuthClientProvider",
|
|
chroma_client_auth_credentials="admin:admin",
|
|
chroma_server_authn_provider="chromadb.auth.basic_authn.BasicAuthenticationServerProvider",
|
|
chroma_server_authn_credentials_file=server_authn_file.name,
|
|
chroma_server_authz_provider="chromadb.auth.simple_rbac_authz.SimpleRBACAuthorizationProvider",
|
|
chroma_server_authz_config_file=server_authz_file.name,
|
|
):
|
|
yield item
|
|
|
|
|
|
def fastapi_fixture_admin_and_singleton_tenant_db_user() -> (
|
|
Generator[System, None, None]
|
|
):
|
|
with tempfile.NamedTemporaryFile("w", suffix=".authn", delete=False) as f:
|
|
f.write(
|
|
"""
|
|
users:
|
|
- id: admin
|
|
tokens:
|
|
- admin-token
|
|
- id: singleton_user
|
|
tenant: singleton_tenant
|
|
databases:
|
|
- singleton_database
|
|
tokens:
|
|
- singleton-token
|
|
"""
|
|
)
|
|
f.close()
|
|
|
|
for item in _fastapi_fixture(
|
|
is_persistent=False,
|
|
chroma_overwrite_singleton_tenant_database_access_from_auth=True,
|
|
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
|
|
chroma_client_auth_credentials="admin-token",
|
|
chroma_server_authn_provider="chromadb.auth.token_authn.TokenAuthenticationServerProvider",
|
|
chroma_server_authn_credentials_file=f.name,
|
|
):
|
|
yield item
|
|
|
|
|
|
@pytest.fixture()
|
|
def integration() -> Generator[System, None, None]:
|
|
"""Fixture generator for returning a client configured via environmenet
|
|
variables, intended for externally configured integration tests
|
|
"""
|
|
settings = Settings(
|
|
allow_reset=True, chroma_api_impl="chromadb.api.fastapi.FastAPI"
|
|
)
|
|
system = System(settings)
|
|
system.start()
|
|
yield system
|
|
system.stop()
|
|
|
|
|
|
@pytest.fixture()
|
|
def async_integration() -> Generator[System, None, None]:
|
|
"""Fixture generator for returning a client configured via environmenet
|
|
variables, intended for externally configured integration tests
|
|
"""
|
|
settings = Settings(
|
|
allow_reset=True, chroma_api_impl="chromadb.api.async_fastapi.AsyncFastAPI"
|
|
)
|
|
system = System(settings)
|
|
system.start()
|
|
yield system
|
|
system.stop()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def python_sqlite_ephemeral() -> Generator[System, None, None]:
|
|
"""Fixture generator for segment-based API using in-memory Sqlite"""
|
|
settings = Settings(
|
|
chroma_api_impl="chromadb.api.segment.SegmentAPI",
|
|
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
|
|
is_persistent=False,
|
|
allow_reset=True,
|
|
)
|
|
system = System(settings)
|
|
system.start()
|
|
yield system
|
|
system.stop()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def python_sqlite_persistent() -> Generator[System, None, None]:
|
|
"""Fixture generator for segment-based API using persistent Sqlite"""
|
|
save_path = tempfile.TemporaryDirectory()
|
|
settings = Settings(
|
|
chroma_api_impl="chromadb.api.segment.SegmentAPI",
|
|
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
|
|
allow_reset=True,
|
|
is_persistent=True,
|
|
persist_directory=save_path.name,
|
|
)
|
|
system = System(settings)
|
|
system.start()
|
|
yield system
|
|
system.stop()
|
|
|
|
try:
|
|
save_path.cleanup()
|
|
|
|
# (Older versions of Python throw NotADirectoryError sometimes instead of PermissionError)
|
|
# (when we drop support for Python < 3.10, we should use ignore_cleanup_errors=True with the context manager instead)
|
|
except (PermissionError, NotADirectoryError) as e:
|
|
# todo: what's holding onto directory contents on Windows?
|
|
if os.name == "nt":
|
|
pass
|
|
else:
|
|
raise e
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def rust_sqlite_ephemeral() -> Generator[System, None, None]:
|
|
"""Fixture generator for system using ephemeral Rust bindings"""
|
|
settings = Settings(
|
|
chroma_api_impl="chromadb.api.rust.RustBindingsAPI",
|
|
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
|
|
is_persistent=False,
|
|
allow_reset=True,
|
|
persist_directory="",
|
|
)
|
|
system = System(settings)
|
|
system.start()
|
|
yield system
|
|
system.stop()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def rust_sqlite_persistent() -> Generator[System, None, None]:
|
|
"""Fixture generator for system using Rust bindings"""
|
|
save_path = tempfile.TemporaryDirectory()
|
|
settings = Settings(
|
|
chroma_api_impl="chromadb.api.rust.RustBindingsAPI",
|
|
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
|
|
is_persistent=True,
|
|
allow_reset=True,
|
|
persist_directory=save_path.name,
|
|
)
|
|
system = System(settings)
|
|
system.start()
|
|
yield system
|
|
system.stop()
|
|
|
|
|
|
@pytest.fixture(
|
|
params=["rust_sqlite_ephemeral"]
|
|
if "CHROMA_RUST_BINDINGS_TEST_ONLY" in os.environ
|
|
else ["python_sqlite_ephemeral"]
|
|
)
|
|
def sqlite(request: pytest.FixtureRequest) -> Generator[System, None, None]:
|
|
return request.getfixturevalue(request.param)
|
|
|
|
|
|
@pytest.fixture(
|
|
params=["rust_sqlite_persistent"]
|
|
if "CHROMA_RUST_BINDINGS_TEST_ONLY" in os.environ
|
|
else ["python_sqlite_persistent"]
|
|
)
|
|
def sqlite_persistent(request: pytest.FixtureRequest) -> Generator[System, None, None]:
|
|
return request.getfixturevalue(request.param)
|
|
|
|
|
|
def filtered_fixture_names() -> List[str]:
|
|
fixtures = [
|
|
"fastapi",
|
|
"async_fastapi",
|
|
"fastapi_persistent",
|
|
"sqlite_fixture",
|
|
"sqlite_persistent",
|
|
]
|
|
|
|
if "CHROMA_INTEGRATION_TEST" in os.environ:
|
|
fixtures.append("integration")
|
|
fixtures.append("async_integration")
|
|
if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ:
|
|
fixtures = ["integration", "async_integration"]
|
|
if "CHROMA_CLUSTER_TEST_ONLY" in os.environ:
|
|
fixtures = ["basic_http_client"]
|
|
if "CHROMA_RUST_BINDINGS_TEST_ONLY" in os.environ:
|
|
fixtures = ["rust_sqlite_ephemeral", "rust_sqlite_persistent"]
|
|
return fixtures
|
|
|
|
|
|
def filtered_http_server_fixture_names() -> List[str]:
|
|
fixtures = [
|
|
fixture
|
|
for fixture in filtered_fixture_names()
|
|
if fixture
|
|
not in [
|
|
"python_sqlite_ephemeral",
|
|
"python_sqlite_persistent",
|
|
"rust_sqlite_ephemeral",
|
|
"rust_sqlite_persistent",
|
|
]
|
|
]
|
|
return fixtures
|
|
|
|
|
|
def system_fixtures_auth() -> List[Callable[[], Generator[System, None, None]]]:
|
|
fixtures = [
|
|
fastapi_server_basic_auth_valid_cred_single_user,
|
|
fastapi_server_basic_auth_valid_cred_multiple_users,
|
|
]
|
|
return fixtures
|
|
|
|
|
|
def system_fixtures_authn_rbac_authz() -> (
|
|
List[Callable[[], Generator[System, None, None]]]
|
|
):
|
|
fixtures = [fastapi_server_basic_authn_rbac_authz]
|
|
return fixtures
|
|
|
|
|
|
def system_fixtures_root_and_singleton_tenant_db_user() -> (
|
|
List[Callable[[], Generator[System, None, None]]]
|
|
):
|
|
fixtures = [fastapi_fixture_admin_and_singleton_tenant_db_user]
|
|
return fixtures
|
|
|
|
|
|
def system_fixtures_wrong_auth() -> List[Callable[[], Generator[System, None, None]]]:
|
|
fixtures = [fastapi_server_basic_auth_invalid_cred]
|
|
return fixtures
|
|
|
|
|
|
def system_fixtures_ssl() -> List[Callable[[], Generator[System, None, None]]]:
|
|
fixtures = [fastapi_ssl]
|
|
return fixtures
|
|
|
|
|
|
@pytest.fixture(scope="module", params=system_fixtures_wrong_auth())
|
|
def system_wrong_auth(
|
|
request: pytest.FixtureRequest,
|
|
) -> Generator[ServerAPI, None, None]:
|
|
yield from request.param()
|
|
|
|
|
|
@pytest.fixture(scope="module", params=system_fixtures_authn_rbac_authz())
|
|
def system_authn_rbac_authz(
|
|
request: pytest.FixtureRequest,
|
|
) -> Generator[ServerAPI, None, None]:
|
|
yield from request.param()
|
|
|
|
|
|
@pytest.fixture(params=filtered_http_server_fixture_names())
|
|
def system_http_server(
|
|
request: pytest.FixtureRequest,
|
|
) -> Generator[ServerAPI, None, None]:
|
|
return request.getfixturevalue(request.param)
|
|
|
|
|
|
@pytest.fixture(scope="function", params=filtered_fixture_names())
|
|
def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
|
|
return request.getfixturevalue(request.param)
|
|
|
|
|
|
@pytest.fixture(scope="module", params=system_fixtures_ssl())
|
|
def system_ssl(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
|
|
yield from request.param()
|
|
|
|
|
|
@pytest.fixture(scope="module", params=system_fixtures_auth())
|
|
def system_auth(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
|
|
yield from request.param()
|
|
|
|
|
|
@async_class_to_sync
|
|
class AsyncClientCreatorSync(AsyncClientCreator):
|
|
pass
|
|
|
|
|
|
@async_class_to_sync
|
|
class AsyncAdminClientSync(AsyncAdminClient):
|
|
pass
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def api(system: System) -> Generator[ServerAPI, None, None]:
|
|
system.reset_state()
|
|
api = system.instance(ServerAPI)
|
|
|
|
if isinstance(api, AsyncFastAPI):
|
|
transformed = async_class_to_sync(api)
|
|
yield transformed
|
|
else:
|
|
yield api
|
|
|
|
|
|
class ClientFactories:
|
|
"""This allows consuming tests to be parameterized by async/sync versions of the client and papers over the async implementation.
|
|
If you don't need to manually construct clients, use the `client` fixture instead.
|
|
"""
|
|
|
|
_system: System
|
|
# Need to track created clients so we can call .clear_system_cache() during teardown
|
|
_created_clients: List[ClientAPI] = []
|
|
|
|
def __init__(self, system: System):
|
|
self._system = system
|
|
|
|
def create_client(self, *args: Any, **kwargs: Any) -> ClientCreator:
|
|
if kwargs.get("settings") is None:
|
|
kwargs["settings"] = self._system.settings
|
|
|
|
if (
|
|
self._system.settings.chroma_api_impl
|
|
== "chromadb.api.async_fastapi.AsyncFastAPI"
|
|
):
|
|
client = cast(ClientCreator, AsyncClientCreatorSync.create(*args, **kwargs))
|
|
self._created_clients.append(client)
|
|
return client
|
|
|
|
client = ClientCreator(*args, **kwargs)
|
|
self._created_clients.append(client)
|
|
return client
|
|
|
|
def create_client_from_system(self) -> ClientCreator:
|
|
if (
|
|
self._system.settings.chroma_api_impl
|
|
== "chromadb.api.async_fastapi.AsyncFastAPI"
|
|
):
|
|
client = cast(
|
|
ClientCreator, AsyncClientCreatorSync.from_system_async(self._system)
|
|
)
|
|
self._created_clients.append(client)
|
|
return client
|
|
|
|
client = ClientCreator.from_system(self._system)
|
|
self._created_clients.append(client)
|
|
return client
|
|
|
|
def create_admin_client(self, *args: Any, **kwargs: Any) -> AdminClient:
|
|
if (
|
|
self._system.settings.chroma_api_impl
|
|
== "chromadb.api.async_fastapi.AsyncFastAPI"
|
|
):
|
|
client = cast(AdminClient, AsyncAdminClientSync(*args, **kwargs))
|
|
self._created_clients.append(client)
|
|
return client
|
|
|
|
client = AdminClient(*args, **kwargs)
|
|
self._created_clients.append(client)
|
|
return client
|
|
|
|
def create_admin_client_from_system(self) -> AdminClient:
|
|
if (
|
|
self._system.settings.chroma_api_impl
|
|
== "chromadb.api.async_fastapi.AsyncFastAPI"
|
|
):
|
|
client = cast(AdminClient, AsyncAdminClientSync.from_system(self._system))
|
|
self._created_clients.append(client)
|
|
return client
|
|
|
|
client = AdminClient.from_system(self._system)
|
|
self._created_clients.append(client)
|
|
return client
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client_factories(system: System) -> Generator[ClientFactories, None, None]:
|
|
system.reset_state()
|
|
|
|
factories = ClientFactories(system)
|
|
yield factories
|
|
|
|
while len(factories._created_clients) > 0:
|
|
client = factories._created_clients.pop()
|
|
client.clear_system_cache()
|
|
del client
|
|
|
|
|
|
def create_isolated_database(client: ClientAPI) -> None:
|
|
"""Create an isolated database for a test and updates the client to use it."""
|
|
admin_settings = client.get_settings()
|
|
if admin_settings.chroma_api_impl == "chromadb.api.async_fastapi.AsyncFastAPI":
|
|
admin_settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
|
|
|
admin = AdminClient(admin_settings)
|
|
database = "test_" + str(uuid.uuid4())
|
|
admin.create_database(database)
|
|
client.set_database(database)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client(system: System) -> Generator[ClientAPI, None, None]:
|
|
system.reset_state()
|
|
|
|
if system.settings.chroma_api_impl == "chromadb.api.async_fastapi.AsyncFastAPI":
|
|
client = cast(Any, AsyncClientCreatorSync.from_system_async(system))
|
|
yield client
|
|
client.clear_system_cache()
|
|
else:
|
|
client = ClientCreator.from_system(system)
|
|
yield client
|
|
client.clear_system_cache()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def http_client(system_http_server: System) -> Generator[ClientAPI, None, None]:
|
|
system_http_server.reset_state()
|
|
|
|
if (
|
|
system_http_server.settings.chroma_api_impl
|
|
== "chromadb.api.async_fastapi.AsyncFastAPI"
|
|
):
|
|
client = cast(Any, AsyncClientCreatorSync.from_system_async(system_http_server))
|
|
yield client
|
|
client.clear_system_cache()
|
|
else:
|
|
client = ClientCreator.from_system(system_http_server)
|
|
yield client
|
|
client.clear_system_cache()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client_ssl(system_ssl: System) -> Generator[ClientAPI, None, None]:
|
|
system_ssl.reset_state()
|
|
client = ClientCreator.from_system(system_ssl)
|
|
yield client
|
|
client.clear_system_cache()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def api_wrong_cred(
|
|
system_wrong_auth: System,
|
|
) -> Generator[ServerAPI, None, None]:
|
|
system_wrong_auth.reset_state()
|
|
api = system_wrong_auth.instance(ServerAPI)
|
|
yield api
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def api_with_authn_rbac_authz(
|
|
system_authn_rbac_authz: System,
|
|
) -> Generator[ServerAPI, None, None]:
|
|
system_authn_rbac_authz.reset_state()
|
|
api = system_authn_rbac_authz.instance(ServerAPI)
|
|
yield api
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def api_with_server_auth(system_auth: System) -> Generator[ServerAPI, None, None]:
|
|
_sys = system_auth
|
|
_sys.reset_state()
|
|
api = _sys.instance(ServerAPI)
|
|
yield api
|
|
|
|
|
|
# Producer / Consumer fixtures #
|
|
|
|
|
|
class ProducerFn(Protocol):
|
|
def __call__(
|
|
self,
|
|
producer: Producer,
|
|
collection_id: UUID,
|
|
embeddings: Iterator[OperationRecord],
|
|
n: int,
|
|
) -> Tuple[Sequence[OperationRecord], Sequence[SeqId]]:
|
|
...
|
|
|
|
|
|
def produce_n_single(
|
|
producer: Producer,
|
|
collection_id: UUID,
|
|
embeddings: Iterator[OperationRecord],
|
|
n: int,
|
|
) -> Tuple[Sequence[OperationRecord], Sequence[SeqId]]:
|
|
submitted_embeddings = []
|
|
seq_ids = []
|
|
for _ in range(n):
|
|
e = next(embeddings)
|
|
seq_id = producer.submit_embedding(collection_id, e)
|
|
submitted_embeddings.append(e)
|
|
seq_ids.append(seq_id)
|
|
return submitted_embeddings, seq_ids
|
|
|
|
|
|
def produce_n_batch(
|
|
producer: Producer,
|
|
collection_id: UUID,
|
|
embeddings: Iterator[OperationRecord],
|
|
n: int,
|
|
) -> Tuple[Sequence[OperationRecord], Sequence[SeqId]]:
|
|
submitted_embeddings = []
|
|
seq_ids: Sequence[SeqId] = []
|
|
for _ in range(n):
|
|
e = next(embeddings)
|
|
submitted_embeddings.append(e)
|
|
seq_ids = producer.submit_embeddings(collection_id, submitted_embeddings)
|
|
return submitted_embeddings, seq_ids
|
|
|
|
|
|
def produce_fn_fixtures() -> List[ProducerFn]:
|
|
return [produce_n_single, produce_n_batch]
|
|
|
|
|
|
@pytest.fixture(scope="module", params=produce_fn_fixtures())
|
|
def produce_fns(
|
|
request: pytest.FixtureRequest,
|
|
) -> Generator[ProducerFn, None, None]:
|
|
yield request.param
|
|
|
|
|
|
def pytest_configure(config): # type: ignore
|
|
embeddings_queue._called_from_test = True
|
|
|
|
|
|
def is_client_in_process(client: ClientAPI) -> bool:
|
|
"""Returns True if the client is in-process (a SQLite client), False if it's out-of-process (a HTTP client)."""
|
|
return client.get_settings().chroma_server_http_port is None
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def log_tests(request: pytest.FixtureRequest) -> Generator[None, None, None]:
|
|
"""Automatically logs the start and end of each test."""
|
|
test_name = request.node.name
|
|
logger.debug(f"Starting test: {test_name}")
|
|
|
|
# Yield control back to the test, allowing it to execute
|
|
yield
|
|
|
|
logger.debug(f"Finished test: {test_name}")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_embeddings() -> Callable[[Documents], Embeddings]:
|
|
"""Return mock embeddings for testing"""
|
|
|
|
def _mock_embeddings(input: Documents) -> Embeddings:
|
|
return [np.array([0.1, 0.2, 0.3], dtype=np.float32) for _ in input]
|
|
|
|
return _mock_embeddings
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_common_deps(monkeypatch: MonkeyPatch) -> MonkeyPatch:
|
|
"""Mock common dependencies"""
|
|
# Create mock modules
|
|
mock_modules = {
|
|
"PIL": MagicMock(),
|
|
"torch": MagicMock(),
|
|
"openai": MagicMock(),
|
|
"cohere": MagicMock(),
|
|
"sentence_transformers": MagicMock(),
|
|
"ollama": MagicMock(),
|
|
"InstructorEmbedding": MagicMock(),
|
|
"voyageai": MagicMock(),
|
|
"text2vec": MagicMock(),
|
|
"open_clip": MagicMock(),
|
|
"boto3": MagicMock(),
|
|
}
|
|
|
|
# Mock all modules at once using monkeypatch.setitem
|
|
monkeypatch.setattr(sys, "modules", dict(sys.modules, **mock_modules))
|
|
|
|
# Mock submodules and attributes
|
|
mock_attributes = {
|
|
"PIL.Image": MagicMock(),
|
|
"sentence_transformers.SentenceTransformer": MagicMock(),
|
|
"ollama.Client": MagicMock(),
|
|
"InstructorEmbedding.INSTRUCTOR": MagicMock(),
|
|
"voyageai.Client": MagicMock(),
|
|
"text2vec.SentenceModel": MagicMock(),
|
|
}
|
|
|
|
# Setup OpenCLIP mock with specific behavior
|
|
mock_model = MagicMock()
|
|
mock_model.encode_text.return_value = np.array([[0.1, 0.2, 0.3]])
|
|
mock_model.encode_image.return_value = np.array([[0.1, 0.2, 0.3]])
|
|
mock_modules["open_clip"].create_model_and_transforms.return_value = (
|
|
mock_model,
|
|
MagicMock(),
|
|
mock_model,
|
|
)
|
|
|
|
# Mock all attributes
|
|
for path, mock in mock_attributes.items():
|
|
monkeypatch.setattr(path, mock, raising=False)
|
|
|
|
return monkeypatch
|