增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -39,6 +39,11 @@ from chromadb.execution.expression.operator import (
Sub,
Sum,
Val,
# GroupBy and Aggregate expressions
Aggregate,
MinK,
MaxK,
GroupBy,
)
from chromadb.execution.expression.plan import (
@@ -87,4 +92,9 @@ __all__ = [
"Sub",
"Sum",
"Val",
# GroupBy and Aggregate expressions
"Aggregate",
"MinK",
"MaxK",
"GroupBy",
]

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Set, Any, Union
from typing import Optional, List, Dict, Set, Any, Union, cast
import numpy as np
from numpy.typing import NDArray
@@ -7,9 +7,11 @@ from chromadb.api.types import (
Embeddings,
IDs,
Include,
OneOrMany,
SparseVector,
TYPE_KEY,
SPARSE_VECTOR_TYPE_VALUE,
maybe_cast_one_to_many,
normalize_embeddings,
validate_embeddings,
)
@@ -1024,7 +1026,7 @@ class Knn(Rank):
- A dense vector (list or numpy array)
- A sparse vector (SparseVector dict)
key: The embedding key to search against. Can be:
- "#embedding" (default) - searches the main embedding field
- Key.EMBEDDING (default) - searches the main embedding field
- A metadata field name (e.g., "my_custom_field") - searches that metadata field
limit: Maximum number of results to consider (default: 16)
default: Default score for records not in KNN results (default: None)
@@ -1054,7 +1056,7 @@ class Knn(Rank):
"NDArray[np.float64]",
"NDArray[np.int32]",
]
key: str = "#embedding"
key: Union[Key, str] = K.EMBEDDING
limit: int = 16
default: Optional[float] = None
return_rank: bool = False
@@ -1069,8 +1071,12 @@ class Knn(Rank):
# Convert numpy array to list
query_value = query_value.tolist()
key_value = self.key
if isinstance(key_value, Key):
key_value = key_value.name
# Build result dict - only include non-default values to keep JSON clean
result = {"query": query_value, "key": self.key, "limit": self.limit}
result = {"query": query_value, "key": key_value, "limit": self.limit}
# Only include optional fields if they're set to non-default values
if self.default is not None:
@@ -1291,3 +1297,216 @@ class Select:
# Convert to set while preserving the Key instances
return Select(keys=set(key_list))
# GroupBy and Aggregate types for grouping search results
def _keys_to_strings(keys: OneOrMany[Union[Key, str]]) -> List[str]:
"""Convert OneOrMany[Key|str] to List[str] for serialization."""
keys_list = cast(List[Union[Key, str]], maybe_cast_one_to_many(keys))
return [k.name if isinstance(k, Key) else k for k in keys_list]
def _strings_to_keys(keys: Union[List[Any], tuple[Any, ...]]) -> List[Union[Key, str]]:
"""Convert List[str] to List[Key] for deserialization."""
return [Key(k) if isinstance(k, str) else k for k in keys]
def _parse_k_aggregate(
op: str, data: Dict[str, Any]
) -> tuple[List[Union[Key, str]], int]:
"""Parse common fields for MinK/MaxK from dict.
Args:
op: The operator name (e.g., "$min_k" or "$max_k")
data: The dict containing the operator
Returns:
Tuple of (keys, k) where keys is List[Union[Key, str]] and k is int
Raises:
TypeError: If data types are invalid
ValueError: If required fields are missing or invalid
"""
agg_data = data[op]
if not isinstance(agg_data, dict):
raise TypeError(f"{op} requires a dict, got {type(agg_data).__name__}")
if "keys" not in agg_data:
raise ValueError(f"{op} requires 'keys' field")
if "k" not in agg_data:
raise ValueError(f"{op} requires 'k' field")
keys = agg_data["keys"]
if not isinstance(keys, (list, tuple)):
raise TypeError(f"{op} keys must be a list, got {type(keys).__name__}")
if not keys:
raise ValueError(f"{op} keys cannot be empty")
k = agg_data["k"]
if not isinstance(k, int):
raise TypeError(f"{op} k must be an integer, got {type(k).__name__}")
if k <= 0:
raise ValueError(f"{op} k must be positive, got {k}")
return _strings_to_keys(keys), k
@dataclass
class Aggregate:
"""Base class for aggregation expressions within groups.
Aggregations determine which records to keep from each group:
- MinK: Keep k records with minimum values (ascending order)
- MaxK: Keep k records with maximum values (descending order)
Examples:
# Keep top 3 by score per group (single key)
MinK(keys=Key.SCORE, k=3)
# Keep top 5 by priority, then score as tiebreaker (multiple keys)
MinK(keys=[Key("priority"), Key.SCORE], k=5)
# Keep bottom 2 by score per group
MaxK(keys=Key.SCORE, k=2)
"""
def to_dict(self) -> Dict[str, Any]:
"""Convert the Aggregate expression to a dictionary for JSON serialization"""
raise NotImplementedError("Subclasses must implement to_dict()")
@staticmethod
def from_dict(data: Dict[str, Any]) -> "Aggregate":
"""Create Aggregate expression from dictionary.
Supports:
- {"$min_k": {"keys": [...], "k": n}} -> MinK(keys=[...], k=n)
- {"$max_k": {"keys": [...], "k": n}} -> MaxK(keys=[...], k=n)
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for Aggregate, got {type(data).__name__}")
if not data:
raise ValueError("Aggregate dict cannot be empty")
if len(data) != 1:
raise ValueError(
f"Aggregate dict must contain exactly one operator, got {len(data)}"
)
op = next(iter(data.keys()))
if op == "$min_k":
keys, k = _parse_k_aggregate(op, data)
return MinK(keys=keys, k=k)
elif op == "$max_k":
keys, k = _parse_k_aggregate(op, data)
return MaxK(keys=keys, k=k)
else:
raise ValueError(f"Unknown aggregate operator: {op}")
@dataclass
class MinK(Aggregate):
"""Keep k records with minimum aggregate key values per group"""
keys: OneOrMany[Union[Key, str]]
k: int
def to_dict(self) -> Dict[str, Any]:
return {"$min_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
@dataclass
class MaxK(Aggregate):
"""Keep k records with maximum aggregate key values per group"""
keys: OneOrMany[Union[Key, str]]
k: int
def to_dict(self) -> Dict[str, Any]:
return {"$max_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
@dataclass
class GroupBy:
"""Group results by metadata keys and aggregate within each group.
Groups search results by one or more metadata fields, then applies an
aggregation (MinK or MaxK) to select records within each group.
The final output is flattened and sorted by score.
Args:
keys: Metadata key(s) to group by. Can be a single key or a list of keys.
E.g., Key("category") or [Key("category"), Key("author")]
aggregate: Aggregation to apply within each group (MinK or MaxK)
Note: Both keys and aggregate must be specified together.
Examples:
# Top 3 documents per category (single key)
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=Key.SCORE, k=3)
)
# Top 2 per (year, category) combination (multiple keys)
GroupBy(
keys=[Key("year"), Key("category")],
aggregate=MinK(keys=Key.SCORE, k=2)
)
# Top 1 per category by priority, score as tiebreaker
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=1)
)
"""
keys: OneOrMany[Union[Key, str]] = field(default_factory=list)
aggregate: Optional[Aggregate] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert the GroupBy to a dictionary for JSON serialization"""
# Default GroupBy (no keys, no aggregate) serializes to {}
if not self.keys or self.aggregate is None:
return {}
result: Dict[str, Any] = {"keys": _keys_to_strings(self.keys)}
result["aggregate"] = self.aggregate.to_dict()
return result
@staticmethod
def from_dict(data: Dict[str, Any]) -> "GroupBy":
"""Create GroupBy from dictionary.
Examples:
- {} -> GroupBy() (default, no grouping)
- {"keys": ["category"], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for GroupBy, got {type(data).__name__}")
# Empty dict returns default GroupBy (no grouping)
if not data:
return GroupBy()
# Non-empty dict requires keys and aggregate
if "keys" not in data:
raise ValueError("GroupBy requires 'keys' field")
if "aggregate" not in data:
raise ValueError("GroupBy requires 'aggregate' field")
keys = data["keys"]
if not isinstance(keys, (list, tuple)):
raise TypeError(f"GroupBy keys must be a list, got {type(keys).__name__}")
if not keys:
raise ValueError("GroupBy keys cannot be empty")
aggregate_data = data["aggregate"]
if not isinstance(aggregate_data, dict):
raise TypeError(
f"GroupBy aggregate must be a dict, got {type(aggregate_data).__name__}"
)
aggregate = Aggregate.from_dict(aggregate_data)
return GroupBy(keys=_strings_to_keys(keys), aggregate=aggregate)

View File

@@ -4,12 +4,12 @@ from typing import List, Dict, Any, Union, Set, Optional
from chromadb.execution.expression.operator import (
KNN,
Filter,
GroupBy,
Limit,
Projection,
Scan,
Rank,
Select,
Val,
Where,
Key,
)
@@ -77,9 +77,18 @@ class Search:
Combined with metadata filtering:
Search().where((Key.ID.is_in(["id1", "id2"])) & (Key("status") == "active"))
With group_by:
(Search()
.rank(Knn(query=[0.1, 0.2]))
.group_by(GroupBy(
keys=[Key("category")],
aggregate=MinK(keys=[Key.SCORE], k=3)
)))
Empty Search() is valid and will use defaults:
- where: None (no filtering)
- rank: None (no ranking - results ordered by default order)
- group_by: None (no grouping)
- limit: No limit
- select: Empty selection
"""
@@ -88,6 +97,7 @@ class Search:
self,
where: Optional[Union[Where, Dict[str, Any]]] = None,
rank: Optional[Union[Rank, Dict[str, Any]]] = None,
group_by: Optional[Union[GroupBy, Dict[str, Any]]] = None,
limit: Optional[Union[Limit, Dict[str, Any], int]] = None,
select: Optional[Union[Select, Dict[str, Any], List[str], Set[str]]] = None,
):
@@ -99,11 +109,13 @@ class Search:
rank: Rank expression or dict for scoring (defaults to None - no ranking)
Dict will be converted using Rank.from_dict()
Note: Primitive numbers are not accepted - use {"$val": number} for constant ranks
group_by: GroupBy configuration for grouping and aggregating results (defaults to None)
Dict will be converted using GroupBy.from_dict()
limit: Limit configuration for pagination (defaults to no limit)
Can be a Limit object, a dict for Limit.from_dict(), or an int
When passing an int, it creates Limit(limit=value, offset=0)
select: Select configuration for keys (defaults to empty selection)
Can be a Select object, a dict for Select.from_dict(),
Can be a Select object, a dict for Select.from_dict(),
or a list/set of strings (e.g., ["#document", "#score"])
"""
# Handle where parameter
@@ -117,7 +129,7 @@ class Search:
raise TypeError(
f"where must be a Where object, dict, or None, got {type(where).__name__}"
)
# Handle rank parameter
if rank is None:
self._rank = None
@@ -129,7 +141,19 @@ class Search:
raise TypeError(
f"rank must be a Rank object, dict, or None, got {type(rank).__name__}"
)
# Handle group_by parameter
if group_by is None:
self._group_by = GroupBy()
elif isinstance(group_by, GroupBy):
self._group_by = group_by
elif isinstance(group_by, dict):
self._group_by = GroupBy.from_dict(group_by)
else:
raise TypeError(
f"group_by must be a GroupBy object, dict, or None, got {type(group_by).__name__}"
)
# Handle limit parameter
if limit is None:
self._limit = Limit()
@@ -143,7 +167,7 @@ class Search:
raise TypeError(
f"limit must be a Limit object, dict, int, or None, got {type(limit).__name__}"
)
# Handle select parameter
if select is None:
self._select = Select()
@@ -164,6 +188,7 @@ class Search:
return {
"filter": self._where.to_dict() if self._where is not None else None,
"rank": self._rank.to_dict() if self._rank is not None else None,
"group_by": self._group_by.to_dict(),
"limit": self._limit.to_dict(),
"select": self._select.to_dict(),
}
@@ -173,7 +198,11 @@ class Search:
"""Select all predefined keys (document, embedding, metadata, score)"""
new_select = Select(keys={Key.DOCUMENT, Key.EMBEDDING, Key.METADATA, Key.SCORE})
return Search(
where=self._where, rank=self._rank, limit=self._limit, select=new_select
where=self._where,
rank=self._rank,
group_by=self._group_by,
limit=self._limit,
select=new_select,
)
def select(self, *keys: Union[Key, str]) -> "Search":
@@ -187,7 +216,11 @@ class Search:
"""
new_select = Select(keys=set(keys))
return Search(
where=self._where, rank=self._rank, limit=self._limit, select=new_select
where=self._where,
rank=self._rank,
group_by=self._group_by,
limit=self._limit,
select=new_select,
)
def where(self, where: Optional[Union[Where, Dict[str, Any]]]) -> "Search":
@@ -202,20 +235,12 @@ class Search:
search.where({"status": "active"})
search.where({"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]})
"""
# Convert dict to Where if needed
if where is None:
converted_where = None
elif isinstance(where, Where):
converted_where = where
elif isinstance(where, dict):
converted_where = Where.from_dict(where)
else:
raise TypeError(
f"where must be a Where object, dict, or None, got {type(where).__name__}"
)
return Search(
where=converted_where, rank=self._rank, limit=self._limit, select=self._select
where=where,
rank=self._rank,
group_by=self._group_by,
limit=self._limit,
select=self._select,
)
def rank(self, rank_expr: Optional[Union[Rank, Dict[str, Any]]]) -> "Search":
@@ -231,20 +256,37 @@ class Search:
search.rank({"$knn": {"query": [0.1, 0.2]}})
search.rank({"$sum": [{"$knn": {"query": [0.1, 0.2]}}, {"$val": 0.5}]})
"""
# Convert dict to Rank if needed
if rank_expr is None:
converted_rank = None
elif isinstance(rank_expr, Rank):
converted_rank = rank_expr
elif isinstance(rank_expr, dict):
converted_rank = Rank.from_dict(rank_expr)
else:
raise TypeError(
f"rank_expr must be a Rank object, dict, or None, got {type(rank_expr).__name__}"
)
return Search(
where=self._where, rank=converted_rank, limit=self._limit, select=self._select
where=self._where,
rank=rank_expr,
group_by=self._group_by,
limit=self._limit,
select=self._select,
)
def group_by(self, group_by: Optional[Union[GroupBy, Dict[str, Any]]]) -> "Search":
"""Set the group_by configuration for grouping and aggregating results
Args:
group_by: A GroupBy object, dict, or None for grouping
Dicts will be converted using GroupBy.from_dict()
Example:
search.group_by(GroupBy(
keys=[Key("category")],
aggregate=MinK(keys=[Key.SCORE], k=3)
))
search.group_by({
"keys": ["category"],
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}
})
"""
return Search(
where=self._where,
rank=self._rank,
group_by=group_by,
limit=self._limit,
select=self._select,
)
def limit(self, limit: int, offset: int = 0) -> "Search":
@@ -259,5 +301,9 @@ class Search:
"""
new_limit = Limit(offset=offset, limit=limit)
return Search(
where=self._where, rank=self._rank, limit=new_limit, select=self._select
where=self._where,
rank=self._rank,
group_by=self._group_by,
limit=new_limit,
select=self._select,
)