增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -1,7 +1,7 @@
import hashlib
import hypothesis
import hypothesis.strategies as st
from typing import Any, Optional, List, Dict, Union, cast
from typing import Any, Optional, List, Dict, Union, cast, Tuple
from typing_extensions import TypedDict
import uuid
import numpy as np
@@ -9,6 +9,10 @@ import numpy.typing as npt
import chromadb.api.types as types
import re
from hypothesis.strategies._internal.strategies import SearchStrategy
from chromadb.test.api.test_schema_e2e import (
SimpleEmbeddingFunction,
DeterministicSparseEmbeddingFunction,
)
from chromadb.test.conftest import NOT_CLUSTER_ONLY
from dataclasses import dataclass
from chromadb.api.types import (
@@ -17,12 +21,27 @@ from chromadb.api.types import (
EmbeddingFunction,
Embeddings,
Metadata,
Schema,
CollectionMetadata,
VectorIndexConfig,
SparseVectorIndexConfig,
StringInvertedIndexConfig,
IntInvertedIndexConfig,
FloatInvertedIndexConfig,
BoolInvertedIndexConfig,
HnswIndexConfig,
SpannIndexConfig,
Space,
)
from chromadb.types import LiteralValue, WhereOperator, LogicalOperator
from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled
from chromadb.test.conftest import is_spann_disabled_mode
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
CreateSpannConfiguration,
CreateHNSWConfiguration,
)
from chromadb.utils.embedding_functions import (
register_embedding_function,
)
# Set the random seed for reproducibility
@@ -266,6 +285,365 @@ class ExternalCollection:
embedding_function: Optional[types.EmbeddingFunction[Embeddable]]
@register_embedding_function
class SimpleIpEmbeddingFunction(SimpleEmbeddingFunction):
"""Simple embedding function with ip space for persistence tests."""
def default_space(self) -> str: # type: ignore[override]
return "ip"
@st.composite
def vector_index_config_strategy(draw: st.DrawFn) -> VectorIndexConfig:
"""Generate VectorIndexConfig with optional space, embedding_function, source_key, hnsw, spann."""
space = None
embedding_function = None
source_key = None
hnsw = None
spann = None
if draw(st.booleans()):
space = draw(st.sampled_from(["cosine", "l2", "ip"]))
if draw(st.booleans()):
embedding_function = SimpleIpEmbeddingFunction(
dim=draw(st.integers(min_value=1, max_value=1000))
)
if draw(st.booleans()):
source_key = draw(st.one_of(st.just("#document"), safe_text))
index_choice = draw(st.sampled_from(["hnsw", "spann", "none"]))
if index_choice == "hnsw":
hnsw = HnswIndexConfig(
ef_construction=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
max_neighbors=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
ef_search=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
sync_threshold=draw(st.integers(min_value=2, max_value=10000))
if draw(st.booleans())
else None,
resize_factor=draw(st.floats(min_value=1.0, max_value=5.0))
if draw(st.booleans())
else None,
)
elif index_choice == "spann":
spann = SpannIndexConfig(
search_nprobe=draw(st.integers(min_value=1, max_value=128))
if draw(st.booleans())
else None,
write_nprobe=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
ef_construction=draw(st.integers(min_value=1, max_value=200))
if draw(st.booleans())
else None,
ef_search=draw(st.integers(min_value=1, max_value=200))
if draw(st.booleans())
else None,
max_neighbors=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
reassign_neighbor_count=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
split_threshold=draw(st.integers(min_value=50, max_value=200))
if draw(st.booleans())
else None,
merge_threshold=draw(st.integers(min_value=25, max_value=100))
if draw(st.booleans())
else None,
)
return VectorIndexConfig(
space=cast(Space, space),
embedding_function=embedding_function,
source_key=source_key,
hnsw=hnsw,
spann=spann,
)
@st.composite
def sparse_vector_index_config_strategy(draw: st.DrawFn) -> SparseVectorIndexConfig:
"""Generate SparseVectorIndexConfig with optional embedding_function, source_key, bm25."""
embedding_function = None
source_key = None
bm25 = None
if draw(st.booleans()):
embedding_function = DeterministicSparseEmbeddingFunction()
source_key = draw(st.one_of(st.just("#document"), safe_text))
if draw(st.booleans()):
bm25 = draw(st.booleans())
return SparseVectorIndexConfig(
embedding_function=embedding_function,
source_key=source_key,
bm25=bm25,
)
@st.composite
def schema_strategy(draw: st.DrawFn) -> Optional[Schema]:
"""Generate a Schema object with various create_index/delete_index operations."""
if draw(st.booleans()):
return None
schema = Schema()
# Decide how many operations to perform
num_operations = draw(st.integers(min_value=0, max_value=5))
sparse_index_created = False
for _ in range(num_operations):
operation = draw(st.sampled_from(["create_index", "delete_index"]))
config_type = draw(
st.sampled_from(
[
"string_inverted",
"int_inverted",
"float_inverted",
"bool_inverted",
"vector",
"sparse_vector",
]
)
)
# Decide if we're setting on a key or globally
use_key = draw(st.booleans())
key = None
if use_key and config_type != "vector":
# Vector indexes can't be set on specific keys, only globally
key = draw(safe_text)
if operation == "create_index":
if config_type == "string_inverted":
schema.create_index(config=StringInvertedIndexConfig(), key=key)
elif config_type == "int_inverted":
schema.create_index(config=IntInvertedIndexConfig(), key=key)
elif config_type == "float_inverted":
schema.create_index(config=FloatInvertedIndexConfig(), key=key)
elif config_type == "bool_inverted":
schema.create_index(config=BoolInvertedIndexConfig(), key=key)
elif config_type == "vector":
vector_config = draw(vector_index_config_strategy())
schema.create_index(config=vector_config, key=None)
elif (
config_type == "sparse_vector"
and not is_spann_disabled_mode
and not sparse_index_created
):
sparse_config = draw(sparse_vector_index_config_strategy())
# Sparse vector MUST have a key
if key is None:
key = draw(safe_text)
schema.create_index(config=sparse_config, key=key)
sparse_index_created = True
elif operation == "delete_index":
if config_type == "string_inverted":
schema.delete_index(config=StringInvertedIndexConfig(), key=key)
elif config_type == "int_inverted":
schema.delete_index(config=IntInvertedIndexConfig(), key=key)
elif config_type == "float_inverted":
schema.delete_index(config=FloatInvertedIndexConfig(), key=key)
elif config_type == "bool_inverted":
schema.delete_index(config=BoolInvertedIndexConfig(), key=key)
# Vector, FTS, and sparse_vector deletion is not currently supported
return schema
@st.composite
def metadata_with_hnsw_strategy(draw: st.DrawFn) -> Optional[CollectionMetadata]:
"""Generate metadata with hnsw parameters."""
metadata: CollectionMetadata = {}
if draw(st.booleans()):
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
if draw(st.booleans()):
metadata["hnsw:construction_ef"] = draw(
st.integers(min_value=1, max_value=1000)
)
if draw(st.booleans()):
metadata["hnsw:search_ef"] = draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans()):
metadata["hnsw:M"] = draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans()):
metadata["hnsw:resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
if draw(st.booleans()):
metadata["hnsw:sync_threshold"] = draw(
st.integers(min_value=2, max_value=10000)
)
return metadata if metadata else None
@st.composite
def create_configuration_strategy(
draw: st.DrawFn,
) -> Optional[CreateCollectionConfiguration]:
"""Generate CreateCollectionConfiguration with mutual exclusivity rules."""
configuration: CreateCollectionConfiguration = {}
# Optionally set embedding_function (independent)
if draw(st.booleans()):
configuration["embedding_function"] = SimpleIpEmbeddingFunction(
dim=draw(st.integers(min_value=1, max_value=1000))
)
# Decide: set space only, OR set hnsw config, OR set spann config
config_choice = draw(
st.sampled_from(
["space_only_hnsw", "space_only_spann", "hnsw", "spann", "none"]
)
)
if config_choice == "space_only_hnsw":
configuration["hnsw"] = CreateHNSWConfiguration(
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
)
elif config_choice == "space_only_spann":
configuration["spann"] = CreateSpannConfiguration(
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
)
elif config_choice == "hnsw":
# Set hnsw config (optionally with space)
hnsw_config: CreateHNSWConfiguration = {}
if draw(st.booleans()):
hnsw_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
hnsw_config["ef_construction"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["ef_search"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["sync_threshold"] = draw(st.integers(min_value=2, max_value=10000))
hnsw_config["resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
configuration["hnsw"] = hnsw_config
elif config_choice == "spann":
# Set spann config (optionally with space)
spann_config: CreateSpannConfiguration = {}
if draw(st.booleans()):
spann_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
spann_config["search_nprobe"] = draw(st.integers(min_value=1, max_value=128))
spann_config["write_nprobe"] = draw(st.integers(min_value=1, max_value=64))
spann_config["ef_construction"] = draw(st.integers(min_value=1, max_value=200))
spann_config["ef_search"] = draw(st.integers(min_value=1, max_value=200))
spann_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=64))
spann_config["reassign_neighbor_count"] = draw(
st.integers(min_value=1, max_value=64)
)
spann_config["split_threshold"] = draw(st.integers(min_value=50, max_value=200))
spann_config["merge_threshold"] = draw(st.integers(min_value=25, max_value=100))
configuration["spann"] = spann_config
return configuration if configuration else None
@dataclass
class CollectionInputCombination:
"""
Input tuple for collection creation tests.
"""
metadata: Optional[CollectionMetadata]
configuration: Optional[CreateCollectionConfiguration]
schema: Optional[Schema]
schema_vector_info: Optional[Dict[str, Any]]
kind: str
def non_none_items(items: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in items.items() if v is not None}
def vector_index_to_dict(config: VectorIndexConfig) -> Dict[str, Any]:
embedding_default_space: Optional[str] = None
if config.embedding_function is not None and hasattr(
config.embedding_function, "default_space"
):
embedding_default_space = cast(str, config.embedding_function.default_space())
return {
"space": config.space,
"hnsw": config.hnsw.model_dump(exclude_none=True) if config.hnsw else None,
"spann": config.spann.model_dump(exclude_none=True) if config.spann else None,
"embedding_function_default_space": embedding_default_space,
}
@st.composite
def _schema_input_strategy(
draw: st.DrawFn,
) -> Tuple[Schema, Dict[str, Any]]:
schema = Schema()
vector_config = draw(vector_index_config_strategy())
schema.create_index(config=vector_config, key=None)
return schema, vector_index_to_dict(vector_config)
@st.composite
def metadata_configuration_schema_strategy(
draw: st.DrawFn,
) -> CollectionInputCombination:
"""
Generate compatible combinations of metadata, configuration, and schema inputs.
"""
choice = draw(
st.sampled_from(
[
"none",
"metadata",
"configuration",
"metadata_configuration",
"schema",
]
)
)
metadata: Optional[CollectionMetadata] = None
configuration: Optional[CreateCollectionConfiguration] = None
schema: Optional[Schema] = None
schema_info: Optional[Dict[str, Any]] = None
if choice in ("metadata", "metadata_configuration"):
metadata = draw(
metadata_with_hnsw_strategy().filter(
lambda value: value is not None and len(value) > 0
)
)
if choice in ("configuration", "metadata_configuration"):
configuration = draw(
create_configuration_strategy().filter(
lambda value: value is not None
and (
(value.get("hnsw") is not None and len(value["hnsw"]) > 0)
or (value.get("spann") is not None and len(value["spann"]) > 0)
)
)
)
if choice == "schema":
schema, schema_info = draw(_schema_input_strategy())
return CollectionInputCombination(
metadata=metadata,
configuration=configuration,
schema=schema,
schema_vector_info=schema_info,
kind=choice,
)
@dataclass
class Collection(ExternalCollection):
"""
@@ -344,7 +722,7 @@ def collections(
spann_config: CreateSpannConfiguration = {
"space": spann_space,
"write_nprobe": 4,
"reassign_neighbor_count": 4
"reassign_neighbor_count": 4,
}
collection_config = {
"spann": spann_config,
@@ -395,7 +773,7 @@ def collections(
known_document_keywords=known_document_keywords,
has_embeddings=has_embeddings,
embedding_function=embedding_function,
collection_config=collection_config
collection_config=collection_config,
)
@@ -421,7 +799,9 @@ def metadata(
del metadata[key] # type: ignore
# Finally, add in some of the known keys for the collection
sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = {
k: st.just(v) for k, v in collection.known_metadata_keys.items()
k: st.just(v)
for k, v in collection.known_metadata_keys.items()
if isinstance(v, (str, int, float))
}
metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore
# We don't allow submitting empty metadata

View File

@@ -31,7 +31,7 @@ collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="co
normal=hypothesis.settings(max_examples=500),
fast=hypothesis.settings(max_examples=200),
),
max_examples=2
max_examples=2,
)
def test_add_miniscule(
client: ClientAPI,
@@ -332,7 +332,8 @@ def test_out_of_order_ids(client: ClientAPI) -> None:
]
coll = client.create_collection(
"test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore
"test",
embedding_function=lambda input: [[1, 2, 3] for _ in input], # type: ignore
)
embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids]
coll.add(ids=ooo_ids, embeddings=embeddings)
@@ -369,3 +370,155 @@ def test_add_partial(client: ClientAPI) -> None:
assert results["ids"] == ["1", "2", "3"]
assert results["metadatas"] == [{"a": 1}, None, {"a": 3}]
assert results["documents"] == ["a", "b", None]
@pytest.mark.skipif(
NOT_CLUSTER_ONLY,
reason="GroupBy is only supported in distributed mode",
)
def test_search_group_by(client: ClientAPI) -> None:
"""Test GroupBy with single key, multiple keys, and multiple ranking keys."""
from chromadb.execution.expression.operator import GroupBy, MinK, Key
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression import Knn
create_isolated_database(client)
coll = client.create_collection(name="test_group_by")
# Test data: 12 records across 3 categories and 2 years
# Embeddings are designed so science docs are closest to query [1,0,0,0]
ids = [
"sci_2023_1",
"sci_2023_2",
"sci_2024_1",
"sci_2024_2",
"tech_2023_1",
"tech_2023_2",
"tech_2024_1",
"tech_2024_2",
"arts_2023_1",
"arts_2023_2",
"arts_2024_1",
"arts_2024_2",
]
embeddings = cast(
Embeddings,
[
# Science - closest to [1,0,0,0]
[1.0, 0.0, 0.0, 0.0], # sci_2023_1: score ~0.0
[0.9, 0.1, 0.0, 0.0], # sci_2023_2: score ~0.141
[0.8, 0.2, 0.0, 0.0], # sci_2024_1: score ~0.283
[0.7, 0.3, 0.0, 0.0], # sci_2024_2: score ~0.424
# Tech - farther from [1,0,0,0]
[0.0, 1.0, 0.0, 0.0], # tech_2023_1: score ~1.414
[0.0, 0.9, 0.1, 0.0], # tech_2023_2: score ~1.345
[0.0, 0.8, 0.2, 0.0], # tech_2024_1: score ~1.281
[0.0, 0.7, 0.3, 0.0], # tech_2024_2: score ~1.221
# Arts - farther from [1,0,0,0]
[0.0, 0.0, 1.0, 0.0], # arts_2023_1: score ~1.414
[0.0, 0.0, 0.9, 0.1], # arts_2023_2: score ~1.345
[0.0, 0.0, 0.8, 0.2], # arts_2024_1: score ~1.281
[0.0, 0.0, 0.7, 0.3], # arts_2024_2: score ~1.221
],
)
metadatas: Metadatas = [
{"category": "science", "year": 2023, "priority": 1},
{"category": "science", "year": 2023, "priority": 2},
{"category": "science", "year": 2024, "priority": 1},
{"category": "science", "year": 2024, "priority": 3},
{"category": "tech", "year": 2023, "priority": 2},
{"category": "tech", "year": 2023, "priority": 1},
{"category": "tech", "year": 2024, "priority": 1},
{"category": "tech", "year": 2024, "priority": 2},
{"category": "arts", "year": 2023, "priority": 3},
{"category": "arts", "year": 2023, "priority": 1},
{"category": "arts", "year": 2024, "priority": 2},
{"category": "arts", "year": 2024, "priority": 1},
]
documents = [f"doc_{id}" for id in ids]
coll.add(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
)
query = [1.0, 0.0, 0.0, 0.0]
# Test 1: Single key grouping - top 2 per category by score
# Expected: 2 best from each category (science, tech, arts)
# - science: sci_2023_1 (0.0), sci_2023_2 (0.141)
# - tech: tech_2024_2 (1.221), tech_2024_1 (1.281)
# - arts: arts_2024_2 (1.221), arts_2024_1 (1.281)
results1 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=2)))
.limit(12)
)
assert results1["ids"] is not None
result1_ids = results1["ids"][0]
assert len(result1_ids) == 6
expected1 = {
"sci_2023_1",
"sci_2023_2",
"tech_2024_2",
"tech_2024_1",
"arts_2024_2",
"arts_2024_1",
}
assert set(result1_ids) == expected1
# Test 2: Multiple key grouping - top 1 per (category, year) combination
# 6 groups: (science,2023), (science,2024), (tech,2023), (tech,2024), (arts,2023), (arts,2024)
results2 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(
GroupBy(
keys=[Key("category"), Key("year")],
aggregate=MinK(keys=Key.SCORE, k=1),
)
)
.limit(12)
)
assert results2["ids"] is not None
result2_ids = results2["ids"][0]
assert len(result2_ids) == 6
expected2 = {
"sci_2023_1",
"sci_2024_1",
"tech_2023_2",
"tech_2024_2",
"arts_2023_2",
"arts_2024_2",
}
assert set(result2_ids) == expected2
# Test 3: Multiple ranking keys - priority first, then score as tiebreaker
# Top 2 per category, sorted by priority (ascending), then score (ascending)
results3 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=2),
)
)
.limit(12)
)
assert results3["ids"] is not None
result3_ids = results3["ids"][0]
assert len(result3_ids) == 6
expected3 = {
"sci_2023_1",
"sci_2024_1",
"tech_2024_1",
"tech_2023_2",
"arts_2024_2",
"arts_2023_2",
}
assert set(result3_ids) == expected3