增加环绕侦察场景适配
This commit is contained in:
@@ -39,6 +39,11 @@ from chromadb.execution.expression.operator import (
|
||||
Sub,
|
||||
Sum,
|
||||
Val,
|
||||
# GroupBy and Aggregate expressions
|
||||
Aggregate,
|
||||
MinK,
|
||||
MaxK,
|
||||
GroupBy,
|
||||
)
|
||||
|
||||
from chromadb.execution.expression.plan import (
|
||||
@@ -87,4 +92,9 @@ __all__ = [
|
||||
"Sub",
|
||||
"Sum",
|
||||
"Val",
|
||||
# GroupBy and Aggregate expressions
|
||||
"Aggregate",
|
||||
"MinK",
|
||||
"MaxK",
|
||||
"GroupBy",
|
||||
]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Set, Any, Union
|
||||
from typing import Optional, List, Dict, Set, Any, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
@@ -7,9 +7,11 @@ from chromadb.api.types import (
|
||||
Embeddings,
|
||||
IDs,
|
||||
Include,
|
||||
OneOrMany,
|
||||
SparseVector,
|
||||
TYPE_KEY,
|
||||
SPARSE_VECTOR_TYPE_VALUE,
|
||||
maybe_cast_one_to_many,
|
||||
normalize_embeddings,
|
||||
validate_embeddings,
|
||||
)
|
||||
@@ -1024,7 +1026,7 @@ class Knn(Rank):
|
||||
- A dense vector (list or numpy array)
|
||||
- A sparse vector (SparseVector dict)
|
||||
key: The embedding key to search against. Can be:
|
||||
- "#embedding" (default) - searches the main embedding field
|
||||
- Key.EMBEDDING (default) - searches the main embedding field
|
||||
- A metadata field name (e.g., "my_custom_field") - searches that metadata field
|
||||
limit: Maximum number of results to consider (default: 16)
|
||||
default: Default score for records not in KNN results (default: None)
|
||||
@@ -1054,7 +1056,7 @@ class Knn(Rank):
|
||||
"NDArray[np.float64]",
|
||||
"NDArray[np.int32]",
|
||||
]
|
||||
key: str = "#embedding"
|
||||
key: Union[Key, str] = K.EMBEDDING
|
||||
limit: int = 16
|
||||
default: Optional[float] = None
|
||||
return_rank: bool = False
|
||||
@@ -1069,8 +1071,12 @@ class Knn(Rank):
|
||||
# Convert numpy array to list
|
||||
query_value = query_value.tolist()
|
||||
|
||||
key_value = self.key
|
||||
if isinstance(key_value, Key):
|
||||
key_value = key_value.name
|
||||
|
||||
# Build result dict - only include non-default values to keep JSON clean
|
||||
result = {"query": query_value, "key": self.key, "limit": self.limit}
|
||||
result = {"query": query_value, "key": key_value, "limit": self.limit}
|
||||
|
||||
# Only include optional fields if they're set to non-default values
|
||||
if self.default is not None:
|
||||
@@ -1291,3 +1297,216 @@ class Select:
|
||||
|
||||
# Convert to set while preserving the Key instances
|
||||
return Select(keys=set(key_list))
|
||||
|
||||
|
||||
# GroupBy and Aggregate types for grouping search results
|
||||
|
||||
|
||||
def _keys_to_strings(keys: OneOrMany[Union[Key, str]]) -> List[str]:
|
||||
"""Convert OneOrMany[Key|str] to List[str] for serialization."""
|
||||
keys_list = cast(List[Union[Key, str]], maybe_cast_one_to_many(keys))
|
||||
return [k.name if isinstance(k, Key) else k for k in keys_list]
|
||||
|
||||
|
||||
def _strings_to_keys(keys: Union[List[Any], tuple[Any, ...]]) -> List[Union[Key, str]]:
|
||||
"""Convert List[str] to List[Key] for deserialization."""
|
||||
return [Key(k) if isinstance(k, str) else k for k in keys]
|
||||
|
||||
|
||||
def _parse_k_aggregate(
|
||||
op: str, data: Dict[str, Any]
|
||||
) -> tuple[List[Union[Key, str]], int]:
|
||||
"""Parse common fields for MinK/MaxK from dict.
|
||||
|
||||
Args:
|
||||
op: The operator name (e.g., "$min_k" or "$max_k")
|
||||
data: The dict containing the operator
|
||||
|
||||
Returns:
|
||||
Tuple of (keys, k) where keys is List[Union[Key, str]] and k is int
|
||||
|
||||
Raises:
|
||||
TypeError: If data types are invalid
|
||||
ValueError: If required fields are missing or invalid
|
||||
"""
|
||||
agg_data = data[op]
|
||||
if not isinstance(agg_data, dict):
|
||||
raise TypeError(f"{op} requires a dict, got {type(agg_data).__name__}")
|
||||
if "keys" not in agg_data:
|
||||
raise ValueError(f"{op} requires 'keys' field")
|
||||
if "k" not in agg_data:
|
||||
raise ValueError(f"{op} requires 'k' field")
|
||||
|
||||
keys = agg_data["keys"]
|
||||
if not isinstance(keys, (list, tuple)):
|
||||
raise TypeError(f"{op} keys must be a list, got {type(keys).__name__}")
|
||||
if not keys:
|
||||
raise ValueError(f"{op} keys cannot be empty")
|
||||
|
||||
k = agg_data["k"]
|
||||
if not isinstance(k, int):
|
||||
raise TypeError(f"{op} k must be an integer, got {type(k).__name__}")
|
||||
if k <= 0:
|
||||
raise ValueError(f"{op} k must be positive, got {k}")
|
||||
|
||||
return _strings_to_keys(keys), k
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aggregate:
|
||||
"""Base class for aggregation expressions within groups.
|
||||
|
||||
Aggregations determine which records to keep from each group:
|
||||
- MinK: Keep k records with minimum values (ascending order)
|
||||
- MaxK: Keep k records with maximum values (descending order)
|
||||
|
||||
Examples:
|
||||
# Keep top 3 by score per group (single key)
|
||||
MinK(keys=Key.SCORE, k=3)
|
||||
|
||||
# Keep top 5 by priority, then score as tiebreaker (multiple keys)
|
||||
MinK(keys=[Key("priority"), Key.SCORE], k=5)
|
||||
|
||||
# Keep bottom 2 by score per group
|
||||
MaxK(keys=Key.SCORE, k=2)
|
||||
"""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the Aggregate expression to a dictionary for JSON serialization"""
|
||||
raise NotImplementedError("Subclasses must implement to_dict()")
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> "Aggregate":
|
||||
"""Create Aggregate expression from dictionary.
|
||||
|
||||
Supports:
|
||||
- {"$min_k": {"keys": [...], "k": n}} -> MinK(keys=[...], k=n)
|
||||
- {"$max_k": {"keys": [...], "k": n}} -> MaxK(keys=[...], k=n)
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected dict for Aggregate, got {type(data).__name__}")
|
||||
|
||||
if not data:
|
||||
raise ValueError("Aggregate dict cannot be empty")
|
||||
|
||||
if len(data) != 1:
|
||||
raise ValueError(
|
||||
f"Aggregate dict must contain exactly one operator, got {len(data)}"
|
||||
)
|
||||
|
||||
op = next(iter(data.keys()))
|
||||
|
||||
if op == "$min_k":
|
||||
keys, k = _parse_k_aggregate(op, data)
|
||||
return MinK(keys=keys, k=k)
|
||||
elif op == "$max_k":
|
||||
keys, k = _parse_k_aggregate(op, data)
|
||||
return MaxK(keys=keys, k=k)
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregate operator: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinK(Aggregate):
|
||||
"""Keep k records with minimum aggregate key values per group"""
|
||||
|
||||
keys: OneOrMany[Union[Key, str]]
|
||||
k: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"$min_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaxK(Aggregate):
|
||||
"""Keep k records with maximum aggregate key values per group"""
|
||||
|
||||
keys: OneOrMany[Union[Key, str]]
|
||||
k: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"$max_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupBy:
|
||||
"""Group results by metadata keys and aggregate within each group.
|
||||
|
||||
Groups search results by one or more metadata fields, then applies an
|
||||
aggregation (MinK or MaxK) to select records within each group.
|
||||
The final output is flattened and sorted by score.
|
||||
|
||||
Args:
|
||||
keys: Metadata key(s) to group by. Can be a single key or a list of keys.
|
||||
E.g., Key("category") or [Key("category"), Key("author")]
|
||||
aggregate: Aggregation to apply within each group (MinK or MaxK)
|
||||
|
||||
Note: Both keys and aggregate must be specified together.
|
||||
|
||||
Examples:
|
||||
# Top 3 documents per category (single key)
|
||||
GroupBy(
|
||||
keys=Key("category"),
|
||||
aggregate=MinK(keys=Key.SCORE, k=3)
|
||||
)
|
||||
|
||||
# Top 2 per (year, category) combination (multiple keys)
|
||||
GroupBy(
|
||||
keys=[Key("year"), Key("category")],
|
||||
aggregate=MinK(keys=Key.SCORE, k=2)
|
||||
)
|
||||
|
||||
# Top 1 per category by priority, score as tiebreaker
|
||||
GroupBy(
|
||||
keys=Key("category"),
|
||||
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=1)
|
||||
)
|
||||
"""
|
||||
|
||||
keys: OneOrMany[Union[Key, str]] = field(default_factory=list)
|
||||
aggregate: Optional[Aggregate] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the GroupBy to a dictionary for JSON serialization"""
|
||||
# Default GroupBy (no keys, no aggregate) serializes to {}
|
||||
if not self.keys or self.aggregate is None:
|
||||
return {}
|
||||
result: Dict[str, Any] = {"keys": _keys_to_strings(self.keys)}
|
||||
result["aggregate"] = self.aggregate.to_dict()
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> "GroupBy":
|
||||
"""Create GroupBy from dictionary.
|
||||
|
||||
Examples:
|
||||
- {} -> GroupBy() (default, no grouping)
|
||||
- {"keys": ["category"], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected dict for GroupBy, got {type(data).__name__}")
|
||||
|
||||
# Empty dict returns default GroupBy (no grouping)
|
||||
if not data:
|
||||
return GroupBy()
|
||||
|
||||
# Non-empty dict requires keys and aggregate
|
||||
if "keys" not in data:
|
||||
raise ValueError("GroupBy requires 'keys' field")
|
||||
if "aggregate" not in data:
|
||||
raise ValueError("GroupBy requires 'aggregate' field")
|
||||
|
||||
keys = data["keys"]
|
||||
if not isinstance(keys, (list, tuple)):
|
||||
raise TypeError(f"GroupBy keys must be a list, got {type(keys).__name__}")
|
||||
if not keys:
|
||||
raise ValueError("GroupBy keys cannot be empty")
|
||||
|
||||
aggregate_data = data["aggregate"]
|
||||
if not isinstance(aggregate_data, dict):
|
||||
raise TypeError(
|
||||
f"GroupBy aggregate must be a dict, got {type(aggregate_data).__name__}"
|
||||
)
|
||||
aggregate = Aggregate.from_dict(aggregate_data)
|
||||
|
||||
return GroupBy(keys=_strings_to_keys(keys), aggregate=aggregate)
|
||||
|
||||
@@ -4,12 +4,12 @@ from typing import List, Dict, Any, Union, Set, Optional
|
||||
from chromadb.execution.expression.operator import (
|
||||
KNN,
|
||||
Filter,
|
||||
GroupBy,
|
||||
Limit,
|
||||
Projection,
|
||||
Scan,
|
||||
Rank,
|
||||
Select,
|
||||
Val,
|
||||
Where,
|
||||
Key,
|
||||
)
|
||||
@@ -77,9 +77,18 @@ class Search:
|
||||
Combined with metadata filtering:
|
||||
Search().where((Key.ID.is_in(["id1", "id2"])) & (Key("status") == "active"))
|
||||
|
||||
With group_by:
|
||||
(Search()
|
||||
.rank(Knn(query=[0.1, 0.2]))
|
||||
.group_by(GroupBy(
|
||||
keys=[Key("category")],
|
||||
aggregate=MinK(keys=[Key.SCORE], k=3)
|
||||
)))
|
||||
|
||||
Empty Search() is valid and will use defaults:
|
||||
- where: None (no filtering)
|
||||
- rank: None (no ranking - results ordered by default order)
|
||||
- group_by: None (no grouping)
|
||||
- limit: No limit
|
||||
- select: Empty selection
|
||||
"""
|
||||
@@ -88,6 +97,7 @@ class Search:
|
||||
self,
|
||||
where: Optional[Union[Where, Dict[str, Any]]] = None,
|
||||
rank: Optional[Union[Rank, Dict[str, Any]]] = None,
|
||||
group_by: Optional[Union[GroupBy, Dict[str, Any]]] = None,
|
||||
limit: Optional[Union[Limit, Dict[str, Any], int]] = None,
|
||||
select: Optional[Union[Select, Dict[str, Any], List[str], Set[str]]] = None,
|
||||
):
|
||||
@@ -99,11 +109,13 @@ class Search:
|
||||
rank: Rank expression or dict for scoring (defaults to None - no ranking)
|
||||
Dict will be converted using Rank.from_dict()
|
||||
Note: Primitive numbers are not accepted - use {"$val": number} for constant ranks
|
||||
group_by: GroupBy configuration for grouping and aggregating results (defaults to None)
|
||||
Dict will be converted using GroupBy.from_dict()
|
||||
limit: Limit configuration for pagination (defaults to no limit)
|
||||
Can be a Limit object, a dict for Limit.from_dict(), or an int
|
||||
When passing an int, it creates Limit(limit=value, offset=0)
|
||||
select: Select configuration for keys (defaults to empty selection)
|
||||
Can be a Select object, a dict for Select.from_dict(),
|
||||
Can be a Select object, a dict for Select.from_dict(),
|
||||
or a list/set of strings (e.g., ["#document", "#score"])
|
||||
"""
|
||||
# Handle where parameter
|
||||
@@ -117,7 +129,7 @@ class Search:
|
||||
raise TypeError(
|
||||
f"where must be a Where object, dict, or None, got {type(where).__name__}"
|
||||
)
|
||||
|
||||
|
||||
# Handle rank parameter
|
||||
if rank is None:
|
||||
self._rank = None
|
||||
@@ -129,7 +141,19 @@ class Search:
|
||||
raise TypeError(
|
||||
f"rank must be a Rank object, dict, or None, got {type(rank).__name__}"
|
||||
)
|
||||
|
||||
|
||||
# Handle group_by parameter
|
||||
if group_by is None:
|
||||
self._group_by = GroupBy()
|
||||
elif isinstance(group_by, GroupBy):
|
||||
self._group_by = group_by
|
||||
elif isinstance(group_by, dict):
|
||||
self._group_by = GroupBy.from_dict(group_by)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"group_by must be a GroupBy object, dict, or None, got {type(group_by).__name__}"
|
||||
)
|
||||
|
||||
# Handle limit parameter
|
||||
if limit is None:
|
||||
self._limit = Limit()
|
||||
@@ -143,7 +167,7 @@ class Search:
|
||||
raise TypeError(
|
||||
f"limit must be a Limit object, dict, int, or None, got {type(limit).__name__}"
|
||||
)
|
||||
|
||||
|
||||
# Handle select parameter
|
||||
if select is None:
|
||||
self._select = Select()
|
||||
@@ -164,6 +188,7 @@ class Search:
|
||||
return {
|
||||
"filter": self._where.to_dict() if self._where is not None else None,
|
||||
"rank": self._rank.to_dict() if self._rank is not None else None,
|
||||
"group_by": self._group_by.to_dict(),
|
||||
"limit": self._limit.to_dict(),
|
||||
"select": self._select.to_dict(),
|
||||
}
|
||||
@@ -173,7 +198,11 @@ class Search:
|
||||
"""Select all predefined keys (document, embedding, metadata, score)"""
|
||||
new_select = Select(keys={Key.DOCUMENT, Key.EMBEDDING, Key.METADATA, Key.SCORE})
|
||||
return Search(
|
||||
where=self._where, rank=self._rank, limit=self._limit, select=new_select
|
||||
where=self._where,
|
||||
rank=self._rank,
|
||||
group_by=self._group_by,
|
||||
limit=self._limit,
|
||||
select=new_select,
|
||||
)
|
||||
|
||||
def select(self, *keys: Union[Key, str]) -> "Search":
|
||||
@@ -187,7 +216,11 @@ class Search:
|
||||
"""
|
||||
new_select = Select(keys=set(keys))
|
||||
return Search(
|
||||
where=self._where, rank=self._rank, limit=self._limit, select=new_select
|
||||
where=self._where,
|
||||
rank=self._rank,
|
||||
group_by=self._group_by,
|
||||
limit=self._limit,
|
||||
select=new_select,
|
||||
)
|
||||
|
||||
def where(self, where: Optional[Union[Where, Dict[str, Any]]]) -> "Search":
|
||||
@@ -202,20 +235,12 @@ class Search:
|
||||
search.where({"status": "active"})
|
||||
search.where({"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]})
|
||||
"""
|
||||
# Convert dict to Where if needed
|
||||
if where is None:
|
||||
converted_where = None
|
||||
elif isinstance(where, Where):
|
||||
converted_where = where
|
||||
elif isinstance(where, dict):
|
||||
converted_where = Where.from_dict(where)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"where must be a Where object, dict, or None, got {type(where).__name__}"
|
||||
)
|
||||
|
||||
return Search(
|
||||
where=converted_where, rank=self._rank, limit=self._limit, select=self._select
|
||||
where=where,
|
||||
rank=self._rank,
|
||||
group_by=self._group_by,
|
||||
limit=self._limit,
|
||||
select=self._select,
|
||||
)
|
||||
|
||||
def rank(self, rank_expr: Optional[Union[Rank, Dict[str, Any]]]) -> "Search":
|
||||
@@ -231,20 +256,37 @@ class Search:
|
||||
search.rank({"$knn": {"query": [0.1, 0.2]}})
|
||||
search.rank({"$sum": [{"$knn": {"query": [0.1, 0.2]}}, {"$val": 0.5}]})
|
||||
"""
|
||||
# Convert dict to Rank if needed
|
||||
if rank_expr is None:
|
||||
converted_rank = None
|
||||
elif isinstance(rank_expr, Rank):
|
||||
converted_rank = rank_expr
|
||||
elif isinstance(rank_expr, dict):
|
||||
converted_rank = Rank.from_dict(rank_expr)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"rank_expr must be a Rank object, dict, or None, got {type(rank_expr).__name__}"
|
||||
)
|
||||
|
||||
return Search(
|
||||
where=self._where, rank=converted_rank, limit=self._limit, select=self._select
|
||||
where=self._where,
|
||||
rank=rank_expr,
|
||||
group_by=self._group_by,
|
||||
limit=self._limit,
|
||||
select=self._select,
|
||||
)
|
||||
|
||||
def group_by(self, group_by: Optional[Union[GroupBy, Dict[str, Any]]]) -> "Search":
|
||||
"""Set the group_by configuration for grouping and aggregating results
|
||||
|
||||
Args:
|
||||
group_by: A GroupBy object, dict, or None for grouping
|
||||
Dicts will be converted using GroupBy.from_dict()
|
||||
|
||||
Example:
|
||||
search.group_by(GroupBy(
|
||||
keys=[Key("category")],
|
||||
aggregate=MinK(keys=[Key.SCORE], k=3)
|
||||
))
|
||||
search.group_by({
|
||||
"keys": ["category"],
|
||||
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}
|
||||
})
|
||||
"""
|
||||
return Search(
|
||||
where=self._where,
|
||||
rank=self._rank,
|
||||
group_by=group_by,
|
||||
limit=self._limit,
|
||||
select=self._select,
|
||||
)
|
||||
|
||||
def limit(self, limit: int, offset: int = 0) -> "Search":
|
||||
@@ -259,5 +301,9 @@ class Search:
|
||||
"""
|
||||
new_limit = Limit(offset=offset, limit=limit)
|
||||
return Search(
|
||||
where=self._where, rank=self._rank, limit=new_limit, select=self._select
|
||||
where=self._where,
|
||||
rank=self._rank,
|
||||
group_by=self._group_by,
|
||||
limit=new_limit,
|
||||
select=self._select,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user