增加环绕侦察场景适配
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api import ServerAPI # noqa: F401
|
||||
@@ -13,7 +14,7 @@ class AttachedFunction:
|
||||
client: "ServerAPI",
|
||||
id: UUID,
|
||||
name: str,
|
||||
function_id: str,
|
||||
function_name: str,
|
||||
input_collection_id: UUID,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]],
|
||||
@@ -26,7 +27,7 @@ class AttachedFunction:
|
||||
client: The API client
|
||||
id: Unique identifier for this attached function
|
||||
name: Name of this attached function instance
|
||||
function_id: The function identifier (e.g., "record_counter")
|
||||
function_name: The function name (e.g., "record_counter", "statistics")
|
||||
input_collection_id: ID of the input collection
|
||||
output_collection: Name of the output collection
|
||||
params: Function-specific parameters
|
||||
@@ -36,7 +37,7 @@ class AttachedFunction:
|
||||
self._client = client
|
||||
self._id = id
|
||||
self._name = name
|
||||
self._function_id = function_id
|
||||
self._function_name = function_name
|
||||
self._input_collection_id = input_collection_id
|
||||
self._output_collection = output_collection
|
||||
self._params = params
|
||||
@@ -54,9 +55,9 @@ class AttachedFunction:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def function_id(self) -> str:
|
||||
"""The function identifier."""
|
||||
return self._function_id
|
||||
def function_name(self) -> str:
|
||||
"""The function name."""
|
||||
return self._function_name
|
||||
|
||||
@property
|
||||
def input_collection_id(self) -> UUID:
|
||||
@@ -73,29 +74,69 @@ class AttachedFunction:
|
||||
"""The function parameters."""
|
||||
return self._params
|
||||
|
||||
def detach(self, delete_output_collection: bool = False) -> bool:
|
||||
"""Detach this function and prevent any further runs.
|
||||
@staticmethod
|
||||
def _normalize_params(params: Optional[Any]) -> Dict[str, Any]:
|
||||
"""Normalize params to a consistent dict format.
|
||||
|
||||
Args:
|
||||
delete_output_collection: Whether to also delete the output collection. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
|
||||
Example:
|
||||
>>> success = attached_fn.detach(delete_output_collection=True)
|
||||
Handles None, empty strings, JSON strings, and dicts.
|
||||
"""
|
||||
return self._client.detach_function(
|
||||
attached_function_id=self._id,
|
||||
delete_output=delete_output_collection,
|
||||
tenant=self._tenant,
|
||||
database=self._database,
|
||||
)
|
||||
if params is None:
|
||||
return {}
|
||||
if isinstance(params, str):
|
||||
try:
|
||||
result = json.loads(params) if params else {}
|
||||
return result if isinstance(result, dict) else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
if isinstance(params, dict):
|
||||
return params
|
||||
return {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AttachedFunction(id={self._id}, name='{self._name}', "
|
||||
f"function_id='{self._function_id}', "
|
||||
f"function_name='{self._function_name}', "
|
||||
f"input_collection_id={self._input_collection_id}, "
|
||||
f"output_collection='{self._output_collection}')"
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare two AttachedFunction objects for equality."""
|
||||
if not isinstance(other, AttachedFunction):
|
||||
return False
|
||||
|
||||
# Normalize params: handle None, {}, and JSON strings
|
||||
self_params = self._normalize_params(self._params)
|
||||
other_params = self._normalize_params(other._params)
|
||||
|
||||
return (
|
||||
self._id == other._id
|
||||
and self._name == other._name
|
||||
and self._function_name == other._function_name
|
||||
and self._input_collection_id == other._input_collection_id
|
||||
and self._output_collection == other._output_collection
|
||||
and self_params == other_params
|
||||
and self._tenant == other._tenant
|
||||
and self._database == other._database
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the AttachedFunction."""
|
||||
# Normalize params using the same logic as __eq__
|
||||
normalized_params = self._normalize_params(self._params)
|
||||
params_tuple = (
|
||||
tuple(sorted(normalized_params.items())) if normalized_params else ()
|
||||
)
|
||||
|
||||
return hash(
|
||||
(
|
||||
self._id,
|
||||
self._name,
|
||||
self._function_name,
|
||||
self._input_collection_id,
|
||||
self._output_collection,
|
||||
params_tuple,
|
||||
self._tenant,
|
||||
self._database,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any
|
||||
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any, Tuple
|
||||
|
||||
from chromadb.api.models.CollectionCommon import CollectionCommon
|
||||
from chromadb.api.types import (
|
||||
@@ -25,6 +25,8 @@ from chromadb.execution.expression.plan import Search
|
||||
|
||||
import logging
|
||||
|
||||
from chromadb.api.functions import Function
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api.models.AttachedFunction import AttachedFunction
|
||||
|
||||
@@ -500,30 +502,36 @@ class Collection(CollectionCommon["ServerAPI"]):
|
||||
|
||||
def attach_function(
|
||||
self,
|
||||
function_id: str,
|
||||
function: Function,
|
||||
name: str,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> "AttachedFunction":
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attach a function to this collection.
|
||||
|
||||
Args:
|
||||
function_id: Built-in function identifier (e.g., "record_counter")
|
||||
function: A Function enum value (e.g., STATISTICS_FUNCTION, RECORD_COUNTER_FUNCTION)
|
||||
name: Unique name for this attached function
|
||||
output_collection: Name of the collection where function output will be stored
|
||||
params: Optional dictionary with function-specific parameters
|
||||
|
||||
Returns:
|
||||
AttachedFunction: Object representing the attached function
|
||||
Tuple of (AttachedFunction, created) where created is True if newly created,
|
||||
False if already existed (idempotent request)
|
||||
|
||||
Example:
|
||||
>>> from chromadb.api.functions import STATISTICS_FUNCTION
|
||||
>>> attached_fn = collection.attach_function(
|
||||
... function_id="record_counter",
|
||||
... function=STATISTICS_FUNCTION,
|
||||
... name="mycoll_stats_fn",
|
||||
... output_collection="mycoll_stats",
|
||||
... params={"threshold": 100}
|
||||
... )
|
||||
>>> if created:
|
||||
... print("New function attached")
|
||||
... else:
|
||||
... print("Function already existed")
|
||||
"""
|
||||
function_id = function.value if isinstance(function, Function) else function
|
||||
return self._client.attach_function(
|
||||
function_id=function_id,
|
||||
name=name,
|
||||
@@ -533,3 +541,47 @@ class Collection(CollectionCommon["ServerAPI"]):
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def get_attached_function(self, name: str) -> "AttachedFunction":
|
||||
"""Get an attached function by name for this collection.
|
||||
|
||||
Args:
|
||||
name: Name of the attached function
|
||||
|
||||
Returns:
|
||||
AttachedFunction: The attached function object
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the attached function doesn't exist
|
||||
"""
|
||||
return self._client.get_attached_function(
|
||||
name=name,
|
||||
input_collection_id=self.id,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
def detach_function(
|
||||
self,
|
||||
name: str,
|
||||
delete_output_collection: bool = False,
|
||||
) -> bool:
|
||||
"""Detach a function from this collection.
|
||||
|
||||
Args:
|
||||
name: The name of the attached function
|
||||
delete_output_collection: Whether to also delete the output collection. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
|
||||
Example:
|
||||
>>> success = collection.detach_function("my_function", delete_output_collection=True)
|
||||
"""
|
||||
return self._client.detach_function(
|
||||
name=name,
|
||||
input_collection_id=self.id,
|
||||
delete_output=delete_output_collection,
|
||||
tenant=self.tenant,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@@ -1021,6 +1021,7 @@ class CollectionCommon(Generic[ClientT]):
|
||||
return Search(
|
||||
where=search._where,
|
||||
rank=embedded_rank,
|
||||
group_by=search._group_by,
|
||||
limit=search._limit,
|
||||
select=search._select,
|
||||
)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user