chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,558 @@
from typing import List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID
from overrides import overrides
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
create_collection_configuration_to_json_str,
UpdateCollectionConfiguration,
update_collection_configuration_to_json_str,
CollectionMetadata,
)
from chromadb.api.types import Schema
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, logger
from chromadb.db.system import SysDB
from chromadb.errors import NotFoundError, UniqueConstraintError, InternalError
from chromadb.proto.convert import (
from_proto_collection,
from_proto_segment,
to_proto_update_metadata,
to_proto_segment,
to_proto_segment_scope,
)
from chromadb.proto.coordinator_pb2 import (
CreateCollectionRequest,
CreateDatabaseRequest,
CreateSegmentRequest,
CreateTenantRequest,
CountCollectionsRequest,
CountCollectionsResponse,
DeleteCollectionRequest,
DeleteDatabaseRequest,
DeleteSegmentRequest,
GetCollectionsRequest,
GetCollectionsResponse,
GetCollectionSizeRequest,
GetCollectionSizeResponse,
GetCollectionWithSegmentsRequest,
GetCollectionWithSegmentsResponse,
GetDatabaseRequest,
GetSegmentsRequest,
GetTenantRequest,
ListDatabasesRequest,
UpdateCollectionRequest,
UpdateSegmentRequest,
)
from chromadb.proto.coordinator_pb2_grpc import SysDBStub
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import (
Collection,
CollectionAndSegments,
Database,
Metadata,
OptionalArgument,
Segment,
SegmentScope,
Tenant,
Unspecified,
UpdateMetadata,
)
from google.protobuf.empty_pb2 import Empty
import grpc
class GrpcSysDB(SysDB):
"""A gRPC implementation of the SysDB. In the distributed system, the SysDB is also
called the 'Coordinator'. This implementation is used by Chroma frontend servers
to call a remote SysDB (Coordinator) service."""
_sys_db_stub: SysDBStub
_channel: grpc.Channel
_coordinator_url: str
_coordinator_port: int
_request_timeout_seconds: int
def __init__(self, system: System):
self._coordinator_url = system.settings.require("chroma_coordinator_host")
# TODO: break out coordinator_port into a separate setting?
self._coordinator_port = system.settings.require("chroma_server_grpc_port")
self._request_timeout_seconds = system.settings.require(
"chroma_sysdb_request_timeout_seconds"
)
return super().__init__(system)
@overrides
def start(self) -> None:
self._channel = grpc.insecure_channel(
f"{self._coordinator_url}:{self._coordinator_port}",
options=[("grpc.max_concurrent_streams", 1000)],
)
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
self._channel = grpc.intercept_channel(self._channel, *interceptors)
self._sys_db_stub = SysDBStub(self._channel) # type: ignore
return super().start()
@overrides
def stop(self) -> None:
self._channel.close()
return super().stop()
@overrides
def reset_state(self) -> None:
self._sys_db_stub.ResetState(Empty())
return super().reset_state()
@overrides
def create_database(
self, id: UUID, name: str, tenant: str = DEFAULT_TENANT
) -> None:
try:
request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant)
response = self._sys_db_stub.CreateDatabase(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(
f"Failed to create database name {name} and database id {id} for tenant {tenant} due to error: {e}"
)
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
@overrides
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
try:
request = GetDatabaseRequest(name=name, tenant=tenant)
response = self._sys_db_stub.GetDatabase(
request, timeout=self._request_timeout_seconds
)
return Database(
id=UUID(hex=response.database.id),
name=response.database.name,
tenant=response.database.tenant,
)
except grpc.RpcError as e:
logger.info(
f"Failed to get database {name} for tenant {tenant} due to error: {e}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError()
@overrides
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
try:
request = DeleteDatabaseRequest(name=name, tenant=tenant)
self._sys_db_stub.DeleteDatabase(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(
f"Failed to delete database {name} for tenant {tenant} due to error: {e}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError
@overrides
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
try:
request = ListDatabasesRequest(limit=limit, offset=offset, tenant=tenant)
response = self._sys_db_stub.ListDatabases(
request, timeout=self._request_timeout_seconds
)
results: List[Database] = []
for proto_database in response.databases:
results.append(
Database(
id=UUID(hex=proto_database.id),
name=proto_database.name,
tenant=proto_database.tenant,
)
)
return results
except grpc.RpcError as e:
logger.info(
f"Failed to list databases for tenant {tenant} due to error: {e}"
)
raise InternalError()
@overrides
def create_tenant(self, name: str) -> None:
try:
request = CreateTenantRequest(name=name)
response = self._sys_db_stub.CreateTenant(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(f"Failed to create tenant {name} due to error: {e}")
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
@overrides
def get_tenant(self, name: str) -> Tenant:
try:
request = GetTenantRequest(name=name)
response = self._sys_db_stub.GetTenant(
request, timeout=self._request_timeout_seconds
)
return Tenant(
name=response.tenant.name,
)
except grpc.RpcError as e:
logger.info(f"Failed to get tenant {name} due to error: {e}")
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError()
@overrides
def create_segment(self, segment: Segment) -> None:
try:
proto_segment = to_proto_segment(segment)
request = CreateSegmentRequest(
segment=proto_segment,
)
response = self._sys_db_stub.CreateSegment(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(f"Failed to create segment {segment}, error: {e}")
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
@overrides
def delete_segment(self, collection: UUID, id: UUID) -> None:
try:
request = DeleteSegmentRequest(
id=id.hex,
collection=collection.hex,
)
response = self._sys_db_stub.DeleteSegment(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(
f"Failed to delete segment with id {id} for collection {collection} due to error: {e}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError()
@overrides
def get_segments(
self,
collection: UUID,
id: Optional[UUID] = None,
type: Optional[str] = None,
scope: Optional[SegmentScope] = None,
) -> Sequence[Segment]:
try:
request = GetSegmentsRequest(
id=id.hex if id else None,
type=type,
scope=to_proto_segment_scope(scope) if scope else None,
collection=collection.hex,
)
response = self._sys_db_stub.GetSegments(
request, timeout=self._request_timeout_seconds
)
results: List[Segment] = []
for proto_segment in response.segments:
segment = from_proto_segment(proto_segment)
results.append(segment)
return results
except grpc.RpcError as e:
logger.info(
f"Failed to get segment id {id}, type {type}, scope {scope} for collection {collection} due to error: {e}"
)
raise InternalError()
@overrides
def update_segment(
self,
collection: UUID,
id: UUID,
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
) -> None:
try:
write_metadata = None
if metadata != Unspecified():
write_metadata = cast(Union[UpdateMetadata, None], metadata)
request = UpdateSegmentRequest(
id=id.hex,
collection=collection.hex,
metadata=to_proto_update_metadata(write_metadata)
if write_metadata
else None,
)
if metadata is None:
request.ClearField("metadata")
request.reset_metadata = True
self._sys_db_stub.UpdateSegment(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.info(
f"Failed to update segment with id {id} for collection {collection}, error: {e}"
)
raise InternalError()
@overrides
def create_collection(
self,
id: UUID,
name: str,
schema: Optional[Schema],
configuration: CreateCollectionConfiguration,
segments: Sequence[Segment],
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Tuple[Collection, bool]:
try:
request = CreateCollectionRequest(
id=id.hex,
name=name,
configuration_json_str=create_collection_configuration_to_json_str(
configuration, cast(CollectionMetadata, metadata)
),
metadata=to_proto_update_metadata(metadata) if metadata else None,
dimension=dimension,
get_or_create=get_or_create,
tenant=tenant,
database=database,
segments=[to_proto_segment(segment) for segment in segments],
)
response = self._sys_db_stub.CreateCollection(
request, timeout=self._request_timeout_seconds
)
collection = from_proto_collection(response.collection)
return collection, response.created
except grpc.RpcError as e:
logger.error(
f"Failed to create collection id {id}, name {name} for database {database} and tenant {tenant} due to error: {e}"
)
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
@overrides
def delete_collection(
self,
id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
try:
request = DeleteCollectionRequest(
id=id.hex,
tenant=tenant,
database=database,
)
response = self._sys_db_stub.DeleteCollection(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
logger.error(
f"Failed to delete collection id {id} for database {database} and tenant {tenant} due to error: {e}"
)
e = cast(grpc.Call, e)
logger.error(
f"Error code: {e.code()}, NotFoundError: {grpc.StatusCode.NOT_FOUND}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError()
@overrides
def get_collections(
self,
id: Optional[UUID] = None,
name: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Collection]:
try:
# TODO: implement limit and offset in the gRPC service
request = None
if id is not None:
request = GetCollectionsRequest(
id=id.hex,
limit=limit,
offset=offset,
)
if name is not None:
if tenant is None and database is None:
raise ValueError(
"If name is specified, tenant and database must also be specified in order to uniquely identify the collection"
)
request = GetCollectionsRequest(
name=name,
tenant=tenant,
database=database,
limit=limit,
offset=offset,
)
if id is None and name is None:
request = GetCollectionsRequest(
tenant=tenant,
database=database,
limit=limit,
offset=offset,
)
response: GetCollectionsResponse = self._sys_db_stub.GetCollections(
request, timeout=self._request_timeout_seconds
)
results: List[Collection] = []
for collection in response.collections:
results.append(from_proto_collection(collection))
return results
except grpc.RpcError as e:
logger.error(
f"Failed to get collections with id {id}, name {name}, tenant {tenant}, database {database} due to error: {e}"
)
raise InternalError()
@overrides
def count_collections(
self,
tenant: str = DEFAULT_TENANT,
database: Optional[str] = None,
) -> int:
try:
if database is None or database == "":
request = CountCollectionsRequest(tenant=tenant)
response: CountCollectionsResponse = self._sys_db_stub.CountCollections(
request
)
return response.count
else:
request = CountCollectionsRequest(
tenant=tenant,
database=database,
)
response: CountCollectionsResponse = self._sys_db_stub.CountCollections(
request
)
return response.count
except grpc.RpcError as e:
logger.error(f"Failed to count collections due to error: {e}")
raise InternalError()
@overrides
def get_collection_size(self, id: UUID) -> int:
try:
request = GetCollectionSizeRequest(id=id.hex)
response: GetCollectionSizeResponse = self._sys_db_stub.GetCollectionSize(
request
)
return response.total_records_post_compaction
except grpc.RpcError as e:
logger.error(f"Failed to get collection {id} size due to error: {e}")
raise InternalError()
@trace_method(
"SysDB.get_collection_with_segments", OpenTelemetryGranularity.OPERATION
)
@overrides
def get_collection_with_segments(
self, collection_id: UUID
) -> CollectionAndSegments:
try:
request = GetCollectionWithSegmentsRequest(id=collection_id.hex)
response: GetCollectionWithSegmentsResponse = (
self._sys_db_stub.GetCollectionWithSegments(request)
)
return CollectionAndSegments(
collection=from_proto_collection(response.collection),
segments=[from_proto_segment(segment) for segment in response.segments],
)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
logger.error(
f"Failed to get collection {collection_id} and its segments due to error: {e}"
)
raise InternalError()
@overrides
def update_collection(
self,
id: UUID,
name: OptionalArgument[str] = Unspecified(),
dimension: OptionalArgument[Optional[int]] = Unspecified(),
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
configuration: OptionalArgument[
Optional[UpdateCollectionConfiguration]
] = Unspecified(),
) -> None:
try:
write_name = None
if name != Unspecified():
write_name = cast(str, name)
write_dimension = None
if dimension != Unspecified():
write_dimension = cast(Union[int, None], dimension)
write_metadata = None
if metadata != Unspecified():
write_metadata = cast(Union[UpdateMetadata, None], metadata)
write_configuration = None
if configuration != Unspecified():
write_configuration = cast(
Union[UpdateCollectionConfiguration, None], configuration
)
request = UpdateCollectionRequest(
id=id.hex,
name=write_name,
dimension=write_dimension,
metadata=to_proto_update_metadata(write_metadata)
if write_metadata
else None,
configuration_json_str=update_collection_configuration_to_json_str(
write_configuration
)
if write_configuration
else None,
)
if metadata is None:
request.ClearField("metadata")
request.reset_metadata = True
response = self._sys_db_stub.UpdateCollection(
request, timeout=self._request_timeout_seconds
)
except grpc.RpcError as e:
e = cast(grpc.Call, e)
logger.error(
f"Failed to update collection id {id}, name {name} due to error: {e}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise UniqueConstraintError()
raise InternalError()
def reset_and_wait_for_ready(self) -> None:
self._sys_db_stub.ResetState(Empty(), wait_for_ready=True)

View File

@@ -0,0 +1,497 @@
from concurrent import futures
from typing import Any, Dict, List, cast
from uuid import UUID
from overrides import overrides
import json
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System
from chromadb.proto.convert import (
from_proto_metadata,
from_proto_update_metadata,
from_proto_segment,
from_proto_segment_scope,
to_proto_collection,
to_proto_segment,
)
import chromadb.proto.chroma_pb2 as proto
from chromadb.proto.coordinator_pb2 import (
CreateCollectionRequest,
CreateCollectionResponse,
CreateDatabaseRequest,
CreateDatabaseResponse,
CreateSegmentRequest,
CreateSegmentResponse,
CreateTenantRequest,
CreateTenantResponse,
CountCollectionsRequest,
CountCollectionsResponse,
DeleteCollectionRequest,
DeleteCollectionResponse,
DeleteSegmentRequest,
DeleteSegmentResponse,
GetCollectionsRequest,
GetCollectionsResponse,
GetCollectionSizeRequest,
GetCollectionSizeResponse,
GetCollectionWithSegmentsRequest,
GetCollectionWithSegmentsResponse,
GetDatabaseRequest,
GetDatabaseResponse,
GetSegmentsRequest,
GetSegmentsResponse,
GetTenantRequest,
GetTenantResponse,
ResetStateResponse,
UpdateCollectionRequest,
UpdateCollectionResponse,
UpdateSegmentRequest,
UpdateSegmentResponse,
)
from chromadb.proto.coordinator_pb2_grpc import (
SysDBServicer,
add_SysDBServicer_to_server,
)
import grpc
from google.protobuf.empty_pb2 import Empty
from chromadb.types import Collection, Metadata, Segment, SegmentScope
class GrpcMockSysDB(SysDBServicer, Component):
"""A mock sysdb implementation that can be used for testing the grpc client. It stores
state in simple python data structures instead of a database."""
_server: grpc.Server
_server_port: int
_segments: Dict[str, Segment] = {}
_collection_to_segments: Dict[str, List[str]] = {}
_tenants_to_databases_to_collections: Dict[
str, Dict[str, Dict[str, Collection]]
] = {}
_tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {}
def __init__(self, system: System):
self._server_port = system.settings.require("chroma_server_grpc_port")
return super().__init__(system)
@overrides
def start(self) -> None:
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
add_SysDBServicer_to_server(self, self._server) # type: ignore
self._server.add_insecure_port(f"[::]:{self._server_port}")
self._server.start()
return super().start()
@overrides
def stop(self) -> None:
self._server.stop(None)
return super().stop()
@overrides
def reset_state(self) -> None:
self._segments = {}
self._tenants_to_databases_to_collections = {}
# Create defaults
self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {}
self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {}
self._tenants_to_database_to_id[DEFAULT_TENANT] = {}
self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0)
return super().reset_state()
@overrides(check_signature=False)
def CreateDatabase(
self, request: CreateDatabaseRequest, context: grpc.ServicerContext
) -> CreateDatabaseResponse:
tenant = request.tenant
database = request.name
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
if database in self._tenants_to_databases_to_collections[tenant]:
context.abort(
grpc.StatusCode.ALREADY_EXISTS, f"Database {database} already exists"
)
self._tenants_to_databases_to_collections[tenant][database] = {}
self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id)
return CreateDatabaseResponse()
@overrides(check_signature=False)
def GetDatabase(
self, request: GetDatabaseRequest, context: grpc.ServicerContext
) -> GetDatabaseResponse:
tenant = request.tenant
database = request.name
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
if database not in self._tenants_to_databases_to_collections[tenant]:
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
id = self._tenants_to_database_to_id[tenant][database]
return GetDatabaseResponse(
database=proto.Database(id=id.hex, name=database, tenant=tenant),
)
@overrides(check_signature=False)
def CreateTenant(
self, request: CreateTenantRequest, context: grpc.ServicerContext
) -> CreateTenantResponse:
tenant = request.name
if tenant in self._tenants_to_databases_to_collections:
context.abort(
grpc.StatusCode.ALREADY_EXISTS, f"Tenant {tenant} already exists"
)
self._tenants_to_databases_to_collections[tenant] = {}
self._tenants_to_database_to_id[tenant] = {}
return CreateTenantResponse()
@overrides(check_signature=False)
def GetTenant(
self, request: GetTenantRequest, context: grpc.ServicerContext
) -> GetTenantResponse:
tenant = request.name
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
return GetTenantResponse(
tenant=proto.Tenant(name=tenant),
)
# We are forced to use check_signature=False because the generated proto code
# does not have type annotations for the request and response objects.
# TODO: investigate generating types for the request and response objects
@overrides(check_signature=False)
def CreateSegment(
self, request: CreateSegmentRequest, context: grpc.ServicerContext
) -> CreateSegmentResponse:
segment = from_proto_segment(request.segment)
return self.CreateSegmentHelper(segment, context)
def CreateSegmentHelper(
self, segment: Segment, context: grpc.ServicerContext
) -> CreateSegmentResponse:
if segment["id"].hex in self._segments:
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Segment {segment['id']} already exists",
)
self._segments[segment["id"].hex] = segment
return CreateSegmentResponse()
@overrides(check_signature=False)
def DeleteSegment(
self, request: DeleteSegmentRequest, context: grpc.ServicerContext
) -> DeleteSegmentResponse:
id_to_delete = request.id
if id_to_delete in self._segments:
del self._segments[id_to_delete]
return DeleteSegmentResponse()
else:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Segment {id_to_delete} not found"
)
@overrides(check_signature=False)
def GetSegments(
self, request: GetSegmentsRequest, context: grpc.ServicerContext
) -> GetSegmentsResponse:
target_id = UUID(hex=request.id) if request.HasField("id") else None
target_type = request.type if request.HasField("type") else None
target_scope = (
from_proto_segment_scope(request.scope)
if request.HasField("scope")
else None
)
target_collection = UUID(hex=request.collection)
found_segments = []
for segment in self._segments.values():
if target_id and segment["id"] != target_id:
continue
if target_type and segment["type"] != target_type:
continue
if target_scope and segment["scope"] != target_scope:
continue
if target_collection and segment["collection"] != target_collection:
continue
found_segments.append(segment)
return GetSegmentsResponse(
segments=[to_proto_segment(segment) for segment in found_segments]
)
@overrides(check_signature=False)
def UpdateSegment(
self, request: UpdateSegmentRequest, context: grpc.ServicerContext
) -> UpdateSegmentResponse:
id_to_update = UUID(request.id)
if id_to_update.hex not in self._segments:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Segment {id_to_update} not found"
)
else:
segment = self._segments[id_to_update.hex]
if request.HasField("metadata"):
target = cast(Dict[str, Any], segment["metadata"])
if segment["metadata"] is None:
segment["metadata"] = {}
self._merge_metadata(target, request.metadata)
if request.HasField("reset_metadata") and request.reset_metadata:
segment["metadata"] = {}
return UpdateSegmentResponse()
@overrides(check_signature=False)
def CreateCollection(
self, request: CreateCollectionRequest, context: grpc.ServicerContext
) -> CreateCollectionResponse:
collection_name = request.name
tenant = request.tenant
database = request.database
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
if database not in self._tenants_to_databases_to_collections[tenant]:
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
# Check if the collection already exists globally by id
for (
search_tenant,
databases,
) in self._tenants_to_databases_to_collections.items():
for search_database, search_collections in databases.items():
if request.id in search_collections:
if (
search_tenant != request.tenant
or search_database != request.database
):
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
)
elif not request.get_or_create:
# If the id exists for this tenant and database, and we are not doing a get_or_create, then
# we should return an already exists error
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
)
# Check if the collection already exists in this database by name
collections = self._tenants_to_databases_to_collections[tenant][database]
matches = [c for c in collections.values() if c["name"] == collection_name]
assert len(matches) <= 1
if len(matches) > 0:
if request.get_or_create:
existing_collection = matches[0]
return CreateCollectionResponse(
collection=to_proto_collection(existing_collection),
created=False,
)
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Collection {collection_name} already exists",
)
configuration_json = json.loads(request.configuration_json_str)
id = UUID(hex=request.id)
new_collection = Collection(
id=id,
name=request.name,
configuration_json=configuration_json,
serialized_schema=None,
metadata=from_proto_metadata(request.metadata),
dimension=request.dimension,
database=database,
tenant=tenant,
version=0,
)
# Check that segments are unique and do not already exist
# Keep a track of the segments that are being added
segments_added = []
# Create segments for the collection
for segment_proto in request.segments:
segment = from_proto_segment(segment_proto)
if segment["id"].hex in self._segments:
# Remove the already added segment since we need to roll back
for s in segments_added:
self.DeleteSegment(DeleteSegmentRequest(id=s), context)
context.abort(
grpc.StatusCode.ALREADY_EXISTS,
f"Segment {segment['id']} already exists",
)
self.CreateSegmentHelper(segment, context)
segments_added.append(segment["id"].hex)
collections[request.id] = new_collection
collection_unique_key = f"{tenant}:{database}:{request.id}"
self._collection_to_segments[collection_unique_key] = segments_added
return CreateCollectionResponse(
collection=to_proto_collection(new_collection),
created=True,
)
@overrides(check_signature=False)
def DeleteCollection(
self, request: DeleteCollectionRequest, context: grpc.ServicerContext
) -> DeleteCollectionResponse:
collection_id = request.id
tenant = request.tenant
database = request.database
if tenant not in self._tenants_to_databases_to_collections:
context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
if database not in self._tenants_to_databases_to_collections[tenant]:
context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
collections = self._tenants_to_databases_to_collections[tenant][database]
if collection_id in collections:
del collections[collection_id]
collection_unique_key = f"{tenant}:{database}:{collection_id}"
segment_ids = self._collection_to_segments[collection_unique_key]
if segment_ids: # Delete segments if provided.
for segment_id in segment_ids:
del self._segments[segment_id]
return DeleteCollectionResponse()
else:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Collection {collection_id} not found"
)
@overrides(check_signature=False)
def GetCollections(
self, request: GetCollectionsRequest, context: grpc.ServicerContext
) -> GetCollectionsResponse:
target_id = UUID(hex=request.id) if request.HasField("id") else None
target_name = request.name if request.HasField("name") else None
allCollections = {}
for tenant, databases in self._tenants_to_databases_to_collections.items():
for database, collections in databases.items():
if request.tenant != "" and tenant != request.tenant:
continue
if request.database != "" and database != request.database:
continue
allCollections.update(collections)
print(
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
)
found_collections = []
for collection in allCollections.values():
if target_id and collection["id"] != target_id:
continue
if target_name and collection["name"] != target_name:
continue
found_collections.append(collection)
return GetCollectionsResponse(
collections=[
to_proto_collection(collection) for collection in found_collections
]
)
@overrides(check_signature=False)
def CountCollections(
self, request: CountCollectionsRequest, context: grpc.ServicerContext
) -> CountCollectionsResponse:
request = GetCollectionsRequest(
tenant=request.tenant,
database=request.database,
)
collections = self.GetCollections(request, context)
return CountCollectionsResponse(count=len(collections.collections))
@overrides(check_signature=False)
def GetCollectionSize(
self, request: GetCollectionSizeRequest, context: grpc.ServicerContext
) -> GetCollectionSizeResponse:
return GetCollectionSizeResponse(
total_records_post_compaction=0,
)
@overrides(check_signature=False)
def GetCollectionWithSegments(
self, request: GetCollectionWithSegmentsRequest, context: grpc.ServicerContext
) -> GetCollectionWithSegmentsResponse:
allCollections = {}
for tenant, databases in self._tenants_to_databases_to_collections.items():
for database, collections in databases.items():
allCollections.update(collections)
print(
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
)
collection = allCollections.get(request.id, None)
if collection is None:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Collection with id {request.id} not found"
)
collection_unique_key = (
f"{collection.tenant}:{collection.database}:{request.id}"
)
segments = [
self._segments[id]
for id in self._collection_to_segments[collection_unique_key]
]
if {segment["scope"] for segment in segments} != {
SegmentScope.METADATA,
SegmentScope.RECORD,
SegmentScope.VECTOR,
}:
context.abort(
grpc.StatusCode.INTERNAL,
f"Incomplete segments for collection {collection}: {segments}",
)
return GetCollectionWithSegmentsResponse(
collection=to_proto_collection(collection),
segments=[to_proto_segment(segment) for segment in segments],
)
@overrides(check_signature=False)
def UpdateCollection(
self, request: UpdateCollectionRequest, context: grpc.ServicerContext
) -> UpdateCollectionResponse:
id_to_update = UUID(request.id)
# Find the collection with this id
collections = {}
for tenant, databases in self._tenants_to_databases_to_collections.items():
for database, maybe_collections in databases.items():
if id_to_update.hex in maybe_collections:
collections = maybe_collections
if id_to_update.hex not in collections:
context.abort(
grpc.StatusCode.NOT_FOUND, f"Collection {id_to_update} not found"
)
else:
collection = collections[id_to_update.hex]
if request.HasField("name"):
collection["name"] = request.name
if request.HasField("dimension"):
collection["dimension"] = request.dimension
if request.HasField("metadata"):
# TODO: IN SysDB SQlite we have technical debt where we
# replace the entire metadata dict with the new one. We should
# fix that by merging it. For now we just do the same thing here
update_metadata = from_proto_update_metadata(request.metadata)
cleaned_metadata = None
if update_metadata is not None:
cleaned_metadata = {}
for key, value in update_metadata.items():
if value is not None:
cleaned_metadata[key] = value
collection["metadata"] = cleaned_metadata
elif request.HasField("reset_metadata"):
if request.reset_metadata:
collection["metadata"] = {}
return UpdateCollectionResponse()
@overrides(check_signature=False)
def ResetState(
self, request: Empty, context: grpc.ServicerContext
) -> ResetStateResponse:
self.reset_state()
return ResetStateResponse()
def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None:
target_metadata = cast(Dict[str, Any], target)
source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source))
target_metadata.update(source_metadata)
# If a key has a None value, remove it from the metadata
for key, value in source_metadata.items():
if value is None and key in target:
del target_metadata[key]

View File

@@ -0,0 +1,273 @@
import logging
from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool
from chromadb.db.migrations import MigratableDB, Migration
from chromadb.config import System, Settings
import chromadb.db.base as base
from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue
from chromadb.db.mixins.sysdb import SqlSysDB
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
import sqlite3
from overrides import override
import pypika
from typing import Sequence, cast, Optional, Type, Any
from typing_extensions import Literal
from types import TracebackType
import os
from uuid import UUID
from threading import local
from importlib_resources import files
from importlib_resources.abc import Traversable
logger = logging.getLogger(__name__)
class TxWrapper(base.TxWrapper):
_conn: Connection
_pool: Pool
def __init__(self, conn_pool: Pool, stack: local):
self._tx_stack = stack
self._conn = conn_pool.connect()
self._pool = conn_pool
@override
def __enter__(self) -> base.Cursor:
if len(self._tx_stack.stack) == 0:
self._conn.execute("PRAGMA case_sensitive_like = ON")
self._conn.execute("BEGIN;")
self._tx_stack.stack.append(self)
return self._conn.cursor() # type: ignore
@override
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Literal[False]:
self._tx_stack.stack.pop()
if len(self._tx_stack.stack) == 0:
if exc_type is None:
self._conn.commit()
else:
self._conn.rollback()
self._conn.cursor().close()
self._pool.return_to_pool(self._conn)
return False
class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB):
_conn_pool: Pool
_settings: Settings
_migration_imports: Sequence[Traversable]
_db_file: str
_tx_stack: local
_is_persistent: bool
def __init__(self, system: System):
self._settings = system.settings
self._migration_imports = [
files("chromadb.migrations.embeddings_queue"),
files("chromadb.migrations.sysdb"),
files("chromadb.migrations.metadb"),
]
self._is_persistent = self._settings.require("is_persistent")
self._opentelemetry_client = system.require(OpenTelemetryClient)
if not self._is_persistent:
# In order to allow sqlite to be shared between multiple threads, we need to use a
# URI connection string with shared cache.
# See https://www.sqlite.org/sharedcache.html
# https://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa
self._db_file = "file::memory:?cache=shared"
self._conn_pool = LockPool(self._db_file, is_uri=True)
else:
self._db_file = (
self._settings.require("persist_directory") + "/chroma.sqlite3"
)
if not os.path.exists(self._db_file):
os.makedirs(os.path.dirname(self._db_file), exist_ok=True)
self._conn_pool = PerThreadPool(self._db_file)
self._tx_stack = local()
super().__init__(system)
@trace_method("SqliteDB.start", OpenTelemetryGranularity.ALL)
@override
def start(self) -> None:
super().start()
with self.tx() as cur:
cur.execute("PRAGMA foreign_keys = ON")
cur.execute("PRAGMA case_sensitive_like = ON")
self.initialize_migrations()
if (
# (don't attempt to access .config if migrations haven't been run)
self._settings.require("migrations") == "apply"
and self.config.get_parameter("automatically_purge").value is False
):
logger.warning(
"⚠️ It looks like you upgraded from a version below 0.5.6 and could benefit from vacuuming your database. Run chromadb utils vacuum --help for more information."
)
@trace_method("SqliteDB.stop", OpenTelemetryGranularity.ALL)
@override
def stop(self) -> None:
super().stop()
self._conn_pool.close()
@staticmethod
@override
def querybuilder() -> Type[pypika.Query]:
return pypika.Query # type: ignore
@staticmethod
@override
def parameter_format() -> str:
return "?"
@staticmethod
@override
def migration_scope() -> str:
return "sqlite"
@override
def migration_dirs(self) -> Sequence[Traversable]:
return self._migration_imports
@override
def tx(self) -> TxWrapper:
if not hasattr(self._tx_stack, "stack"):
self._tx_stack.stack = []
return TxWrapper(self._conn_pool, stack=self._tx_stack)
@trace_method("SqliteDB.reset_state", OpenTelemetryGranularity.ALL)
@override
def reset_state(self) -> None:
if not self._settings.require("allow_reset"):
raise ValueError(
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
)
with self.tx() as cur:
# Drop all tables
cur.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table'
"""
)
for row in cur.fetchall():
cur.execute(f"DROP TABLE IF EXISTS {row[0]}")
self._conn_pool.close()
self.start()
super().reset_state()
@trace_method("SqliteDB.setup_migrations", OpenTelemetryGranularity.ALL)
@override
def setup_migrations(self) -> None:
with self.tx() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS migrations (
dir TEXT NOT NULL,
version INTEGER NOT NULL,
filename TEXT NOT NULL,
sql TEXT NOT NULL,
hash TEXT NOT NULL,
PRIMARY KEY (dir, version)
)
"""
)
@trace_method("SqliteDB.migrations_initialized", OpenTelemetryGranularity.ALL)
@override
def migrations_initialized(self) -> bool:
with self.tx() as cur:
cur.execute(
"""SELECT count(*) FROM sqlite_master
WHERE type='table' AND name='migrations'"""
)
if cur.fetchone()[0] == 0:
return False
else:
return True
@trace_method("SqliteDB.db_migrations", OpenTelemetryGranularity.ALL)
@override
def db_migrations(self, dir: Traversable) -> Sequence[Migration]:
with self.tx() as cur:
cur.execute(
"""
SELECT dir, version, filename, sql, hash
FROM migrations
WHERE dir = ?
ORDER BY version ASC
""",
(dir.name,),
)
migrations = []
for row in cur.fetchall():
found_dir = cast(str, row[0])
found_version = cast(int, row[1])
found_filename = cast(str, row[2])
found_sql = cast(str, row[3])
found_hash = cast(str, row[4])
migrations.append(
Migration(
dir=found_dir,
version=found_version,
filename=found_filename,
sql=found_sql,
hash=found_hash,
scope=self.migration_scope(),
)
)
return migrations
@override
def apply_migration(self, cur: base.Cursor, migration: Migration) -> None:
cur.executescript(migration["sql"])
cur.execute(
"""
INSERT INTO migrations (dir, version, filename, sql, hash)
VALUES (?, ?, ?, ?, ?)
""",
(
migration["dir"],
migration["version"],
migration["filename"],
migration["sql"],
migration["hash"],
),
)
@staticmethod
@override
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]:
return UUID(value) if value is not None else None
@staticmethod
@override
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]:
return str(uuid) if uuid is not None else None
@staticmethod
@override
def unique_constraint_error() -> Type[BaseException]:
return sqlite3.IntegrityError
def vacuum(self, timeout: int = 5) -> None:
"""Runs VACUUM on the database. `timeout` is the maximum time to wait for an exclusive lock in seconds."""
conn = self._conn_pool.connect()
conn.execute(f"PRAGMA busy_timeout = {int(timeout) * 1000}")
conn.execute("VACUUM")
conn.execute(
"""
INSERT INTO maintenance_log (operation, timestamp)
VALUES ('vacuum', CURRENT_TIMESTAMP)
"""
)

View File

@@ -0,0 +1,163 @@
import sqlite3
import weakref
from abc import ABC, abstractmethod
from typing import Any, Set
import threading
from overrides import override
from typing_extensions import Annotated
class Connection:
"""A threadpool connection that returns itself to the pool on close()"""
_pool: "Pool"
_db_file: str
_conn: sqlite3.Connection
def __init__(
self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any
):
self._pool = pool
self._db_file = db_file
self._conn = sqlite3.connect(
db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs
) # type: ignore
self._conn.isolation_level = None # Handle commits explicitly
def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: # type: ignore
if parameters is ...:
return self._conn.execute(sql)
return self._conn.execute(sql, parameters)
def commit(self) -> None:
self._conn.commit()
def rollback(self) -> None:
self._conn.rollback()
def cursor(self) -> sqlite3.Cursor:
return self._conn.cursor()
def close_actual(self) -> None:
"""Actually closes the connection to the db"""
self._conn.close()
class Pool(ABC):
"""Abstract base class for a pool of connections to a sqlite database."""
@abstractmethod
def __init__(self, db_file: str, is_uri: bool) -> None:
pass
@abstractmethod
def connect(self, *args: Any, **kwargs: Any) -> Connection:
"""Return a connection from the pool."""
pass
@abstractmethod
def close(self) -> None:
"""Close all connections in the pool."""
pass
@abstractmethod
def return_to_pool(self, conn: Connection) -> None:
"""Return a connection to the pool."""
pass
class LockPool(Pool):
"""A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time.
This is used because sqlite does not support multithreaded access with connection timeouts when using the
shared cache mode. We use the shared cache mode to allow multiple threads to share a database.
"""
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.RLock
_connection: threading.local
_db_file: str
_is_uri: bool
def __init__(self, db_file: str, is_uri: bool = False):
self._connections = set()
self._connection = threading.local()
self._lock = threading.RLock()
self._db_file = db_file
self._is_uri = is_uri
@override
def connect(self, *args: Any, **kwargs: Any) -> Connection:
self._lock.acquire()
if hasattr(self._connection, "conn") and self._connection.conn is not None:
return self._connection.conn # type: ignore # cast doesn't work here for some reason
else:
new_connection = Connection(
self, self._db_file, self._is_uri, *args, **kwargs
)
self._connection.conn = new_connection
self._connections.add(weakref.ref(new_connection))
return new_connection
@override
def return_to_pool(self, conn: Connection) -> None:
try:
self._lock.release()
except RuntimeError:
pass
@override
def close(self) -> None:
for conn in self._connections:
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()
try:
self._lock.release()
except RuntimeError:
pass
class PerThreadPool(Pool):
"""Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be
extended to do so and block on connect() if the cap is reached.
"""
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.Lock
_connection: threading.local
_db_file: str
_is_uri_: bool
def __init__(self, db_file: str, is_uri: bool = False):
self._connections = set()
self._connection = threading.local()
self._lock = threading.Lock()
self._db_file = db_file
self._is_uri = is_uri
@override
def connect(self, *args: Any, **kwargs: Any) -> Connection:
if hasattr(self._connection, "conn") and self._connection.conn is not None:
return self._connection.conn # type: ignore # cast doesn't work here for some reason
else:
new_connection = Connection(
self, self._db_file, self._is_uri, *args, **kwargs
)
self._connection.conn = new_connection
with self._lock:
self._connections.add(weakref.ref(new_connection))
return new_connection
@override
def close(self) -> None:
with self._lock:
for conn in self._connections:
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()
@override
def return_to_pool(self, conn: Connection) -> None:
pass # Each thread gets its own connection, so we don't need to return it to the pool