增加环绕侦察场景适配
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -31,7 +31,7 @@ from chromadb.utils.embedding_functions import (
|
||||
)
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.models.CollectionCommon import CollectionCommon
|
||||
from chromadb.errors import InvalidArgumentError, InternalError
|
||||
from chromadb.errors import InvalidArgumentError
|
||||
from chromadb.execution.expression import Knn, Search
|
||||
from chromadb.types import Collection as CollectionModel
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
|
||||
@@ -509,7 +509,7 @@ def test_sparse_vector_not_allowed_locally(
|
||||
schema = Schema()
|
||||
schema.create_index(key="sparse_metadata", config=SparseVectorIndexConfig())
|
||||
with pytest.raises(
|
||||
InternalError, match="Sparse vector indexing is not enabled in local"
|
||||
InvalidArgumentError, match="Sparse vector indexing is not enabled in local"
|
||||
):
|
||||
_create_isolated_collection(client_factories, schema=schema)
|
||||
|
||||
@@ -1011,12 +1011,26 @@ def test_collection_fork_inherits_and_isolates_schema(
|
||||
client_factories,
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
parent_version_before_add = get_collection_version(client, collection.name)
|
||||
|
||||
parent_ids = [f"parent-{i}" for i in range(251)]
|
||||
parent_docs = [f"parent doc {i}" for i in range(251)]
|
||||
parent_metadatas: List[Mapping[str, Any]] = [
|
||||
{"shared_key": f"parent_{i}"} for i in range(251)
|
||||
]
|
||||
|
||||
collection.add(
|
||||
ids=["parent-1"],
|
||||
documents=["parent doc"],
|
||||
metadatas=[{"shared_key": "parent"}],
|
||||
ids=parent_ids,
|
||||
documents=parent_docs,
|
||||
metadatas=parent_metadatas,
|
||||
)
|
||||
|
||||
# Wait for parent to compact before forking. Otherwise, the fork inherits
|
||||
# uncompacted logs, and compaction of those inherited logs could increment
|
||||
# the fork's version before the fork's own data is compacted.
|
||||
wait_for_version_increase(client, collection.name, parent_version_before_add)
|
||||
|
||||
assert collection.schema is not None
|
||||
parent_schema_json = collection.schema.serialize_to_json()
|
||||
|
||||
@@ -1050,8 +1064,8 @@ def test_collection_fork_inherits_and_isolates_schema(
|
||||
assert reloaded_parent.schema is not None
|
||||
assert "child_only" not in reloaded_parent.schema.keys
|
||||
|
||||
parent_results = reloaded_parent.get(where={"shared_key": "parent"})
|
||||
assert set(parent_results["ids"]) == {"parent-1"}
|
||||
parent_results = reloaded_parent.get(where={"shared_key": "parent_10"})
|
||||
assert set(parent_results["ids"]) == {"parent-10"}
|
||||
|
||||
child_results = forked.get(where={"child_only": "value_10"})
|
||||
assert set(child_results["ids"]) == {"fork-10"}
|
||||
|
||||
@@ -7,11 +7,21 @@ for automatically processing collections.
|
||||
|
||||
import pytest
|
||||
from chromadb.api.client import Client as ClientCreator
|
||||
from chromadb.api.functions import (
|
||||
RECORD_COUNTER_FUNCTION,
|
||||
STATISTICS_FUNCTION,
|
||||
Function,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from chromadb.errors import ChromaError, NotFoundError
|
||||
from chromadb.test.utils.wait_for_version_increase import (
|
||||
get_collection_version,
|
||||
wait_for_version_increase,
|
||||
)
|
||||
from time import sleep
|
||||
|
||||
|
||||
def test_function_attach_and_detach(basic_http_client: System) -> None:
|
||||
def test_count_function_attach_and_detach(basic_http_client: System) -> None:
|
||||
"""Test creating and removing a function with the record_counter operator"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
@@ -22,45 +32,39 @@ def test_function_attach_and_detach(basic_http_client: System) -> None:
|
||||
metadata={"description": "Sample documents for task processing"},
|
||||
)
|
||||
|
||||
# Add initial documents
|
||||
collection.add(
|
||||
ids=["doc1", "doc2", "doc3"],
|
||||
documents=[
|
||||
"The quick brown fox jumps over the lazy dog",
|
||||
"Machine learning is a subset of artificial intelligence",
|
||||
"Python is a popular programming language",
|
||||
],
|
||||
metadatas=[{"source": "proverb"}, {"source": "tech"}, {"source": "tech"}],
|
||||
)
|
||||
|
||||
# Verify collection has documents
|
||||
assert collection.count() == 3
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
attached_fn = collection.attach_function(
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function_id="record_counter", # Built-in operator that counts records
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Verify task creation succeeded
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
initial_version = get_collection_version(client, collection.name)
|
||||
|
||||
# Add more documents
|
||||
# Add documents
|
||||
collection.add(
|
||||
ids=["doc4", "doc5"],
|
||||
documents=[
|
||||
"Chroma is a vector database",
|
||||
"Tasks automate data processing",
|
||||
],
|
||||
ids=["doc_{}".format(i) for i in range(0, 300)],
|
||||
documents=["test document"] * 300,
|
||||
)
|
||||
|
||||
# Verify documents were added
|
||||
assert collection.count() == 5
|
||||
assert collection.count() == 300
|
||||
|
||||
wait_for_version_increase(client, collection.name, initial_version)
|
||||
# Give some time to invalidate the frontend query cache
|
||||
sleep(60)
|
||||
|
||||
result = client.get_collection("my_documents_counts").get("function_output")
|
||||
assert result["metadatas"] is not None
|
||||
assert result["metadatas"][0]["total_count"] == 300
|
||||
|
||||
# Remove the task
|
||||
success = attached_fn.detach(
|
||||
success = collection.detach_function(
|
||||
attached_fn.name,
|
||||
delete_output_collection=True,
|
||||
)
|
||||
|
||||
@@ -79,13 +83,42 @@ def test_task_with_invalid_function(basic_http_client: System) -> None:
|
||||
# Attempt to create task with non-existent function should raise ChromaError
|
||||
with pytest.raises(ChromaError, match="function not found"):
|
||||
collection.attach_function(
|
||||
function=Function._NONEXISTENT_TEST_ONLY,
|
||||
name="invalid_task",
|
||||
function_id="nonexistent_function",
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
|
||||
|
||||
def test_attach_function_returns_function_name(basic_http_client: System) -> None:
|
||||
"""Test that attach_function and get_attached_function return function_name field instead of UUID"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
collection = client.create_collection(name="test_function_name")
|
||||
collection.add(ids=["id1"], documents=["doc1"])
|
||||
|
||||
# Attach a function and verify function_name field in response
|
||||
attached_fn, created = collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="my_counter",
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Verify the attached function has function_name (not function_id UUID)
|
||||
assert created is True
|
||||
assert attached_fn.function_name == "record_counter"
|
||||
assert attached_fn.name == "my_counter"
|
||||
|
||||
# Get the attached function and verify function_name field is also present
|
||||
retrieved_fn = collection.get_attached_function("my_counter")
|
||||
assert retrieved_fn == attached_fn
|
||||
|
||||
# Clean up
|
||||
collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
|
||||
|
||||
def test_function_multiple_collections(basic_http_client: System) -> None:
|
||||
"""Test attaching functions on multiple collections"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
@@ -95,93 +128,163 @@ def test_function_multiple_collections(basic_http_client: System) -> None:
|
||||
collection1 = client.create_collection(name="collection_1")
|
||||
collection1.add(ids=["id1", "id2"], documents=["doc1", "doc2"])
|
||||
|
||||
attached_fn1 = collection1.attach_function(
|
||||
attached_fn1, created1 = collection1.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_1",
|
||||
function_id="record_counter",
|
||||
output_collection="output_1",
|
||||
params=None,
|
||||
)
|
||||
|
||||
assert attached_fn1 is not None
|
||||
assert created1 is True
|
||||
|
||||
# Create second collection and task
|
||||
collection2 = client.create_collection(name="collection_2")
|
||||
collection2.add(ids=["id3", "id4"], documents=["doc3", "doc4"])
|
||||
|
||||
attached_fn2 = collection2.attach_function(
|
||||
attached_fn2, created2 = collection2.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_2",
|
||||
function_id="record_counter",
|
||||
output_collection="output_2",
|
||||
params=None,
|
||||
)
|
||||
|
||||
assert attached_fn2 is not None
|
||||
assert created2 is True
|
||||
|
||||
# Task IDs should be different
|
||||
assert attached_fn1.id != attached_fn2.id
|
||||
|
||||
# Clean up
|
||||
assert attached_fn1.detach(delete_output_collection=True) is True
|
||||
assert attached_fn2.detach(delete_output_collection=True) is True
|
||||
assert (
|
||||
collection1.detach_function(attached_fn1.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
collection2.detach_function(attached_fn2.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_functions_multiple_attached_functions(basic_http_client: System) -> None:
|
||||
"""Test attaching multiple functions on the same collection"""
|
||||
def test_functions_one_attached_function_per_collection(
|
||||
basic_http_client: System,
|
||||
) -> None:
|
||||
"""Test that only one attached function is allowed per collection"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create a single collection
|
||||
collection = client.create_collection(name="multi_task_collection")
|
||||
collection = client.create_collection(name="single_task_collection")
|
||||
collection.add(ids=["id1", "id2", "id3"], documents=["doc1", "doc2", "doc3"])
|
||||
|
||||
# Create first task on the collection
|
||||
attached_fn1 = collection.attach_function(
|
||||
attached_fn1, created = collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_1",
|
||||
function_id="record_counter",
|
||||
output_collection="output_1",
|
||||
params=None,
|
||||
)
|
||||
|
||||
assert attached_fn1 is not None
|
||||
assert created is True
|
||||
|
||||
# Create second task on the SAME collection with a different name
|
||||
attached_fn2 = collection.attach_function(
|
||||
# Attempt to create a second task with a different name should fail
|
||||
# (only one attached function allowed per collection)
|
||||
with pytest.raises(
|
||||
ChromaError,
|
||||
match="collection already has an attached function: name=task_1, function=record_counter, output_collection=output_1",
|
||||
):
|
||||
collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_2",
|
||||
output_collection="output_2",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Attempt to create a task with the same name but different function_id should also fail
|
||||
with pytest.raises(
|
||||
ChromaError,
|
||||
match=r"collection already has an attached function: name=task_1, function=record_counter, output_collection=output_1",
|
||||
):
|
||||
collection.attach_function(
|
||||
function=STATISTICS_FUNCTION,
|
||||
name="task_1",
|
||||
output_collection="output_different", # Different output collection
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Detach the first function
|
||||
assert (
|
||||
collection.detach_function(attached_fn1.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
|
||||
# Now we should be able to attach a new function
|
||||
attached_fn2, created2 = collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="task_2",
|
||||
function_id="record_counter",
|
||||
output_collection="output_2",
|
||||
params=None,
|
||||
)
|
||||
|
||||
assert attached_fn2 is not None
|
||||
assert created2 is True
|
||||
assert attached_fn2.id != attached_fn1.id
|
||||
|
||||
# Task IDs should be different even though they're on the same collection
|
||||
assert attached_fn1.id != attached_fn2.id
|
||||
|
||||
# Create third task on the same collection
|
||||
attached_fn3 = collection.attach_function(
|
||||
name="task_3",
|
||||
function_id="record_counter",
|
||||
output_collection="output_3",
|
||||
params=None,
|
||||
# Clean up
|
||||
assert (
|
||||
collection.detach_function(attached_fn2.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
|
||||
assert attached_fn3 is not None
|
||||
assert attached_fn3.id != attached_fn1.id
|
||||
assert attached_fn3.id != attached_fn2.id
|
||||
|
||||
# Attempt to create a task with duplicate name on same collection should fail
|
||||
with pytest.raises(ChromaError, match="already exists"):
|
||||
def test_attach_function_with_invalid_params(basic_http_client: System) -> None:
|
||||
"""Test that attach_function with non-empty params raises an error"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
collection = client.create_collection(name="test_invalid_params")
|
||||
collection.add(ids=["id1"], documents=["test document"])
|
||||
|
||||
# Attempt to create task with non-empty params should fail
|
||||
# (no functions currently accept parameters)
|
||||
with pytest.raises(
|
||||
ChromaError,
|
||||
match="params must be empty - no functions currently accept parameters",
|
||||
):
|
||||
collection.attach_function(
|
||||
name="task_1", # Duplicate name
|
||||
function_id="record_counter",
|
||||
output_collection="output_duplicate",
|
||||
params=None,
|
||||
name="invalid_params_task",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params={"some_key": "some_value"},
|
||||
)
|
||||
|
||||
# Clean up - remove each task individually
|
||||
assert attached_fn1.detach(delete_output_collection=True) is True
|
||||
assert attached_fn2.detach(delete_output_collection=True) is True
|
||||
assert attached_fn3.detach(delete_output_collection=True) is True
|
||||
|
||||
def test_attach_function_output_collection_already_exists(
|
||||
basic_http_client: System,
|
||||
) -> None:
|
||||
"""Test that attach_function fails when output collection name already exists"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create a collection that will be used as input
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test document"])
|
||||
|
||||
# Create another collection with the name we want to use for output
|
||||
client.create_collection(name="existing_output_collection")
|
||||
|
||||
# Attempt to create task with output collection name that already exists
|
||||
with pytest.raises(
|
||||
ChromaError,
|
||||
match=r"Output collection \[existing_output_collection\] already exists",
|
||||
):
|
||||
input_collection.attach_function(
|
||||
name="my_task",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="existing_output_collection",
|
||||
params=None,
|
||||
)
|
||||
|
||||
|
||||
def test_function_remove_nonexistent(basic_http_client: System) -> None:
|
||||
@@ -191,15 +294,294 @@ def test_function_remove_nonexistent(basic_http_client: System) -> None:
|
||||
|
||||
collection = client.create_collection(name="test_collection")
|
||||
collection.add(ids=["id1"], documents=["test"])
|
||||
attached_fn = collection.attach_function(
|
||||
attached_fn, _ = collection.attach_function(
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
name="test_function",
|
||||
function_id="record_counter",
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
|
||||
attached_fn.detach(delete_output_collection=True)
|
||||
collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
|
||||
# Trying to detach this function again should raise NotFoundError
|
||||
with pytest.raises(NotFoundError, match="does not exist"):
|
||||
attached_fn.detach(delete_output_collection=True)
|
||||
collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
|
||||
|
||||
def test_attach_to_output_collection_fails(basic_http_client: System) -> None:
|
||||
"""Test that attaching a function to an output collection fails"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create input collection
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test"])
|
||||
|
||||
_, _ = input_collection.attach_function(
|
||||
name="test_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
output_collection = client.get_collection(name="output_collection")
|
||||
|
||||
with pytest.raises(
|
||||
ChromaError, match="cannot attach function to an output collection"
|
||||
):
|
||||
_ = output_collection.attach_function(
|
||||
name="test_function_2",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection_2",
|
||||
params=None,
|
||||
)
|
||||
|
||||
|
||||
def test_delete_output_collection_detaches_function(basic_http_client: System) -> None:
|
||||
"""Test that deleting an output collection also detaches the attached function"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create input collection and attach a function
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test"])
|
||||
|
||||
attached_fn, created = input_collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
# Delete the output collection directly
|
||||
client.delete_collection("output_collection")
|
||||
|
||||
# The attached function should now be gone - trying to get it should raise NotFoundError
|
||||
with pytest.raises(NotFoundError):
|
||||
input_collection.get_attached_function("my_function")
|
||||
|
||||
|
||||
def test_delete_orphaned_output_collection(basic_http_client: System) -> None:
|
||||
"""Test that deleting an output collection from a recently detached function works"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create input collection and attach a function
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test"])
|
||||
|
||||
attached_fn, created = input_collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
input_collection.detach_function(attached_fn.name, delete_output_collection=False)
|
||||
|
||||
# Delete the output collection directly
|
||||
client.delete_collection("output_collection")
|
||||
|
||||
# The attached function should still exist but be marked as detached
|
||||
with pytest.raises(NotFoundError):
|
||||
input_collection.get_attached_function("my_function")
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
# Try to use the function - it should fail since it's detached
|
||||
client.get_collection("output_collection")
|
||||
|
||||
def test_partial_attach_function_repair(
|
||||
basic_http_client: System,
|
||||
) -> None:
|
||||
"""Test creating and removing a function with the record_counter operator"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create a collection
|
||||
collection = client.get_or_create_collection(
|
||||
name="my_document",
|
||||
)
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
assert created is True
|
||||
|
||||
# Verify task creation succeeded
|
||||
assert attached_fn is not None
|
||||
|
||||
collection2 = client.get_or_create_collection(
|
||||
name="my_document2",
|
||||
)
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
# This should fail
|
||||
with pytest.raises(
|
||||
ChromaError, match=r"Output collection \[my_documents_counts\] already exists"
|
||||
):
|
||||
attached_fn, _ = collection2.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Detach the function
|
||||
assert (
|
||||
collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
is True
|
||||
)
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
attached_fn, created = collection2.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
|
||||
def test_output_collection_created_with_schema(basic_http_client: System) -> None:
|
||||
"""Test that output collections are created with the source_attached_function_id in the schema"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create input collection and attach a function
|
||||
input_collection = client.create_collection(name="input_collection")
|
||||
input_collection.add(ids=["id1"], documents=["test"])
|
||||
|
||||
attached_fn, created = input_collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
# Get the output collection - it should exist
|
||||
output_collection = client.get_collection(name="output_collection")
|
||||
assert output_collection is not None
|
||||
|
||||
# The source_attached_function_id is stored in the schema (not metadata)
|
||||
# We can't directly access the schema from the client, but we verify the collection exists
|
||||
# and the attached function orchestrator will use this field internally
|
||||
assert "source_attached_function_id" in output_collection._model.pretty_schema()
|
||||
|
||||
# Clean up
|
||||
input_collection.detach_function(attached_fn.name, delete_output_collection=True)
|
||||
|
||||
|
||||
def test_count_function_attach_and_detach_attach_attach(
|
||||
basic_http_client: System,
|
||||
) -> None:
|
||||
"""Test creating and removing a function with the record_counter operator"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
# Create a collection
|
||||
collection = client.get_or_create_collection(
|
||||
name="my_document",
|
||||
metadata={"description": "Sample documents for task processing"},
|
||||
)
|
||||
|
||||
# Create a task that counts records in the collection
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
|
||||
# Verify task creation succeeded
|
||||
assert created is True
|
||||
assert attached_fn is not None
|
||||
initial_version = get_collection_version(client, collection.name)
|
||||
|
||||
# Add documents
|
||||
collection.add(
|
||||
ids=["doc_{}".format(i) for i in range(0, 300)],
|
||||
documents=["test document"] * 300,
|
||||
)
|
||||
|
||||
# Verify documents were added
|
||||
assert collection.count() == 300
|
||||
|
||||
wait_for_version_increase(client, collection.name, initial_version)
|
||||
# Give some time to invalidate the frontend query cache
|
||||
sleep(60)
|
||||
|
||||
result = client.get_collection("my_documents_counts").get("function_output")
|
||||
assert result["metadatas"] is not None
|
||||
assert result["metadatas"][0]["total_count"] == 300
|
||||
|
||||
# Remove the task
|
||||
success = collection.detach_function(
|
||||
attached_fn.name, delete_output_collection=True
|
||||
)
|
||||
|
||||
# Verify task removal succeeded
|
||||
assert success is True
|
||||
|
||||
# Attach a function that counts records in the collection
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn is not None
|
||||
assert created is True
|
||||
|
||||
# Attach a function that counts records in the collection
|
||||
attached_fn, created = collection.attach_function(
|
||||
name="count_my_docs",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="my_documents_counts",
|
||||
params=None,
|
||||
)
|
||||
assert created is False
|
||||
assert attached_fn is not None
|
||||
|
||||
def test_attach_function_idempotency(basic_http_client: System) -> None:
|
||||
"""Test that attach_function is idempotent - calling it twice with same params returns created=False"""
|
||||
client = ClientCreator.from_system(basic_http_client)
|
||||
client.reset()
|
||||
|
||||
collection = client.create_collection(name="idempotency_test")
|
||||
collection.add(ids=["id1"], documents=["test document"])
|
||||
|
||||
# First attach - should be newly created
|
||||
attached_fn1, created1 = collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn1 is not None
|
||||
assert created1 is True
|
||||
|
||||
# Second attach with identical params - should be idempotent (created=False)
|
||||
attached_fn2, created2 = collection.attach_function(
|
||||
name="my_function",
|
||||
function=RECORD_COUNTER_FUNCTION,
|
||||
output_collection="output_collection",
|
||||
params=None,
|
||||
)
|
||||
assert attached_fn2 is not None
|
||||
assert created2 is False
|
||||
|
||||
# Both should return the same function ID
|
||||
assert attached_fn1.id == attached_fn2.id
|
||||
|
||||
# Clean up
|
||||
collection.detach_function(attached_fn1.name, delete_output_collection=True)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import math
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
|
||||
from chromadb import SparseVector
|
||||
from chromadb.utils.embedding_functions.chroma_bm25_embedding_function import (
|
||||
DEFAULT_CHROMA_BM25_STOPWORDS,
|
||||
ChromaBm25EmbeddingFunction,
|
||||
@@ -137,3 +139,64 @@ def test_validate_config_update_allows_known_keys() -> None:
|
||||
embedder.validate_config_update(
|
||||
embedder.get_config(), {"k": 1.1, "stopwords": ["custom"]}
|
||||
)
|
||||
|
||||
|
||||
def test_multithreaded_usage() -> None:
|
||||
embedder = ChromaBm25EmbeddingFunction()
|
||||
base_texts = [
|
||||
"""The gravitational wave background from massive black hole binaries emit bursts of
|
||||
gravitational waves at periapse. Such events may be directly resolvable in the Galactic
|
||||
centre. However, if the star does not spiral in, the emitted GWs are not resolvable for
|
||||
extra-galactic MBHs, but constitute a source of background noise. We estimate the power
|
||||
spectrum of this extreme mass ratio burst background.""",
|
||||
"""Dynamics of planets in exoplanetary systems with multiple stars showing how the
|
||||
gravitational interactions between the stars and planets affect the orbital stability
|
||||
and long-term evolution of the planetary system architectures.""",
|
||||
"""Diurnal Thermal Tides in a Non-rotating atmosphere with realistic heating profiles
|
||||
and temperature gradients that demonstrate the complex interplay between radiation
|
||||
and atmospheric dynamics in planetary atmospheres.""",
|
||||
"""Intermittent turbulence, noise and waves in stellar atmospheres create complex
|
||||
patterns of energy transport and momentum deposition that influence the structure
|
||||
and evolution of stellar interiors and surfaces.""",
|
||||
"""Superconductivity in quantum materials and condensed matter physics systems
|
||||
exhibiting novel quantum phenomena including topological phases, strongly correlated
|
||||
electron systems, and exotic superconducting pairing mechanisms.""",
|
||||
"""Machine learning models require careful tuning of hyperparameters including learning
|
||||
rates, regularization coefficients, and architectural choices that demonstrate the
|
||||
complex interplay between optimization algorithms and model capacity.""",
|
||||
"""Natural language processing enables text understanding through sophisticated
|
||||
algorithms that analyze semantic relationships, syntactic structures, and contextual
|
||||
information to extract meaningful representations from unstructured textual data.""",
|
||||
"""Vector databases store high-dimensional embeddings efficiently using advanced
|
||||
indexing techniques including approximate nearest neighbor search algorithms that
|
||||
balance accuracy and computational efficiency for large-scale similarity search.""",
|
||||
]
|
||||
texts = base_texts * 30
|
||||
|
||||
num_threads = 10
|
||||
|
||||
def process_single_text(text: str) -> SparseVector:
|
||||
return embedder([text])[0]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(process_single_text, text) for text in texts]
|
||||
all_results = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
embedding = future.result()
|
||||
all_results.append(embedding)
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"Threading error detected: {type(e).__name__}: {e}. "
|
||||
"This indicates the stemmer is not thread-safe when cached."
|
||||
)
|
||||
|
||||
assert len(all_results) == len(texts)
|
||||
|
||||
for embedding in all_results:
|
||||
assert embedding.indices
|
||||
assert len(embedding.indices) == len(embedding.values)
|
||||
assert _is_sorted(embedding.indices)
|
||||
for value in embedding.values:
|
||||
assert value > 0
|
||||
assert math.isfinite(value)
|
||||
|
||||
@@ -33,12 +33,14 @@ def test_get_builtins_holds() -> None:
|
||||
"GoogleGenerativeAiEmbeddingFunction",
|
||||
"GooglePalmEmbeddingFunction",
|
||||
"GoogleVertexEmbeddingFunction",
|
||||
"GoogleGenaiEmbeddingFunction",
|
||||
"HuggingFaceEmbeddingFunction",
|
||||
"HuggingFaceEmbeddingServer",
|
||||
"InstructorEmbeddingFunction",
|
||||
"JinaEmbeddingFunction",
|
||||
"MistralEmbeddingFunction",
|
||||
"MorphEmbeddingFunction",
|
||||
"NomicEmbeddingFunction",
|
||||
"ONNXMiniLM_L6_V2",
|
||||
"OllamaEmbeddingFunction",
|
||||
"OpenAIEmbeddingFunction",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import hypothesis
|
||||
import hypothesis.strategies as st
|
||||
from typing import Any, Optional, List, Dict, Union, cast
|
||||
from typing import Any, Optional, List, Dict, Union, cast, Tuple
|
||||
from typing_extensions import TypedDict
|
||||
import uuid
|
||||
import numpy as np
|
||||
@@ -9,6 +9,10 @@ import numpy.typing as npt
|
||||
import chromadb.api.types as types
|
||||
import re
|
||||
from hypothesis.strategies._internal.strategies import SearchStrategy
|
||||
from chromadb.test.api.test_schema_e2e import (
|
||||
SimpleEmbeddingFunction,
|
||||
DeterministicSparseEmbeddingFunction,
|
||||
)
|
||||
from chromadb.test.conftest import NOT_CLUSTER_ONLY
|
||||
from dataclasses import dataclass
|
||||
from chromadb.api.types import (
|
||||
@@ -17,12 +21,27 @@ from chromadb.api.types import (
|
||||
EmbeddingFunction,
|
||||
Embeddings,
|
||||
Metadata,
|
||||
Schema,
|
||||
CollectionMetadata,
|
||||
VectorIndexConfig,
|
||||
SparseVectorIndexConfig,
|
||||
StringInvertedIndexConfig,
|
||||
IntInvertedIndexConfig,
|
||||
FloatInvertedIndexConfig,
|
||||
BoolInvertedIndexConfig,
|
||||
HnswIndexConfig,
|
||||
SpannIndexConfig,
|
||||
Space,
|
||||
)
|
||||
from chromadb.types import LiteralValue, WhereOperator, LogicalOperator
|
||||
from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled
|
||||
from chromadb.test.conftest import is_spann_disabled_mode
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
CreateSpannConfiguration,
|
||||
CreateHNSWConfiguration,
|
||||
)
|
||||
from chromadb.utils.embedding_functions import (
|
||||
register_embedding_function,
|
||||
)
|
||||
|
||||
# Set the random seed for reproducibility
|
||||
@@ -266,6 +285,365 @@ class ExternalCollection:
|
||||
embedding_function: Optional[types.EmbeddingFunction[Embeddable]]
|
||||
|
||||
|
||||
@register_embedding_function
|
||||
class SimpleIpEmbeddingFunction(SimpleEmbeddingFunction):
|
||||
"""Simple embedding function with ip space for persistence tests."""
|
||||
|
||||
def default_space(self) -> str: # type: ignore[override]
|
||||
return "ip"
|
||||
|
||||
|
||||
@st.composite
|
||||
def vector_index_config_strategy(draw: st.DrawFn) -> VectorIndexConfig:
|
||||
"""Generate VectorIndexConfig with optional space, embedding_function, source_key, hnsw, spann."""
|
||||
space = None
|
||||
embedding_function = None
|
||||
source_key = None
|
||||
hnsw = None
|
||||
spann = None
|
||||
|
||||
if draw(st.booleans()):
|
||||
space = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
|
||||
if draw(st.booleans()):
|
||||
embedding_function = SimpleIpEmbeddingFunction(
|
||||
dim=draw(st.integers(min_value=1, max_value=1000))
|
||||
)
|
||||
|
||||
if draw(st.booleans()):
|
||||
source_key = draw(st.one_of(st.just("#document"), safe_text))
|
||||
|
||||
index_choice = draw(st.sampled_from(["hnsw", "spann", "none"]))
|
||||
|
||||
if index_choice == "hnsw":
|
||||
hnsw = HnswIndexConfig(
|
||||
ef_construction=draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
max_neighbors=draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
ef_search=draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
sync_threshold=draw(st.integers(min_value=2, max_value=10000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
resize_factor=draw(st.floats(min_value=1.0, max_value=5.0))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
)
|
||||
elif index_choice == "spann":
|
||||
spann = SpannIndexConfig(
|
||||
search_nprobe=draw(st.integers(min_value=1, max_value=128))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
write_nprobe=draw(st.integers(min_value=1, max_value=64))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
ef_construction=draw(st.integers(min_value=1, max_value=200))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
ef_search=draw(st.integers(min_value=1, max_value=200))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
max_neighbors=draw(st.integers(min_value=1, max_value=64))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
reassign_neighbor_count=draw(st.integers(min_value=1, max_value=64))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
split_threshold=draw(st.integers(min_value=50, max_value=200))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
merge_threshold=draw(st.integers(min_value=25, max_value=100))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
)
|
||||
|
||||
return VectorIndexConfig(
|
||||
space=cast(Space, space),
|
||||
embedding_function=embedding_function,
|
||||
source_key=source_key,
|
||||
hnsw=hnsw,
|
||||
spann=spann,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def sparse_vector_index_config_strategy(draw: st.DrawFn) -> SparseVectorIndexConfig:
|
||||
"""Generate SparseVectorIndexConfig with optional embedding_function, source_key, bm25."""
|
||||
embedding_function = None
|
||||
source_key = None
|
||||
bm25 = None
|
||||
|
||||
if draw(st.booleans()):
|
||||
embedding_function = DeterministicSparseEmbeddingFunction()
|
||||
source_key = draw(st.one_of(st.just("#document"), safe_text))
|
||||
|
||||
if draw(st.booleans()):
|
||||
bm25 = draw(st.booleans())
|
||||
|
||||
return SparseVectorIndexConfig(
|
||||
embedding_function=embedding_function,
|
||||
source_key=source_key,
|
||||
bm25=bm25,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def schema_strategy(draw: st.DrawFn) -> Optional[Schema]:
|
||||
"""Generate a Schema object with various create_index/delete_index operations."""
|
||||
if draw(st.booleans()):
|
||||
return None
|
||||
|
||||
schema = Schema()
|
||||
|
||||
# Decide how many operations to perform
|
||||
num_operations = draw(st.integers(min_value=0, max_value=5))
|
||||
sparse_index_created = False
|
||||
|
||||
for _ in range(num_operations):
|
||||
operation = draw(st.sampled_from(["create_index", "delete_index"]))
|
||||
config_type = draw(
|
||||
st.sampled_from(
|
||||
[
|
||||
"string_inverted",
|
||||
"int_inverted",
|
||||
"float_inverted",
|
||||
"bool_inverted",
|
||||
"vector",
|
||||
"sparse_vector",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Decide if we're setting on a key or globally
|
||||
use_key = draw(st.booleans())
|
||||
key = None
|
||||
if use_key and config_type != "vector":
|
||||
# Vector indexes can't be set on specific keys, only globally
|
||||
key = draw(safe_text)
|
||||
|
||||
if operation == "create_index":
|
||||
if config_type == "string_inverted":
|
||||
schema.create_index(config=StringInvertedIndexConfig(), key=key)
|
||||
elif config_type == "int_inverted":
|
||||
schema.create_index(config=IntInvertedIndexConfig(), key=key)
|
||||
elif config_type == "float_inverted":
|
||||
schema.create_index(config=FloatInvertedIndexConfig(), key=key)
|
||||
elif config_type == "bool_inverted":
|
||||
schema.create_index(config=BoolInvertedIndexConfig(), key=key)
|
||||
elif config_type == "vector":
|
||||
vector_config = draw(vector_index_config_strategy())
|
||||
schema.create_index(config=vector_config, key=None)
|
||||
elif (
|
||||
config_type == "sparse_vector"
|
||||
and not is_spann_disabled_mode
|
||||
and not sparse_index_created
|
||||
):
|
||||
sparse_config = draw(sparse_vector_index_config_strategy())
|
||||
# Sparse vector MUST have a key
|
||||
if key is None:
|
||||
key = draw(safe_text)
|
||||
schema.create_index(config=sparse_config, key=key)
|
||||
sparse_index_created = True
|
||||
|
||||
elif operation == "delete_index":
|
||||
if config_type == "string_inverted":
|
||||
schema.delete_index(config=StringInvertedIndexConfig(), key=key)
|
||||
elif config_type == "int_inverted":
|
||||
schema.delete_index(config=IntInvertedIndexConfig(), key=key)
|
||||
elif config_type == "float_inverted":
|
||||
schema.delete_index(config=FloatInvertedIndexConfig(), key=key)
|
||||
elif config_type == "bool_inverted":
|
||||
schema.delete_index(config=BoolInvertedIndexConfig(), key=key)
|
||||
# Vector, FTS, and sparse_vector deletion is not currently supported
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@st.composite
|
||||
def metadata_with_hnsw_strategy(draw: st.DrawFn) -> Optional[CollectionMetadata]:
|
||||
"""Generate metadata with hnsw parameters."""
|
||||
metadata: CollectionMetadata = {}
|
||||
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:construction_ef"] = draw(
|
||||
st.integers(min_value=1, max_value=1000)
|
||||
)
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:search_ef"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:M"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:sync_threshold"] = draw(
|
||||
st.integers(min_value=2, max_value=10000)
|
||||
)
|
||||
|
||||
return metadata if metadata else None
|
||||
|
||||
|
||||
@st.composite
|
||||
def create_configuration_strategy(
|
||||
draw: st.DrawFn,
|
||||
) -> Optional[CreateCollectionConfiguration]:
|
||||
"""Generate CreateCollectionConfiguration with mutual exclusivity rules."""
|
||||
configuration: CreateCollectionConfiguration = {}
|
||||
|
||||
# Optionally set embedding_function (independent)
|
||||
if draw(st.booleans()):
|
||||
configuration["embedding_function"] = SimpleIpEmbeddingFunction(
|
||||
dim=draw(st.integers(min_value=1, max_value=1000))
|
||||
)
|
||||
|
||||
# Decide: set space only, OR set hnsw config, OR set spann config
|
||||
config_choice = draw(
|
||||
st.sampled_from(
|
||||
["space_only_hnsw", "space_only_spann", "hnsw", "spann", "none"]
|
||||
)
|
||||
)
|
||||
|
||||
if config_choice == "space_only_hnsw":
|
||||
configuration["hnsw"] = CreateHNSWConfiguration(
|
||||
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
)
|
||||
elif config_choice == "space_only_spann":
|
||||
configuration["spann"] = CreateSpannConfiguration(
|
||||
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
)
|
||||
elif config_choice == "hnsw":
|
||||
# Set hnsw config (optionally with space)
|
||||
hnsw_config: CreateHNSWConfiguration = {}
|
||||
if draw(st.booleans()):
|
||||
hnsw_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
hnsw_config["ef_construction"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
hnsw_config["ef_search"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
hnsw_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
hnsw_config["sync_threshold"] = draw(st.integers(min_value=2, max_value=10000))
|
||||
hnsw_config["resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
|
||||
configuration["hnsw"] = hnsw_config
|
||||
elif config_choice == "spann":
|
||||
# Set spann config (optionally with space)
|
||||
spann_config: CreateSpannConfiguration = {}
|
||||
if draw(st.booleans()):
|
||||
spann_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
spann_config["search_nprobe"] = draw(st.integers(min_value=1, max_value=128))
|
||||
spann_config["write_nprobe"] = draw(st.integers(min_value=1, max_value=64))
|
||||
spann_config["ef_construction"] = draw(st.integers(min_value=1, max_value=200))
|
||||
spann_config["ef_search"] = draw(st.integers(min_value=1, max_value=200))
|
||||
spann_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=64))
|
||||
spann_config["reassign_neighbor_count"] = draw(
|
||||
st.integers(min_value=1, max_value=64)
|
||||
)
|
||||
spann_config["split_threshold"] = draw(st.integers(min_value=50, max_value=200))
|
||||
spann_config["merge_threshold"] = draw(st.integers(min_value=25, max_value=100))
|
||||
configuration["spann"] = spann_config
|
||||
|
||||
return configuration if configuration else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectionInputCombination:
|
||||
"""
|
||||
Input tuple for collection creation tests.
|
||||
"""
|
||||
|
||||
metadata: Optional[CollectionMetadata]
|
||||
configuration: Optional[CreateCollectionConfiguration]
|
||||
schema: Optional[Schema]
|
||||
schema_vector_info: Optional[Dict[str, Any]]
|
||||
kind: str
|
||||
|
||||
|
||||
def non_none_items(items: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: v for k, v in items.items() if v is not None}
|
||||
|
||||
|
||||
def vector_index_to_dict(config: VectorIndexConfig) -> Dict[str, Any]:
|
||||
embedding_default_space: Optional[str] = None
|
||||
if config.embedding_function is not None and hasattr(
|
||||
config.embedding_function, "default_space"
|
||||
):
|
||||
embedding_default_space = cast(str, config.embedding_function.default_space())
|
||||
|
||||
return {
|
||||
"space": config.space,
|
||||
"hnsw": config.hnsw.model_dump(exclude_none=True) if config.hnsw else None,
|
||||
"spann": config.spann.model_dump(exclude_none=True) if config.spann else None,
|
||||
"embedding_function_default_space": embedding_default_space,
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def _schema_input_strategy(
|
||||
draw: st.DrawFn,
|
||||
) -> Tuple[Schema, Dict[str, Any]]:
|
||||
schema = Schema()
|
||||
vector_config = draw(vector_index_config_strategy())
|
||||
schema.create_index(config=vector_config, key=None)
|
||||
return schema, vector_index_to_dict(vector_config)
|
||||
|
||||
|
||||
@st.composite
|
||||
def metadata_configuration_schema_strategy(
|
||||
draw: st.DrawFn,
|
||||
) -> CollectionInputCombination:
|
||||
"""
|
||||
Generate compatible combinations of metadata, configuration, and schema inputs.
|
||||
"""
|
||||
|
||||
choice = draw(
|
||||
st.sampled_from(
|
||||
[
|
||||
"none",
|
||||
"metadata",
|
||||
"configuration",
|
||||
"metadata_configuration",
|
||||
"schema",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
metadata: Optional[CollectionMetadata] = None
|
||||
configuration: Optional[CreateCollectionConfiguration] = None
|
||||
schema: Optional[Schema] = None
|
||||
schema_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
if choice in ("metadata", "metadata_configuration"):
|
||||
metadata = draw(
|
||||
metadata_with_hnsw_strategy().filter(
|
||||
lambda value: value is not None and len(value) > 0
|
||||
)
|
||||
)
|
||||
|
||||
if choice in ("configuration", "metadata_configuration"):
|
||||
configuration = draw(
|
||||
create_configuration_strategy().filter(
|
||||
lambda value: value is not None
|
||||
and (
|
||||
(value.get("hnsw") is not None and len(value["hnsw"]) > 0)
|
||||
or (value.get("spann") is not None and len(value["spann"]) > 0)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if choice == "schema":
|
||||
schema, schema_info = draw(_schema_input_strategy())
|
||||
|
||||
return CollectionInputCombination(
|
||||
metadata=metadata,
|
||||
configuration=configuration,
|
||||
schema=schema,
|
||||
schema_vector_info=schema_info,
|
||||
kind=choice,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Collection(ExternalCollection):
|
||||
"""
|
||||
@@ -344,7 +722,7 @@ def collections(
|
||||
spann_config: CreateSpannConfiguration = {
|
||||
"space": spann_space,
|
||||
"write_nprobe": 4,
|
||||
"reassign_neighbor_count": 4
|
||||
"reassign_neighbor_count": 4,
|
||||
}
|
||||
collection_config = {
|
||||
"spann": spann_config,
|
||||
@@ -395,7 +773,7 @@ def collections(
|
||||
known_document_keywords=known_document_keywords,
|
||||
has_embeddings=has_embeddings,
|
||||
embedding_function=embedding_function,
|
||||
collection_config=collection_config
|
||||
collection_config=collection_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -421,7 +799,9 @@ def metadata(
|
||||
del metadata[key] # type: ignore
|
||||
# Finally, add in some of the known keys for the collection
|
||||
sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = {
|
||||
k: st.just(v) for k, v in collection.known_metadata_keys.items()
|
||||
k: st.just(v)
|
||||
for k, v in collection.known_metadata_keys.items()
|
||||
if isinstance(v, (str, int, float))
|
||||
}
|
||||
metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore
|
||||
# We don't allow submitting empty metadata
|
||||
|
||||
@@ -31,7 +31,7 @@ collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="co
|
||||
normal=hypothesis.settings(max_examples=500),
|
||||
fast=hypothesis.settings(max_examples=200),
|
||||
),
|
||||
max_examples=2
|
||||
max_examples=2,
|
||||
)
|
||||
def test_add_miniscule(
|
||||
client: ClientAPI,
|
||||
@@ -332,7 +332,8 @@ def test_out_of_order_ids(client: ClientAPI) -> None:
|
||||
]
|
||||
|
||||
coll = client.create_collection(
|
||||
"test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore
|
||||
"test",
|
||||
embedding_function=lambda input: [[1, 2, 3] for _ in input], # type: ignore
|
||||
)
|
||||
embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids]
|
||||
coll.add(ids=ooo_ids, embeddings=embeddings)
|
||||
@@ -369,3 +370,155 @@ def test_add_partial(client: ClientAPI) -> None:
|
||||
assert results["ids"] == ["1", "2", "3"]
|
||||
assert results["metadatas"] == [{"a": 1}, None, {"a": 3}]
|
||||
assert results["documents"] == ["a", "b", None]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
NOT_CLUSTER_ONLY,
|
||||
reason="GroupBy is only supported in distributed mode",
|
||||
)
|
||||
def test_search_group_by(client: ClientAPI) -> None:
|
||||
"""Test GroupBy with single key, multiple keys, and multiple ranking keys."""
|
||||
from chromadb.execution.expression.operator import GroupBy, MinK, Key
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression import Knn
|
||||
|
||||
create_isolated_database(client)
|
||||
|
||||
coll = client.create_collection(name="test_group_by")
|
||||
|
||||
# Test data: 12 records across 3 categories and 2 years
|
||||
# Embeddings are designed so science docs are closest to query [1,0,0,0]
|
||||
ids = [
|
||||
"sci_2023_1",
|
||||
"sci_2023_2",
|
||||
"sci_2024_1",
|
||||
"sci_2024_2",
|
||||
"tech_2023_1",
|
||||
"tech_2023_2",
|
||||
"tech_2024_1",
|
||||
"tech_2024_2",
|
||||
"arts_2023_1",
|
||||
"arts_2023_2",
|
||||
"arts_2024_1",
|
||||
"arts_2024_2",
|
||||
]
|
||||
embeddings = cast(
|
||||
Embeddings,
|
||||
[
|
||||
# Science - closest to [1,0,0,0]
|
||||
[1.0, 0.0, 0.0, 0.0], # sci_2023_1: score ~0.0
|
||||
[0.9, 0.1, 0.0, 0.0], # sci_2023_2: score ~0.141
|
||||
[0.8, 0.2, 0.0, 0.0], # sci_2024_1: score ~0.283
|
||||
[0.7, 0.3, 0.0, 0.0], # sci_2024_2: score ~0.424
|
||||
# Tech - farther from [1,0,0,0]
|
||||
[0.0, 1.0, 0.0, 0.0], # tech_2023_1: score ~1.414
|
||||
[0.0, 0.9, 0.1, 0.0], # tech_2023_2: score ~1.345
|
||||
[0.0, 0.8, 0.2, 0.0], # tech_2024_1: score ~1.281
|
||||
[0.0, 0.7, 0.3, 0.0], # tech_2024_2: score ~1.221
|
||||
# Arts - farther from [1,0,0,0]
|
||||
[0.0, 0.0, 1.0, 0.0], # arts_2023_1: score ~1.414
|
||||
[0.0, 0.0, 0.9, 0.1], # arts_2023_2: score ~1.345
|
||||
[0.0, 0.0, 0.8, 0.2], # arts_2024_1: score ~1.281
|
||||
[0.0, 0.0, 0.7, 0.3], # arts_2024_2: score ~1.221
|
||||
],
|
||||
)
|
||||
metadatas: Metadatas = [
|
||||
{"category": "science", "year": 2023, "priority": 1},
|
||||
{"category": "science", "year": 2023, "priority": 2},
|
||||
{"category": "science", "year": 2024, "priority": 1},
|
||||
{"category": "science", "year": 2024, "priority": 3},
|
||||
{"category": "tech", "year": 2023, "priority": 2},
|
||||
{"category": "tech", "year": 2023, "priority": 1},
|
||||
{"category": "tech", "year": 2024, "priority": 1},
|
||||
{"category": "tech", "year": 2024, "priority": 2},
|
||||
{"category": "arts", "year": 2023, "priority": 3},
|
||||
{"category": "arts", "year": 2023, "priority": 1},
|
||||
{"category": "arts", "year": 2024, "priority": 2},
|
||||
{"category": "arts", "year": 2024, "priority": 1},
|
||||
]
|
||||
documents = [f"doc_{id}" for id in ids]
|
||||
|
||||
coll.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
query = [1.0, 0.0, 0.0, 0.0]
|
||||
|
||||
# Test 1: Single key grouping - top 2 per category by score
|
||||
# Expected: 2 best from each category (science, tech, arts)
|
||||
# - science: sci_2023_1 (0.0), sci_2023_2 (0.141)
|
||||
# - tech: tech_2024_2 (1.221), tech_2024_1 (1.281)
|
||||
# - arts: arts_2024_2 (1.221), arts_2024_1 (1.281)
|
||||
results1 = coll.search(
|
||||
Search()
|
||||
.rank(Knn(query=query, limit=12))
|
||||
.group_by(GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=2)))
|
||||
.limit(12)
|
||||
)
|
||||
assert results1["ids"] is not None
|
||||
result1_ids = results1["ids"][0]
|
||||
assert len(result1_ids) == 6
|
||||
expected1 = {
|
||||
"sci_2023_1",
|
||||
"sci_2023_2",
|
||||
"tech_2024_2",
|
||||
"tech_2024_1",
|
||||
"arts_2024_2",
|
||||
"arts_2024_1",
|
||||
}
|
||||
assert set(result1_ids) == expected1
|
||||
|
||||
# Test 2: Multiple key grouping - top 1 per (category, year) combination
|
||||
# 6 groups: (science,2023), (science,2024), (tech,2023), (tech,2024), (arts,2023), (arts,2024)
|
||||
results2 = coll.search(
|
||||
Search()
|
||||
.rank(Knn(query=query, limit=12))
|
||||
.group_by(
|
||||
GroupBy(
|
||||
keys=[Key("category"), Key("year")],
|
||||
aggregate=MinK(keys=Key.SCORE, k=1),
|
||||
)
|
||||
)
|
||||
.limit(12)
|
||||
)
|
||||
assert results2["ids"] is not None
|
||||
result2_ids = results2["ids"][0]
|
||||
assert len(result2_ids) == 6
|
||||
expected2 = {
|
||||
"sci_2023_1",
|
||||
"sci_2024_1",
|
||||
"tech_2023_2",
|
||||
"tech_2024_2",
|
||||
"arts_2023_2",
|
||||
"arts_2024_2",
|
||||
}
|
||||
assert set(result2_ids) == expected2
|
||||
|
||||
# Test 3: Multiple ranking keys - priority first, then score as tiebreaker
|
||||
# Top 2 per category, sorted by priority (ascending), then score (ascending)
|
||||
results3 = coll.search(
|
||||
Search()
|
||||
.rank(Knn(query=query, limit=12))
|
||||
.group_by(
|
||||
GroupBy(
|
||||
keys=Key("category"),
|
||||
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=2),
|
||||
)
|
||||
)
|
||||
.limit(12)
|
||||
)
|
||||
assert results3["ids"] is not None
|
||||
result3_ids = results3["ids"][0]
|
||||
assert len(result3_ids) == 6
|
||||
expected3 = {
|
||||
"sci_2023_1",
|
||||
"sci_2024_1",
|
||||
"tech_2024_1",
|
||||
"tech_2023_2",
|
||||
"arts_2024_2",
|
||||
"arts_2023_2",
|
||||
}
|
||||
assert set(result3_ids) == expected3
|
||||
|
||||
@@ -1983,7 +1983,9 @@ def test_sparse_vector_in_metadata_validation():
|
||||
with pytest.raises(ValueError, match="SparseVector values must be numbers"):
|
||||
invalid_metadata_4 = {
|
||||
"text": "non-numeric value",
|
||||
"sparse_embedding": SparseVector(indices=[0, 1], values=[0.1, "not_a_number"]), # type: ignore
|
||||
"sparse_embedding": SparseVector(
|
||||
indices=[0, 1], values=[0.1, "not_a_number"]
|
||||
), # type: ignore
|
||||
}
|
||||
|
||||
# Test 7: Multiple sparse vectors in metadata
|
||||
@@ -2683,6 +2685,59 @@ def test_rrf_to_dict() -> None:
|
||||
print("All RRF tests passed!")
|
||||
|
||||
|
||||
def test_group_by_serialization() -> None:
|
||||
"""Test GroupBy, MinK, and MaxK serialization and deserialization."""
|
||||
import pytest
|
||||
from chromadb.execution.expression.operator import (
|
||||
GroupBy,
|
||||
MinK,
|
||||
MaxK,
|
||||
Key,
|
||||
Aggregate,
|
||||
)
|
||||
|
||||
# to_dict with OneOrMany keys
|
||||
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
|
||||
assert group_by.to_dict() == {
|
||||
"keys": ["category"],
|
||||
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
|
||||
}
|
||||
|
||||
# to_dict with multiple keys and MaxK
|
||||
group_by = GroupBy(
|
||||
keys=[Key("year"), Key("category")],
|
||||
aggregate=MaxK(keys=[Key.SCORE, Key("priority")], k=5),
|
||||
)
|
||||
assert group_by.to_dict() == {
|
||||
"keys": ["year", "category"],
|
||||
"aggregate": {"$max_k": {"keys": ["#score", "priority"], "k": 5}},
|
||||
}
|
||||
|
||||
# Round-trip
|
||||
original = GroupBy(keys=[Key("category")], aggregate=MinK(keys=[Key.SCORE], k=3))
|
||||
assert GroupBy.from_dict(original.to_dict()).to_dict() == original.to_dict()
|
||||
|
||||
# Empty GroupBy serializes to {} and from_dict({}) returns default GroupBy
|
||||
empty_group_by = GroupBy()
|
||||
assert empty_group_by.to_dict() == {}
|
||||
assert GroupBy.from_dict({}).to_dict() == {}
|
||||
|
||||
# Error cases
|
||||
with pytest.raises(ValueError, match="requires 'keys' field"):
|
||||
GroupBy.from_dict({"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}})
|
||||
|
||||
with pytest.raises(ValueError, match="requires 'aggregate' field"):
|
||||
GroupBy.from_dict({"keys": ["category"]})
|
||||
|
||||
with pytest.raises(ValueError, match="keys cannot be empty"):
|
||||
GroupBy.from_dict(
|
||||
{"keys": [], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown aggregate operator"):
|
||||
Aggregate.from_dict({"$unknown": {"keys": ["#score"], "k": 3}})
|
||||
|
||||
|
||||
# Expression API Tests - Testing dict support and from_dict methods
|
||||
class TestSearchDictSupport:
|
||||
"""Test Search class dict input support."""
|
||||
@@ -2786,6 +2841,49 @@ class TestSearchDictSupport:
|
||||
with pytest.raises(TypeError, match="select must be"):
|
||||
Search(select=123)
|
||||
|
||||
def test_search_with_group_by(self):
|
||||
"""Test Search accepts group_by as dict, object, and builder method."""
|
||||
import pytest
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression.operator import GroupBy, MinK, Key
|
||||
|
||||
# Dict input
|
||||
search = Search(
|
||||
group_by={
|
||||
"keys": ["category"],
|
||||
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
|
||||
}
|
||||
)
|
||||
assert isinstance(search._group_by, GroupBy)
|
||||
|
||||
# Object input and builder method
|
||||
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
|
||||
assert Search(group_by=group_by)._group_by is group_by
|
||||
assert Search().group_by(group_by)._group_by.aggregate is not None
|
||||
|
||||
# Invalid inputs
|
||||
with pytest.raises(TypeError, match="group_by must be"):
|
||||
Search(group_by="invalid")
|
||||
with pytest.raises(ValueError, match="requires 'aggregate' field"):
|
||||
Search(group_by={"keys": ["category"]})
|
||||
|
||||
def test_search_group_by_serialization(self):
|
||||
"""Test Search serializes group_by correctly."""
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression.operator import GroupBy, MinK, Key, Knn
|
||||
|
||||
# Without group_by - empty dict
|
||||
search = Search().rank(Knn(query=[0.1, 0.2])).limit(10)
|
||||
assert search.to_dict()["group_by"] == {}
|
||||
|
||||
# With group_by - has keys and aggregate
|
||||
search = Search().group_by(
|
||||
GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
|
||||
)
|
||||
result = search.to_dict()["group_by"]
|
||||
assert result["keys"] == ["category"]
|
||||
assert result["aggregate"] == {"$min_k": {"keys": ["#score"], "k": 3}}
|
||||
|
||||
|
||||
class TestWhereFromDict:
|
||||
"""Test Where.from_dict() conversion."""
|
||||
@@ -3310,3 +3408,27 @@ class TestRoundTripConversion:
|
||||
return d1 == d2
|
||||
|
||||
assert compare_search_dicts(new_dict, search_dict)
|
||||
|
||||
def test_search_round_trip_with_group_by(self):
|
||||
"""Test Search round-trip with group_by."""
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression.operator import Key, GroupBy, MinK
|
||||
|
||||
original = Search(
|
||||
where=Key("status") == "active",
|
||||
group_by=GroupBy(
|
||||
keys=[Key("category")],
|
||||
aggregate=MinK(keys=[Key.SCORE], k=3),
|
||||
),
|
||||
)
|
||||
|
||||
# Verify group_by round-trip
|
||||
search_dict = original.to_dict()
|
||||
assert search_dict["group_by"]["keys"] == ["category"]
|
||||
assert search_dict["group_by"]["aggregate"] == {
|
||||
"$min_k": {"keys": ["#score"], "k": 3}
|
||||
}
|
||||
|
||||
# Reconstruct and compare group_by
|
||||
restored = Search(group_by=GroupBy.from_dict(search_dict["group_by"]))
|
||||
assert restored.to_dict()["group_by"] == search_dict["group_by"]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, Generator, cast
|
||||
from unittest.mock import patch
|
||||
from typing import Any, Callable, Generator, cast, Dict, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.config import Settings, System
|
||||
from chromadb.api import ClientAPI
|
||||
import chromadb.server.fastapi
|
||||
from chromadb.api.fastapi import FastAPI
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
@@ -110,3 +111,43 @@ def test_http_client_with_inconsistent_port_settings(
|
||||
str(e)
|
||||
== "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]"
|
||||
)
|
||||
|
||||
|
||||
def make_sync_client_factory() -> Tuple[Callable[..., Any], Dict[str, Any]]:
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
# takes any positional args to match httpx.Client
|
||||
def factory(*_: Any, **kwargs: Any) -> Any:
|
||||
captured.update(kwargs)
|
||||
session = MagicMock()
|
||||
session.headers = {}
|
||||
return session
|
||||
|
||||
return factory, captured
|
||||
|
||||
|
||||
def test_fastapi_uses_http_limits_from_settings() -> None:
|
||||
settings = Settings(
|
||||
chroma_api_impl="chromadb.api.fastapi.FastAPI",
|
||||
chroma_server_host="localhost",
|
||||
chroma_server_http_port=9000,
|
||||
chroma_server_ssl_verify=True,
|
||||
chroma_http_keepalive_secs=12.5,
|
||||
chroma_http_max_connections=64,
|
||||
chroma_http_max_keepalive_connections=16,
|
||||
)
|
||||
system = System(settings)
|
||||
|
||||
factory, captured = make_sync_client_factory()
|
||||
|
||||
with patch.object(FastAPI, "require", side_effect=[MagicMock(), MagicMock()]):
|
||||
with patch("chromadb.api.fastapi.httpx.Client", side_effect=factory):
|
||||
api = FastAPI(system)
|
||||
|
||||
api.stop()
|
||||
limits = captured["limits"]
|
||||
assert limits.keepalive_expiry == 12.5
|
||||
assert limits.max_connections == 64
|
||||
assert limits.max_keepalive_connections == 16
|
||||
assert captured["timeout"] is None
|
||||
assert captured["verify"] is True
|
||||
|
||||
@@ -189,3 +189,21 @@ def test_runtime_dependencies() -> None:
|
||||
assert data.starts == ["D", "C"]
|
||||
system.stop()
|
||||
assert data.stops == ["C", "D"]
|
||||
|
||||
|
||||
def test_http_client_setting_defaults() -> None:
|
||||
settings = Settings()
|
||||
assert settings.chroma_http_keepalive_secs == 40.0
|
||||
assert settings.chroma_http_max_connections is None
|
||||
assert settings.chroma_http_max_keepalive_connections is None
|
||||
|
||||
|
||||
def test_http_client_setting_overrides() -> None:
|
||||
settings = Settings(
|
||||
chroma_http_keepalive_secs=5.5,
|
||||
chroma_http_max_connections=123,
|
||||
chroma_http_max_keepalive_connections=17,
|
||||
)
|
||||
assert settings.chroma_http_keepalive_secs == 5.5
|
||||
assert settings.chroma_http_max_connections == 123
|
||||
assert settings.chroma_http_max_keepalive_connections == 17
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from typing import List, Any, Callable
|
||||
from typing import List, Any, Callable, Dict
|
||||
from jsonschema import ValidationError
|
||||
from unittest.mock import MagicMock, create_autospec
|
||||
from chromadb.utils.embedding_functions.schemas import (
|
||||
@@ -7,7 +7,10 @@ from chromadb.utils.embedding_functions.schemas import (
|
||||
load_schema,
|
||||
get_available_schemas,
|
||||
)
|
||||
from chromadb.utils.embedding_functions import known_embedding_functions
|
||||
from chromadb.utils.embedding_functions import (
|
||||
known_embedding_functions,
|
||||
sparse_known_embedding_functions,
|
||||
)
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
@@ -143,3 +146,306 @@ class TestEmbeddingFunctionSchemas:
|
||||
if schema.get("additionalProperties", True) is False:
|
||||
with pytest.raises(ValidationError):
|
||||
validate_config_schema(test_config, schema_name)
|
||||
|
||||
def _create_valid_config_from_schema(
|
||||
self, schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a valid config from a schema by filling in required fields"""
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
if "required" in schema and "properties" in schema:
|
||||
for field in schema["required"]:
|
||||
if field in schema["properties"]:
|
||||
field_schema = schema["properties"][field]
|
||||
config[field] = self._get_value_from_field_schema(field_schema)
|
||||
|
||||
return config
|
||||
|
||||
def _get_value_from_field_schema(self, field_schema: Dict[str, Any]) -> Any:
|
||||
"""Get a valid value from a field schema"""
|
||||
# Handle enums - use first enum value
|
||||
if "enum" in field_schema:
|
||||
return field_schema["enum"][0]
|
||||
|
||||
# Handle type (could be a list or single value)
|
||||
field_type = field_schema.get("type")
|
||||
if field_type is None:
|
||||
return "dummy" # Fallback if no type specified
|
||||
|
||||
if isinstance(field_type, list):
|
||||
# If null is in the type list, prefer non-null type
|
||||
non_null_types = [t for t in field_type if t != "null"]
|
||||
field_type = non_null_types[0] if non_null_types else field_type[0]
|
||||
|
||||
if field_type == "object":
|
||||
# Handle nested objects
|
||||
nested_config = {}
|
||||
if "properties" in field_schema:
|
||||
nested_required = field_schema.get("required", [])
|
||||
for prop in nested_required:
|
||||
if prop in field_schema["properties"]:
|
||||
nested_config[prop] = self._get_value_from_field_schema(
|
||||
field_schema["properties"][prop]
|
||||
)
|
||||
return nested_config if nested_config else {}
|
||||
|
||||
if field_type == "array":
|
||||
# Return empty array for arrays
|
||||
return []
|
||||
|
||||
# Use the existing dummy value method for primitive types
|
||||
return self._get_dummy_value(field_type)
|
||||
|
||||
def _has_custom_validation(self, ef_class: Any) -> bool:
|
||||
"""Check if validate_config actually validates (not just base implementation)"""
|
||||
try:
|
||||
# Try with an obviously invalid config - if it doesn't raise, it's base implementation
|
||||
invalid_config = {"__invalid_test_config__": True}
|
||||
try:
|
||||
ef_class.validate_config(invalid_config)
|
||||
# If we get here without exception, it's using base implementation
|
||||
return False
|
||||
except (ValidationError, ValueError, FileNotFoundError):
|
||||
# If it raises any validation-related error, it's actually validating
|
||||
return True
|
||||
except Exception:
|
||||
# Any other exception means it's trying to validate (e.g., schema not found)
|
||||
return True
|
||||
|
||||
def _setup_env_vars_for_ef(
|
||||
self, ef_name: str, mock_common_deps: MonkeyPatch
|
||||
) -> None:
|
||||
"""Set up environment variables needed for embedding function instantiation"""
|
||||
# Map of embedding function names to their default API key environment variable names
|
||||
api_key_env_vars = {
|
||||
"cohere": "CHROMA_COHERE_API_KEY",
|
||||
"openai": "CHROMA_OPENAI_API_KEY",
|
||||
"huggingface": "CHROMA_HUGGINGFACE_API_KEY",
|
||||
"huggingface_server": "CHROMA_HUGGINGFACE_API_KEY",
|
||||
"google_palm": "CHROMA_GOOGLE_PALM_API_KEY",
|
||||
"google_generative_ai": "CHROMA_GOOGLE_GENAI_API_KEY",
|
||||
"google_vertex": "CHROMA_GOOGLE_VERTEX_API_KEY",
|
||||
"jina": "CHROMA_JINA_API_KEY",
|
||||
"mistral": "MISTRAL_API_KEY",
|
||||
"morph": "MORPH_API_KEY",
|
||||
"voyageai": "CHROMA_VOYAGE_API_KEY",
|
||||
"cloudflare_workers_ai": "CHROMA_CLOUDFLARE_API_KEY",
|
||||
"together_ai": "CHROMA_TOGETHER_AI_API_KEY",
|
||||
"baseten": "CHROMA_BASETEN_API_KEY",
|
||||
"roboflow": "CHROMA_ROBOFLOW_API_KEY",
|
||||
"amazon_bedrock": "AWS_ACCESS_KEY_ID", # AWS uses different env vars
|
||||
"chroma-cloud-qwen": "CHROMA_API_KEY",
|
||||
# Sparse embedding functions
|
||||
"chroma-cloud-splade": "CHROMA_API_KEY",
|
||||
}
|
||||
|
||||
# Set API key environment variable if needed
|
||||
if ef_name in api_key_env_vars:
|
||||
mock_common_deps.setenv(api_key_env_vars[ef_name], "test-api-key")
|
||||
|
||||
# Special cases that need additional environment variables
|
||||
if ef_name == "amazon_bedrock":
|
||||
mock_common_deps.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
|
||||
mock_common_deps.setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
def _create_ef_instance(
|
||||
self, ef_name: str, ef_class: Any, mock_common_deps: MonkeyPatch
|
||||
) -> Any:
|
||||
"""Create an embedding function instance, handling special cases"""
|
||||
# Set up environment variables first
|
||||
self._setup_env_vars_for_ef(ef_name, mock_common_deps)
|
||||
|
||||
# Mock missing modules that are imported inside __init__ methods
|
||||
import sys
|
||||
|
||||
# Create mock modules
|
||||
mock_pil = MagicMock()
|
||||
mock_pil_image = MagicMock()
|
||||
mock_google_genai = MagicMock()
|
||||
mock_vertexai = MagicMock()
|
||||
mock_vertexai_lm = MagicMock()
|
||||
mock_boto3 = MagicMock()
|
||||
mock_jina = MagicMock()
|
||||
mock_mistralai = MagicMock()
|
||||
|
||||
# Mock boto3.Session for amazon_bedrock
|
||||
mock_boto3_session = MagicMock()
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.region_name = "us-east-1"
|
||||
mock_session_instance.profile_name = None
|
||||
mock_session_instance.client.return_value = MagicMock()
|
||||
mock_boto3_session.return_value = mock_session_instance
|
||||
mock_boto3.Session = mock_boto3_session
|
||||
|
||||
# Mock vertexai.init and TextEmbeddingModel
|
||||
mock_text_embedding_model = MagicMock()
|
||||
mock_text_embedding_model.from_pretrained.return_value = MagicMock()
|
||||
mock_vertexai_lm.TextEmbeddingModel = mock_text_embedding_model
|
||||
mock_vertexai.language_models = mock_vertexai_lm
|
||||
mock_vertexai.init = MagicMock()
|
||||
|
||||
# Mock google.generativeai - need to set up google module first
|
||||
mock_google = MagicMock()
|
||||
mock_google_genai.configure = MagicMock() # For palm.configure()
|
||||
mock_google_genai.GenerativeModel = MagicMock(return_value=MagicMock())
|
||||
mock_google.generativeai = mock_google_genai
|
||||
|
||||
# Mock jina Client
|
||||
mock_jina.Client = MagicMock()
|
||||
|
||||
# Mock mistralai
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_mistral_client.return_value.embeddings.create.return_value.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3])
|
||||
]
|
||||
mock_mistralai.Mistral = mock_mistral_client
|
||||
|
||||
# Add missing modules to sys.modules using monkeypatch
|
||||
modules_to_mock = {
|
||||
"PIL": mock_pil,
|
||||
"PIL.Image": mock_pil_image,
|
||||
"google": mock_google,
|
||||
"google.generativeai": mock_google_genai,
|
||||
"vertexai": mock_vertexai,
|
||||
"vertexai.language_models": mock_vertexai_lm,
|
||||
"boto3": mock_boto3,
|
||||
"jina": mock_jina,
|
||||
"mistralai": mock_mistralai,
|
||||
}
|
||||
|
||||
for module_name, mock_module in modules_to_mock.items():
|
||||
mock_common_deps.setitem(sys.modules, module_name, mock_module)
|
||||
|
||||
# Special cases that need additional arguments
|
||||
if ef_name == "cloudflare_workers_ai":
|
||||
return ef_class(
|
||||
model_name="test-model",
|
||||
account_id="test-account-id",
|
||||
)
|
||||
elif ef_name == "baseten":
|
||||
# Baseten needs api_key explicitly passed even with env var
|
||||
return ef_class(
|
||||
api_key="test-api-key",
|
||||
api_base="https://test.api.baseten.co",
|
||||
)
|
||||
elif ef_name == "amazon_bedrock":
|
||||
# Amazon Bedrock needs a boto3 session - create a mock session
|
||||
# boto3 is already mocked in sys.modules above
|
||||
mock_session = mock_boto3.Session(region_name="us-east-1")
|
||||
return ef_class(
|
||||
session=mock_session,
|
||||
model_name="amazon.titan-embed-text-v1",
|
||||
)
|
||||
elif ef_name == "huggingface_server":
|
||||
return ef_class(url="http://localhost:8080")
|
||||
elif ef_name == "google_vertex":
|
||||
return ef_class(project_id="test-project", region="us-central1")
|
||||
elif ef_name == "mistral":
|
||||
return ef_class(model="mistral-embed")
|
||||
elif ef_name == "roboflow":
|
||||
return ef_class() # No model_name needed
|
||||
elif ef_name == "chroma-cloud-qwen":
|
||||
from chromadb.utils.embedding_functions.chroma_cloud_qwen_embedding_function import (
|
||||
ChromaCloudQwenEmbeddingModel,
|
||||
)
|
||||
|
||||
return ef_class(
|
||||
model=ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
|
||||
task="nl_to_code",
|
||||
)
|
||||
else:
|
||||
# Try with no args first
|
||||
try:
|
||||
return ef_class()
|
||||
except Exception:
|
||||
# If that fails, try with common minimal args
|
||||
return ef_class(model_name="test-model")
|
||||
|
||||
@pytest.mark.parametrize("ef_name", get_embedding_function_names())
|
||||
def test_validate_config_with_schema(
|
||||
self,
|
||||
ef_name: str,
|
||||
mock_embeddings: Callable[[Documents], Embeddings],
|
||||
mock_common_deps: MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test that validate_config works correctly with actual configs from embedding functions"""
|
||||
ef_class = known_embedding_functions[ef_name]
|
||||
|
||||
# Skip if the embedding function doesn't have a validate_config method
|
||||
if not hasattr(ef_class, "validate_config"):
|
||||
pytest.skip(f"{ef_name} does not have validate_config method")
|
||||
|
||||
# Check if it's callable (static methods are callable on the class)
|
||||
if not callable(getattr(ef_class, "validate_config", None)):
|
||||
pytest.skip(f"{ef_name} validate_config is not callable")
|
||||
|
||||
# Skip if using base implementation (doesn't actually validate)
|
||||
if not self._has_custom_validation(ef_class):
|
||||
pytest.skip(
|
||||
f"{ef_name} uses base validate_config implementation (no validation)"
|
||||
)
|
||||
|
||||
# Create a real instance to get the actual config
|
||||
# We'll mock __call__ to avoid needing to actually generate embeddings
|
||||
try:
|
||||
ef_instance = self._create_ef_instance(ef_name, ef_class, mock_common_deps)
|
||||
except Exception as e:
|
||||
pytest.skip(
|
||||
f"{ef_name} requires arguments that we cannot provide without external deps: {e}"
|
||||
)
|
||||
|
||||
# Mock only __call__ to avoid needing to actually generate embeddings
|
||||
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
|
||||
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
|
||||
|
||||
# Get the actual config from the embedding function (this uses the real get_config method)
|
||||
config = ef_instance.get_config()
|
||||
|
||||
# Filter out None values - optional fields with None shouldn't be included in validation
|
||||
# This matches common JSON schema practice where optional fields are omitted rather than null
|
||||
config = {k: v for k, v in config.items() if v is not None}
|
||||
|
||||
# Validate the actual config using the embedding function's validate_config method
|
||||
ef_class.validate_config(config)
|
||||
|
||||
def test_validate_config_sparse_embedding_functions(
|
||||
self,
|
||||
mock_embeddings: Callable[[Documents], Embeddings],
|
||||
mock_common_deps: MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test validate_config for sparse embedding functions with actual configs"""
|
||||
for ef_name, ef_class in sparse_known_embedding_functions.items():
|
||||
# Skip if the embedding function doesn't have a validate_config method
|
||||
if not hasattr(ef_class, "validate_config"):
|
||||
continue
|
||||
|
||||
# Check if it's callable (static methods are callable on the class)
|
||||
if not callable(getattr(ef_class, "validate_config", None)):
|
||||
continue
|
||||
|
||||
# Skip if using base implementation (doesn't actually validate)
|
||||
if not self._has_custom_validation(ef_class):
|
||||
continue
|
||||
|
||||
# Create a real instance to get the actual config
|
||||
# We'll mock __call__ to avoid needing to actually generate embeddings
|
||||
try:
|
||||
ef_instance = self._create_ef_instance(
|
||||
ef_name, ef_class, mock_common_deps
|
||||
)
|
||||
except Exception:
|
||||
continue # Skip if we can't create instance
|
||||
|
||||
# Mock only __call__ to avoid needing to actually generate embeddings
|
||||
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
|
||||
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
|
||||
|
||||
# Get the actual config from the embedding function (this uses the real get_config method)
|
||||
config = ef_instance.get_config()
|
||||
|
||||
# Filter out None values - optional fields with None shouldn't be included in validation
|
||||
# This matches common JSON schema practice where optional fields are omitted rather than null
|
||||
config = {k: v for k, v in config.items() if v is not None}
|
||||
|
||||
# Validate the actual config using the embedding function's validate_config method
|
||||
ef_class.validate_config(config)
|
||||
|
||||
Reference in New Issue
Block a user