chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,806 @@
|
||||
import orjson
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, cast, Tuple, List
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
import httpx
|
||||
import urllib.parse
|
||||
from overrides import override
|
||||
|
||||
from chromadb.api.models.AttachedFunction import AttachedFunction
|
||||
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
UpdateCollectionConfiguration,
|
||||
update_collection_configuration_to_json,
|
||||
create_collection_configuration_to_json,
|
||||
)
|
||||
from chromadb import __version__
|
||||
from chromadb.api.base_http_client import BaseHTTPClient
|
||||
from chromadb.types import Database, Tenant, Collection as CollectionModel
|
||||
from chromadb.api import ServerAPI
|
||||
from chromadb.execution.expression.plan import Search
|
||||
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Embeddings,
|
||||
IDs,
|
||||
Include,
|
||||
Schema,
|
||||
Metadatas,
|
||||
URIs,
|
||||
Where,
|
||||
WhereDocument,
|
||||
GetResult,
|
||||
QueryResult,
|
||||
SearchResult,
|
||||
CollectionMetadata,
|
||||
validate_batch,
|
||||
convert_np_embeddings_to_list,
|
||||
IncludeMetadataDocuments,
|
||||
IncludeMetadataDocumentsDistances,
|
||||
)
|
||||
|
||||
from chromadb.api.types import (
|
||||
IncludeMetadataDocumentsEmbeddings,
|
||||
optional_embeddings_to_base64_strings,
|
||||
serialize_metadata,
|
||||
deserialize_metadata,
|
||||
)
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.auth import (
|
||||
ClientAuthProvider,
|
||||
)
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.telemetry.product import ProductTelemetryClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FastAPI(BaseHTTPClient, ServerAPI):
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
system.settings.require("chroma_server_host")
|
||||
system.settings.require("chroma_server_http_port")
|
||||
|
||||
self._opentelemetry_client = self.require(OpenTelemetryClient)
|
||||
self._product_telemetry_client = self.require(ProductTelemetryClient)
|
||||
self._settings = system.settings
|
||||
|
||||
self._api_url = FastAPI.resolve_url(
|
||||
chroma_server_host=str(system.settings.chroma_server_host),
|
||||
chroma_server_http_port=system.settings.chroma_server_http_port,
|
||||
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
|
||||
default_api_path=system.settings.chroma_server_api_default_path,
|
||||
)
|
||||
|
||||
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
|
||||
self._session = httpx.Client(timeout=None, limits=limits)
|
||||
|
||||
self._header = system.settings.chroma_server_headers or {}
|
||||
self._header["Content-Type"] = "application/json"
|
||||
self._header["User-Agent"] = (
|
||||
"Chroma Python Client v"
|
||||
+ __version__
|
||||
+ " (https://github.com/chroma-core/chroma)"
|
||||
)
|
||||
|
||||
if self._settings.chroma_server_ssl_verify is not None:
|
||||
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
|
||||
if self._header is not None:
|
||||
self._session.headers.update(self._header)
|
||||
|
||||
if system.settings.chroma_client_auth_provider:
|
||||
self._auth_provider = self.require(ClientAuthProvider)
|
||||
_headers = self._auth_provider.authenticate()
|
||||
for header, value in _headers.items():
|
||||
self._session.headers[header] = value.get_secret_value()
|
||||
|
||||
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
|
||||
# If the request has json in kwargs, use orjson to serialize it,
|
||||
# remove it from kwargs, and add it to the content parameter
|
||||
# This is because httpx uses a slower json serializer
|
||||
if "json" in kwargs:
|
||||
data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY)
|
||||
kwargs["content"] = data
|
||||
|
||||
# Unlike requests, httpx does not automatically escape the path
|
||||
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
|
||||
url = self._api_url + escaped_path
|
||||
|
||||
response = self._session.request(method, url, **cast(Any, kwargs))
|
||||
BaseHTTPClient._raise_chroma_error(response)
|
||||
return orjson.loads(response.text)
|
||||
|
||||
@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def heartbeat(self) -> int:
|
||||
"""Returns the current server time in nanoseconds to check if the server is alive"""
|
||||
resp_json = self._make_request("get", "/heartbeat")
|
||||
return int(resp_json["nanosecond heartbeat"])
|
||||
|
||||
# Migrated to rust in distributed.
|
||||
@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def create_database(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> None:
|
||||
"""Creates a database"""
|
||||
self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases",
|
||||
json={"name": name},
|
||||
)
|
||||
|
||||
# Migrated to rust in distributed.
|
||||
@trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_database(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Database:
|
||||
"""Returns a database"""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{name}",
|
||||
)
|
||||
return Database(
|
||||
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.delete_database", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def delete_database(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> None:
|
||||
"""Deletes a database"""
|
||||
self._make_request(
|
||||
"delete",
|
||||
f"/tenants/{tenant}/databases/{name}",
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
"""Returns a list of all databases"""
|
||||
json_databases = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases",
|
||||
params=BaseHTTPClient._clean_params(
|
||||
{
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
),
|
||||
)
|
||||
databases = [
|
||||
Database(id=db["id"], name=db["name"], tenant=db["tenant"])
|
||||
for db in json_databases
|
||||
]
|
||||
return databases
|
||||
|
||||
@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def create_tenant(self, name: str) -> None:
|
||||
self._make_request("post", "/tenants", json={"name": name})
|
||||
|
||||
@trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_tenant(self, name: str) -> Tenant:
|
||||
resp_json = self._make_request("get", "/tenants/" + name)
|
||||
return Tenant(name=resp_json["name"])
|
||||
|
||||
@trace_method("FastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_user_identity(self) -> UserIdentity:
|
||||
return UserIdentity(**self._make_request("get", "/auth/identity"))
|
||||
|
||||
@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def list_collections(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Sequence[CollectionModel]:
|
||||
"""Returns a list of all collections"""
|
||||
json_collections = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections",
|
||||
params=BaseHTTPClient._clean_params(
|
||||
{
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
),
|
||||
)
|
||||
collection_models = [
|
||||
CollectionModel.from_json(json_collection)
|
||||
for json_collection in json_collections
|
||||
]
|
||||
|
||||
return collection_models
|
||||
|
||||
@trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def count_collections(
|
||||
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
||||
) -> int:
|
||||
"""Returns a count of collections"""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections_count",
|
||||
)
|
||||
return cast(int, resp_json)
|
||||
|
||||
@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
"""Creates a collection"""
|
||||
config_json = (
|
||||
create_collection_configuration_to_json(configuration, metadata)
|
||||
if configuration
|
||||
else None
|
||||
)
|
||||
serialized_schema = schema.serialize_to_json() if schema else None
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections",
|
||||
json={
|
||||
"name": name,
|
||||
"metadata": metadata,
|
||||
"configuration": config_json,
|
||||
"schema": serialized_schema,
|
||||
"get_or_create": get_or_create,
|
||||
},
|
||||
)
|
||||
model = CollectionModel.from_json(resp_json)
|
||||
|
||||
return model
|
||||
|
||||
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
"""Returns a collection"""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{name}",
|
||||
)
|
||||
|
||||
model = CollectionModel.from_json(resp_json)
|
||||
return model
|
||||
|
||||
@trace_method(
|
||||
"FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
|
||||
)
|
||||
@override
|
||||
def get_or_create_collection(
|
||||
self,
|
||||
name: str,
|
||||
schema: Optional[Schema] = None,
|
||||
configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
return self.create_collection(
|
||||
name=name,
|
||||
metadata=metadata,
|
||||
configuration=configuration,
|
||||
schema=schema,
|
||||
get_or_create=True,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _modify(
|
||||
self,
|
||||
id: UUID,
|
||||
new_name: Optional[str] = None,
|
||||
new_metadata: Optional[CollectionMetadata] = None,
|
||||
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
"""Updates a collection"""
|
||||
self._make_request(
|
||||
"put",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{id}",
|
||||
json={
|
||||
"new_metadata": new_metadata,
|
||||
"new_name": new_name,
|
||||
"new_configuration": update_collection_configuration_to_json(
|
||||
new_configuration
|
||||
)
|
||||
if new_configuration
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._fork", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _fork(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
new_name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> CollectionModel:
|
||||
"""Forks a collection"""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
|
||||
json={"new_name": new_name},
|
||||
)
|
||||
model = CollectionModel.from_json(resp_json)
|
||||
return model
|
||||
|
||||
@trace_method("FastAPI._search", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _search(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
searches: List[Search],
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> SearchResult:
|
||||
"""Performs hybrid search on a collection"""
|
||||
# Convert Search objects to dictionaries
|
||||
payload = {"searches": [s.to_dict() for s in searches]}
|
||||
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/search",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
# Deserialize metadatas: convert transport format to SparseVector instances
|
||||
metadata_batches = resp_json.get("metadatas", None)
|
||||
if metadata_batches is not None:
|
||||
# SearchResult has nested structure: List[Optional[List[Optional[Metadata]]]]
|
||||
resp_json["metadatas"] = [
|
||||
[
|
||||
deserialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in metadatas
|
||||
]
|
||||
if metadatas is not None
|
||||
else None
|
||||
for metadatas in metadata_batches
|
||||
]
|
||||
|
||||
return SearchResult(resp_json)
|
||||
|
||||
@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def delete_collection(
|
||||
self,
|
||||
name: str,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
"""Deletes a collection"""
|
||||
self._make_request(
|
||||
"delete",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{name}",
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _count(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> int:
|
||||
"""Returns the number of embeddings in the database"""
|
||||
resp_json = self._make_request(
|
||||
"get",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/count",
|
||||
)
|
||||
return cast(int, resp_json)
|
||||
|
||||
@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _peek(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
n: int = 10,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
return cast(
|
||||
GetResult,
|
||||
self._get(
|
||||
collection_id,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
limit=n,
|
||||
include=IncludeMetadataDocumentsEmbeddings,
|
||||
),
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _get(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocuments,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> GetResult:
|
||||
# Servers do not support receiving "data", as that is hydrated by the client as a loadable
|
||||
filtered_include = [i for i in include if i != "data"]
|
||||
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/get",
|
||||
json={
|
||||
"ids": ids,
|
||||
"where": where,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"where_document": where_document,
|
||||
"include": filtered_include,
|
||||
},
|
||||
)
|
||||
|
||||
# Deserialize metadatas: convert transport format to SparseVector instances
|
||||
metadatas = resp_json.get("metadatas", None)
|
||||
if metadatas is not None:
|
||||
metadatas = [
|
||||
deserialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in metadatas
|
||||
]
|
||||
|
||||
return GetResult(
|
||||
ids=resp_json["ids"],
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadatas, # type: ignore
|
||||
documents=resp_json.get("documents", None),
|
||||
data=None,
|
||||
uris=resp_json.get("uris", None),
|
||||
included=include,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def _delete(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: Optional[IDs] = None,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
"""Deletes embeddings from the database"""
|
||||
self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete",
|
||||
json={
|
||||
"ids": ids,
|
||||
"where": where,
|
||||
"where_document": where_document,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
|
||||
def _submit_batch(
|
||||
self,
|
||||
batch: Tuple[
|
||||
IDs,
|
||||
Optional[Embeddings],
|
||||
Optional[Metadatas],
|
||||
Optional[Documents],
|
||||
Optional[URIs],
|
||||
],
|
||||
url: str,
|
||||
) -> None:
|
||||
"""
|
||||
Submits a batch of embeddings to the database
|
||||
"""
|
||||
# Serialize metadatas: convert SparseVector instances to transport format
|
||||
serialized_metadatas = None
|
||||
if batch[2] is not None:
|
||||
serialized_metadatas = [
|
||||
serialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in batch[2]
|
||||
]
|
||||
|
||||
data = {
|
||||
"ids": batch[0],
|
||||
"embeddings": optional_embeddings_to_base64_strings(batch[1])
|
||||
if self.supports_base64_encoding()
|
||||
else batch[1],
|
||||
"metadatas": serialized_metadatas,
|
||||
"documents": batch[3],
|
||||
"uris": batch[4],
|
||||
}
|
||||
|
||||
self._make_request("post", url, json=data)
|
||||
|
||||
@trace_method("FastAPI._add", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def _add(
|
||||
self,
|
||||
ids: IDs,
|
||||
collection_id: UUID,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""
|
||||
Adds a batch of embeddings to the database
|
||||
- pass in column oriented data lists
|
||||
"""
|
||||
batch = (
|
||||
ids,
|
||||
embeddings,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
)
|
||||
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
|
||||
self._submit_batch(
|
||||
batch,
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/add",
|
||||
)
|
||||
return True
|
||||
|
||||
@trace_method("FastAPI._update", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def _update(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""
|
||||
Updates a batch of embeddings in the database
|
||||
- pass in column oriented data lists
|
||||
"""
|
||||
batch = (
|
||||
ids,
|
||||
embeddings if embeddings is not None else None,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
)
|
||||
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
|
||||
self._submit_batch(
|
||||
batch,
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/update",
|
||||
)
|
||||
return True
|
||||
|
||||
@trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def _upsert(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
ids: IDs,
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
uris: Optional[URIs] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""
|
||||
Upserts a batch of embeddings in the database
|
||||
- pass in column oriented data lists
|
||||
"""
|
||||
batch = (
|
||||
ids,
|
||||
embeddings,
|
||||
metadatas,
|
||||
documents,
|
||||
uris,
|
||||
)
|
||||
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
|
||||
self._submit_batch(
|
||||
batch,
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/upsert",
|
||||
)
|
||||
return True
|
||||
|
||||
@trace_method("FastAPI._query", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def _query(
|
||||
self,
|
||||
collection_id: UUID,
|
||||
query_embeddings: Embeddings,
|
||||
ids: Optional[IDs] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Include = IncludeMetadataDocumentsDistances,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> QueryResult:
|
||||
# Clients do not support receiving "data", as that is hydrated by the client as a loadable
|
||||
filtered_include = [i for i in include if i != "data"]
|
||||
|
||||
"""Gets the nearest neighbors of a single embedding"""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
|
||||
json={
|
||||
"ids": ids,
|
||||
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
|
||||
if query_embeddings is not None
|
||||
else None,
|
||||
"n_results": n_results,
|
||||
"where": where,
|
||||
"where_document": where_document,
|
||||
"include": filtered_include,
|
||||
},
|
||||
)
|
||||
|
||||
# Deserialize metadatas: convert transport format to SparseVector instances
|
||||
metadata_batches = resp_json.get("metadatas", None)
|
||||
if metadata_batches is not None:
|
||||
metadata_batches = [
|
||||
[
|
||||
deserialize_metadata(metadata) if metadata is not None else None
|
||||
for metadata in metadatas
|
||||
]
|
||||
if metadatas is not None
|
||||
else None
|
||||
for metadatas in metadata_batches
|
||||
]
|
||||
|
||||
return QueryResult(
|
||||
ids=resp_json["ids"],
|
||||
distances=resp_json.get("distances", None),
|
||||
embeddings=resp_json.get("embeddings", None),
|
||||
metadatas=metadata_batches, # type: ignore
|
||||
documents=resp_json.get("documents", None),
|
||||
uris=resp_json.get("uris", None),
|
||||
data=None,
|
||||
included=include,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def reset(self) -> bool:
|
||||
"""Resets the database"""
|
||||
resp_json = self._make_request("post", "/reset")
|
||||
return cast(bool, resp_json)
|
||||
|
||||
@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_version(self) -> str:
|
||||
"""Returns the version of the server"""
|
||||
resp_json = self._make_request("get", "/version")
|
||||
return cast(str, resp_json)
|
||||
|
||||
@override
|
||||
def get_settings(self) -> Settings:
|
||||
"""Returns the settings of the client"""
|
||||
return self._settings
|
||||
|
||||
@trace_method("FastAPI.get_pre_flight_checks", OpenTelemetryGranularity.OPERATION)
|
||||
def get_pre_flight_checks(self) -> Any:
|
||||
if self.pre_flight_checks is None:
|
||||
resp_json = self._make_request("get", "/pre-flight-checks")
|
||||
self.pre_flight_checks = resp_json
|
||||
return self.pre_flight_checks
|
||||
|
||||
@trace_method(
|
||||
"FastAPI.supports_base64_encoding", OpenTelemetryGranularity.OPERATION
|
||||
)
|
||||
def supports_base64_encoding(self) -> bool:
|
||||
pre_flight_checks = self.get_pre_flight_checks()
|
||||
b64_encoding_enabled = cast(
|
||||
bool, pre_flight_checks.get("supports_base64_encoding", False)
|
||||
)
|
||||
return b64_encoding_enabled
|
||||
|
||||
@trace_method("FastAPI.get_max_batch_size", OpenTelemetryGranularity.OPERATION)
|
||||
@override
|
||||
def get_max_batch_size(self) -> int:
|
||||
pre_flight_checks = self.get_pre_flight_checks()
|
||||
max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1))
|
||||
return max_batch_size
|
||||
|
||||
@trace_method("FastAPI.attach_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def attach_function(
|
||||
self,
|
||||
function_id: str,
|
||||
name: str,
|
||||
input_collection_id: UUID,
|
||||
output_collection: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> "AttachedFunction":
|
||||
"""Attach a function to a collection."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/attach",
|
||||
json={
|
||||
"name": name,
|
||||
"function_id": function_id,
|
||||
"output_collection": output_collection,
|
||||
"params": params,
|
||||
},
|
||||
)
|
||||
|
||||
return AttachedFunction(
|
||||
client=self,
|
||||
id=UUID(resp_json["attached_function"]["id"]),
|
||||
name=resp_json["attached_function"]["name"],
|
||||
function_id=resp_json["attached_function"]["function_id"],
|
||||
input_collection_id=input_collection_id,
|
||||
output_collection=output_collection,
|
||||
params=params,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
|
||||
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def detach_function(
|
||||
self,
|
||||
attached_function_id: UUID,
|
||||
delete_output: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> bool:
|
||||
"""Detach a function and prevent any further runs."""
|
||||
resp_json = self._make_request(
|
||||
"post",
|
||||
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
|
||||
json={
|
||||
"delete_output": delete_output,
|
||||
},
|
||||
)
|
||||
return cast(bool, resp_json["success"])
|
||||
Reference in New Issue
Block a user