chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,558 @@
|
||||
from typing import List, Optional, Sequence, Tuple, Union, cast
|
||||
from uuid import UUID
|
||||
from overrides import overrides
|
||||
from chromadb.api.collection_configuration import (
|
||||
CreateCollectionConfiguration,
|
||||
create_collection_configuration_to_json_str,
|
||||
UpdateCollectionConfiguration,
|
||||
update_collection_configuration_to_json_str,
|
||||
CollectionMetadata,
|
||||
)
|
||||
from chromadb.api.types import Schema
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, logger
|
||||
from chromadb.db.system import SysDB
|
||||
from chromadb.errors import NotFoundError, UniqueConstraintError, InternalError
|
||||
from chromadb.proto.convert import (
|
||||
from_proto_collection,
|
||||
from_proto_segment,
|
||||
to_proto_update_metadata,
|
||||
to_proto_segment,
|
||||
to_proto_segment_scope,
|
||||
)
|
||||
from chromadb.proto.coordinator_pb2 import (
|
||||
CreateCollectionRequest,
|
||||
CreateDatabaseRequest,
|
||||
CreateSegmentRequest,
|
||||
CreateTenantRequest,
|
||||
CountCollectionsRequest,
|
||||
CountCollectionsResponse,
|
||||
DeleteCollectionRequest,
|
||||
DeleteDatabaseRequest,
|
||||
DeleteSegmentRequest,
|
||||
GetCollectionsRequest,
|
||||
GetCollectionsResponse,
|
||||
GetCollectionSizeRequest,
|
||||
GetCollectionSizeResponse,
|
||||
GetCollectionWithSegmentsRequest,
|
||||
GetCollectionWithSegmentsResponse,
|
||||
GetDatabaseRequest,
|
||||
GetSegmentsRequest,
|
||||
GetTenantRequest,
|
||||
ListDatabasesRequest,
|
||||
UpdateCollectionRequest,
|
||||
UpdateSegmentRequest,
|
||||
)
|
||||
from chromadb.proto.coordinator_pb2_grpc import SysDBStub
|
||||
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
|
||||
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.types import (
|
||||
Collection,
|
||||
CollectionAndSegments,
|
||||
Database,
|
||||
Metadata,
|
||||
OptionalArgument,
|
||||
Segment,
|
||||
SegmentScope,
|
||||
Tenant,
|
||||
Unspecified,
|
||||
UpdateMetadata,
|
||||
)
|
||||
from google.protobuf.empty_pb2 import Empty
|
||||
import grpc
|
||||
|
||||
|
||||
class GrpcSysDB(SysDB):
|
||||
"""A gRPC implementation of the SysDB. In the distributed system, the SysDB is also
|
||||
called the 'Coordinator'. This implementation is used by Chroma frontend servers
|
||||
to call a remote SysDB (Coordinator) service."""
|
||||
|
||||
_sys_db_stub: SysDBStub
|
||||
_channel: grpc.Channel
|
||||
_coordinator_url: str
|
||||
_coordinator_port: int
|
||||
_request_timeout_seconds: int
|
||||
|
||||
def __init__(self, system: System):
|
||||
self._coordinator_url = system.settings.require("chroma_coordinator_host")
|
||||
# TODO: break out coordinator_port into a separate setting?
|
||||
self._coordinator_port = system.settings.require("chroma_server_grpc_port")
|
||||
self._request_timeout_seconds = system.settings.require(
|
||||
"chroma_sysdb_request_timeout_seconds"
|
||||
)
|
||||
return super().__init__(system)
|
||||
|
||||
@overrides
|
||||
def start(self) -> None:
|
||||
self._channel = grpc.insecure_channel(
|
||||
f"{self._coordinator_url}:{self._coordinator_port}",
|
||||
options=[("grpc.max_concurrent_streams", 1000)],
|
||||
)
|
||||
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
|
||||
self._channel = grpc.intercept_channel(self._channel, *interceptors)
|
||||
self._sys_db_stub = SysDBStub(self._channel) # type: ignore
|
||||
return super().start()
|
||||
|
||||
@overrides
|
||||
def stop(self) -> None:
|
||||
self._channel.close()
|
||||
return super().stop()
|
||||
|
||||
@overrides
|
||||
def reset_state(self) -> None:
|
||||
self._sys_db_stub.ResetState(Empty())
|
||||
return super().reset_state()
|
||||
|
||||
@overrides
|
||||
def create_database(
|
||||
self, id: UUID, name: str, tenant: str = DEFAULT_TENANT
|
||||
) -> None:
|
||||
try:
|
||||
request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant)
|
||||
response = self._sys_db_stub.CreateDatabase(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to create database name {name} and database id {id} for tenant {tenant} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
||||
try:
|
||||
request = GetDatabaseRequest(name=name, tenant=tenant)
|
||||
response = self._sys_db_stub.GetDatabase(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
return Database(
|
||||
id=UUID(hex=response.database.id),
|
||||
name=response.database.name,
|
||||
tenant=response.database.tenant,
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to get database {name} for tenant {tenant} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
||||
try:
|
||||
request = DeleteDatabaseRequest(name=name, tenant=tenant)
|
||||
self._sys_db_stub.DeleteDatabase(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to delete database {name} for tenant {tenant} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError
|
||||
|
||||
@overrides
|
||||
def list_databases(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
) -> Sequence[Database]:
|
||||
try:
|
||||
request = ListDatabasesRequest(limit=limit, offset=offset, tenant=tenant)
|
||||
response = self._sys_db_stub.ListDatabases(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
results: List[Database] = []
|
||||
for proto_database in response.databases:
|
||||
results.append(
|
||||
Database(
|
||||
id=UUID(hex=proto_database.id),
|
||||
name=proto_database.name,
|
||||
tenant=proto_database.tenant,
|
||||
)
|
||||
)
|
||||
return results
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to list databases for tenant {tenant} due to error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def create_tenant(self, name: str) -> None:
|
||||
try:
|
||||
request = CreateTenantRequest(name=name)
|
||||
response = self._sys_db_stub.CreateTenant(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(f"Failed to create tenant {name} due to error: {e}")
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_tenant(self, name: str) -> Tenant:
|
||||
try:
|
||||
request = GetTenantRequest(name=name)
|
||||
response = self._sys_db_stub.GetTenant(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
return Tenant(
|
||||
name=response.tenant.name,
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(f"Failed to get tenant {name} due to error: {e}")
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def create_segment(self, segment: Segment) -> None:
|
||||
try:
|
||||
proto_segment = to_proto_segment(segment)
|
||||
request = CreateSegmentRequest(
|
||||
segment=proto_segment,
|
||||
)
|
||||
response = self._sys_db_stub.CreateSegment(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(f"Failed to create segment {segment}, error: {e}")
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def delete_segment(self, collection: UUID, id: UUID) -> None:
|
||||
try:
|
||||
request = DeleteSegmentRequest(
|
||||
id=id.hex,
|
||||
collection=collection.hex,
|
||||
)
|
||||
response = self._sys_db_stub.DeleteSegment(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to delete segment with id {id} for collection {collection} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_segments(
|
||||
self,
|
||||
collection: UUID,
|
||||
id: Optional[UUID] = None,
|
||||
type: Optional[str] = None,
|
||||
scope: Optional[SegmentScope] = None,
|
||||
) -> Sequence[Segment]:
|
||||
try:
|
||||
request = GetSegmentsRequest(
|
||||
id=id.hex if id else None,
|
||||
type=type,
|
||||
scope=to_proto_segment_scope(scope) if scope else None,
|
||||
collection=collection.hex,
|
||||
)
|
||||
response = self._sys_db_stub.GetSegments(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
results: List[Segment] = []
|
||||
for proto_segment in response.segments:
|
||||
segment = from_proto_segment(proto_segment)
|
||||
results.append(segment)
|
||||
return results
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to get segment id {id}, type {type}, scope {scope} for collection {collection} due to error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def update_segment(
|
||||
self,
|
||||
collection: UUID,
|
||||
id: UUID,
|
||||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
|
||||
) -> None:
|
||||
try:
|
||||
write_metadata = None
|
||||
if metadata != Unspecified():
|
||||
write_metadata = cast(Union[UpdateMetadata, None], metadata)
|
||||
|
||||
request = UpdateSegmentRequest(
|
||||
id=id.hex,
|
||||
collection=collection.hex,
|
||||
metadata=to_proto_update_metadata(write_metadata)
|
||||
if write_metadata
|
||||
else None,
|
||||
)
|
||||
|
||||
if metadata is None:
|
||||
request.ClearField("metadata")
|
||||
request.reset_metadata = True
|
||||
|
||||
self._sys_db_stub.UpdateSegment(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.info(
|
||||
f"Failed to update segment with id {id} for collection {collection}, error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def create_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
name: str,
|
||||
schema: Optional[Schema],
|
||||
configuration: CreateCollectionConfiguration,
|
||||
segments: Sequence[Segment],
|
||||
metadata: Optional[Metadata] = None,
|
||||
dimension: Optional[int] = None,
|
||||
get_or_create: bool = False,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> Tuple[Collection, bool]:
|
||||
try:
|
||||
request = CreateCollectionRequest(
|
||||
id=id.hex,
|
||||
name=name,
|
||||
configuration_json_str=create_collection_configuration_to_json_str(
|
||||
configuration, cast(CollectionMetadata, metadata)
|
||||
),
|
||||
metadata=to_proto_update_metadata(metadata) if metadata else None,
|
||||
dimension=dimension,
|
||||
get_or_create=get_or_create,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
segments=[to_proto_segment(segment) for segment in segments],
|
||||
)
|
||||
response = self._sys_db_stub.CreateCollection(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
collection = from_proto_collection(response.collection)
|
||||
return collection, response.created
|
||||
except grpc.RpcError as e:
|
||||
logger.error(
|
||||
f"Failed to create collection id {id}, name {name} for database {database} and tenant {tenant} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def delete_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
) -> None:
|
||||
try:
|
||||
request = DeleteCollectionRequest(
|
||||
id=id.hex,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
response = self._sys_db_stub.DeleteCollection(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logger.error(
|
||||
f"Failed to delete collection id {id} for database {database} and tenant {tenant} due to error: {e}"
|
||||
)
|
||||
e = cast(grpc.Call, e)
|
||||
logger.error(
|
||||
f"Error code: {e.code()}, NotFoundError: {grpc.StatusCode.NOT_FOUND}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_collections(
|
||||
self,
|
||||
id: Optional[UUID] = None,
|
||||
name: Optional[str] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> Sequence[Collection]:
|
||||
try:
|
||||
# TODO: implement limit and offset in the gRPC service
|
||||
request = None
|
||||
if id is not None:
|
||||
request = GetCollectionsRequest(
|
||||
id=id.hex,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if name is not None:
|
||||
if tenant is None and database is None:
|
||||
raise ValueError(
|
||||
"If name is specified, tenant and database must also be specified in order to uniquely identify the collection"
|
||||
)
|
||||
request = GetCollectionsRequest(
|
||||
name=name,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if id is None and name is None:
|
||||
request = GetCollectionsRequest(
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
response: GetCollectionsResponse = self._sys_db_stub.GetCollections(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
results: List[Collection] = []
|
||||
for collection in response.collections:
|
||||
results.append(from_proto_collection(collection))
|
||||
return results
|
||||
except grpc.RpcError as e:
|
||||
logger.error(
|
||||
f"Failed to get collections with id {id}, name {name}, tenant {tenant}, database {database} due to error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def count_collections(
|
||||
self,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: Optional[str] = None,
|
||||
) -> int:
|
||||
try:
|
||||
if database is None or database == "":
|
||||
request = CountCollectionsRequest(tenant=tenant)
|
||||
response: CountCollectionsResponse = self._sys_db_stub.CountCollections(
|
||||
request
|
||||
)
|
||||
return response.count
|
||||
else:
|
||||
request = CountCollectionsRequest(
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
response: CountCollectionsResponse = self._sys_db_stub.CountCollections(
|
||||
request
|
||||
)
|
||||
return response.count
|
||||
except grpc.RpcError as e:
|
||||
logger.error(f"Failed to count collections due to error: {e}")
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def get_collection_size(self, id: UUID) -> int:
|
||||
try:
|
||||
request = GetCollectionSizeRequest(id=id.hex)
|
||||
response: GetCollectionSizeResponse = self._sys_db_stub.GetCollectionSize(
|
||||
request
|
||||
)
|
||||
return response.total_records_post_compaction
|
||||
except grpc.RpcError as e:
|
||||
logger.error(f"Failed to get collection {id} size due to error: {e}")
|
||||
raise InternalError()
|
||||
|
||||
@trace_method(
|
||||
"SysDB.get_collection_with_segments", OpenTelemetryGranularity.OPERATION
|
||||
)
|
||||
@overrides
|
||||
def get_collection_with_segments(
|
||||
self, collection_id: UUID
|
||||
) -> CollectionAndSegments:
|
||||
try:
|
||||
request = GetCollectionWithSegmentsRequest(id=collection_id.hex)
|
||||
response: GetCollectionWithSegmentsResponse = (
|
||||
self._sys_db_stub.GetCollectionWithSegments(request)
|
||||
)
|
||||
return CollectionAndSegments(
|
||||
collection=from_proto_collection(response.collection),
|
||||
segments=[from_proto_segment(segment) for segment in response.segments],
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
logger.error(
|
||||
f"Failed to get collection {collection_id} and its segments due to error: {e}"
|
||||
)
|
||||
raise InternalError()
|
||||
|
||||
@overrides
|
||||
def update_collection(
|
||||
self,
|
||||
id: UUID,
|
||||
name: OptionalArgument[str] = Unspecified(),
|
||||
dimension: OptionalArgument[Optional[int]] = Unspecified(),
|
||||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
|
||||
configuration: OptionalArgument[
|
||||
Optional[UpdateCollectionConfiguration]
|
||||
] = Unspecified(),
|
||||
) -> None:
|
||||
try:
|
||||
write_name = None
|
||||
if name != Unspecified():
|
||||
write_name = cast(str, name)
|
||||
|
||||
write_dimension = None
|
||||
if dimension != Unspecified():
|
||||
write_dimension = cast(Union[int, None], dimension)
|
||||
|
||||
write_metadata = None
|
||||
if metadata != Unspecified():
|
||||
write_metadata = cast(Union[UpdateMetadata, None], metadata)
|
||||
|
||||
write_configuration = None
|
||||
if configuration != Unspecified():
|
||||
write_configuration = cast(
|
||||
Union[UpdateCollectionConfiguration, None], configuration
|
||||
)
|
||||
|
||||
request = UpdateCollectionRequest(
|
||||
id=id.hex,
|
||||
name=write_name,
|
||||
dimension=write_dimension,
|
||||
metadata=to_proto_update_metadata(write_metadata)
|
||||
if write_metadata
|
||||
else None,
|
||||
configuration_json_str=update_collection_configuration_to_json_str(
|
||||
write_configuration
|
||||
)
|
||||
if write_configuration
|
||||
else None,
|
||||
)
|
||||
if metadata is None:
|
||||
request.ClearField("metadata")
|
||||
request.reset_metadata = True
|
||||
|
||||
response = self._sys_db_stub.UpdateCollection(
|
||||
request, timeout=self._request_timeout_seconds
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
e = cast(grpc.Call, e)
|
||||
logger.error(
|
||||
f"Failed to update collection id {id}, name {name} due to error: {e}"
|
||||
)
|
||||
if e.code() == grpc.StatusCode.NOT_FOUND:
|
||||
raise NotFoundError()
|
||||
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
||||
raise UniqueConstraintError()
|
||||
raise InternalError()
|
||||
|
||||
def reset_and_wait_for_ready(self) -> None:
|
||||
self._sys_db_stub.ResetState(Empty(), wait_for_ready=True)
|
||||
@@ -0,0 +1,497 @@
|
||||
from concurrent import futures
|
||||
from typing import Any, Dict, List, cast
|
||||
from uuid import UUID
|
||||
from overrides import overrides
|
||||
import json
|
||||
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System
|
||||
from chromadb.proto.convert import (
|
||||
from_proto_metadata,
|
||||
from_proto_update_metadata,
|
||||
from_proto_segment,
|
||||
from_proto_segment_scope,
|
||||
to_proto_collection,
|
||||
to_proto_segment,
|
||||
)
|
||||
import chromadb.proto.chroma_pb2 as proto
|
||||
from chromadb.proto.coordinator_pb2 import (
|
||||
CreateCollectionRequest,
|
||||
CreateCollectionResponse,
|
||||
CreateDatabaseRequest,
|
||||
CreateDatabaseResponse,
|
||||
CreateSegmentRequest,
|
||||
CreateSegmentResponse,
|
||||
CreateTenantRequest,
|
||||
CreateTenantResponse,
|
||||
CountCollectionsRequest,
|
||||
CountCollectionsResponse,
|
||||
DeleteCollectionRequest,
|
||||
DeleteCollectionResponse,
|
||||
DeleteSegmentRequest,
|
||||
DeleteSegmentResponse,
|
||||
GetCollectionsRequest,
|
||||
GetCollectionsResponse,
|
||||
GetCollectionSizeRequest,
|
||||
GetCollectionSizeResponse,
|
||||
GetCollectionWithSegmentsRequest,
|
||||
GetCollectionWithSegmentsResponse,
|
||||
GetDatabaseRequest,
|
||||
GetDatabaseResponse,
|
||||
GetSegmentsRequest,
|
||||
GetSegmentsResponse,
|
||||
GetTenantRequest,
|
||||
GetTenantResponse,
|
||||
ResetStateResponse,
|
||||
UpdateCollectionRequest,
|
||||
UpdateCollectionResponse,
|
||||
UpdateSegmentRequest,
|
||||
UpdateSegmentResponse,
|
||||
)
|
||||
from chromadb.proto.coordinator_pb2_grpc import (
|
||||
SysDBServicer,
|
||||
add_SysDBServicer_to_server,
|
||||
)
|
||||
import grpc
|
||||
from google.protobuf.empty_pb2 import Empty
|
||||
from chromadb.types import Collection, Metadata, Segment, SegmentScope
|
||||
|
||||
|
||||
class GrpcMockSysDB(SysDBServicer, Component):
|
||||
"""A mock sysdb implementation that can be used for testing the grpc client. It stores
|
||||
state in simple python data structures instead of a database."""
|
||||
|
||||
_server: grpc.Server
|
||||
_server_port: int
|
||||
_segments: Dict[str, Segment] = {}
|
||||
_collection_to_segments: Dict[str, List[str]] = {}
|
||||
_tenants_to_databases_to_collections: Dict[
|
||||
str, Dict[str, Dict[str, Collection]]
|
||||
] = {}
|
||||
_tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {}
|
||||
|
||||
def __init__(self, system: System):
|
||||
self._server_port = system.settings.require("chroma_server_grpc_port")
|
||||
return super().__init__(system)
|
||||
|
||||
@overrides
|
||||
def start(self) -> None:
|
||||
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
add_SysDBServicer_to_server(self, self._server) # type: ignore
|
||||
self._server.add_insecure_port(f"[::]:{self._server_port}")
|
||||
self._server.start()
|
||||
return super().start()
|
||||
|
||||
@overrides
|
||||
def stop(self) -> None:
|
||||
self._server.stop(None)
|
||||
return super().stop()
|
||||
|
||||
@overrides
|
||||
def reset_state(self) -> None:
|
||||
self._segments = {}
|
||||
self._tenants_to_databases_to_collections = {}
|
||||
# Create defaults
|
||||
self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {}
|
||||
self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {}
|
||||
self._tenants_to_database_to_id[DEFAULT_TENANT] = {}
|
||||
self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0)
|
||||
return super().reset_state()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def CreateDatabase(
|
||||
self, request: CreateDatabaseRequest, context: grpc.ServicerContext
|
||||
) -> CreateDatabaseResponse:
|
||||
tenant = request.tenant
|
||||
database = request.name
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
if database in self._tenants_to_databases_to_collections[tenant]:
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS, f"Database {database} already exists"
|
||||
)
|
||||
self._tenants_to_databases_to_collections[tenant][database] = {}
|
||||
self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id)
|
||||
return CreateDatabaseResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetDatabase(
|
||||
self, request: GetDatabaseRequest, context: grpc.ServicerContext
|
||||
) -> GetDatabaseResponse:
|
||||
tenant = request.tenant
|
||||
database = request.name
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
if database not in self._tenants_to_databases_to_collections[tenant]:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
|
||||
id = self._tenants_to_database_to_id[tenant][database]
|
||||
return GetDatabaseResponse(
|
||||
database=proto.Database(id=id.hex, name=database, tenant=tenant),
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def CreateTenant(
|
||||
self, request: CreateTenantRequest, context: grpc.ServicerContext
|
||||
) -> CreateTenantResponse:
|
||||
tenant = request.name
|
||||
if tenant in self._tenants_to_databases_to_collections:
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS, f"Tenant {tenant} already exists"
|
||||
)
|
||||
self._tenants_to_databases_to_collections[tenant] = {}
|
||||
self._tenants_to_database_to_id[tenant] = {}
|
||||
return CreateTenantResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetTenant(
|
||||
self, request: GetTenantRequest, context: grpc.ServicerContext
|
||||
) -> GetTenantResponse:
|
||||
tenant = request.name
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
return GetTenantResponse(
|
||||
tenant=proto.Tenant(name=tenant),
|
||||
)
|
||||
|
||||
# We are forced to use check_signature=False because the generated proto code
|
||||
# does not have type annotations for the request and response objects.
|
||||
# TODO: investigate generating types for the request and response objects
|
||||
@overrides(check_signature=False)
|
||||
def CreateSegment(
|
||||
self, request: CreateSegmentRequest, context: grpc.ServicerContext
|
||||
) -> CreateSegmentResponse:
|
||||
segment = from_proto_segment(request.segment)
|
||||
return self.CreateSegmentHelper(segment, context)
|
||||
|
||||
def CreateSegmentHelper(
|
||||
self, segment: Segment, context: grpc.ServicerContext
|
||||
) -> CreateSegmentResponse:
|
||||
if segment["id"].hex in self._segments:
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Segment {segment['id']} already exists",
|
||||
)
|
||||
self._segments[segment["id"].hex] = segment
|
||||
return CreateSegmentResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def DeleteSegment(
|
||||
self, request: DeleteSegmentRequest, context: grpc.ServicerContext
|
||||
) -> DeleteSegmentResponse:
|
||||
id_to_delete = request.id
|
||||
if id_to_delete in self._segments:
|
||||
del self._segments[id_to_delete]
|
||||
return DeleteSegmentResponse()
|
||||
else:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Segment {id_to_delete} not found"
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetSegments(
|
||||
self, request: GetSegmentsRequest, context: grpc.ServicerContext
|
||||
) -> GetSegmentsResponse:
|
||||
target_id = UUID(hex=request.id) if request.HasField("id") else None
|
||||
target_type = request.type if request.HasField("type") else None
|
||||
target_scope = (
|
||||
from_proto_segment_scope(request.scope)
|
||||
if request.HasField("scope")
|
||||
else None
|
||||
)
|
||||
target_collection = UUID(hex=request.collection)
|
||||
|
||||
found_segments = []
|
||||
for segment in self._segments.values():
|
||||
if target_id and segment["id"] != target_id:
|
||||
continue
|
||||
if target_type and segment["type"] != target_type:
|
||||
continue
|
||||
if target_scope and segment["scope"] != target_scope:
|
||||
continue
|
||||
if target_collection and segment["collection"] != target_collection:
|
||||
continue
|
||||
found_segments.append(segment)
|
||||
return GetSegmentsResponse(
|
||||
segments=[to_proto_segment(segment) for segment in found_segments]
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def UpdateSegment(
|
||||
self, request: UpdateSegmentRequest, context: grpc.ServicerContext
|
||||
) -> UpdateSegmentResponse:
|
||||
id_to_update = UUID(request.id)
|
||||
if id_to_update.hex not in self._segments:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Segment {id_to_update} not found"
|
||||
)
|
||||
else:
|
||||
segment = self._segments[id_to_update.hex]
|
||||
if request.HasField("metadata"):
|
||||
target = cast(Dict[str, Any], segment["metadata"])
|
||||
if segment["metadata"] is None:
|
||||
segment["metadata"] = {}
|
||||
self._merge_metadata(target, request.metadata)
|
||||
if request.HasField("reset_metadata") and request.reset_metadata:
|
||||
segment["metadata"] = {}
|
||||
return UpdateSegmentResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def CreateCollection(
|
||||
self, request: CreateCollectionRequest, context: grpc.ServicerContext
|
||||
) -> CreateCollectionResponse:
|
||||
collection_name = request.name
|
||||
tenant = request.tenant
|
||||
database = request.database
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
if database not in self._tenants_to_databases_to_collections[tenant]:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
|
||||
|
||||
# Check if the collection already exists globally by id
|
||||
for (
|
||||
search_tenant,
|
||||
databases,
|
||||
) in self._tenants_to_databases_to_collections.items():
|
||||
for search_database, search_collections in databases.items():
|
||||
if request.id in search_collections:
|
||||
if (
|
||||
search_tenant != request.tenant
|
||||
or search_database != request.database
|
||||
):
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
|
||||
)
|
||||
elif not request.get_or_create:
|
||||
# If the id exists for this tenant and database, and we are not doing a get_or_create, then
|
||||
# we should return an already exists error
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
|
||||
)
|
||||
|
||||
# Check if the collection already exists in this database by name
|
||||
collections = self._tenants_to_databases_to_collections[tenant][database]
|
||||
matches = [c for c in collections.values() if c["name"] == collection_name]
|
||||
assert len(matches) <= 1
|
||||
if len(matches) > 0:
|
||||
if request.get_or_create:
|
||||
existing_collection = matches[0]
|
||||
return CreateCollectionResponse(
|
||||
collection=to_proto_collection(existing_collection),
|
||||
created=False,
|
||||
)
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Collection {collection_name} already exists",
|
||||
)
|
||||
|
||||
configuration_json = json.loads(request.configuration_json_str)
|
||||
|
||||
id = UUID(hex=request.id)
|
||||
new_collection = Collection(
|
||||
id=id,
|
||||
name=request.name,
|
||||
configuration_json=configuration_json,
|
||||
serialized_schema=None,
|
||||
metadata=from_proto_metadata(request.metadata),
|
||||
dimension=request.dimension,
|
||||
database=database,
|
||||
tenant=tenant,
|
||||
version=0,
|
||||
)
|
||||
|
||||
# Check that segments are unique and do not already exist
|
||||
# Keep a track of the segments that are being added
|
||||
segments_added = []
|
||||
# Create segments for the collection
|
||||
for segment_proto in request.segments:
|
||||
segment = from_proto_segment(segment_proto)
|
||||
if segment["id"].hex in self._segments:
|
||||
# Remove the already added segment since we need to roll back
|
||||
for s in segments_added:
|
||||
self.DeleteSegment(DeleteSegmentRequest(id=s), context)
|
||||
context.abort(
|
||||
grpc.StatusCode.ALREADY_EXISTS,
|
||||
f"Segment {segment['id']} already exists",
|
||||
)
|
||||
self.CreateSegmentHelper(segment, context)
|
||||
segments_added.append(segment["id"].hex)
|
||||
|
||||
collections[request.id] = new_collection
|
||||
collection_unique_key = f"{tenant}:{database}:{request.id}"
|
||||
self._collection_to_segments[collection_unique_key] = segments_added
|
||||
return CreateCollectionResponse(
|
||||
collection=to_proto_collection(new_collection),
|
||||
created=True,
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def DeleteCollection(
|
||||
self, request: DeleteCollectionRequest, context: grpc.ServicerContext
|
||||
) -> DeleteCollectionResponse:
|
||||
collection_id = request.id
|
||||
tenant = request.tenant
|
||||
database = request.database
|
||||
if tenant not in self._tenants_to_databases_to_collections:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
|
||||
if database not in self._tenants_to_databases_to_collections[tenant]:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
|
||||
collections = self._tenants_to_databases_to_collections[tenant][database]
|
||||
if collection_id in collections:
|
||||
del collections[collection_id]
|
||||
collection_unique_key = f"{tenant}:{database}:{collection_id}"
|
||||
segment_ids = self._collection_to_segments[collection_unique_key]
|
||||
if segment_ids: # Delete segments if provided.
|
||||
for segment_id in segment_ids:
|
||||
del self._segments[segment_id]
|
||||
return DeleteCollectionResponse()
|
||||
else:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Collection {collection_id} not found"
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetCollections(
|
||||
self, request: GetCollectionsRequest, context: grpc.ServicerContext
|
||||
) -> GetCollectionsResponse:
|
||||
target_id = UUID(hex=request.id) if request.HasField("id") else None
|
||||
target_name = request.name if request.HasField("name") else None
|
||||
|
||||
allCollections = {}
|
||||
for tenant, databases in self._tenants_to_databases_to_collections.items():
|
||||
for database, collections in databases.items():
|
||||
if request.tenant != "" and tenant != request.tenant:
|
||||
continue
|
||||
if request.database != "" and database != request.database:
|
||||
continue
|
||||
allCollections.update(collections)
|
||||
print(
|
||||
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
|
||||
)
|
||||
found_collections = []
|
||||
for collection in allCollections.values():
|
||||
if target_id and collection["id"] != target_id:
|
||||
continue
|
||||
if target_name and collection["name"] != target_name:
|
||||
continue
|
||||
found_collections.append(collection)
|
||||
return GetCollectionsResponse(
|
||||
collections=[
|
||||
to_proto_collection(collection) for collection in found_collections
|
||||
]
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def CountCollections(
|
||||
self, request: CountCollectionsRequest, context: grpc.ServicerContext
|
||||
) -> CountCollectionsResponse:
|
||||
request = GetCollectionsRequest(
|
||||
tenant=request.tenant,
|
||||
database=request.database,
|
||||
)
|
||||
collections = self.GetCollections(request, context)
|
||||
return CountCollectionsResponse(count=len(collections.collections))
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetCollectionSize(
|
||||
self, request: GetCollectionSizeRequest, context: grpc.ServicerContext
|
||||
) -> GetCollectionSizeResponse:
|
||||
return GetCollectionSizeResponse(
|
||||
total_records_post_compaction=0,
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def GetCollectionWithSegments(
|
||||
self, request: GetCollectionWithSegmentsRequest, context: grpc.ServicerContext
|
||||
) -> GetCollectionWithSegmentsResponse:
|
||||
allCollections = {}
|
||||
for tenant, databases in self._tenants_to_databases_to_collections.items():
|
||||
for database, collections in databases.items():
|
||||
allCollections.update(collections)
|
||||
print(
|
||||
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
|
||||
)
|
||||
collection = allCollections.get(request.id, None)
|
||||
if collection is None:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Collection with id {request.id} not found"
|
||||
)
|
||||
collection_unique_key = (
|
||||
f"{collection.tenant}:{collection.database}:{request.id}"
|
||||
)
|
||||
segments = [
|
||||
self._segments[id]
|
||||
for id in self._collection_to_segments[collection_unique_key]
|
||||
]
|
||||
if {segment["scope"] for segment in segments} != {
|
||||
SegmentScope.METADATA,
|
||||
SegmentScope.RECORD,
|
||||
SegmentScope.VECTOR,
|
||||
}:
|
||||
context.abort(
|
||||
grpc.StatusCode.INTERNAL,
|
||||
f"Incomplete segments for collection {collection}: {segments}",
|
||||
)
|
||||
|
||||
return GetCollectionWithSegmentsResponse(
|
||||
collection=to_proto_collection(collection),
|
||||
segments=[to_proto_segment(segment) for segment in segments],
|
||||
)
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def UpdateCollection(
|
||||
self, request: UpdateCollectionRequest, context: grpc.ServicerContext
|
||||
) -> UpdateCollectionResponse:
|
||||
id_to_update = UUID(request.id)
|
||||
# Find the collection with this id
|
||||
collections = {}
|
||||
for tenant, databases in self._tenants_to_databases_to_collections.items():
|
||||
for database, maybe_collections in databases.items():
|
||||
if id_to_update.hex in maybe_collections:
|
||||
collections = maybe_collections
|
||||
|
||||
if id_to_update.hex not in collections:
|
||||
context.abort(
|
||||
grpc.StatusCode.NOT_FOUND, f"Collection {id_to_update} not found"
|
||||
)
|
||||
else:
|
||||
collection = collections[id_to_update.hex]
|
||||
if request.HasField("name"):
|
||||
collection["name"] = request.name
|
||||
if request.HasField("dimension"):
|
||||
collection["dimension"] = request.dimension
|
||||
if request.HasField("metadata"):
|
||||
# TODO: IN SysDB SQlite we have technical debt where we
|
||||
# replace the entire metadata dict with the new one. We should
|
||||
# fix that by merging it. For now we just do the same thing here
|
||||
|
||||
update_metadata = from_proto_update_metadata(request.metadata)
|
||||
cleaned_metadata = None
|
||||
if update_metadata is not None:
|
||||
cleaned_metadata = {}
|
||||
for key, value in update_metadata.items():
|
||||
if value is not None:
|
||||
cleaned_metadata[key] = value
|
||||
|
||||
collection["metadata"] = cleaned_metadata
|
||||
elif request.HasField("reset_metadata"):
|
||||
if request.reset_metadata:
|
||||
collection["metadata"] = {}
|
||||
|
||||
return UpdateCollectionResponse()
|
||||
|
||||
@overrides(check_signature=False)
|
||||
def ResetState(
|
||||
self, request: Empty, context: grpc.ServicerContext
|
||||
) -> ResetStateResponse:
|
||||
self.reset_state()
|
||||
return ResetStateResponse()
|
||||
|
||||
def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None:
|
||||
target_metadata = cast(Dict[str, Any], target)
|
||||
source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source))
|
||||
target_metadata.update(source_metadata)
|
||||
# If a key has a None value, remove it from the metadata
|
||||
for key, value in source_metadata.items():
|
||||
if value is None and key in target:
|
||||
del target_metadata[key]
|
||||
@@ -0,0 +1,273 @@
|
||||
import logging
|
||||
from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool
|
||||
from chromadb.db.migrations import MigratableDB, Migration
|
||||
from chromadb.config import System, Settings
|
||||
import chromadb.db.base as base
|
||||
from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue
|
||||
from chromadb.db.mixins.sysdb import SqlSysDB
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
import sqlite3
|
||||
from overrides import override
|
||||
import pypika
|
||||
from typing import Sequence, cast, Optional, Type, Any
|
||||
from typing_extensions import Literal
|
||||
from types import TracebackType
|
||||
import os
|
||||
from uuid import UUID
|
||||
from threading import local
|
||||
from importlib_resources import files
|
||||
from importlib_resources.abc import Traversable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TxWrapper(base.TxWrapper):
|
||||
_conn: Connection
|
||||
_pool: Pool
|
||||
|
||||
def __init__(self, conn_pool: Pool, stack: local):
|
||||
self._tx_stack = stack
|
||||
self._conn = conn_pool.connect()
|
||||
self._pool = conn_pool
|
||||
|
||||
@override
|
||||
def __enter__(self) -> base.Cursor:
|
||||
if len(self._tx_stack.stack) == 0:
|
||||
self._conn.execute("PRAGMA case_sensitive_like = ON")
|
||||
self._conn.execute("BEGIN;")
|
||||
self._tx_stack.stack.append(self)
|
||||
return self._conn.cursor() # type: ignore
|
||||
|
||||
@override
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
self._tx_stack.stack.pop()
|
||||
if len(self._tx_stack.stack) == 0:
|
||||
if exc_type is None:
|
||||
self._conn.commit()
|
||||
else:
|
||||
self._conn.rollback()
|
||||
self._conn.cursor().close()
|
||||
self._pool.return_to_pool(self._conn)
|
||||
return False
|
||||
|
||||
|
||||
class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB):
|
||||
_conn_pool: Pool
|
||||
_settings: Settings
|
||||
_migration_imports: Sequence[Traversable]
|
||||
_db_file: str
|
||||
_tx_stack: local
|
||||
_is_persistent: bool
|
||||
|
||||
def __init__(self, system: System):
|
||||
self._settings = system.settings
|
||||
self._migration_imports = [
|
||||
files("chromadb.migrations.embeddings_queue"),
|
||||
files("chromadb.migrations.sysdb"),
|
||||
files("chromadb.migrations.metadb"),
|
||||
]
|
||||
self._is_persistent = self._settings.require("is_persistent")
|
||||
self._opentelemetry_client = system.require(OpenTelemetryClient)
|
||||
if not self._is_persistent:
|
||||
# In order to allow sqlite to be shared between multiple threads, we need to use a
|
||||
# URI connection string with shared cache.
|
||||
# See https://www.sqlite.org/sharedcache.html
|
||||
# https://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa
|
||||
self._db_file = "file::memory:?cache=shared"
|
||||
self._conn_pool = LockPool(self._db_file, is_uri=True)
|
||||
else:
|
||||
self._db_file = (
|
||||
self._settings.require("persist_directory") + "/chroma.sqlite3"
|
||||
)
|
||||
if not os.path.exists(self._db_file):
|
||||
os.makedirs(os.path.dirname(self._db_file), exist_ok=True)
|
||||
self._conn_pool = PerThreadPool(self._db_file)
|
||||
self._tx_stack = local()
|
||||
super().__init__(system)
|
||||
|
||||
@trace_method("SqliteDB.start", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
with self.tx() as cur:
|
||||
cur.execute("PRAGMA foreign_keys = ON")
|
||||
cur.execute("PRAGMA case_sensitive_like = ON")
|
||||
self.initialize_migrations()
|
||||
|
||||
if (
|
||||
# (don't attempt to access .config if migrations haven't been run)
|
||||
self._settings.require("migrations") == "apply"
|
||||
and self.config.get_parameter("automatically_purge").value is False
|
||||
):
|
||||
logger.warning(
|
||||
"⚠️ It looks like you upgraded from a version below 0.5.6 and could benefit from vacuuming your database. Run chromadb utils vacuum --help for more information."
|
||||
)
|
||||
|
||||
@trace_method("SqliteDB.stop", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
self._conn_pool.close()
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def querybuilder() -> Type[pypika.Query]:
|
||||
return pypika.Query # type: ignore
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def parameter_format() -> str:
|
||||
return "?"
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def migration_scope() -> str:
|
||||
return "sqlite"
|
||||
|
||||
@override
|
||||
def migration_dirs(self) -> Sequence[Traversable]:
|
||||
return self._migration_imports
|
||||
|
||||
@override
|
||||
def tx(self) -> TxWrapper:
|
||||
if not hasattr(self._tx_stack, "stack"):
|
||||
self._tx_stack.stack = []
|
||||
return TxWrapper(self._conn_pool, stack=self._tx_stack)
|
||||
|
||||
@trace_method("SqliteDB.reset_state", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def reset_state(self) -> None:
|
||||
if not self._settings.require("allow_reset"):
|
||||
raise ValueError(
|
||||
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
|
||||
)
|
||||
with self.tx() as cur:
|
||||
# Drop all tables
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table'
|
||||
"""
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
cur.execute(f"DROP TABLE IF EXISTS {row[0]}")
|
||||
self._conn_pool.close()
|
||||
self.start()
|
||||
super().reset_state()
|
||||
|
||||
@trace_method("SqliteDB.setup_migrations", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def setup_migrations(self) -> None:
|
||||
with self.tx() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
dir TEXT NOT NULL,
|
||||
version INTEGER NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
sql TEXT NOT NULL,
|
||||
hash TEXT NOT NULL,
|
||||
PRIMARY KEY (dir, version)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
@trace_method("SqliteDB.migrations_initialized", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def migrations_initialized(self) -> bool:
|
||||
with self.tx() as cur:
|
||||
cur.execute(
|
||||
"""SELECT count(*) FROM sqlite_master
|
||||
WHERE type='table' AND name='migrations'"""
|
||||
)
|
||||
|
||||
if cur.fetchone()[0] == 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@trace_method("SqliteDB.db_migrations", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def db_migrations(self, dir: Traversable) -> Sequence[Migration]:
|
||||
with self.tx() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT dir, version, filename, sql, hash
|
||||
FROM migrations
|
||||
WHERE dir = ?
|
||||
ORDER BY version ASC
|
||||
""",
|
||||
(dir.name,),
|
||||
)
|
||||
|
||||
migrations = []
|
||||
for row in cur.fetchall():
|
||||
found_dir = cast(str, row[0])
|
||||
found_version = cast(int, row[1])
|
||||
found_filename = cast(str, row[2])
|
||||
found_sql = cast(str, row[3])
|
||||
found_hash = cast(str, row[4])
|
||||
migrations.append(
|
||||
Migration(
|
||||
dir=found_dir,
|
||||
version=found_version,
|
||||
filename=found_filename,
|
||||
sql=found_sql,
|
||||
hash=found_hash,
|
||||
scope=self.migration_scope(),
|
||||
)
|
||||
)
|
||||
return migrations
|
||||
|
||||
@override
|
||||
def apply_migration(self, cur: base.Cursor, migration: Migration) -> None:
|
||||
cur.executescript(migration["sql"])
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO migrations (dir, version, filename, sql, hash)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
migration["dir"],
|
||||
migration["version"],
|
||||
migration["filename"],
|
||||
migration["sql"],
|
||||
migration["hash"],
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]:
|
||||
return UUID(value) if value is not None else None
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]:
|
||||
return str(uuid) if uuid is not None else None
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def unique_constraint_error() -> Type[BaseException]:
|
||||
return sqlite3.IntegrityError
|
||||
|
||||
def vacuum(self, timeout: int = 5) -> None:
|
||||
"""Runs VACUUM on the database. `timeout` is the maximum time to wait for an exclusive lock in seconds."""
|
||||
conn = self._conn_pool.connect()
|
||||
conn.execute(f"PRAGMA busy_timeout = {int(timeout) * 1000}")
|
||||
conn.execute("VACUUM")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO maintenance_log (operation, timestamp)
|
||||
VALUES ('vacuum', CURRENT_TIMESTAMP)
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,163 @@
|
||||
import sqlite3
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Set
|
||||
import threading
|
||||
from overrides import override
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class Connection:
|
||||
"""A threadpool connection that returns itself to the pool on close()"""
|
||||
|
||||
_pool: "Pool"
|
||||
_db_file: str
|
||||
_conn: sqlite3.Connection
|
||||
|
||||
def __init__(
|
||||
self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any
|
||||
):
|
||||
self._pool = pool
|
||||
self._db_file = db_file
|
||||
self._conn = sqlite3.connect(
|
||||
db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs
|
||||
) # type: ignore
|
||||
self._conn.isolation_level = None # Handle commits explicitly
|
||||
|
||||
def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: # type: ignore
|
||||
if parameters is ...:
|
||||
return self._conn.execute(sql)
|
||||
return self._conn.execute(sql, parameters)
|
||||
|
||||
def commit(self) -> None:
|
||||
self._conn.commit()
|
||||
|
||||
def rollback(self) -> None:
|
||||
self._conn.rollback()
|
||||
|
||||
def cursor(self) -> sqlite3.Cursor:
|
||||
return self._conn.cursor()
|
||||
|
||||
def close_actual(self) -> None:
|
||||
"""Actually closes the connection to the db"""
|
||||
self._conn.close()
|
||||
|
||||
|
||||
class Pool(ABC):
|
||||
"""Abstract base class for a pool of connections to a sqlite database."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, db_file: str, is_uri: bool) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, *args: Any, **kwargs: Any) -> Connection:
|
||||
"""Return a connection from the pool."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def return_to_pool(self, conn: Connection) -> None:
|
||||
"""Return a connection to the pool."""
|
||||
pass
|
||||
|
||||
|
||||
class LockPool(Pool):
|
||||
"""A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time.
|
||||
This is used because sqlite does not support multithreaded access with connection timeouts when using the
|
||||
shared cache mode. We use the shared cache mode to allow multiple threads to share a database.
|
||||
"""
|
||||
|
||||
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
|
||||
_lock: threading.RLock
|
||||
_connection: threading.local
|
||||
_db_file: str
|
||||
_is_uri: bool
|
||||
|
||||
def __init__(self, db_file: str, is_uri: bool = False):
|
||||
self._connections = set()
|
||||
self._connection = threading.local()
|
||||
self._lock = threading.RLock()
|
||||
self._db_file = db_file
|
||||
self._is_uri = is_uri
|
||||
|
||||
@override
|
||||
def connect(self, *args: Any, **kwargs: Any) -> Connection:
|
||||
self._lock.acquire()
|
||||
if hasattr(self._connection, "conn") and self._connection.conn is not None:
|
||||
return self._connection.conn # type: ignore # cast doesn't work here for some reason
|
||||
else:
|
||||
new_connection = Connection(
|
||||
self, self._db_file, self._is_uri, *args, **kwargs
|
||||
)
|
||||
self._connection.conn = new_connection
|
||||
self._connections.add(weakref.ref(new_connection))
|
||||
return new_connection
|
||||
|
||||
@override
|
||||
def return_to_pool(self, conn: Connection) -> None:
|
||||
try:
|
||||
self._lock.release()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
for conn in self._connections:
|
||||
if conn() is not None:
|
||||
conn().close_actual() # type: ignore
|
||||
self._connections.clear()
|
||||
self._connection = threading.local()
|
||||
try:
|
||||
self._lock.release()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
class PerThreadPool(Pool):
|
||||
"""Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be
|
||||
extended to do so and block on connect() if the cap is reached.
|
||||
"""
|
||||
|
||||
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
|
||||
_lock: threading.Lock
|
||||
_connection: threading.local
|
||||
_db_file: str
|
||||
_is_uri_: bool
|
||||
|
||||
def __init__(self, db_file: str, is_uri: bool = False):
|
||||
self._connections = set()
|
||||
self._connection = threading.local()
|
||||
self._lock = threading.Lock()
|
||||
self._db_file = db_file
|
||||
self._is_uri = is_uri
|
||||
|
||||
@override
|
||||
def connect(self, *args: Any, **kwargs: Any) -> Connection:
|
||||
if hasattr(self._connection, "conn") and self._connection.conn is not None:
|
||||
return self._connection.conn # type: ignore # cast doesn't work here for some reason
|
||||
else:
|
||||
new_connection = Connection(
|
||||
self, self._db_file, self._is_uri, *args, **kwargs
|
||||
)
|
||||
self._connection.conn = new_connection
|
||||
with self._lock:
|
||||
self._connections.add(weakref.ref(new_connection))
|
||||
return new_connection
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
for conn in self._connections:
|
||||
if conn() is not None:
|
||||
conn().close_actual() # type: ignore
|
||||
self._connections.clear()
|
||||
self._connection = threading.local()
|
||||
|
||||
@override
|
||||
def return_to_pool(self, conn: Connection) -> None:
|
||||
pass # Each thread gets its own connection, so we don't need to return it to the pool
|
||||
Reference in New Issue
Block a user