增加环绕侦察场景适配
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import orjson
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, cast, Tuple, List
|
||||
from typing import Any, Dict, Mapping, Optional, cast, Tuple, List
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
import httpx
|
||||
@@ -79,8 +79,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
default_api_path=system.settings.chroma_server_api_default_path,
|
||||
)
|
||||
|
||||
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
|
||||
self._session = httpx.Client(timeout=None, limits=limits)
|
||||
if self._settings.chroma_server_ssl_verify is not None:
|
||||
self._session = httpx.Client(
|
||||
timeout=None,
|
||||
limits=self.http_limits,
|
||||
verify=self._settings.chroma_server_ssl_verify,
|
||||
)
|
||||
else:
|
||||
self._session = httpx.Client(timeout=None, limits=self.http_limits)
|
||||
|
||||
self._header = system.settings.chroma_server_headers or {}
|
||||
self._header["Content-Type"] = "application/json"
|
||||
@@ -90,8 +96,6 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
+ " (https://github.com/chroma-core/chroma)"
|
||||
)
|
||||
|
||||
if self._settings.chroma_server_ssl_verify is not None:
|
||||
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
|
||||
if self._header is not None:
|
||||
self._session.headers.update(self._header)
|
||||
|
||||
@@ -101,6 +105,14 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
for header, value in _headers.items():
|
||||
self._session.headers[header] = value.get_secret_value()
|
||||
|
||||
@override
|
||||
def get_request_headers(self) -> Mapping[str, str]:
|
||||
return dict(self._session.headers)
|
||||
|
||||
@override
|
||||
def get_api_url(self) -> str:
|
||||
return self._api_url
|
||||
|
||||
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
|
||||
# If the request has json in kwargs, use orjson to serialize it,
|
||||
# remove it from kwargs, and add it to the content parameter
|
||||
@@ -492,7 +504,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
return GetResult(
|
||||
ids=resp_json["ids"],
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadatas, # type: ignore
|
||||
metadatas=metadatas,
|
||||
documents=resp_json.get("documents", None),
|
||||
data=None,
|
||||
uris=resp_json.get("uris", None),
|
||||
@@ -700,7 +712,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
ids=resp_json["ids"],
|
||||
distances=resp_json.get("distances", None),
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadata_batches, # type: ignore
|
||||
metadatas=metadata_batches,
|
||||
documents=resp_json.get("documents", None),
|
||||
uris=resp_json.get("uris", None),
|
||||
data=None,
|
||||
@@ -761,7 +773,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
) -> Tuple["AttachedFunction", bool]:
|
||||
"""Attach a function to a collection."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
@@ -774,23 +786,56 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
},
|
||||
)
|
||||
|
||||
return AttachedFunction(
|
||||
attached_function = AttachedFunction(
|
||||
client=self,
|
||||
id=UUID(resp_json["attached_function"]["id"]),
|
||||
name=resp_json["attached_function"]["name"],
|
||||
function_id=resp_json["attached_function"]["function_id"],
|
||||
function_name=resp_json["attached_function"]["function_name"],
|
||||
input_collection_id=input_collection_id,
|
||||
output_collection=output_collection,
|
||||
params=params,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
created = resp_json.get(
|
||||
"created", True
|
||||
) # Default to True for backwards compatibility
|
||||
return (attached_function, created)
|
||||
|
||||
@trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def get_attached_function(
|
||||
self,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Get an attached function by name for a specific collection."""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/{name}",
|
||||
)
|
||||
|
||||
af = resp_json["attached_function"]
|
||||
return AttachedFunction(
|
||||
client=self,
|
||||
id=UUID(af["id"]),
|
||||
name=af["name"],
|
||||
function_name=af["function_name"],
|
||||
input_collection_id=input_collection_id,
|
||||
output_collection=af["output_collection"],
|
||||
params=af.get("params"),
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
@@ -798,7 +843,7 @@ class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
"""Detach a function and prevent any further runs."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/attached_functions/{name}/detach",
|
||||
json={
|
||||
"delete_output": delete_output,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user