增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -31,7 +31,7 @@ from chromadb.utils.embedding_functions import (
)
from chromadb.api.models.Collection import Collection
from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.errors import InvalidArgumentError, InternalError
from chromadb.errors import InvalidArgumentError
from chromadb.execution.expression import Knn, Search
from chromadb.types import Collection as CollectionModel
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
@@ -509,7 +509,7 @@ def test_sparse_vector_not_allowed_locally(
schema = Schema()
schema.create_index(key="sparse_metadata", config=SparseVectorIndexConfig())
with pytest.raises(
InternalError, match="Sparse vector indexing is not enabled in local"
InvalidArgumentError, match="Sparse vector indexing is not enabled in local"
):
_create_isolated_collection(client_factories, schema=schema)
@@ -1011,12 +1011,26 @@ def test_collection_fork_inherits_and_isolates_schema(
client_factories,
schema=schema,
)
parent_version_before_add = get_collection_version(client, collection.name)
parent_ids = [f"parent-{i}" for i in range(251)]
parent_docs = [f"parent doc {i}" for i in range(251)]
parent_metadatas: List[Mapping[str, Any]] = [
{"shared_key": f"parent_{i}"} for i in range(251)
]
collection.add(
ids=["parent-1"],
documents=["parent doc"],
metadatas=[{"shared_key": "parent"}],
ids=parent_ids,
documents=parent_docs,
metadatas=parent_metadatas,
)
# Wait for parent to compact before forking. Otherwise, the fork inherits
# uncompacted logs, and compaction of those inherited logs could increment
# the fork's version before the fork's own data is compacted.
wait_for_version_increase(client, collection.name, parent_version_before_add)
assert collection.schema is not None
parent_schema_json = collection.schema.serialize_to_json()
@@ -1050,8 +1064,8 @@ def test_collection_fork_inherits_and_isolates_schema(
assert reloaded_parent.schema is not None
assert "child_only" not in reloaded_parent.schema.keys
parent_results = reloaded_parent.get(where={"shared_key": "parent"})
assert set(parent_results["ids"]) == {"parent-1"}
parent_results = reloaded_parent.get(where={"shared_key": "parent_10"})
assert set(parent_results["ids"]) == {"parent-10"}
child_results = forked.get(where={"child_only": "value_10"})
assert set(child_results["ids"]) == {"fork-10"}

View File

@@ -7,11 +7,21 @@ for automatically processing collections.
import pytest
from chromadb.api.client import Client as ClientCreator
from chromadb.api.functions import (
RECORD_COUNTER_FUNCTION,
STATISTICS_FUNCTION,
Function,
)
from chromadb.config import System
from chromadb.errors import ChromaError, NotFoundError
from chromadb.test.utils.wait_for_version_increase import (
get_collection_version,
wait_for_version_increase,
)
from time import sleep
def test_function_attach_and_detach(basic_http_client: System) -> None:
def test_count_function_attach_and_detach(basic_http_client: System) -> None:
"""Test creating and removing a function with the record_counter operator"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
@@ -22,45 +32,39 @@ def test_function_attach_and_detach(basic_http_client: System) -> None:
metadata={"description": "Sample documents for task processing"},
)
# Add initial documents
collection.add(
ids=["doc1", "doc2", "doc3"],
documents=[
"The quick brown fox jumps over the lazy dog",
"Machine learning is a subset of artificial intelligence",
"Python is a popular programming language",
],
metadatas=[{"source": "proverb"}, {"source": "tech"}, {"source": "tech"}],
)
# Verify collection has documents
assert collection.count() == 3
# Create a task that counts records in the collection
attached_fn = collection.attach_function(
attached_fn, created = collection.attach_function(
name="count_my_docs",
function_id="record_counter", # Built-in operator that counts records
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
# Verify task creation succeeded
assert attached_fn is not None
assert created is True
initial_version = get_collection_version(client, collection.name)
# Add more documents
# Add documents
collection.add(
ids=["doc4", "doc5"],
documents=[
"Chroma is a vector database",
"Tasks automate data processing",
],
ids=["doc_{}".format(i) for i in range(0, 300)],
documents=["test document"] * 300,
)
# Verify documents were added
assert collection.count() == 5
assert collection.count() == 300
wait_for_version_increase(client, collection.name, initial_version)
# Give some time to invalidate the frontend query cache
sleep(60)
result = client.get_collection("my_documents_counts").get("function_output")
assert result["metadatas"] is not None
assert result["metadatas"][0]["total_count"] == 300
# Remove the task
success = attached_fn.detach(
success = collection.detach_function(
attached_fn.name,
delete_output_collection=True,
)
@@ -79,13 +83,42 @@ def test_task_with_invalid_function(basic_http_client: System) -> None:
# Attempt to create task with non-existent function should raise ChromaError
with pytest.raises(ChromaError, match="function not found"):
collection.attach_function(
function=Function._NONEXISTENT_TEST_ONLY,
name="invalid_task",
function_id="nonexistent_function",
output_collection="output_collection",
params=None,
)
def test_attach_function_returns_function_name(basic_http_client: System) -> None:
"""Test that attach_function and get_attached_function return function_name field instead of UUID"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
collection = client.create_collection(name="test_function_name")
collection.add(ids=["id1"], documents=["doc1"])
# Attach a function and verify function_name field in response
attached_fn, created = collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="my_counter",
output_collection="output_collection",
params=None,
)
# Verify the attached function has function_name (not function_id UUID)
assert created is True
assert attached_fn.function_name == "record_counter"
assert attached_fn.name == "my_counter"
# Get the attached function and verify function_name field is also present
retrieved_fn = collection.get_attached_function("my_counter")
assert retrieved_fn == attached_fn
# Clean up
collection.detach_function(attached_fn.name, delete_output_collection=True)
def test_function_multiple_collections(basic_http_client: System) -> None:
"""Test attaching functions on multiple collections"""
client = ClientCreator.from_system(basic_http_client)
@@ -95,93 +128,163 @@ def test_function_multiple_collections(basic_http_client: System) -> None:
collection1 = client.create_collection(name="collection_1")
collection1.add(ids=["id1", "id2"], documents=["doc1", "doc2"])
attached_fn1 = collection1.attach_function(
attached_fn1, created1 = collection1.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_1",
function_id="record_counter",
output_collection="output_1",
params=None,
)
assert attached_fn1 is not None
assert created1 is True
# Create second collection and task
collection2 = client.create_collection(name="collection_2")
collection2.add(ids=["id3", "id4"], documents=["doc3", "doc4"])
attached_fn2 = collection2.attach_function(
attached_fn2, created2 = collection2.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_2",
function_id="record_counter",
output_collection="output_2",
params=None,
)
assert attached_fn2 is not None
assert created2 is True
# Task IDs should be different
assert attached_fn1.id != attached_fn2.id
# Clean up
assert attached_fn1.detach(delete_output_collection=True) is True
assert attached_fn2.detach(delete_output_collection=True) is True
assert (
collection1.detach_function(attached_fn1.name, delete_output_collection=True)
is True
)
assert (
collection2.detach_function(attached_fn2.name, delete_output_collection=True)
is True
)
def test_functions_multiple_attached_functions(basic_http_client: System) -> None:
"""Test attaching multiple functions on the same collection"""
def test_functions_one_attached_function_per_collection(
basic_http_client: System,
) -> None:
"""Test that only one attached function is allowed per collection"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create a single collection
collection = client.create_collection(name="multi_task_collection")
collection = client.create_collection(name="single_task_collection")
collection.add(ids=["id1", "id2", "id3"], documents=["doc1", "doc2", "doc3"])
# Create first task on the collection
attached_fn1 = collection.attach_function(
attached_fn1, created = collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_1",
function_id="record_counter",
output_collection="output_1",
params=None,
)
assert attached_fn1 is not None
assert created is True
# Create second task on the SAME collection with a different name
attached_fn2 = collection.attach_function(
# Attempt to create a second task with a different name should fail
# (only one attached function allowed per collection)
with pytest.raises(
ChromaError,
match="collection already has an attached function: name=task_1, function=record_counter, output_collection=output_1",
):
collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_2",
output_collection="output_2",
params=None,
)
# Attempt to create a task with the same name but different function_id should also fail
with pytest.raises(
ChromaError,
match=r"collection already has an attached function: name=task_1, function=record_counter, output_collection=output_1",
):
collection.attach_function(
function=STATISTICS_FUNCTION,
name="task_1",
output_collection="output_different", # Different output collection
params=None,
)
# Detach the first function
assert (
collection.detach_function(attached_fn1.name, delete_output_collection=True)
is True
)
# Now we should be able to attach a new function
attached_fn2, created2 = collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="task_2",
function_id="record_counter",
output_collection="output_2",
params=None,
)
assert attached_fn2 is not None
assert created2 is True
assert attached_fn2.id != attached_fn1.id
# Task IDs should be different even though they're on the same collection
assert attached_fn1.id != attached_fn2.id
# Create third task on the same collection
attached_fn3 = collection.attach_function(
name="task_3",
function_id="record_counter",
output_collection="output_3",
params=None,
# Clean up
assert (
collection.detach_function(attached_fn2.name, delete_output_collection=True)
is True
)
assert attached_fn3 is not None
assert attached_fn3.id != attached_fn1.id
assert attached_fn3.id != attached_fn2.id
# Attempt to create a task with duplicate name on same collection should fail
with pytest.raises(ChromaError, match="already exists"):
def test_attach_function_with_invalid_params(basic_http_client: System) -> None:
"""Test that attach_function with non-empty params raises an error"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
collection = client.create_collection(name="test_invalid_params")
collection.add(ids=["id1"], documents=["test document"])
# Attempt to create task with non-empty params should fail
# (no functions currently accept parameters)
with pytest.raises(
ChromaError,
match="params must be empty - no functions currently accept parameters",
):
collection.attach_function(
name="task_1", # Duplicate name
function_id="record_counter",
output_collection="output_duplicate",
params=None,
name="invalid_params_task",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params={"some_key": "some_value"},
)
# Clean up - remove each task individually
assert attached_fn1.detach(delete_output_collection=True) is True
assert attached_fn2.detach(delete_output_collection=True) is True
assert attached_fn3.detach(delete_output_collection=True) is True
def test_attach_function_output_collection_already_exists(
basic_http_client: System,
) -> None:
"""Test that attach_function fails when output collection name already exists"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create a collection that will be used as input
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test document"])
# Create another collection with the name we want to use for output
client.create_collection(name="existing_output_collection")
# Attempt to create task with output collection name that already exists
with pytest.raises(
ChromaError,
match=r"Output collection \[existing_output_collection\] already exists",
):
input_collection.attach_function(
name="my_task",
function=RECORD_COUNTER_FUNCTION,
output_collection="existing_output_collection",
params=None,
)
def test_function_remove_nonexistent(basic_http_client: System) -> None:
@@ -191,15 +294,294 @@ def test_function_remove_nonexistent(basic_http_client: System) -> None:
collection = client.create_collection(name="test_collection")
collection.add(ids=["id1"], documents=["test"])
attached_fn = collection.attach_function(
attached_fn, _ = collection.attach_function(
function=RECORD_COUNTER_FUNCTION,
name="test_function",
function_id="record_counter",
output_collection="output_collection",
params=None,
)
attached_fn.detach(delete_output_collection=True)
collection.detach_function(attached_fn.name, delete_output_collection=True)
# Trying to detach this function again should raise NotFoundError
with pytest.raises(NotFoundError, match="does not exist"):
attached_fn.detach(delete_output_collection=True)
collection.detach_function(attached_fn.name, delete_output_collection=True)
def test_attach_to_output_collection_fails(basic_http_client: System) -> None:
"""Test that attaching a function to an output collection fails"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create input collection
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test"])
_, _ = input_collection.attach_function(
name="test_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
output_collection = client.get_collection(name="output_collection")
with pytest.raises(
ChromaError, match="cannot attach function to an output collection"
):
_ = output_collection.attach_function(
name="test_function_2",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection_2",
params=None,
)
def test_delete_output_collection_detaches_function(basic_http_client: System) -> None:
"""Test that deleting an output collection also detaches the attached function"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create input collection and attach a function
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test"])
attached_fn, created = input_collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn is not None
assert created is True
# Delete the output collection directly
client.delete_collection("output_collection")
# The attached function should now be gone - trying to get it should raise NotFoundError
with pytest.raises(NotFoundError):
input_collection.get_attached_function("my_function")
def test_delete_orphaned_output_collection(basic_http_client: System) -> None:
"""Test that deleting an output collection from a recently detached function works"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create input collection and attach a function
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test"])
attached_fn, created = input_collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn is not None
assert created is True
input_collection.detach_function(attached_fn.name, delete_output_collection=False)
# Delete the output collection directly
client.delete_collection("output_collection")
# The attached function should still exist but be marked as detached
with pytest.raises(NotFoundError):
input_collection.get_attached_function("my_function")
with pytest.raises(NotFoundError):
# Try to use the function - it should fail since it's detached
client.get_collection("output_collection")
def test_partial_attach_function_repair(
basic_http_client: System,
) -> None:
"""Test creating and removing a function with the record_counter operator"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create a collection
collection = client.get_or_create_collection(
name="my_document",
)
# Create a task that counts records in the collection
attached_fn, created = collection.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
assert created is True
# Verify task creation succeeded
assert attached_fn is not None
collection2 = client.get_or_create_collection(
name="my_document2",
)
# Create a task that counts records in the collection
# This should fail
with pytest.raises(
ChromaError, match=r"Output collection \[my_documents_counts\] already exists"
):
attached_fn, _ = collection2.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
# Detach the function
assert (
collection.detach_function(attached_fn.name, delete_output_collection=True)
is True
)
# Create a task that counts records in the collection
attached_fn, created = collection2.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
assert attached_fn is not None
assert created is True
def test_output_collection_created_with_schema(basic_http_client: System) -> None:
"""Test that output collections are created with the source_attached_function_id in the schema"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create input collection and attach a function
input_collection = client.create_collection(name="input_collection")
input_collection.add(ids=["id1"], documents=["test"])
attached_fn, created = input_collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn is not None
assert created is True
# Get the output collection - it should exist
output_collection = client.get_collection(name="output_collection")
assert output_collection is not None
# The source_attached_function_id is stored in the schema (not metadata)
# We can't directly access the schema from the client, but we verify the collection exists
# and the attached function orchestrator will use this field internally
assert "source_attached_function_id" in output_collection._model.pretty_schema()
# Clean up
input_collection.detach_function(attached_fn.name, delete_output_collection=True)
def test_count_function_attach_and_detach_attach_attach(
basic_http_client: System,
) -> None:
"""Test creating and removing a function with the record_counter operator"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
# Create a collection
collection = client.get_or_create_collection(
name="my_document",
metadata={"description": "Sample documents for task processing"},
)
# Create a task that counts records in the collection
attached_fn, created = collection.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
# Verify task creation succeeded
assert created is True
assert attached_fn is not None
initial_version = get_collection_version(client, collection.name)
# Add documents
collection.add(
ids=["doc_{}".format(i) for i in range(0, 300)],
documents=["test document"] * 300,
)
# Verify documents were added
assert collection.count() == 300
wait_for_version_increase(client, collection.name, initial_version)
# Give some time to invalidate the frontend query cache
sleep(60)
result = client.get_collection("my_documents_counts").get("function_output")
assert result["metadatas"] is not None
assert result["metadatas"][0]["total_count"] == 300
# Remove the task
success = collection.detach_function(
attached_fn.name, delete_output_collection=True
)
# Verify task removal succeeded
assert success is True
# Attach a function that counts records in the collection
attached_fn, created = collection.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
assert attached_fn is not None
assert created is True
# Attach a function that counts records in the collection
attached_fn, created = collection.attach_function(
name="count_my_docs",
function=RECORD_COUNTER_FUNCTION,
output_collection="my_documents_counts",
params=None,
)
assert created is False
assert attached_fn is not None
def test_attach_function_idempotency(basic_http_client: System) -> None:
"""Test that attach_function is idempotent - calling it twice with same params returns created=False"""
client = ClientCreator.from_system(basic_http_client)
client.reset()
collection = client.create_collection(name="idempotency_test")
collection.add(ids=["id1"], documents=["test document"])
# First attach - should be newly created
attached_fn1, created1 = collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn1 is not None
assert created1 is True
# Second attach with identical params - should be idempotent (created=False)
attached_fn2, created2 = collection.attach_function(
name="my_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="output_collection",
params=None,
)
assert attached_fn2 is not None
assert created2 is False
# Both should return the same function ID
assert attached_fn1.id == attached_fn2.id
# Clean up
collection.detach_function(attached_fn1.name, delete_output_collection=True)

View File

@@ -1,7 +1,9 @@
import math
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from chromadb import SparseVector
from chromadb.utils.embedding_functions.chroma_bm25_embedding_function import (
DEFAULT_CHROMA_BM25_STOPWORDS,
ChromaBm25EmbeddingFunction,
@@ -137,3 +139,64 @@ def test_validate_config_update_allows_known_keys() -> None:
embedder.validate_config_update(
embedder.get_config(), {"k": 1.1, "stopwords": ["custom"]}
)
def test_multithreaded_usage() -> None:
embedder = ChromaBm25EmbeddingFunction()
base_texts = [
"""The gravitational wave background from massive black hole binaries emit bursts of
gravitational waves at periapse. Such events may be directly resolvable in the Galactic
centre. However, if the star does not spiral in, the emitted GWs are not resolvable for
extra-galactic MBHs, but constitute a source of background noise. We estimate the power
spectrum of this extreme mass ratio burst background.""",
"""Dynamics of planets in exoplanetary systems with multiple stars showing how the
gravitational interactions between the stars and planets affect the orbital stability
and long-term evolution of the planetary system architectures.""",
"""Diurnal Thermal Tides in a Non-rotating atmosphere with realistic heating profiles
and temperature gradients that demonstrate the complex interplay between radiation
and atmospheric dynamics in planetary atmospheres.""",
"""Intermittent turbulence, noise and waves in stellar atmospheres create complex
patterns of energy transport and momentum deposition that influence the structure
and evolution of stellar interiors and surfaces.""",
"""Superconductivity in quantum materials and condensed matter physics systems
exhibiting novel quantum phenomena including topological phases, strongly correlated
electron systems, and exotic superconducting pairing mechanisms.""",
"""Machine learning models require careful tuning of hyperparameters including learning
rates, regularization coefficients, and architectural choices that demonstrate the
complex interplay between optimization algorithms and model capacity.""",
"""Natural language processing enables text understanding through sophisticated
algorithms that analyze semantic relationships, syntactic structures, and contextual
information to extract meaningful representations from unstructured textual data.""",
"""Vector databases store high-dimensional embeddings efficiently using advanced
indexing techniques including approximate nearest neighbor search algorithms that
balance accuracy and computational efficiency for large-scale similarity search.""",
]
texts = base_texts * 30
num_threads = 10
def process_single_text(text: str) -> SparseVector:
return embedder([text])[0]
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(process_single_text, text) for text in texts]
all_results = []
for future in as_completed(futures):
try:
embedding = future.result()
all_results.append(embedding)
except Exception as e:
pytest.fail(
f"Threading error detected: {type(e).__name__}: {e}. "
"This indicates the stemmer is not thread-safe when cached."
)
assert len(all_results) == len(texts)
for embedding in all_results:
assert embedding.indices
assert len(embedding.indices) == len(embedding.values)
assert _is_sorted(embedding.indices)
for value in embedding.values:
assert value > 0
assert math.isfinite(value)

View File

@@ -33,12 +33,14 @@ def test_get_builtins_holds() -> None:
"GoogleGenerativeAiEmbeddingFunction",
"GooglePalmEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"GoogleGenaiEmbeddingFunction",
"HuggingFaceEmbeddingFunction",
"HuggingFaceEmbeddingServer",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"MistralEmbeddingFunction",
"MorphEmbeddingFunction",
"NomicEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OllamaEmbeddingFunction",
"OpenAIEmbeddingFunction",

View File

@@ -1,7 +1,7 @@
import hashlib
import hypothesis
import hypothesis.strategies as st
from typing import Any, Optional, List, Dict, Union, cast
from typing import Any, Optional, List, Dict, Union, cast, Tuple
from typing_extensions import TypedDict
import uuid
import numpy as np
@@ -9,6 +9,10 @@ import numpy.typing as npt
import chromadb.api.types as types
import re
from hypothesis.strategies._internal.strategies import SearchStrategy
from chromadb.test.api.test_schema_e2e import (
SimpleEmbeddingFunction,
DeterministicSparseEmbeddingFunction,
)
from chromadb.test.conftest import NOT_CLUSTER_ONLY
from dataclasses import dataclass
from chromadb.api.types import (
@@ -17,12 +21,27 @@ from chromadb.api.types import (
EmbeddingFunction,
Embeddings,
Metadata,
Schema,
CollectionMetadata,
VectorIndexConfig,
SparseVectorIndexConfig,
StringInvertedIndexConfig,
IntInvertedIndexConfig,
FloatInvertedIndexConfig,
BoolInvertedIndexConfig,
HnswIndexConfig,
SpannIndexConfig,
Space,
)
from chromadb.types import LiteralValue, WhereOperator, LogicalOperator
from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled
from chromadb.test.conftest import is_spann_disabled_mode
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
CreateSpannConfiguration,
CreateHNSWConfiguration,
)
from chromadb.utils.embedding_functions import (
register_embedding_function,
)
# Set the random seed for reproducibility
@@ -266,6 +285,365 @@ class ExternalCollection:
embedding_function: Optional[types.EmbeddingFunction[Embeddable]]
@register_embedding_function
class SimpleIpEmbeddingFunction(SimpleEmbeddingFunction):
"""Simple embedding function with ip space for persistence tests."""
def default_space(self) -> str: # type: ignore[override]
return "ip"
@st.composite
def vector_index_config_strategy(draw: st.DrawFn) -> VectorIndexConfig:
"""Generate VectorIndexConfig with optional space, embedding_function, source_key, hnsw, spann."""
space = None
embedding_function = None
source_key = None
hnsw = None
spann = None
if draw(st.booleans()):
space = draw(st.sampled_from(["cosine", "l2", "ip"]))
if draw(st.booleans()):
embedding_function = SimpleIpEmbeddingFunction(
dim=draw(st.integers(min_value=1, max_value=1000))
)
if draw(st.booleans()):
source_key = draw(st.one_of(st.just("#document"), safe_text))
index_choice = draw(st.sampled_from(["hnsw", "spann", "none"]))
if index_choice == "hnsw":
hnsw = HnswIndexConfig(
ef_construction=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
max_neighbors=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
ef_search=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
sync_threshold=draw(st.integers(min_value=2, max_value=10000))
if draw(st.booleans())
else None,
resize_factor=draw(st.floats(min_value=1.0, max_value=5.0))
if draw(st.booleans())
else None,
)
elif index_choice == "spann":
spann = SpannIndexConfig(
search_nprobe=draw(st.integers(min_value=1, max_value=128))
if draw(st.booleans())
else None,
write_nprobe=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
ef_construction=draw(st.integers(min_value=1, max_value=200))
if draw(st.booleans())
else None,
ef_search=draw(st.integers(min_value=1, max_value=200))
if draw(st.booleans())
else None,
max_neighbors=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
reassign_neighbor_count=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
split_threshold=draw(st.integers(min_value=50, max_value=200))
if draw(st.booleans())
else None,
merge_threshold=draw(st.integers(min_value=25, max_value=100))
if draw(st.booleans())
else None,
)
return VectorIndexConfig(
space=cast(Space, space),
embedding_function=embedding_function,
source_key=source_key,
hnsw=hnsw,
spann=spann,
)
@st.composite
def sparse_vector_index_config_strategy(draw: st.DrawFn) -> SparseVectorIndexConfig:
"""Generate SparseVectorIndexConfig with optional embedding_function, source_key, bm25."""
embedding_function = None
source_key = None
bm25 = None
if draw(st.booleans()):
embedding_function = DeterministicSparseEmbeddingFunction()
source_key = draw(st.one_of(st.just("#document"), safe_text))
if draw(st.booleans()):
bm25 = draw(st.booleans())
return SparseVectorIndexConfig(
embedding_function=embedding_function,
source_key=source_key,
bm25=bm25,
)
@st.composite
def schema_strategy(draw: st.DrawFn) -> Optional[Schema]:
"""Generate a Schema object with various create_index/delete_index operations."""
if draw(st.booleans()):
return None
schema = Schema()
# Decide how many operations to perform
num_operations = draw(st.integers(min_value=0, max_value=5))
sparse_index_created = False
for _ in range(num_operations):
operation = draw(st.sampled_from(["create_index", "delete_index"]))
config_type = draw(
st.sampled_from(
[
"string_inverted",
"int_inverted",
"float_inverted",
"bool_inverted",
"vector",
"sparse_vector",
]
)
)
# Decide if we're setting on a key or globally
use_key = draw(st.booleans())
key = None
if use_key and config_type != "vector":
# Vector indexes can't be set on specific keys, only globally
key = draw(safe_text)
if operation == "create_index":
if config_type == "string_inverted":
schema.create_index(config=StringInvertedIndexConfig(), key=key)
elif config_type == "int_inverted":
schema.create_index(config=IntInvertedIndexConfig(), key=key)
elif config_type == "float_inverted":
schema.create_index(config=FloatInvertedIndexConfig(), key=key)
elif config_type == "bool_inverted":
schema.create_index(config=BoolInvertedIndexConfig(), key=key)
elif config_type == "vector":
vector_config = draw(vector_index_config_strategy())
schema.create_index(config=vector_config, key=None)
elif (
config_type == "sparse_vector"
and not is_spann_disabled_mode
and not sparse_index_created
):
sparse_config = draw(sparse_vector_index_config_strategy())
# Sparse vector MUST have a key
if key is None:
key = draw(safe_text)
schema.create_index(config=sparse_config, key=key)
sparse_index_created = True
elif operation == "delete_index":
if config_type == "string_inverted":
schema.delete_index(config=StringInvertedIndexConfig(), key=key)
elif config_type == "int_inverted":
schema.delete_index(config=IntInvertedIndexConfig(), key=key)
elif config_type == "float_inverted":
schema.delete_index(config=FloatInvertedIndexConfig(), key=key)
elif config_type == "bool_inverted":
schema.delete_index(config=BoolInvertedIndexConfig(), key=key)
# Vector, FTS, and sparse_vector deletion is not currently supported
return schema
@st.composite
def metadata_with_hnsw_strategy(draw: st.DrawFn) -> Optional[CollectionMetadata]:
"""Generate metadata with hnsw parameters."""
metadata: CollectionMetadata = {}
if draw(st.booleans()):
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
if draw(st.booleans()):
metadata["hnsw:construction_ef"] = draw(
st.integers(min_value=1, max_value=1000)
)
if draw(st.booleans()):
metadata["hnsw:search_ef"] = draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans()):
metadata["hnsw:M"] = draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans()):
metadata["hnsw:resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
if draw(st.booleans()):
metadata["hnsw:sync_threshold"] = draw(
st.integers(min_value=2, max_value=10000)
)
return metadata if metadata else None
@st.composite
def create_configuration_strategy(
draw: st.DrawFn,
) -> Optional[CreateCollectionConfiguration]:
"""Generate CreateCollectionConfiguration with mutual exclusivity rules."""
configuration: CreateCollectionConfiguration = {}
# Optionally set embedding_function (independent)
if draw(st.booleans()):
configuration["embedding_function"] = SimpleIpEmbeddingFunction(
dim=draw(st.integers(min_value=1, max_value=1000))
)
# Decide: set space only, OR set hnsw config, OR set spann config
config_choice = draw(
st.sampled_from(
["space_only_hnsw", "space_only_spann", "hnsw", "spann", "none"]
)
)
if config_choice == "space_only_hnsw":
configuration["hnsw"] = CreateHNSWConfiguration(
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
)
elif config_choice == "space_only_spann":
configuration["spann"] = CreateSpannConfiguration(
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
)
elif config_choice == "hnsw":
# Set hnsw config (optionally with space)
hnsw_config: CreateHNSWConfiguration = {}
if draw(st.booleans()):
hnsw_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
hnsw_config["ef_construction"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["ef_search"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["sync_threshold"] = draw(st.integers(min_value=2, max_value=10000))
hnsw_config["resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
configuration["hnsw"] = hnsw_config
elif config_choice == "spann":
# Set spann config (optionally with space)
spann_config: CreateSpannConfiguration = {}
if draw(st.booleans()):
spann_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
spann_config["search_nprobe"] = draw(st.integers(min_value=1, max_value=128))
spann_config["write_nprobe"] = draw(st.integers(min_value=1, max_value=64))
spann_config["ef_construction"] = draw(st.integers(min_value=1, max_value=200))
spann_config["ef_search"] = draw(st.integers(min_value=1, max_value=200))
spann_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=64))
spann_config["reassign_neighbor_count"] = draw(
st.integers(min_value=1, max_value=64)
)
spann_config["split_threshold"] = draw(st.integers(min_value=50, max_value=200))
spann_config["merge_threshold"] = draw(st.integers(min_value=25, max_value=100))
configuration["spann"] = spann_config
return configuration if configuration else None
@dataclass
class CollectionInputCombination:
"""
Input tuple for collection creation tests.
"""
metadata: Optional[CollectionMetadata]
configuration: Optional[CreateCollectionConfiguration]
schema: Optional[Schema]
schema_vector_info: Optional[Dict[str, Any]]
kind: str
def non_none_items(items: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in items.items() if v is not None}
def vector_index_to_dict(config: VectorIndexConfig) -> Dict[str, Any]:
embedding_default_space: Optional[str] = None
if config.embedding_function is not None and hasattr(
config.embedding_function, "default_space"
):
embedding_default_space = cast(str, config.embedding_function.default_space())
return {
"space": config.space,
"hnsw": config.hnsw.model_dump(exclude_none=True) if config.hnsw else None,
"spann": config.spann.model_dump(exclude_none=True) if config.spann else None,
"embedding_function_default_space": embedding_default_space,
}
@st.composite
def _schema_input_strategy(
draw: st.DrawFn,
) -> Tuple[Schema, Dict[str, Any]]:
schema = Schema()
vector_config = draw(vector_index_config_strategy())
schema.create_index(config=vector_config, key=None)
return schema, vector_index_to_dict(vector_config)
@st.composite
def metadata_configuration_schema_strategy(
draw: st.DrawFn,
) -> CollectionInputCombination:
"""
Generate compatible combinations of metadata, configuration, and schema inputs.
"""
choice = draw(
st.sampled_from(
[
"none",
"metadata",
"configuration",
"metadata_configuration",
"schema",
]
)
)
metadata: Optional[CollectionMetadata] = None
configuration: Optional[CreateCollectionConfiguration] = None
schema: Optional[Schema] = None
schema_info: Optional[Dict[str, Any]] = None
if choice in ("metadata", "metadata_configuration"):
metadata = draw(
metadata_with_hnsw_strategy().filter(
lambda value: value is not None and len(value) > 0
)
)
if choice in ("configuration", "metadata_configuration"):
configuration = draw(
create_configuration_strategy().filter(
lambda value: value is not None
and (
(value.get("hnsw") is not None and len(value["hnsw"]) > 0)
or (value.get("spann") is not None and len(value["spann"]) > 0)
)
)
)
if choice == "schema":
schema, schema_info = draw(_schema_input_strategy())
return CollectionInputCombination(
metadata=metadata,
configuration=configuration,
schema=schema,
schema_vector_info=schema_info,
kind=choice,
)
@dataclass
class Collection(ExternalCollection):
"""
@@ -344,7 +722,7 @@ def collections(
spann_config: CreateSpannConfiguration = {
"space": spann_space,
"write_nprobe": 4,
"reassign_neighbor_count": 4
"reassign_neighbor_count": 4,
}
collection_config = {
"spann": spann_config,
@@ -395,7 +773,7 @@ def collections(
known_document_keywords=known_document_keywords,
has_embeddings=has_embeddings,
embedding_function=embedding_function,
collection_config=collection_config
collection_config=collection_config,
)
@@ -421,7 +799,9 @@ def metadata(
del metadata[key] # type: ignore
# Finally, add in some of the known keys for the collection
sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = {
k: st.just(v) for k, v in collection.known_metadata_keys.items()
k: st.just(v)
for k, v in collection.known_metadata_keys.items()
if isinstance(v, (str, int, float))
}
metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore
# We don't allow submitting empty metadata

View File

@@ -31,7 +31,7 @@ collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="co
normal=hypothesis.settings(max_examples=500),
fast=hypothesis.settings(max_examples=200),
),
max_examples=2
max_examples=2,
)
def test_add_miniscule(
client: ClientAPI,
@@ -332,7 +332,8 @@ def test_out_of_order_ids(client: ClientAPI) -> None:
]
coll = client.create_collection(
"test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore
"test",
embedding_function=lambda input: [[1, 2, 3] for _ in input], # type: ignore
)
embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids]
coll.add(ids=ooo_ids, embeddings=embeddings)
@@ -369,3 +370,155 @@ def test_add_partial(client: ClientAPI) -> None:
assert results["ids"] == ["1", "2", "3"]
assert results["metadatas"] == [{"a": 1}, None, {"a": 3}]
assert results["documents"] == ["a", "b", None]
@pytest.mark.skipif(
NOT_CLUSTER_ONLY,
reason="GroupBy is only supported in distributed mode",
)
def test_search_group_by(client: ClientAPI) -> None:
"""Test GroupBy with single key, multiple keys, and multiple ranking keys."""
from chromadb.execution.expression.operator import GroupBy, MinK, Key
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression import Knn
create_isolated_database(client)
coll = client.create_collection(name="test_group_by")
# Test data: 12 records across 3 categories and 2 years
# Embeddings are designed so science docs are closest to query [1,0,0,0]
ids = [
"sci_2023_1",
"sci_2023_2",
"sci_2024_1",
"sci_2024_2",
"tech_2023_1",
"tech_2023_2",
"tech_2024_1",
"tech_2024_2",
"arts_2023_1",
"arts_2023_2",
"arts_2024_1",
"arts_2024_2",
]
embeddings = cast(
Embeddings,
[
# Science - closest to [1,0,0,0]
[1.0, 0.0, 0.0, 0.0], # sci_2023_1: score ~0.0
[0.9, 0.1, 0.0, 0.0], # sci_2023_2: score ~0.141
[0.8, 0.2, 0.0, 0.0], # sci_2024_1: score ~0.283
[0.7, 0.3, 0.0, 0.0], # sci_2024_2: score ~0.424
# Tech - farther from [1,0,0,0]
[0.0, 1.0, 0.0, 0.0], # tech_2023_1: score ~1.414
[0.0, 0.9, 0.1, 0.0], # tech_2023_2: score ~1.345
[0.0, 0.8, 0.2, 0.0], # tech_2024_1: score ~1.281
[0.0, 0.7, 0.3, 0.0], # tech_2024_2: score ~1.221
# Arts - farther from [1,0,0,0]
[0.0, 0.0, 1.0, 0.0], # arts_2023_1: score ~1.414
[0.0, 0.0, 0.9, 0.1], # arts_2023_2: score ~1.345
[0.0, 0.0, 0.8, 0.2], # arts_2024_1: score ~1.281
[0.0, 0.0, 0.7, 0.3], # arts_2024_2: score ~1.221
],
)
metadatas: Metadatas = [
{"category": "science", "year": 2023, "priority": 1},
{"category": "science", "year": 2023, "priority": 2},
{"category": "science", "year": 2024, "priority": 1},
{"category": "science", "year": 2024, "priority": 3},
{"category": "tech", "year": 2023, "priority": 2},
{"category": "tech", "year": 2023, "priority": 1},
{"category": "tech", "year": 2024, "priority": 1},
{"category": "tech", "year": 2024, "priority": 2},
{"category": "arts", "year": 2023, "priority": 3},
{"category": "arts", "year": 2023, "priority": 1},
{"category": "arts", "year": 2024, "priority": 2},
{"category": "arts", "year": 2024, "priority": 1},
]
documents = [f"doc_{id}" for id in ids]
coll.add(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
)
query = [1.0, 0.0, 0.0, 0.0]
# Test 1: Single key grouping - top 2 per category by score
# Expected: 2 best from each category (science, tech, arts)
# - science: sci_2023_1 (0.0), sci_2023_2 (0.141)
# - tech: tech_2024_2 (1.221), tech_2024_1 (1.281)
# - arts: arts_2024_2 (1.221), arts_2024_1 (1.281)
results1 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=2)))
.limit(12)
)
assert results1["ids"] is not None
result1_ids = results1["ids"][0]
assert len(result1_ids) == 6
expected1 = {
"sci_2023_1",
"sci_2023_2",
"tech_2024_2",
"tech_2024_1",
"arts_2024_2",
"arts_2024_1",
}
assert set(result1_ids) == expected1
# Test 2: Multiple key grouping - top 1 per (category, year) combination
# 6 groups: (science,2023), (science,2024), (tech,2023), (tech,2024), (arts,2023), (arts,2024)
results2 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(
GroupBy(
keys=[Key("category"), Key("year")],
aggregate=MinK(keys=Key.SCORE, k=1),
)
)
.limit(12)
)
assert results2["ids"] is not None
result2_ids = results2["ids"][0]
assert len(result2_ids) == 6
expected2 = {
"sci_2023_1",
"sci_2024_1",
"tech_2023_2",
"tech_2024_2",
"arts_2023_2",
"arts_2024_2",
}
assert set(result2_ids) == expected2
# Test 3: Multiple ranking keys - priority first, then score as tiebreaker
# Top 2 per category, sorted by priority (ascending), then score (ascending)
results3 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=2),
)
)
.limit(12)
)
assert results3["ids"] is not None
result3_ids = results3["ids"][0]
assert len(result3_ids) == 6
expected3 = {
"sci_2023_1",
"sci_2024_1",
"tech_2024_1",
"tech_2023_2",
"arts_2024_2",
"arts_2023_2",
}
assert set(result3_ids) == expected3

View File

@@ -1983,7 +1983,9 @@ def test_sparse_vector_in_metadata_validation():
with pytest.raises(ValueError, match="SparseVector values must be numbers"):
invalid_metadata_4 = {
"text": "non-numeric value",
"sparse_embedding": SparseVector(indices=[0, 1], values=[0.1, "not_a_number"]), # type: ignore
"sparse_embedding": SparseVector(
indices=[0, 1], values=[0.1, "not_a_number"]
), # type: ignore
}
# Test 7: Multiple sparse vectors in metadata
@@ -2683,6 +2685,59 @@ def test_rrf_to_dict() -> None:
print("All RRF tests passed!")
def test_group_by_serialization() -> None:
"""Test GroupBy, MinK, and MaxK serialization and deserialization."""
import pytest
from chromadb.execution.expression.operator import (
GroupBy,
MinK,
MaxK,
Key,
Aggregate,
)
# to_dict with OneOrMany keys
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
assert group_by.to_dict() == {
"keys": ["category"],
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
}
# to_dict with multiple keys and MaxK
group_by = GroupBy(
keys=[Key("year"), Key("category")],
aggregate=MaxK(keys=[Key.SCORE, Key("priority")], k=5),
)
assert group_by.to_dict() == {
"keys": ["year", "category"],
"aggregate": {"$max_k": {"keys": ["#score", "priority"], "k": 5}},
}
# Round-trip
original = GroupBy(keys=[Key("category")], aggregate=MinK(keys=[Key.SCORE], k=3))
assert GroupBy.from_dict(original.to_dict()).to_dict() == original.to_dict()
# Empty GroupBy serializes to {} and from_dict({}) returns default GroupBy
empty_group_by = GroupBy()
assert empty_group_by.to_dict() == {}
assert GroupBy.from_dict({}).to_dict() == {}
# Error cases
with pytest.raises(ValueError, match="requires 'keys' field"):
GroupBy.from_dict({"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}})
with pytest.raises(ValueError, match="requires 'aggregate' field"):
GroupBy.from_dict({"keys": ["category"]})
with pytest.raises(ValueError, match="keys cannot be empty"):
GroupBy.from_dict(
{"keys": [], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
)
with pytest.raises(ValueError, match="Unknown aggregate operator"):
Aggregate.from_dict({"$unknown": {"keys": ["#score"], "k": 3}})
# Expression API Tests - Testing dict support and from_dict methods
class TestSearchDictSupport:
"""Test Search class dict input support."""
@@ -2786,6 +2841,49 @@ class TestSearchDictSupport:
with pytest.raises(TypeError, match="select must be"):
Search(select=123)
def test_search_with_group_by(self):
"""Test Search accepts group_by as dict, object, and builder method."""
import pytest
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import GroupBy, MinK, Key
# Dict input
search = Search(
group_by={
"keys": ["category"],
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
}
)
assert isinstance(search._group_by, GroupBy)
# Object input and builder method
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
assert Search(group_by=group_by)._group_by is group_by
assert Search().group_by(group_by)._group_by.aggregate is not None
# Invalid inputs
with pytest.raises(TypeError, match="group_by must be"):
Search(group_by="invalid")
with pytest.raises(ValueError, match="requires 'aggregate' field"):
Search(group_by={"keys": ["category"]})
def test_search_group_by_serialization(self):
"""Test Search serializes group_by correctly."""
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import GroupBy, MinK, Key, Knn
# Without group_by - empty dict
search = Search().rank(Knn(query=[0.1, 0.2])).limit(10)
assert search.to_dict()["group_by"] == {}
# With group_by - has keys and aggregate
search = Search().group_by(
GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
)
result = search.to_dict()["group_by"]
assert result["keys"] == ["category"]
assert result["aggregate"] == {"$min_k": {"keys": ["#score"], "k": 3}}
class TestWhereFromDict:
"""Test Where.from_dict() conversion."""
@@ -3310,3 +3408,27 @@ class TestRoundTripConversion:
return d1 == d2
assert compare_search_dicts(new_dict, search_dict)
def test_search_round_trip_with_group_by(self):
"""Test Search round-trip with group_by."""
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import Key, GroupBy, MinK
original = Search(
where=Key("status") == "active",
group_by=GroupBy(
keys=[Key("category")],
aggregate=MinK(keys=[Key.SCORE], k=3),
),
)
# Verify group_by round-trip
search_dict = original.to_dict()
assert search_dict["group_by"]["keys"] == ["category"]
assert search_dict["group_by"]["aggregate"] == {
"$min_k": {"keys": ["#score"], "k": 3}
}
# Reconstruct and compare group_by
restored = Search(group_by=GroupBy.from_dict(search_dict["group_by"]))
assert restored.to_dict()["group_by"] == search_dict["group_by"]

View File

@@ -1,10 +1,11 @@
import asyncio
from typing import Any, Callable, Generator, cast
from unittest.mock import patch
from typing import Any, Callable, Generator, cast, Dict, Tuple
from unittest.mock import MagicMock, patch
import chromadb
from chromadb.config import Settings
from chromadb.config import Settings, System
from chromadb.api import ClientAPI
import chromadb.server.fastapi
from chromadb.api.fastapi import FastAPI
import pytest
import tempfile
import os
@@ -110,3 +111,43 @@ def test_http_client_with_inconsistent_port_settings(
str(e)
== "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]"
)
def make_sync_client_factory() -> Tuple[Callable[..., Any], Dict[str, Any]]:
captured: Dict[str, Any] = {}
# takes any positional args to match httpx.Client
def factory(*_: Any, **kwargs: Any) -> Any:
captured.update(kwargs)
session = MagicMock()
session.headers = {}
return session
return factory, captured
def test_fastapi_uses_http_limits_from_settings() -> None:
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host="localhost",
chroma_server_http_port=9000,
chroma_server_ssl_verify=True,
chroma_http_keepalive_secs=12.5,
chroma_http_max_connections=64,
chroma_http_max_keepalive_connections=16,
)
system = System(settings)
factory, captured = make_sync_client_factory()
with patch.object(FastAPI, "require", side_effect=[MagicMock(), MagicMock()]):
with patch("chromadb.api.fastapi.httpx.Client", side_effect=factory):
api = FastAPI(system)
api.stop()
limits = captured["limits"]
assert limits.keepalive_expiry == 12.5
assert limits.max_connections == 64
assert limits.max_keepalive_connections == 16
assert captured["timeout"] is None
assert captured["verify"] is True

View File

@@ -189,3 +189,21 @@ def test_runtime_dependencies() -> None:
assert data.starts == ["D", "C"]
system.stop()
assert data.stops == ["C", "D"]
def test_http_client_setting_defaults() -> None:
settings = Settings()
assert settings.chroma_http_keepalive_secs == 40.0
assert settings.chroma_http_max_connections is None
assert settings.chroma_http_max_keepalive_connections is None
def test_http_client_setting_overrides() -> None:
settings = Settings(
chroma_http_keepalive_secs=5.5,
chroma_http_max_connections=123,
chroma_http_max_keepalive_connections=17,
)
assert settings.chroma_http_keepalive_secs == 5.5
assert settings.chroma_http_max_connections == 123
assert settings.chroma_http_max_keepalive_connections == 17

View File

@@ -1,5 +1,5 @@
import pytest
from typing import List, Any, Callable
from typing import List, Any, Callable, Dict
from jsonschema import ValidationError
from unittest.mock import MagicMock, create_autospec
from chromadb.utils.embedding_functions.schemas import (
@@ -7,7 +7,10 @@ from chromadb.utils.embedding_functions.schemas import (
load_schema,
get_available_schemas,
)
from chromadb.utils.embedding_functions import known_embedding_functions
from chromadb.utils.embedding_functions import (
known_embedding_functions,
sparse_known_embedding_functions,
)
from chromadb.api.types import Documents, Embeddings
from pytest import MonkeyPatch
@@ -143,3 +146,306 @@ class TestEmbeddingFunctionSchemas:
if schema.get("additionalProperties", True) is False:
with pytest.raises(ValidationError):
validate_config_schema(test_config, schema_name)
def _create_valid_config_from_schema(
self, schema: Dict[str, Any]
) -> Dict[str, Any]:
"""Create a valid config from a schema by filling in required fields"""
config: Dict[str, Any] = {}
if "required" in schema and "properties" in schema:
for field in schema["required"]:
if field in schema["properties"]:
field_schema = schema["properties"][field]
config[field] = self._get_value_from_field_schema(field_schema)
return config
def _get_value_from_field_schema(self, field_schema: Dict[str, Any]) -> Any:
"""Get a valid value from a field schema"""
# Handle enums - use first enum value
if "enum" in field_schema:
return field_schema["enum"][0]
# Handle type (could be a list or single value)
field_type = field_schema.get("type")
if field_type is None:
return "dummy" # Fallback if no type specified
if isinstance(field_type, list):
# If null is in the type list, prefer non-null type
non_null_types = [t for t in field_type if t != "null"]
field_type = non_null_types[0] if non_null_types else field_type[0]
if field_type == "object":
# Handle nested objects
nested_config = {}
if "properties" in field_schema:
nested_required = field_schema.get("required", [])
for prop in nested_required:
if prop in field_schema["properties"]:
nested_config[prop] = self._get_value_from_field_schema(
field_schema["properties"][prop]
)
return nested_config if nested_config else {}
if field_type == "array":
# Return empty array for arrays
return []
# Use the existing dummy value method for primitive types
return self._get_dummy_value(field_type)
def _has_custom_validation(self, ef_class: Any) -> bool:
"""Check if validate_config actually validates (not just base implementation)"""
try:
# Try with an obviously invalid config - if it doesn't raise, it's base implementation
invalid_config = {"__invalid_test_config__": True}
try:
ef_class.validate_config(invalid_config)
# If we get here without exception, it's using base implementation
return False
except (ValidationError, ValueError, FileNotFoundError):
# If it raises any validation-related error, it's actually validating
return True
except Exception:
# Any other exception means it's trying to validate (e.g., schema not found)
return True
def _setup_env_vars_for_ef(
self, ef_name: str, mock_common_deps: MonkeyPatch
) -> None:
"""Set up environment variables needed for embedding function instantiation"""
# Map of embedding function names to their default API key environment variable names
api_key_env_vars = {
"cohere": "CHROMA_COHERE_API_KEY",
"openai": "CHROMA_OPENAI_API_KEY",
"huggingface": "CHROMA_HUGGINGFACE_API_KEY",
"huggingface_server": "CHROMA_HUGGINGFACE_API_KEY",
"google_palm": "CHROMA_GOOGLE_PALM_API_KEY",
"google_generative_ai": "CHROMA_GOOGLE_GENAI_API_KEY",
"google_vertex": "CHROMA_GOOGLE_VERTEX_API_KEY",
"jina": "CHROMA_JINA_API_KEY",
"mistral": "MISTRAL_API_KEY",
"morph": "MORPH_API_KEY",
"voyageai": "CHROMA_VOYAGE_API_KEY",
"cloudflare_workers_ai": "CHROMA_CLOUDFLARE_API_KEY",
"together_ai": "CHROMA_TOGETHER_AI_API_KEY",
"baseten": "CHROMA_BASETEN_API_KEY",
"roboflow": "CHROMA_ROBOFLOW_API_KEY",
"amazon_bedrock": "AWS_ACCESS_KEY_ID", # AWS uses different env vars
"chroma-cloud-qwen": "CHROMA_API_KEY",
# Sparse embedding functions
"chroma-cloud-splade": "CHROMA_API_KEY",
}
# Set API key environment variable if needed
if ef_name in api_key_env_vars:
mock_common_deps.setenv(api_key_env_vars[ef_name], "test-api-key")
# Special cases that need additional environment variables
if ef_name == "amazon_bedrock":
mock_common_deps.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
mock_common_deps.setenv("AWS_REGION", "us-east-1")
def _create_ef_instance(
self, ef_name: str, ef_class: Any, mock_common_deps: MonkeyPatch
) -> Any:
"""Create an embedding function instance, handling special cases"""
# Set up environment variables first
self._setup_env_vars_for_ef(ef_name, mock_common_deps)
# Mock missing modules that are imported inside __init__ methods
import sys
# Create mock modules
mock_pil = MagicMock()
mock_pil_image = MagicMock()
mock_google_genai = MagicMock()
mock_vertexai = MagicMock()
mock_vertexai_lm = MagicMock()
mock_boto3 = MagicMock()
mock_jina = MagicMock()
mock_mistralai = MagicMock()
# Mock boto3.Session for amazon_bedrock
mock_boto3_session = MagicMock()
mock_session_instance = MagicMock()
mock_session_instance.region_name = "us-east-1"
mock_session_instance.profile_name = None
mock_session_instance.client.return_value = MagicMock()
mock_boto3_session.return_value = mock_session_instance
mock_boto3.Session = mock_boto3_session
# Mock vertexai.init and TextEmbeddingModel
mock_text_embedding_model = MagicMock()
mock_text_embedding_model.from_pretrained.return_value = MagicMock()
mock_vertexai_lm.TextEmbeddingModel = mock_text_embedding_model
mock_vertexai.language_models = mock_vertexai_lm
mock_vertexai.init = MagicMock()
# Mock google.generativeai - need to set up google module first
mock_google = MagicMock()
mock_google_genai.configure = MagicMock() # For palm.configure()
mock_google_genai.GenerativeModel = MagicMock(return_value=MagicMock())
mock_google.generativeai = mock_google_genai
# Mock jina Client
mock_jina.Client = MagicMock()
# Mock mistralai
mock_mistral_client = MagicMock()
mock_mistral_client.return_value.embeddings.create.return_value.data = [
MagicMock(embedding=[0.1, 0.2, 0.3])
]
mock_mistralai.Mistral = mock_mistral_client
# Add missing modules to sys.modules using monkeypatch
modules_to_mock = {
"PIL": mock_pil,
"PIL.Image": mock_pil_image,
"google": mock_google,
"google.generativeai": mock_google_genai,
"vertexai": mock_vertexai,
"vertexai.language_models": mock_vertexai_lm,
"boto3": mock_boto3,
"jina": mock_jina,
"mistralai": mock_mistralai,
}
for module_name, mock_module in modules_to_mock.items():
mock_common_deps.setitem(sys.modules, module_name, mock_module)
# Special cases that need additional arguments
if ef_name == "cloudflare_workers_ai":
return ef_class(
model_name="test-model",
account_id="test-account-id",
)
elif ef_name == "baseten":
# Baseten needs api_key explicitly passed even with env var
return ef_class(
api_key="test-api-key",
api_base="https://test.api.baseten.co",
)
elif ef_name == "amazon_bedrock":
# Amazon Bedrock needs a boto3 session - create a mock session
# boto3 is already mocked in sys.modules above
mock_session = mock_boto3.Session(region_name="us-east-1")
return ef_class(
session=mock_session,
model_name="amazon.titan-embed-text-v1",
)
elif ef_name == "huggingface_server":
return ef_class(url="http://localhost:8080")
elif ef_name == "google_vertex":
return ef_class(project_id="test-project", region="us-central1")
elif ef_name == "mistral":
return ef_class(model="mistral-embed")
elif ef_name == "roboflow":
return ef_class() # No model_name needed
elif ef_name == "chroma-cloud-qwen":
from chromadb.utils.embedding_functions.chroma_cloud_qwen_embedding_function import (
ChromaCloudQwenEmbeddingModel,
)
return ef_class(
model=ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
task="nl_to_code",
)
else:
# Try with no args first
try:
return ef_class()
except Exception:
# If that fails, try with common minimal args
return ef_class(model_name="test-model")
@pytest.mark.parametrize("ef_name", get_embedding_function_names())
def test_validate_config_with_schema(
self,
ef_name: str,
mock_embeddings: Callable[[Documents], Embeddings],
mock_common_deps: MonkeyPatch,
) -> None:
"""Test that validate_config works correctly with actual configs from embedding functions"""
ef_class = known_embedding_functions[ef_name]
# Skip if the embedding function doesn't have a validate_config method
if not hasattr(ef_class, "validate_config"):
pytest.skip(f"{ef_name} does not have validate_config method")
# Check if it's callable (static methods are callable on the class)
if not callable(getattr(ef_class, "validate_config", None)):
pytest.skip(f"{ef_name} validate_config is not callable")
# Skip if using base implementation (doesn't actually validate)
if not self._has_custom_validation(ef_class):
pytest.skip(
f"{ef_name} uses base validate_config implementation (no validation)"
)
# Create a real instance to get the actual config
# We'll mock __call__ to avoid needing to actually generate embeddings
try:
ef_instance = self._create_ef_instance(ef_name, ef_class, mock_common_deps)
except Exception as e:
pytest.skip(
f"{ef_name} requires arguments that we cannot provide without external deps: {e}"
)
# Mock only __call__ to avoid needing to actually generate embeddings
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
# Get the actual config from the embedding function (this uses the real get_config method)
config = ef_instance.get_config()
# Filter out None values - optional fields with None shouldn't be included in validation
# This matches common JSON schema practice where optional fields are omitted rather than null
config = {k: v for k, v in config.items() if v is not None}
# Validate the actual config using the embedding function's validate_config method
ef_class.validate_config(config)
def test_validate_config_sparse_embedding_functions(
self,
mock_embeddings: Callable[[Documents], Embeddings],
mock_common_deps: MonkeyPatch,
) -> None:
"""Test validate_config for sparse embedding functions with actual configs"""
for ef_name, ef_class in sparse_known_embedding_functions.items():
# Skip if the embedding function doesn't have a validate_config method
if not hasattr(ef_class, "validate_config"):
continue
# Check if it's callable (static methods are callable on the class)
if not callable(getattr(ef_class, "validate_config", None)):
continue
# Skip if using base implementation (doesn't actually validate)
if not self._has_custom_validation(ef_class):
continue
# Create a real instance to get the actual config
# We'll mock __call__ to avoid needing to actually generate embeddings
try:
ef_instance = self._create_ef_instance(
ef_name, ef_class, mock_common_deps
)
except Exception:
continue # Skip if we can't create instance
# Mock only __call__ to avoid needing to actually generate embeddings
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
# Get the actual config from the embedding function (this uses the real get_config method)
config = ef_instance.get_config()
# Filter out None values - optional fields with None shouldn't be included in validation
# This matches common JSON schema practice where optional fields are omitted rather than null
config = {k: v for k, v in config.items() if v is not None}
# Validate the actual config using the embedding function's validate_config method
ef_class.validate_config(config)