增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -1,6 +1,6 @@
import orjson
import logging
from typing import Any, Dict, Optional, cast, Tuple, List
from typing import Any, Dict, Mapping, Optional, cast, Tuple, List
from typing import Sequence
from uuid import UUID
import httpx
@@ -79,8 +79,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
default_api_path=system.settings.chroma_server_api_default_path,
)
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
self._session = httpx.Client(timeout=None, limits=limits)
if self._settings.chroma_server_ssl_verify is not None:
self._session = httpx.Client(
timeout=None,
limits=self.http_limits,
verify=self._settings.chroma_server_ssl_verify,
)
else:
self._session = httpx.Client(timeout=None, limits=self.http_limits)
self._header = system.settings.chroma_server_headers or {}
self._header["Content-Type"] = "application/json"
@@ -90,8 +96,6 @@ class FastAPI(BaseHTTPClient, ServerAPI):
+ " (https://github.com/chroma-core/chroma)"
)
if self._settings.chroma_server_ssl_verify is not None:
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
if self._header is not None:
self._session.headers.update(self._header)
@@ -101,6 +105,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
for header, value in _headers.items():
self._session.headers[header] = value.get_secret_value()
@override
def get_request_headers(self) -> Mapping[str, str]:
return dict(self._session.headers)
@override
def get_api_url(self) -> str:
return self._api_url
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
# If the request has json in kwargs, use orjson to serialize it,
# remove it from kwargs, and add it to the content parameter
@@ -492,7 +504,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
return GetResult(
ids=resp_json["ids"],
embeddings=resp_json.get("embeddings", None),
metadatas=metadatas, # type: ignore
metadatas=metadatas,
documents=resp_json.get("documents", None),
data=None,
uris=resp_json.get("uris", None),
@@ -700,7 +712,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
ids=resp_json["ids"],
distances=resp_json.get("distances", None),
embeddings=resp_json.get("embeddings", None),
metadatas=metadata_batches, # type: ignore
metadatas=metadata_batches,
documents=resp_json.get("documents", None),
uris=resp_json.get("uris", None),
data=None,
@@ -761,7 +773,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
) -> Tuple["AttachedFunction", bool]:
"""Attach a function to a collection."""
resp_json = self._make_request(
"post",
@@ -774,23 +786,56 @@ class FastAPI(BaseHTTPClient, ServerAPI):
},
)
return AttachedFunction(
attached_function = AttachedFunction(
client=self,
id=UUID(resp_json["attached_function"]["id"]),
name=resp_json["attached_function"]["name"],
function_id=resp_json["attached_function"]["function_id"],
function_name=resp_json["attached_function"]["function_name"],
input_collection_id=input_collection_id,
output_collection=output_collection,
params=params,
tenant=tenant,
database=database,
)
created = resp_json.get(
"created", True
) # Default to True for backwards compatibility
return (attached_function, created)
@trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.ALL)
@override
def get_attached_function(
self,
name: str,
input_collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "AttachedFunction":
"""Get an attached function by name for a specific collection."""
resp_json = self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/{name}",
)
af = resp_json["attached_function"]
return AttachedFunction(
client=self,
id=UUID(af["id"]),
name=af["name"],
function_name=af["function_name"],
input_collection_id=input_collection_id,
output_collection=af["output_collection"],
params=af.get("params"),
tenant=tenant,
database=database,
)
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
@override
def detach_function(
self,
attached_function_id: UUID,
name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
@@ -798,7 +843,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
"""Detach a function and prevent any further runs."""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/attached_functions/{name}/detach",
json={
"delete_output": delete_output,
},