增加环绕侦察场景适配
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import hypothesis
|
||||
import hypothesis.strategies as st
|
||||
from typing import Any, Optional, List, Dict, Union, cast
|
||||
from typing import Any, Optional, List, Dict, Union, cast, Tuple
|
||||
from typing_extensions import TypedDict
|
||||
import uuid
|
||||
import numpy as np
|
||||
@@ -9,6 +9,10 @@ import numpy.typing as npt
|
||||
import chromadb.api.types as types
|
||||
import re
|
||||
from hypothesis.strategies._internal.strategies import SearchStrategy
|
||||
from chromadb.test.api.test_schema_e2e import (
|
||||
SimpleEmbeddingFunction,
|
||||
DeterministicSparseEmbeddingFunction,
|
||||
)
|
||||
from chromadb.test.conftest import NOT_CLUSTER_ONLY
|
||||
from dataclasses import dataclass
|
||||
from chromadb.api.types import (
|
||||
@@ -17,12 +21,27 @@ from chromadb.api.types import (
|
||||
EmbeddingFunction,
|
||||
Embeddings,
|
||||
Metadata,
|
||||
Schema,
|
||||
CollectionMetadata,
|
||||
VectorIndexConfig,
|
||||
SparseVectorIndexConfig,
|
||||
StringInvertedIndexConfig,
|
||||
IntInvertedIndexConfig,
|
||||
FloatInvertedIndexConfig,
|
||||
BoolInvertedIndexConfig,
|
||||
HnswIndexConfig,
|
||||
SpannIndexConfig,
|
||||
Space,
|
||||
)
|
||||
from chromadb.types import LiteralValue, WhereOperator, LogicalOperator
|
||||
from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled
|
||||
from chromadb.test.conftest import is_spann_disabled_mode
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
CreateSpannConfiguration,
|
||||
CreateHNSWConfiguration,
|
||||
)
|
||||
from chromadb.utils.embedding_functions import (
|
||||
register_embedding_function,
|
||||
)
|
||||
|
||||
# Set the random seed for reproducibility
|
||||
@@ -266,6 +285,365 @@ class ExternalCollection:
|
||||
embedding_function: Optional[types.EmbeddingFunction[Embeddable]]
|
||||
|
||||
|
||||
@register_embedding_function
|
||||
class SimpleIpEmbeddingFunction(SimpleEmbeddingFunction):
|
||||
"""Simple embedding function with ip space for persistence tests."""
|
||||
|
||||
def default_space(self) -> str: # type: ignore[override]
|
||||
return "ip"
|
||||
|
||||
|
||||
@st.composite
|
||||
def vector_index_config_strategy(draw: st.DrawFn) -> VectorIndexConfig:
|
||||
"""Generate VectorIndexConfig with optional space, embedding_function, source_key, hnsw, spann."""
|
||||
space = None
|
||||
embedding_function = None
|
||||
source_key = None
|
||||
hnsw = None
|
||||
spann = None
|
||||
|
||||
if draw(st.booleans()):
|
||||
space = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
|
||||
if draw(st.booleans()):
|
||||
embedding_function = SimpleIpEmbeddingFunction(
|
||||
dim=draw(st.integers(min_value=1, max_value=1000))
|
||||
)
|
||||
|
||||
if draw(st.booleans()):
|
||||
source_key = draw(st.one_of(st.just("#document"), safe_text))
|
||||
|
||||
index_choice = draw(st.sampled_from(["hnsw", "spann", "none"]))
|
||||
|
||||
if index_choice == "hnsw":
|
||||
hnsw = HnswIndexConfig(
|
||||
ef_construction=draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
max_neighbors=draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
ef_search=draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
sync_threshold=draw(st.integers(min_value=2, max_value=10000))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
resize_factor=draw(st.floats(min_value=1.0, max_value=5.0))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
)
|
||||
elif index_choice == "spann":
|
||||
spann = SpannIndexConfig(
|
||||
search_nprobe=draw(st.integers(min_value=1, max_value=128))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
write_nprobe=draw(st.integers(min_value=1, max_value=64))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
ef_construction=draw(st.integers(min_value=1, max_value=200))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
ef_search=draw(st.integers(min_value=1, max_value=200))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
max_neighbors=draw(st.integers(min_value=1, max_value=64))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
reassign_neighbor_count=draw(st.integers(min_value=1, max_value=64))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
split_threshold=draw(st.integers(min_value=50, max_value=200))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
merge_threshold=draw(st.integers(min_value=25, max_value=100))
|
||||
if draw(st.booleans())
|
||||
else None,
|
||||
)
|
||||
|
||||
return VectorIndexConfig(
|
||||
space=cast(Space, space),
|
||||
embedding_function=embedding_function,
|
||||
source_key=source_key,
|
||||
hnsw=hnsw,
|
||||
spann=spann,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def sparse_vector_index_config_strategy(draw: st.DrawFn) -> SparseVectorIndexConfig:
|
||||
"""Generate SparseVectorIndexConfig with optional embedding_function, source_key, bm25."""
|
||||
embedding_function = None
|
||||
source_key = None
|
||||
bm25 = None
|
||||
|
||||
if draw(st.booleans()):
|
||||
embedding_function = DeterministicSparseEmbeddingFunction()
|
||||
source_key = draw(st.one_of(st.just("#document"), safe_text))
|
||||
|
||||
if draw(st.booleans()):
|
||||
bm25 = draw(st.booleans())
|
||||
|
||||
return SparseVectorIndexConfig(
|
||||
embedding_function=embedding_function,
|
||||
source_key=source_key,
|
||||
bm25=bm25,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def schema_strategy(draw: st.DrawFn) -> Optional[Schema]:
|
||||
"""Generate a Schema object with various create_index/delete_index operations."""
|
||||
if draw(st.booleans()):
|
||||
return None
|
||||
|
||||
schema = Schema()
|
||||
|
||||
# Decide how many operations to perform
|
||||
num_operations = draw(st.integers(min_value=0, max_value=5))
|
||||
sparse_index_created = False
|
||||
|
||||
for _ in range(num_operations):
|
||||
operation = draw(st.sampled_from(["create_index", "delete_index"]))
|
||||
config_type = draw(
|
||||
st.sampled_from(
|
||||
[
|
||||
"string_inverted",
|
||||
"int_inverted",
|
||||
"float_inverted",
|
||||
"bool_inverted",
|
||||
"vector",
|
||||
"sparse_vector",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Decide if we're setting on a key or globally
|
||||
use_key = draw(st.booleans())
|
||||
key = None
|
||||
if use_key and config_type != "vector":
|
||||
# Vector indexes can't be set on specific keys, only globally
|
||||
key = draw(safe_text)
|
||||
|
||||
if operation == "create_index":
|
||||
if config_type == "string_inverted":
|
||||
schema.create_index(config=StringInvertedIndexConfig(), key=key)
|
||||
elif config_type == "int_inverted":
|
||||
schema.create_index(config=IntInvertedIndexConfig(), key=key)
|
||||
elif config_type == "float_inverted":
|
||||
schema.create_index(config=FloatInvertedIndexConfig(), key=key)
|
||||
elif config_type == "bool_inverted":
|
||||
schema.create_index(config=BoolInvertedIndexConfig(), key=key)
|
||||
elif config_type == "vector":
|
||||
vector_config = draw(vector_index_config_strategy())
|
||||
schema.create_index(config=vector_config, key=None)
|
||||
elif (
|
||||
config_type == "sparse_vector"
|
||||
and not is_spann_disabled_mode
|
||||
and not sparse_index_created
|
||||
):
|
||||
sparse_config = draw(sparse_vector_index_config_strategy())
|
||||
# Sparse vector MUST have a key
|
||||
if key is None:
|
||||
key = draw(safe_text)
|
||||
schema.create_index(config=sparse_config, key=key)
|
||||
sparse_index_created = True
|
||||
|
||||
elif operation == "delete_index":
|
||||
if config_type == "string_inverted":
|
||||
schema.delete_index(config=StringInvertedIndexConfig(), key=key)
|
||||
elif config_type == "int_inverted":
|
||||
schema.delete_index(config=IntInvertedIndexConfig(), key=key)
|
||||
elif config_type == "float_inverted":
|
||||
schema.delete_index(config=FloatInvertedIndexConfig(), key=key)
|
||||
elif config_type == "bool_inverted":
|
||||
schema.delete_index(config=BoolInvertedIndexConfig(), key=key)
|
||||
# Vector, FTS, and sparse_vector deletion is not currently supported
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@st.composite
|
||||
def metadata_with_hnsw_strategy(draw: st.DrawFn) -> Optional[CollectionMetadata]:
|
||||
"""Generate metadata with hnsw parameters."""
|
||||
metadata: CollectionMetadata = {}
|
||||
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:construction_ef"] = draw(
|
||||
st.integers(min_value=1, max_value=1000)
|
||||
)
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:search_ef"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:M"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
|
||||
if draw(st.booleans()):
|
||||
metadata["hnsw:sync_threshold"] = draw(
|
||||
st.integers(min_value=2, max_value=10000)
|
||||
)
|
||||
|
||||
return metadata if metadata else None
|
||||
|
||||
|
||||
@st.composite
|
||||
def create_configuration_strategy(
|
||||
draw: st.DrawFn,
|
||||
) -> Optional[CreateCollectionConfiguration]:
|
||||
"""Generate CreateCollectionConfiguration with mutual exclusivity rules."""
|
||||
configuration: CreateCollectionConfiguration = {}
|
||||
|
||||
# Optionally set embedding_function (independent)
|
||||
if draw(st.booleans()):
|
||||
configuration["embedding_function"] = SimpleIpEmbeddingFunction(
|
||||
dim=draw(st.integers(min_value=1, max_value=1000))
|
||||
)
|
||||
|
||||
# Decide: set space only, OR set hnsw config, OR set spann config
|
||||
config_choice = draw(
|
||||
st.sampled_from(
|
||||
["space_only_hnsw", "space_only_spann", "hnsw", "spann", "none"]
|
||||
)
|
||||
)
|
||||
|
||||
if config_choice == "space_only_hnsw":
|
||||
configuration["hnsw"] = CreateHNSWConfiguration(
|
||||
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
)
|
||||
elif config_choice == "space_only_spann":
|
||||
configuration["spann"] = CreateSpannConfiguration(
|
||||
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
)
|
||||
elif config_choice == "hnsw":
|
||||
# Set hnsw config (optionally with space)
|
||||
hnsw_config: CreateHNSWConfiguration = {}
|
||||
if draw(st.booleans()):
|
||||
hnsw_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
hnsw_config["ef_construction"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
hnsw_config["ef_search"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
hnsw_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=1000))
|
||||
hnsw_config["sync_threshold"] = draw(st.integers(min_value=2, max_value=10000))
|
||||
hnsw_config["resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
|
||||
configuration["hnsw"] = hnsw_config
|
||||
elif config_choice == "spann":
|
||||
# Set spann config (optionally with space)
|
||||
spann_config: CreateSpannConfiguration = {}
|
||||
if draw(st.booleans()):
|
||||
spann_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
|
||||
spann_config["search_nprobe"] = draw(st.integers(min_value=1, max_value=128))
|
||||
spann_config["write_nprobe"] = draw(st.integers(min_value=1, max_value=64))
|
||||
spann_config["ef_construction"] = draw(st.integers(min_value=1, max_value=200))
|
||||
spann_config["ef_search"] = draw(st.integers(min_value=1, max_value=200))
|
||||
spann_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=64))
|
||||
spann_config["reassign_neighbor_count"] = draw(
|
||||
st.integers(min_value=1, max_value=64)
|
||||
)
|
||||
spann_config["split_threshold"] = draw(st.integers(min_value=50, max_value=200))
|
||||
spann_config["merge_threshold"] = draw(st.integers(min_value=25, max_value=100))
|
||||
configuration["spann"] = spann_config
|
||||
|
||||
return configuration if configuration else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectionInputCombination:
|
||||
"""
|
||||
Input tuple for collection creation tests.
|
||||
"""
|
||||
|
||||
metadata: Optional[CollectionMetadata]
|
||||
configuration: Optional[CreateCollectionConfiguration]
|
||||
schema: Optional[Schema]
|
||||
schema_vector_info: Optional[Dict[str, Any]]
|
||||
kind: str
|
||||
|
||||
|
||||
def non_none_items(items: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: v for k, v in items.items() if v is not None}
|
||||
|
||||
|
||||
def vector_index_to_dict(config: VectorIndexConfig) -> Dict[str, Any]:
|
||||
embedding_default_space: Optional[str] = None
|
||||
if config.embedding_function is not None and hasattr(
|
||||
config.embedding_function, "default_space"
|
||||
):
|
||||
embedding_default_space = cast(str, config.embedding_function.default_space())
|
||||
|
||||
return {
|
||||
"space": config.space,
|
||||
"hnsw": config.hnsw.model_dump(exclude_none=True) if config.hnsw else None,
|
||||
"spann": config.spann.model_dump(exclude_none=True) if config.spann else None,
|
||||
"embedding_function_default_space": embedding_default_space,
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def _schema_input_strategy(
|
||||
draw: st.DrawFn,
|
||||
) -> Tuple[Schema, Dict[str, Any]]:
|
||||
schema = Schema()
|
||||
vector_config = draw(vector_index_config_strategy())
|
||||
schema.create_index(config=vector_config, key=None)
|
||||
return schema, vector_index_to_dict(vector_config)
|
||||
|
||||
|
||||
@st.composite
|
||||
def metadata_configuration_schema_strategy(
|
||||
draw: st.DrawFn,
|
||||
) -> CollectionInputCombination:
|
||||
"""
|
||||
Generate compatible combinations of metadata, configuration, and schema inputs.
|
||||
"""
|
||||
|
||||
choice = draw(
|
||||
st.sampled_from(
|
||||
[
|
||||
"none",
|
||||
"metadata",
|
||||
"configuration",
|
||||
"metadata_configuration",
|
||||
"schema",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
metadata: Optional[CollectionMetadata] = None
|
||||
configuration: Optional[CreateCollectionConfiguration] = None
|
||||
schema: Optional[Schema] = None
|
||||
schema_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
if choice in ("metadata", "metadata_configuration"):
|
||||
metadata = draw(
|
||||
metadata_with_hnsw_strategy().filter(
|
||||
lambda value: value is not None and len(value) > 0
|
||||
)
|
||||
)
|
||||
|
||||
if choice in ("configuration", "metadata_configuration"):
|
||||
configuration = draw(
|
||||
create_configuration_strategy().filter(
|
||||
lambda value: value is not None
|
||||
and (
|
||||
(value.get("hnsw") is not None and len(value["hnsw"]) > 0)
|
||||
or (value.get("spann") is not None and len(value["spann"]) > 0)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if choice == "schema":
|
||||
schema, schema_info = draw(_schema_input_strategy())
|
||||
|
||||
return CollectionInputCombination(
|
||||
metadata=metadata,
|
||||
configuration=configuration,
|
||||
schema=schema,
|
||||
schema_vector_info=schema_info,
|
||||
kind=choice,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Collection(ExternalCollection):
|
||||
"""
|
||||
@@ -344,7 +722,7 @@ def collections(
|
||||
spann_config: CreateSpannConfiguration = {
|
||||
"space": spann_space,
|
||||
"write_nprobe": 4,
|
||||
"reassign_neighbor_count": 4
|
||||
"reassign_neighbor_count": 4,
|
||||
}
|
||||
collection_config = {
|
||||
"spann": spann_config,
|
||||
@@ -395,7 +773,7 @@ def collections(
|
||||
known_document_keywords=known_document_keywords,
|
||||
has_embeddings=has_embeddings,
|
||||
embedding_function=embedding_function,
|
||||
collection_config=collection_config
|
||||
collection_config=collection_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -421,7 +799,9 @@ def metadata(
|
||||
del metadata[key] # type: ignore
|
||||
# Finally, add in some of the known keys for the collection
|
||||
sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = {
|
||||
k: st.just(v) for k, v in collection.known_metadata_keys.items()
|
||||
k: st.just(v)
|
||||
for k, v in collection.known_metadata_keys.items()
|
||||
if isinstance(v, (str, int, float))
|
||||
}
|
||||
metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore
|
||||
# We don't allow submitting empty metadata
|
||||
|
||||
@@ -31,7 +31,7 @@ collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="co
|
||||
normal=hypothesis.settings(max_examples=500),
|
||||
fast=hypothesis.settings(max_examples=200),
|
||||
),
|
||||
max_examples=2
|
||||
max_examples=2,
|
||||
)
|
||||
def test_add_miniscule(
|
||||
client: ClientAPI,
|
||||
@@ -332,7 +332,8 @@ def test_out_of_order_ids(client: ClientAPI) -> None:
|
||||
]
|
||||
|
||||
coll = client.create_collection(
|
||||
"test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore
|
||||
"test",
|
||||
embedding_function=lambda input: [[1, 2, 3] for _ in input], # type: ignore
|
||||
)
|
||||
embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids]
|
||||
coll.add(ids=ooo_ids, embeddings=embeddings)
|
||||
@@ -369,3 +370,155 @@ def test_add_partial(client: ClientAPI) -> None:
|
||||
assert results["ids"] == ["1", "2", "3"]
|
||||
assert results["metadatas"] == [{"a": 1}, None, {"a": 3}]
|
||||
assert results["documents"] == ["a", "b", None]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
NOT_CLUSTER_ONLY,
|
||||
reason="GroupBy is only supported in distributed mode",
|
||||
)
|
||||
def test_search_group_by(client: ClientAPI) -> None:
|
||||
"""Test GroupBy with single key, multiple keys, and multiple ranking keys."""
|
||||
from chromadb.execution.expression.operator import GroupBy, MinK, Key
|
||||
from chromadb.execution.expression.plan import Search
|
||||
from chromadb.execution.expression import Knn
|
||||
|
||||
create_isolated_database(client)
|
||||
|
||||
coll = client.create_collection(name="test_group_by")
|
||||
|
||||
# Test data: 12 records across 3 categories and 2 years
|
||||
# Embeddings are designed so science docs are closest to query [1,0,0,0]
|
||||
ids = [
|
||||
"sci_2023_1",
|
||||
"sci_2023_2",
|
||||
"sci_2024_1",
|
||||
"sci_2024_2",
|
||||
"tech_2023_1",
|
||||
"tech_2023_2",
|
||||
"tech_2024_1",
|
||||
"tech_2024_2",
|
||||
"arts_2023_1",
|
||||
"arts_2023_2",
|
||||
"arts_2024_1",
|
||||
"arts_2024_2",
|
||||
]
|
||||
embeddings = cast(
|
||||
Embeddings,
|
||||
[
|
||||
# Science - closest to [1,0,0,0]
|
||||
[1.0, 0.0, 0.0, 0.0], # sci_2023_1: score ~0.0
|
||||
[0.9, 0.1, 0.0, 0.0], # sci_2023_2: score ~0.141
|
||||
[0.8, 0.2, 0.0, 0.0], # sci_2024_1: score ~0.283
|
||||
[0.7, 0.3, 0.0, 0.0], # sci_2024_2: score ~0.424
|
||||
# Tech - farther from [1,0,0,0]
|
||||
[0.0, 1.0, 0.0, 0.0], # tech_2023_1: score ~1.414
|
||||
[0.0, 0.9, 0.1, 0.0], # tech_2023_2: score ~1.345
|
||||
[0.0, 0.8, 0.2, 0.0], # tech_2024_1: score ~1.281
|
||||
[0.0, 0.7, 0.3, 0.0], # tech_2024_2: score ~1.221
|
||||
# Arts - farther from [1,0,0,0]
|
||||
[0.0, 0.0, 1.0, 0.0], # arts_2023_1: score ~1.414
|
||||
[0.0, 0.0, 0.9, 0.1], # arts_2023_2: score ~1.345
|
||||
[0.0, 0.0, 0.8, 0.2], # arts_2024_1: score ~1.281
|
||||
[0.0, 0.0, 0.7, 0.3], # arts_2024_2: score ~1.221
|
||||
],
|
||||
)
|
||||
metadatas: Metadatas = [
|
||||
{"category": "science", "year": 2023, "priority": 1},
|
||||
{"category": "science", "year": 2023, "priority": 2},
|
||||
{"category": "science", "year": 2024, "priority": 1},
|
||||
{"category": "science", "year": 2024, "priority": 3},
|
||||
{"category": "tech", "year": 2023, "priority": 2},
|
||||
{"category": "tech", "year": 2023, "priority": 1},
|
||||
{"category": "tech", "year": 2024, "priority": 1},
|
||||
{"category": "tech", "year": 2024, "priority": 2},
|
||||
{"category": "arts", "year": 2023, "priority": 3},
|
||||
{"category": "arts", "year": 2023, "priority": 1},
|
||||
{"category": "arts", "year": 2024, "priority": 2},
|
||||
{"category": "arts", "year": 2024, "priority": 1},
|
||||
]
|
||||
documents = [f"doc_{id}" for id in ids]
|
||||
|
||||
coll.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
query = [1.0, 0.0, 0.0, 0.0]
|
||||
|
||||
# Test 1: Single key grouping - top 2 per category by score
|
||||
# Expected: 2 best from each category (science, tech, arts)
|
||||
# - science: sci_2023_1 (0.0), sci_2023_2 (0.141)
|
||||
# - tech: tech_2024_2 (1.221), tech_2024_1 (1.281)
|
||||
# - arts: arts_2024_2 (1.221), arts_2024_1 (1.281)
|
||||
results1 = coll.search(
|
||||
Search()
|
||||
.rank(Knn(query=query, limit=12))
|
||||
.group_by(GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=2)))
|
||||
.limit(12)
|
||||
)
|
||||
assert results1["ids"] is not None
|
||||
result1_ids = results1["ids"][0]
|
||||
assert len(result1_ids) == 6
|
||||
expected1 = {
|
||||
"sci_2023_1",
|
||||
"sci_2023_2",
|
||||
"tech_2024_2",
|
||||
"tech_2024_1",
|
||||
"arts_2024_2",
|
||||
"arts_2024_1",
|
||||
}
|
||||
assert set(result1_ids) == expected1
|
||||
|
||||
# Test 2: Multiple key grouping - top 1 per (category, year) combination
|
||||
# 6 groups: (science,2023), (science,2024), (tech,2023), (tech,2024), (arts,2023), (arts,2024)
|
||||
results2 = coll.search(
|
||||
Search()
|
||||
.rank(Knn(query=query, limit=12))
|
||||
.group_by(
|
||||
GroupBy(
|
||||
keys=[Key("category"), Key("year")],
|
||||
aggregate=MinK(keys=Key.SCORE, k=1),
|
||||
)
|
||||
)
|
||||
.limit(12)
|
||||
)
|
||||
assert results2["ids"] is not None
|
||||
result2_ids = results2["ids"][0]
|
||||
assert len(result2_ids) == 6
|
||||
expected2 = {
|
||||
"sci_2023_1",
|
||||
"sci_2024_1",
|
||||
"tech_2023_2",
|
||||
"tech_2024_2",
|
||||
"arts_2023_2",
|
||||
"arts_2024_2",
|
||||
}
|
||||
assert set(result2_ids) == expected2
|
||||
|
||||
# Test 3: Multiple ranking keys - priority first, then score as tiebreaker
|
||||
# Top 2 per category, sorted by priority (ascending), then score (ascending)
|
||||
results3 = coll.search(
|
||||
Search()
|
||||
.rank(Knn(query=query, limit=12))
|
||||
.group_by(
|
||||
GroupBy(
|
||||
keys=Key("category"),
|
||||
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=2),
|
||||
)
|
||||
)
|
||||
.limit(12)
|
||||
)
|
||||
assert results3["ids"] is not None
|
||||
result3_ids = results3["ids"][0]
|
||||
assert len(result3_ids) == 6
|
||||
expected3 = {
|
||||
"sci_2023_1",
|
||||
"sci_2024_1",
|
||||
"tech_2024_1",
|
||||
"tech_2023_2",
|
||||
"arts_2024_2",
|
||||
"arts_2023_2",
|
||||
}
|
||||
assert set(result3_ids) == expected3
|
||||
|
||||
Reference in New Issue
Block a user