807 lines
27 KiB
Python
807 lines
27 KiB
Python
import orjson
|
|
import logging
|
|
from typing import Any, Dict, Optional, cast, Tuple, List
|
|
from typing import Sequence
|
|
from uuid import UUID
|
|
import httpx
|
|
import urllib.parse
|
|
from overrides import override
|
|
|
|
from chromadb.api.models.AttachedFunction import AttachedFunction
|
|
|
|
from chromadb.api.collection_configuration import (
|
|
CreateCollectionConfiguration,
|
|
UpdateCollectionConfiguration,
|
|
update_collection_configuration_to_json,
|
|
create_collection_configuration_to_json,
|
|
)
|
|
from chromadb import __version__
|
|
from chromadb.api.base_http_client import BaseHTTPClient
|
|
from chromadb.types import Database, Tenant, Collection as CollectionModel
|
|
from chromadb.api import ServerAPI
|
|
from chromadb.execution.expression.plan import Search
|
|
|
|
from chromadb.api.types import (
|
|
Documents,
|
|
Embeddings,
|
|
IDs,
|
|
Include,
|
|
Schema,
|
|
Metadatas,
|
|
URIs,
|
|
Where,
|
|
WhereDocument,
|
|
GetResult,
|
|
QueryResult,
|
|
SearchResult,
|
|
CollectionMetadata,
|
|
validate_batch,
|
|
convert_np_embeddings_to_list,
|
|
IncludeMetadataDocuments,
|
|
IncludeMetadataDocumentsDistances,
|
|
)
|
|
|
|
from chromadb.api.types import (
|
|
IncludeMetadataDocumentsEmbeddings,
|
|
optional_embeddings_to_base64_strings,
|
|
serialize_metadata,
|
|
deserialize_metadata,
|
|
)
|
|
from chromadb.auth import UserIdentity
|
|
from chromadb.auth import (
|
|
ClientAuthProvider,
|
|
)
|
|
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
|
|
from chromadb.telemetry.opentelemetry import (
|
|
OpenTelemetryClient,
|
|
OpenTelemetryGranularity,
|
|
trace_method,
|
|
)
|
|
from chromadb.telemetry.product import ProductTelemetryClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FastAPI(BaseHTTPClient, ServerAPI):
|
|
def __init__(self, system: System):
|
|
super().__init__(system)
|
|
system.settings.require("chroma_server_host")
|
|
system.settings.require("chroma_server_http_port")
|
|
|
|
self._opentelemetry_client = self.require(OpenTelemetryClient)
|
|
self._product_telemetry_client = self.require(ProductTelemetryClient)
|
|
self._settings = system.settings
|
|
|
|
self._api_url = FastAPI.resolve_url(
|
|
chroma_server_host=str(system.settings.chroma_server_host),
|
|
chroma_server_http_port=system.settings.chroma_server_http_port,
|
|
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
|
|
default_api_path=system.settings.chroma_server_api_default_path,
|
|
)
|
|
|
|
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
|
|
self._session = httpx.Client(timeout=None, limits=limits)
|
|
|
|
self._header = system.settings.chroma_server_headers or {}
|
|
self._header["Content-Type"] = "application/json"
|
|
self._header["User-Agent"] = (
|
|
"Chroma Python Client v"
|
|
+ __version__
|
|
+ " (https://github.com/chroma-core/chroma)"
|
|
)
|
|
|
|
if self._settings.chroma_server_ssl_verify is not None:
|
|
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
|
|
if self._header is not None:
|
|
self._session.headers.update(self._header)
|
|
|
|
if system.settings.chroma_client_auth_provider:
|
|
self._auth_provider = self.require(ClientAuthProvider)
|
|
_headers = self._auth_provider.authenticate()
|
|
for header, value in _headers.items():
|
|
self._session.headers[header] = value.get_secret_value()
|
|
|
|
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
|
|
# If the request has json in kwargs, use orjson to serialize it,
|
|
# remove it from kwargs, and add it to the content parameter
|
|
# This is because httpx uses a slower json serializer
|
|
if "json" in kwargs:
|
|
data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY)
|
|
kwargs["content"] = data
|
|
|
|
# Unlike requests, httpx does not automatically escape the path
|
|
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
|
|
url = self._api_url + escaped_path
|
|
|
|
response = self._session.request(method, url, **cast(Any, kwargs))
|
|
BaseHTTPClient._raise_chroma_error(response)
|
|
return orjson.loads(response.text)
|
|
|
|
@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def heartbeat(self) -> int:
|
|
"""Returns the current server time in nanoseconds to check if the server is alive"""
|
|
resp_json = self._make_request("get", "/heartbeat")
|
|
return int(resp_json["nanosecond heartbeat"])
|
|
|
|
# Migrated to rust in distributed.
|
|
@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def create_database(
|
|
self,
|
|
name: str,
|
|
tenant: str = DEFAULT_TENANT,
|
|
) -> None:
|
|
"""Creates a database"""
|
|
self._make_request(
|
|
"post",
|
|
f"/tenants/{tenant}/databases",
|
|
json={"name": name},
|
|
)
|
|
|
|
# Migrated to rust in distributed.
|
|
@trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def get_database(
|
|
self,
|
|
name: str,
|
|
tenant: str = DEFAULT_TENANT,
|
|
) -> Database:
|
|
"""Returns a database"""
|
|
resp_json = self._make_request(
|
|
"get",
|
|
f"/tenants/{tenant}/databases/{name}",
|
|
)
|
|
return Database(
|
|
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
|
|
)
|
|
|
|
@trace_method("FastAPI.delete_database", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def delete_database(
|
|
self,
|
|
name: str,
|
|
tenant: str = DEFAULT_TENANT,
|
|
) -> None:
|
|
"""Deletes a database"""
|
|
self._make_request(
|
|
"delete",
|
|
f"/tenants/{tenant}/databases/{name}",
|
|
)
|
|
|
|
@trace_method("FastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def list_databases(
|
|
self,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
) -> Sequence[Database]:
|
|
"""Returns a list of all databases"""
|
|
json_databases = self._make_request(
|
|
"get",
|
|
f"/tenants/{tenant}/databases",
|
|
params=BaseHTTPClient._clean_params(
|
|
{
|
|
"limit": limit,
|
|
"offset": offset,
|
|
}
|
|
),
|
|
)
|
|
databases = [
|
|
Database(id=db["id"], name=db["name"], tenant=db["tenant"])
|
|
for db in json_databases
|
|
]
|
|
return databases
|
|
|
|
@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def create_tenant(self, name: str) -> None:
|
|
self._make_request("post", "/tenants", json={"name": name})
|
|
|
|
@trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def get_tenant(self, name: str) -> Tenant:
|
|
resp_json = self._make_request("get", "/tenants/" + name)
|
|
return Tenant(name=resp_json["name"])
|
|
|
|
@trace_method("FastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def get_user_identity(self) -> UserIdentity:
|
|
return UserIdentity(**self._make_request("get", "/auth/identity"))
|
|
|
|
@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def list_collections(
|
|
self,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> Sequence[CollectionModel]:
|
|
"""Returns a list of all collections"""
|
|
json_collections = self._make_request(
|
|
"get",
|
|
f"/tenants/{tenant}/databases/{database}/collections",
|
|
params=BaseHTTPClient._clean_params(
|
|
{
|
|
"limit": limit,
|
|
"offset": offset,
|
|
}
|
|
),
|
|
)
|
|
collection_models = [
|
|
CollectionModel.from_json(json_collection)
|
|
for json_collection in json_collections
|
|
]
|
|
|
|
return collection_models
|
|
|
|
@trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def count_collections(
|
|
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
|
) -> int:
|
|
"""Returns a count of collections"""
|
|
resp_json = self._make_request(
|
|
"get",
|
|
f"/tenants/{tenant}/databases/{database}/collections_count",
|
|
)
|
|
return cast(int, resp_json)
|
|
|
|
@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def create_collection(
|
|
self,
|
|
name: str,
|
|
schema: Optional[Schema] = None,
|
|
configuration: Optional[CreateCollectionConfiguration] = None,
|
|
metadata: Optional[CollectionMetadata] = None,
|
|
get_or_create: bool = False,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> CollectionModel:
|
|
"""Creates a collection"""
|
|
config_json = (
|
|
create_collection_configuration_to_json(configuration, metadata)
|
|
if configuration
|
|
else None
|
|
)
|
|
serialized_schema = schema.serialize_to_json() if schema else None
|
|
resp_json = self._make_request(
|
|
"post",
|
|
f"/tenants/{tenant}/databases/{database}/collections",
|
|
json={
|
|
"name": name,
|
|
"metadata": metadata,
|
|
"configuration": config_json,
|
|
"schema": serialized_schema,
|
|
"get_or_create": get_or_create,
|
|
},
|
|
)
|
|
model = CollectionModel.from_json(resp_json)
|
|
|
|
return model
|
|
|
|
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def get_collection(
|
|
self,
|
|
name: str,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> CollectionModel:
|
|
"""Returns a collection"""
|
|
resp_json = self._make_request(
|
|
"get",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{name}",
|
|
)
|
|
|
|
model = CollectionModel.from_json(resp_json)
|
|
return model
|
|
|
|
@trace_method(
|
|
"FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
|
|
)
|
|
@override
|
|
def get_or_create_collection(
|
|
self,
|
|
name: str,
|
|
schema: Optional[Schema] = None,
|
|
configuration: Optional[CreateCollectionConfiguration] = None,
|
|
metadata: Optional[CollectionMetadata] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> CollectionModel:
|
|
return self.create_collection(
|
|
name=name,
|
|
metadata=metadata,
|
|
configuration=configuration,
|
|
schema=schema,
|
|
get_or_create=True,
|
|
tenant=tenant,
|
|
database=database,
|
|
)
|
|
|
|
@trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def _modify(
|
|
self,
|
|
id: UUID,
|
|
new_name: Optional[str] = None,
|
|
new_metadata: Optional[CollectionMetadata] = None,
|
|
new_configuration: Optional[UpdateCollectionConfiguration] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> None:
|
|
"""Updates a collection"""
|
|
self._make_request(
|
|
"put",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{id}",
|
|
json={
|
|
"new_metadata": new_metadata,
|
|
"new_name": new_name,
|
|
"new_configuration": update_collection_configuration_to_json(
|
|
new_configuration
|
|
)
|
|
if new_configuration
|
|
else None,
|
|
},
|
|
)
|
|
|
|
@trace_method("FastAPI._fork", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def _fork(
|
|
self,
|
|
collection_id: UUID,
|
|
new_name: str,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> CollectionModel:
|
|
"""Forks a collection"""
|
|
resp_json = self._make_request(
|
|
"post",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
|
|
json={"new_name": new_name},
|
|
)
|
|
model = CollectionModel.from_json(resp_json)
|
|
return model
|
|
|
|
@trace_method("FastAPI._search", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def _search(
|
|
self,
|
|
collection_id: UUID,
|
|
searches: List[Search],
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> SearchResult:
|
|
"""Performs hybrid search on a collection"""
|
|
# Convert Search objects to dictionaries
|
|
payload = {"searches": [s.to_dict() for s in searches]}
|
|
|
|
resp_json = self._make_request(
|
|
"post",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/search",
|
|
json=payload,
|
|
)
|
|
|
|
# Deserialize metadatas: convert transport format to SparseVector instances
|
|
metadata_batches = resp_json.get("metadatas", None)
|
|
if metadata_batches is not None:
|
|
# SearchResult has nested structure: List[Optional[List[Optional[Metadata]]]]
|
|
resp_json["metadatas"] = [
|
|
[
|
|
deserialize_metadata(metadata) if metadata is not None else None
|
|
for metadata in metadatas
|
|
]
|
|
if metadatas is not None
|
|
else None
|
|
for metadatas in metadata_batches
|
|
]
|
|
|
|
return SearchResult(resp_json)
|
|
|
|
@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def delete_collection(
|
|
self,
|
|
name: str,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> None:
|
|
"""Deletes a collection"""
|
|
self._make_request(
|
|
"delete",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{name}",
|
|
)
|
|
|
|
@trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def _count(
|
|
self,
|
|
collection_id: UUID,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> int:
|
|
"""Returns the number of embeddings in the database"""
|
|
resp_json = self._make_request(
|
|
"get",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/count",
|
|
)
|
|
return cast(int, resp_json)
|
|
|
|
@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def _peek(
|
|
self,
|
|
collection_id: UUID,
|
|
n: int = 10,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> GetResult:
|
|
return cast(
|
|
GetResult,
|
|
self._get(
|
|
collection_id,
|
|
tenant=tenant,
|
|
database=database,
|
|
limit=n,
|
|
include=IncludeMetadataDocumentsEmbeddings,
|
|
),
|
|
)
|
|
|
|
@trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def _get(
|
|
self,
|
|
collection_id: UUID,
|
|
ids: Optional[IDs] = None,
|
|
where: Optional[Where] = None,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None,
|
|
where_document: Optional[WhereDocument] = None,
|
|
include: Include = IncludeMetadataDocuments,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> GetResult:
|
|
# Servers do not support receiving "data", as that is hydrated by the client as a loadable
|
|
filtered_include = [i for i in include if i != "data"]
|
|
|
|
resp_json = self._make_request(
|
|
"post",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/get",
|
|
json={
|
|
"ids": ids,
|
|
"where": where,
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"where_document": where_document,
|
|
"include": filtered_include,
|
|
},
|
|
)
|
|
|
|
# Deserialize metadatas: convert transport format to SparseVector instances
|
|
metadatas = resp_json.get("metadatas", None)
|
|
if metadatas is not None:
|
|
metadatas = [
|
|
deserialize_metadata(metadata) if metadata is not None else None
|
|
for metadata in metadatas
|
|
]
|
|
|
|
return GetResult(
|
|
ids=resp_json["ids"],
|
|
embeddings=resp_json.get("embeddings", None),
|
|
metadatas=metadatas, # type: ignore
|
|
documents=resp_json.get("documents", None),
|
|
data=None,
|
|
uris=resp_json.get("uris", None),
|
|
included=include,
|
|
)
|
|
|
|
@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def _delete(
|
|
self,
|
|
collection_id: UUID,
|
|
ids: Optional[IDs] = None,
|
|
where: Optional[Where] = None,
|
|
where_document: Optional[WhereDocument] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> None:
|
|
"""Deletes embeddings from the database"""
|
|
self._make_request(
|
|
"post",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete",
|
|
json={
|
|
"ids": ids,
|
|
"where": where,
|
|
"where_document": where_document,
|
|
},
|
|
)
|
|
return None
|
|
|
|
@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
|
|
def _submit_batch(
|
|
self,
|
|
batch: Tuple[
|
|
IDs,
|
|
Optional[Embeddings],
|
|
Optional[Metadatas],
|
|
Optional[Documents],
|
|
Optional[URIs],
|
|
],
|
|
url: str,
|
|
) -> None:
|
|
"""
|
|
Submits a batch of embeddings to the database
|
|
"""
|
|
# Serialize metadatas: convert SparseVector instances to transport format
|
|
serialized_metadatas = None
|
|
if batch[2] is not None:
|
|
serialized_metadatas = [
|
|
serialize_metadata(metadata) if metadata is not None else None
|
|
for metadata in batch[2]
|
|
]
|
|
|
|
data = {
|
|
"ids": batch[0],
|
|
"embeddings": optional_embeddings_to_base64_strings(batch[1])
|
|
if self.supports_base64_encoding()
|
|
else batch[1],
|
|
"metadatas": serialized_metadatas,
|
|
"documents": batch[3],
|
|
"uris": batch[4],
|
|
}
|
|
|
|
self._make_request("post", url, json=data)
|
|
|
|
@trace_method("FastAPI._add", OpenTelemetryGranularity.ALL)
|
|
@override
|
|
def _add(
|
|
self,
|
|
ids: IDs,
|
|
collection_id: UUID,
|
|
embeddings: Embeddings,
|
|
metadatas: Optional[Metadatas] = None,
|
|
documents: Optional[Documents] = None,
|
|
uris: Optional[URIs] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> bool:
|
|
"""
|
|
Adds a batch of embeddings to the database
|
|
- pass in column oriented data lists
|
|
"""
|
|
batch = (
|
|
ids,
|
|
embeddings,
|
|
metadatas,
|
|
documents,
|
|
uris,
|
|
)
|
|
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
|
|
self._submit_batch(
|
|
batch,
|
|
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/add",
|
|
)
|
|
return True
|
|
|
|
@trace_method("FastAPI._update", OpenTelemetryGranularity.ALL)
|
|
@override
|
|
def _update(
|
|
self,
|
|
collection_id: UUID,
|
|
ids: IDs,
|
|
embeddings: Optional[Embeddings] = None,
|
|
metadatas: Optional[Metadatas] = None,
|
|
documents: Optional[Documents] = None,
|
|
uris: Optional[URIs] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> bool:
|
|
"""
|
|
Updates a batch of embeddings in the database
|
|
- pass in column oriented data lists
|
|
"""
|
|
batch = (
|
|
ids,
|
|
embeddings if embeddings is not None else None,
|
|
metadatas,
|
|
documents,
|
|
uris,
|
|
)
|
|
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
|
|
self._submit_batch(
|
|
batch,
|
|
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/update",
|
|
)
|
|
return True
|
|
|
|
@trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL)
|
|
@override
|
|
def _upsert(
|
|
self,
|
|
collection_id: UUID,
|
|
ids: IDs,
|
|
embeddings: Embeddings,
|
|
metadatas: Optional[Metadatas] = None,
|
|
documents: Optional[Documents] = None,
|
|
uris: Optional[URIs] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> bool:
|
|
"""
|
|
Upserts a batch of embeddings in the database
|
|
- pass in column oriented data lists
|
|
"""
|
|
batch = (
|
|
ids,
|
|
embeddings,
|
|
metadatas,
|
|
documents,
|
|
uris,
|
|
)
|
|
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()})
|
|
self._submit_batch(
|
|
batch,
|
|
f"/tenants/{tenant}/databases/{database}/collections/{str(collection_id)}/upsert",
|
|
)
|
|
return True
|
|
|
|
@trace_method("FastAPI._query", OpenTelemetryGranularity.ALL)
|
|
@override
|
|
def _query(
|
|
self,
|
|
collection_id: UUID,
|
|
query_embeddings: Embeddings,
|
|
ids: Optional[IDs] = None,
|
|
n_results: int = 10,
|
|
where: Optional[Where] = None,
|
|
where_document: Optional[WhereDocument] = None,
|
|
include: Include = IncludeMetadataDocumentsDistances,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> QueryResult:
|
|
# Clients do not support receiving "data", as that is hydrated by the client as a loadable
|
|
filtered_include = [i for i in include if i != "data"]
|
|
|
|
"""Gets the nearest neighbors of a single embedding"""
|
|
resp_json = self._make_request(
|
|
"post",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
|
|
json={
|
|
"ids": ids,
|
|
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
|
|
if query_embeddings is not None
|
|
else None,
|
|
"n_results": n_results,
|
|
"where": where,
|
|
"where_document": where_document,
|
|
"include": filtered_include,
|
|
},
|
|
)
|
|
|
|
# Deserialize metadatas: convert transport format to SparseVector instances
|
|
metadata_batches = resp_json.get("metadatas", None)
|
|
if metadata_batches is not None:
|
|
metadata_batches = [
|
|
[
|
|
deserialize_metadata(metadata) if metadata is not None else None
|
|
for metadata in metadatas
|
|
]
|
|
if metadatas is not None
|
|
else None
|
|
for metadatas in metadata_batches
|
|
]
|
|
|
|
return QueryResult(
|
|
ids=resp_json["ids"],
|
|
distances=resp_json.get("distances", None),
|
|
embeddings=resp_json.get("embeddings", None),
|
|
metadatas=metadata_batches, # type: ignore
|
|
documents=resp_json.get("documents", None),
|
|
uris=resp_json.get("uris", None),
|
|
data=None,
|
|
included=include,
|
|
)
|
|
|
|
@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
|
|
@override
|
|
def reset(self) -> bool:
|
|
"""Resets the database"""
|
|
resp_json = self._make_request("post", "/reset")
|
|
return cast(bool, resp_json)
|
|
|
|
@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def get_version(self) -> str:
|
|
"""Returns the version of the server"""
|
|
resp_json = self._make_request("get", "/version")
|
|
return cast(str, resp_json)
|
|
|
|
@override
|
|
def get_settings(self) -> Settings:
|
|
"""Returns the settings of the client"""
|
|
return self._settings
|
|
|
|
@trace_method("FastAPI.get_pre_flight_checks", OpenTelemetryGranularity.OPERATION)
|
|
def get_pre_flight_checks(self) -> Any:
|
|
if self.pre_flight_checks is None:
|
|
resp_json = self._make_request("get", "/pre-flight-checks")
|
|
self.pre_flight_checks = resp_json
|
|
return self.pre_flight_checks
|
|
|
|
@trace_method(
|
|
"FastAPI.supports_base64_encoding", OpenTelemetryGranularity.OPERATION
|
|
)
|
|
def supports_base64_encoding(self) -> bool:
|
|
pre_flight_checks = self.get_pre_flight_checks()
|
|
b64_encoding_enabled = cast(
|
|
bool, pre_flight_checks.get("supports_base64_encoding", False)
|
|
)
|
|
return b64_encoding_enabled
|
|
|
|
@trace_method("FastAPI.get_max_batch_size", OpenTelemetryGranularity.OPERATION)
|
|
@override
|
|
def get_max_batch_size(self) -> int:
|
|
pre_flight_checks = self.get_pre_flight_checks()
|
|
max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1))
|
|
return max_batch_size
|
|
|
|
@trace_method("FastAPI.attach_function", OpenTelemetryGranularity.ALL)
|
|
@override
|
|
def attach_function(
|
|
self,
|
|
function_id: str,
|
|
name: str,
|
|
input_collection_id: UUID,
|
|
output_collection: str,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> "AttachedFunction":
|
|
"""Attach a function to a collection."""
|
|
resp_json = self._make_request(
|
|
"post",
|
|
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/functions/attach",
|
|
json={
|
|
"name": name,
|
|
"function_id": function_id,
|
|
"output_collection": output_collection,
|
|
"params": params,
|
|
},
|
|
)
|
|
|
|
return AttachedFunction(
|
|
client=self,
|
|
id=UUID(resp_json["attached_function"]["id"]),
|
|
name=resp_json["attached_function"]["name"],
|
|
function_id=resp_json["attached_function"]["function_id"],
|
|
input_collection_id=input_collection_id,
|
|
output_collection=output_collection,
|
|
params=params,
|
|
tenant=tenant,
|
|
database=database,
|
|
)
|
|
|
|
@trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL)
|
|
@override
|
|
def detach_function(
|
|
self,
|
|
attached_function_id: UUID,
|
|
delete_output: bool = False,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
) -> bool:
|
|
"""Detach a function and prevent any further runs."""
|
|
resp_json = self._make_request(
|
|
"post",
|
|
f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach",
|
|
json={
|
|
"delete_output": delete_output,
|
|
},
|
|
)
|
|
return cast(bool, resp_json["success"])
|