chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,125 @@
|
||||
from typing import Optional, Sequence, TypeVar
|
||||
from abc import abstractmethod
|
||||
from chromadb.types import (
|
||||
Collection,
|
||||
MetadataEmbeddingRecord,
|
||||
Operation,
|
||||
RequestVersionContext,
|
||||
VectorEmbeddingRecord,
|
||||
Where,
|
||||
WhereDocument,
|
||||
VectorQuery,
|
||||
VectorQueryResult,
|
||||
Segment,
|
||||
SeqId,
|
||||
Metadata,
|
||||
)
|
||||
from chromadb.config import Component, System
|
||||
from uuid import UUID
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SegmentType(Enum):
|
||||
SQLITE = "urn:chroma:segment/metadata/sqlite"
|
||||
HNSW_LOCAL_MEMORY = "urn:chroma:segment/vector/hnsw-local-memory"
|
||||
HNSW_LOCAL_PERSISTED = "urn:chroma:segment/vector/hnsw-local-persisted"
|
||||
HNSW_DISTRIBUTED = "urn:chroma:segment/vector/hnsw-distributed"
|
||||
BLOCKFILE_RECORD = "urn:chroma:segment/record/blockfile"
|
||||
BLOCKFILE_METADATA = "urn:chroma:segment/metadata/blockfile"
|
||||
|
||||
|
||||
class SegmentImplementation(Component):
|
||||
@abstractmethod
|
||||
def __init__(self, sytstem: System, segment: Segment):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count(self, request_version_context: RequestVersionContext) -> int:
|
||||
"""Get the number of embeddings in this segment"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def max_seqid(self) -> SeqId:
|
||||
"""Get the maximum SeqID currently indexed by this segment"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]:
|
||||
"""Given an arbitrary metadata map (e.g, from a collection), validate it and
|
||||
return metadata (if any) that is applicable and should be applied to the
|
||||
segment. Validation errors will be reported to the user."""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def delete(self) -> None:
|
||||
"""Delete the segment and all its data"""
|
||||
...
|
||||
|
||||
|
||||
S = TypeVar("S", bound=SegmentImplementation)
|
||||
|
||||
|
||||
class MetadataReader(SegmentImplementation):
|
||||
"""Embedding Metadata segment interface"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(
|
||||
self,
|
||||
request_version_context: RequestVersionContext,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
include_metadata: bool = True,
|
||||
) -> Sequence[MetadataEmbeddingRecord]:
|
||||
"""Query for embedding metadata."""
|
||||
pass
|
||||
|
||||
|
||||
class VectorReader(SegmentImplementation):
|
||||
"""Embedding Vector segment interface"""
|
||||
|
||||
@abstractmethod
|
||||
def get_vectors(
|
||||
self,
|
||||
request_version_context: RequestVersionContext,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
) -> Sequence[VectorEmbeddingRecord]:
|
||||
"""Get embeddings from the segment. If no IDs are provided, all embeddings are
|
||||
returned."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_vectors(
|
||||
self, query: VectorQuery
|
||||
) -> Sequence[Sequence[VectorQueryResult]]:
|
||||
"""Given a vector query, return the top-k nearest neighbors for vector in the
|
||||
query."""
|
||||
pass
|
||||
|
||||
|
||||
class SegmentManager(Component):
|
||||
"""Interface for a pluggable strategy for creating, retrieving and instantiating
|
||||
segments as required"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare_segments_for_new_collection(
|
||||
self, collection: Collection
|
||||
) -> Sequence[Segment]:
|
||||
"""Return the segments required for a new collection. Returns only segment data,
|
||||
does not persist to the SysDB"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
|
||||
"""Delete any local state for all the segments associated with a collection, and
|
||||
returns a sequence of their IDs. Does not update the SysDB."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None:
|
||||
"""Signal to the segment manager that a collection is about to be used, so that
|
||||
it can preload segments as needed. This is only a hint, and implementations are
|
||||
free to ignore it."""
|
||||
pass
|
||||
@@ -0,0 +1,80 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from overrides import EnforceOverrides, overrides
|
||||
from chromadb.config import Component, System
|
||||
from chromadb.types import Segment
|
||||
|
||||
|
||||
class SegmentDirectory(Component):
|
||||
"""A segment directory is a data interface that manages the location of segments. Concretely, this
|
||||
means that for distributed chroma, it provides the grpc endpoint for a segment."""
|
||||
|
||||
@abstractmethod
|
||||
def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
|
||||
"""Return the segment residences for a given segment ID. Will return at most n residences.
|
||||
Should only return less than n residences if there are less than n residences available.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def register_updated_segment_callback(
|
||||
self, callback: Callable[[Segment], None]
|
||||
) -> None:
|
||||
"""Register a callback that will be called when a segment is updated"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Member:
|
||||
id: str
|
||||
ip: str
|
||||
node: str
|
||||
|
||||
|
||||
Memberlist = List[Member]
|
||||
|
||||
|
||||
class MemberlistProvider(Component, EnforceOverrides):
|
||||
"""Returns the latest memberlist and provdes a callback for when it changes. This
|
||||
callback may be called from a different thread than the one that called. Callers should ensure
|
||||
that they are thread-safe."""
|
||||
|
||||
callbacks: List[Callable[[Memberlist], Any]]
|
||||
|
||||
def __init__(self, system: System):
|
||||
self.callbacks = []
|
||||
super().__init__(system)
|
||||
|
||||
@abstractmethod
|
||||
def get_memberlist(self) -> Memberlist:
|
||||
"""Returns the latest memberlist"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_memberlist_name(self, memberlist: str) -> None:
|
||||
"""Sets the memberlist that this provider will watch"""
|
||||
pass
|
||||
|
||||
@overrides
|
||||
def stop(self) -> None:
|
||||
"""Stops watching the memberlist"""
|
||||
self.callbacks = []
|
||||
|
||||
def register_updated_memberlist_callback(
|
||||
self, callback: Callable[[Memberlist], Any]
|
||||
) -> None:
|
||||
"""Registers a callback that will be called when the memberlist changes. May be called many times
|
||||
with the same memberlist, so callers should be idempotent. May be called from a different thread.
|
||||
"""
|
||||
self.callbacks.append(callback)
|
||||
|
||||
def unregister_updated_memberlist_callback(
|
||||
self, callback: Callable[[Memberlist], Any]
|
||||
) -> bool:
|
||||
"""Unregisters a callback that was previously registered. Returns True if the callback was
|
||||
successfully unregistered, False if it was not ever registered."""
|
||||
if callback in self.callbacks:
|
||||
self.callbacks.remove(callback)
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,337 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
from kubernetes import client, config, watch
|
||||
from kubernetes.client.rest import ApiException
|
||||
from overrides import EnforceOverrides, override
|
||||
from chromadb.config import RoutingMode, System
|
||||
from chromadb.segment.distributed import (
|
||||
Member,
|
||||
Memberlist,
|
||||
MemberlistProvider,
|
||||
SegmentDirectory,
|
||||
)
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
add_attributes_to_current_span,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.types import Segment
|
||||
from chromadb.utils.rendezvous_hash import assign, murmur3hasher
|
||||
|
||||
# These could go in config but given that they will rarely change, they are here for now to avoid
|
||||
# polluting the config file further.
|
||||
WATCH_TIMEOUT_SECONDS = 60
|
||||
KUBERNETES_NAMESPACE = "chroma"
|
||||
KUBERNETES_GROUP = "chroma.cluster"
|
||||
HEADLESS_SERVICE = "svc.cluster.local"
|
||||
|
||||
|
||||
class MockMemberlistProvider(MemberlistProvider, EnforceOverrides):
|
||||
"""A mock memberlist provider for testing"""
|
||||
|
||||
_memberlist: Memberlist
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self._memberlist = [
|
||||
Member(id="a", ip="10.0.0.1", node="node1"),
|
||||
Member(id="b", ip="10.0.0.2", node="node2"),
|
||||
Member(id="c", ip="10.0.0.3", node="node3"),
|
||||
]
|
||||
|
||||
@override
|
||||
def get_memberlist(self) -> Memberlist:
|
||||
return self._memberlist
|
||||
|
||||
@override
|
||||
def set_memberlist_name(self, memberlist: str) -> None:
|
||||
pass # The mock provider does not need to set the memberlist name
|
||||
|
||||
def update_memberlist(self, memberlist: Memberlist) -> None:
|
||||
"""Updates the memberlist and calls all registered callbacks. This mocks an update from a k8s CR"""
|
||||
self._memberlist = memberlist
|
||||
for callback in self.callbacks:
|
||||
callback(memberlist)
|
||||
|
||||
|
||||
class CustomResourceMemberlistProvider(MemberlistProvider, EnforceOverrides):
|
||||
"""A memberlist provider that uses a k8s custom resource to store the memberlist"""
|
||||
|
||||
_kubernetes_api: client.CustomObjectsApi
|
||||
_memberlist_name: Optional[str]
|
||||
_curr_memberlist: Optional[Memberlist]
|
||||
_curr_memberlist_mutex: threading.Lock
|
||||
_watch_thread: Optional[threading.Thread]
|
||||
_kill_watch_thread: threading.Event
|
||||
_done_waiting_for_reset: threading.Event
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
config.load_config()
|
||||
self._kubernetes_api = client.CustomObjectsApi()
|
||||
self._watch_thread = None
|
||||
self._memberlist_name = None
|
||||
self._curr_memberlist = None
|
||||
self._curr_memberlist_mutex = threading.Lock()
|
||||
self._kill_watch_thread = threading.Event()
|
||||
self._done_waiting_for_reset = threading.Event()
|
||||
|
||||
@override
|
||||
def start(self) -> None:
|
||||
if self._memberlist_name is None:
|
||||
raise ValueError("Memberlist name must be set before starting")
|
||||
self.get_memberlist()
|
||||
self._done_waiting_for_reset.clear()
|
||||
self._watch_worker_memberlist()
|
||||
return super().start()
|
||||
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
self._curr_memberlist = None
|
||||
self._memberlist_name = None
|
||||
|
||||
# Stop the watch thread
|
||||
self._kill_watch_thread.set()
|
||||
if self._watch_thread is not None:
|
||||
self._watch_thread.join()
|
||||
self._watch_thread = None
|
||||
self._kill_watch_thread.clear()
|
||||
self._done_waiting_for_reset.clear()
|
||||
return super().stop()
|
||||
|
||||
@override
|
||||
def reset_state(self) -> None:
|
||||
# Reset the memberlist in kubernetes, and wait for it to
|
||||
# get propagated back again
|
||||
# Note that the component must be running in order to reset the state
|
||||
if not self._system.settings.require("allow_reset"):
|
||||
raise ValueError(
|
||||
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
|
||||
)
|
||||
if self._memberlist_name:
|
||||
self._done_waiting_for_reset.clear()
|
||||
self._kubernetes_api.patch_namespaced_custom_object(
|
||||
group=KUBERNETES_GROUP,
|
||||
version="v1",
|
||||
namespace=KUBERNETES_NAMESPACE,
|
||||
plural="memberlists",
|
||||
name=self._memberlist_name,
|
||||
body={
|
||||
"kind": "MemberList",
|
||||
"spec": {"members": []},
|
||||
},
|
||||
)
|
||||
self._done_waiting_for_reset.wait(5.0)
|
||||
# TODO: For some reason the above can flake and the memberlist won't be populated
|
||||
# Given that this is a test harness, just sleep for an additional 500ms for now
|
||||
# We should understand why this flaps
|
||||
time.sleep(0.5)
|
||||
|
||||
@override
|
||||
def get_memberlist(self) -> Memberlist:
|
||||
if self._curr_memberlist is None:
|
||||
self._curr_memberlist = self._fetch_memberlist()
|
||||
return self._curr_memberlist
|
||||
|
||||
@override
|
||||
def set_memberlist_name(self, memberlist: str) -> None:
|
||||
self._memberlist_name = memberlist
|
||||
|
||||
def _fetch_memberlist(self) -> Memberlist:
|
||||
api_response = self._kubernetes_api.get_namespaced_custom_object(
|
||||
group=KUBERNETES_GROUP,
|
||||
version="v1",
|
||||
namespace=KUBERNETES_NAMESPACE,
|
||||
plural="memberlists",
|
||||
name=f"{self._memberlist_name}",
|
||||
)
|
||||
api_response = cast(Dict[str, Any], api_response)
|
||||
if "spec" not in api_response:
|
||||
return []
|
||||
response_spec = cast(Dict[str, Any], api_response["spec"])
|
||||
return self._parse_response_memberlist(response_spec)
|
||||
|
||||
def _watch_worker_memberlist(self) -> None:
|
||||
# TODO: We may want to make this watch function a library function that can be used by other
|
||||
# components that need to watch k8s custom resources.
|
||||
def run_watch() -> None:
|
||||
w = watch.Watch()
|
||||
|
||||
def do_watch() -> None:
|
||||
for event in w.stream(
|
||||
self._kubernetes_api.list_namespaced_custom_object,
|
||||
group=KUBERNETES_GROUP,
|
||||
version="v1",
|
||||
namespace=KUBERNETES_NAMESPACE,
|
||||
plural="memberlists",
|
||||
field_selector=f"metadata.name={self._memberlist_name}",
|
||||
timeout_seconds=WATCH_TIMEOUT_SECONDS,
|
||||
):
|
||||
event = cast(Dict[str, Any], event)
|
||||
response_spec = event["object"]["spec"]
|
||||
response_spec = cast(Dict[str, Any], response_spec)
|
||||
with self._curr_memberlist_mutex:
|
||||
self._curr_memberlist = self._parse_response_memberlist(
|
||||
response_spec
|
||||
)
|
||||
self._notify(self._curr_memberlist)
|
||||
if (
|
||||
self._system.settings.require("allow_reset")
|
||||
and not self._done_waiting_for_reset.is_set()
|
||||
and len(self._curr_memberlist) > 0
|
||||
):
|
||||
self._done_waiting_for_reset.set()
|
||||
|
||||
# Watch the custom resource for changes
|
||||
# Watch with a timeout and retry so we can gracefully stop this if needed
|
||||
while not self._kill_watch_thread.is_set():
|
||||
try:
|
||||
do_watch()
|
||||
except ApiException as e:
|
||||
# If status code is 410, the watch has expired and we need to start a new one.
|
||||
if e.status == 410:
|
||||
pass
|
||||
return
|
||||
|
||||
if self._watch_thread is None:
|
||||
thread = threading.Thread(target=run_watch, daemon=True)
|
||||
thread.start()
|
||||
self._watch_thread = thread
|
||||
else:
|
||||
raise Exception("A watch thread is already running.")
|
||||
|
||||
def _parse_response_memberlist(
|
||||
self, api_response_spec: Dict[str, Any]
|
||||
) -> Memberlist:
|
||||
if "members" not in api_response_spec:
|
||||
return []
|
||||
parsed = []
|
||||
for m in api_response_spec["members"]:
|
||||
id = m["member_id"]
|
||||
ip = m["member_ip"] if "member_ip" in m else ""
|
||||
node = m["member_node_name"] if "member_node_name" in m else ""
|
||||
parsed.append(Member(id=id, ip=ip, node=node))
|
||||
return parsed
|
||||
|
||||
def _notify(self, memberlist: Memberlist) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback(memberlist)
|
||||
|
||||
|
||||
class RendezvousHashSegmentDirectory(SegmentDirectory, EnforceOverrides):
|
||||
_memberlist_provider: MemberlistProvider
|
||||
_curr_memberlist_mutex: threading.Lock
|
||||
_curr_memberlist: Optional[Memberlist]
|
||||
_routing_mode: RoutingMode
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self._memberlist_provider = self.require(MemberlistProvider)
|
||||
memberlist_name = system.settings.require("worker_memberlist_name")
|
||||
self._memberlist_provider.set_memberlist_name(memberlist_name)
|
||||
self._routing_mode = system.settings.require(
|
||||
"chroma_segment_directory_routing_mode"
|
||||
)
|
||||
|
||||
self._curr_memberlist = None
|
||||
self._curr_memberlist_mutex = threading.Lock()
|
||||
|
||||
@override
|
||||
def start(self) -> None:
|
||||
self._curr_memberlist = self._memberlist_provider.get_memberlist()
|
||||
self._memberlist_provider.register_updated_memberlist_callback(
|
||||
self._update_memberlist
|
||||
)
|
||||
return super().start()
|
||||
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
self._memberlist_provider.unregister_updated_memberlist_callback(
|
||||
self._update_memberlist
|
||||
)
|
||||
return super().stop()
|
||||
|
||||
@override
|
||||
def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
|
||||
if self._curr_memberlist is None or len(self._curr_memberlist) == 0:
|
||||
raise ValueError("Memberlist is not initialized")
|
||||
|
||||
# assign() will throw an error if n is greater than the number of members
|
||||
# clamp n to the number of members to align with the contract of this method
|
||||
# which is to return at most n endpoints
|
||||
n = min(n, len(self._curr_memberlist))
|
||||
|
||||
# Check if all members in the memberlist have a node set,
|
||||
# if so, route using the node
|
||||
|
||||
# NOTE(@hammadb) 1/8/2024: This is to handle the migration between routing
|
||||
# using the member id and routing using the node name
|
||||
# We want to route using the node name over the member id
|
||||
# because the node may have a disk cache that we want a
|
||||
# stable identifier for over deploys.
|
||||
can_use_node_routing = (
|
||||
all([m.node != "" and len(m.node) != 0 for m in self._curr_memberlist])
|
||||
and self._routing_mode == RoutingMode.NODE
|
||||
)
|
||||
if can_use_node_routing:
|
||||
# If we are using node routing and the segments
|
||||
assignments = assign(
|
||||
segment["collection"].hex,
|
||||
[m.node for m in self._curr_memberlist],
|
||||
murmur3hasher,
|
||||
n,
|
||||
)
|
||||
else:
|
||||
# Query to the same collection should end up on the same endpoint
|
||||
assignments = assign(
|
||||
segment["collection"].hex,
|
||||
[m.id for m in self._curr_memberlist],
|
||||
murmur3hasher,
|
||||
n,
|
||||
)
|
||||
assignments_set = set(assignments)
|
||||
out_endpoints = []
|
||||
for member in self._curr_memberlist:
|
||||
is_chosen_with_node_routing = (
|
||||
can_use_node_routing and member.node in assignments_set
|
||||
)
|
||||
is_chosen_with_id_routing = (
|
||||
not can_use_node_routing and member.id in assignments_set
|
||||
)
|
||||
if is_chosen_with_node_routing or is_chosen_with_id_routing:
|
||||
# If the memberlist has an ip, use it, otherwise use the member id with the headless service
|
||||
# this is for backwards compatibility with the old memberlist which only had ids
|
||||
if member.ip is not None and member.ip != "":
|
||||
endpoint = f"{member.ip}:50051"
|
||||
out_endpoints.append(endpoint)
|
||||
else:
|
||||
service_name = self.extract_service_name(member.id)
|
||||
endpoint = f"{member.id}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051"
|
||||
out_endpoints.append(endpoint)
|
||||
return out_endpoints
|
||||
|
||||
@override
|
||||
def register_updated_segment_callback(
|
||||
self, callback: Callable[[Segment], None]
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@trace_method(
|
||||
"RendezvousHashSegmentDirectory._update_memberlist",
|
||||
OpenTelemetryGranularity.ALL,
|
||||
)
|
||||
def _update_memberlist(self, memberlist: Memberlist) -> None:
|
||||
with self._curr_memberlist_mutex:
|
||||
add_attributes_to_current_span(
|
||||
{"new_memberlist": [m.id for m in memberlist]}
|
||||
)
|
||||
self._curr_memberlist = memberlist
|
||||
|
||||
def extract_service_name(self, pod_name: str) -> Optional[str]:
|
||||
# Split the pod name by the hyphen
|
||||
parts = pod_name.split("-")
|
||||
# The service name is expected to be the prefix before the last hyphen
|
||||
if len(parts) > 1:
|
||||
return "-".join(parts[:-1])
|
||||
return None
|
||||
119
backend_service/venv/lib/python3.13/site-packages/chromadb/segment/impl/manager/cache/cache.py
vendored
Normal file
119
backend_service/venv/lib/python3.13/site-packages/chromadb/segment/impl/manager/cache/cache.py
vendored
Normal file
@@ -0,0 +1,119 @@
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Callable
|
||||
from chromadb.types import Segment
|
||||
from overrides import override
|
||||
from typing import Dict, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SegmentCache(ABC):
|
||||
@abstractmethod
|
||||
def get(self, key: uuid.UUID) -> Optional[Segment]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pop(self, key: uuid.UUID) -> Optional[Segment]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: uuid.UUID, value: Segment) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BasicCache(SegmentCache):
|
||||
def __init__(self):
|
||||
self.cache: Dict[uuid.UUID, Segment] = {}
|
||||
self.lock = threading.RLock()
|
||||
|
||||
@override
|
||||
def get(self, key: uuid.UUID) -> Optional[Segment]:
|
||||
with self.lock:
|
||||
return self.cache.get(key)
|
||||
|
||||
@override
|
||||
def pop(self, key: uuid.UUID) -> Optional[Segment]:
|
||||
with self.lock:
|
||||
return self.cache.pop(key, None)
|
||||
|
||||
@override
|
||||
def set(self, key: uuid.UUID, value: Segment) -> None:
|
||||
with self.lock:
|
||||
self.cache[key] = value
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
with self.lock:
|
||||
self.cache = {}
|
||||
|
||||
|
||||
class SegmentLRUCache(BasicCache):
|
||||
"""A simple LRU cache implementation that handles objects with dynamic sizes.
|
||||
The size of each object is determined by a user-provided size function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
size_func: Callable[[uuid.UUID], int],
|
||||
callback: Optional[Callable[[uuid.UUID, Segment], Any]] = None,
|
||||
):
|
||||
self.capacity = capacity
|
||||
self.size_func = size_func
|
||||
self.cache: Dict[uuid.UUID, Segment] = {}
|
||||
self.history = []
|
||||
self.callback = callback
|
||||
self.lock = threading.RLock()
|
||||
|
||||
def _upsert_key(self, key: uuid.UUID):
|
||||
if key in self.history:
|
||||
self.history.remove(key)
|
||||
self.history.append(key)
|
||||
else:
|
||||
self.history.append(key)
|
||||
|
||||
@override
|
||||
def get(self, key: uuid.UUID) -> Optional[Segment]:
|
||||
with self.lock:
|
||||
self._upsert_key(key)
|
||||
if key in self.cache:
|
||||
return self.cache[key]
|
||||
else:
|
||||
return None
|
||||
|
||||
@override
|
||||
def pop(self, key: uuid.UUID) -> Optional[Segment]:
|
||||
with self.lock:
|
||||
if key in self.history:
|
||||
self.history.remove(key)
|
||||
return self.cache.pop(key, None)
|
||||
|
||||
@override
|
||||
def set(self, key: uuid.UUID, value: Segment) -> None:
|
||||
with self.lock:
|
||||
if key in self.cache:
|
||||
return
|
||||
item_size = self.size_func(key)
|
||||
key_sizes = {key: self.size_func(key) for key in self.cache}
|
||||
total_size = sum(key_sizes.values())
|
||||
index = 0
|
||||
# Evict items if capacity is exceeded
|
||||
while total_size + item_size > self.capacity and len(self.history) > index:
|
||||
key_delete = self.history[index]
|
||||
if key_delete in self.cache:
|
||||
self.callback(key_delete, self.cache[key_delete])
|
||||
del self.cache[key_delete]
|
||||
total_size -= key_sizes[key_delete]
|
||||
index += 1
|
||||
|
||||
self.cache[key] = value
|
||||
self._upsert_key(key)
|
||||
|
||||
@override
|
||||
def reset(self):
|
||||
with self.lock:
|
||||
self.cache = {}
|
||||
self.history = []
|
||||
@@ -0,0 +1,97 @@
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Sequence
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from overrides import override
|
||||
|
||||
from chromadb.config import System
|
||||
from chromadb.db.system import SysDB
|
||||
from chromadb.segment import (
|
||||
SegmentImplementation,
|
||||
SegmentManager,
|
||||
SegmentType,
|
||||
)
|
||||
from chromadb.segment.distributed import SegmentDirectory
|
||||
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.types import (
|
||||
Collection,
|
||||
Operation,
|
||||
Segment,
|
||||
SegmentScope,
|
||||
)
|
||||
|
||||
|
||||
class DistributedSegmentManager(SegmentManager):
|
||||
_sysdb: SysDB
|
||||
_system: System
|
||||
_instances: Dict[UUID, SegmentImplementation]
|
||||
_segment_directory: SegmentDirectory
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self._sysdb = self.require(SysDB)
|
||||
self._segment_directory = self.require(SegmentDirectory)
|
||||
self._system = system
|
||||
self._instances = {}
|
||||
self._lock = Lock()
|
||||
|
||||
@trace_method(
|
||||
"DistributedSegmentManager.prepare_segments_for_new_collection",
|
||||
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
|
||||
)
|
||||
@override
|
||||
def prepare_segments_for_new_collection(
|
||||
self, collection: Collection
|
||||
) -> Sequence[Segment]:
|
||||
vector_segment = Segment(
|
||||
id=uuid4(),
|
||||
type=SegmentType.HNSW_DISTRIBUTED.value,
|
||||
scope=SegmentScope.VECTOR,
|
||||
collection=collection.id,
|
||||
metadata=PersistentHnswParams.extract(collection.metadata)
|
||||
if collection.metadata
|
||||
else None,
|
||||
file_paths={},
|
||||
)
|
||||
metadata_segment = Segment(
|
||||
id=uuid4(),
|
||||
type=SegmentType.BLOCKFILE_METADATA.value,
|
||||
scope=SegmentScope.METADATA,
|
||||
collection=collection.id,
|
||||
metadata=None,
|
||||
file_paths={},
|
||||
)
|
||||
record_segment = Segment(
|
||||
id=uuid4(),
|
||||
type=SegmentType.BLOCKFILE_RECORD.value,
|
||||
scope=SegmentScope.RECORD,
|
||||
collection=collection.id,
|
||||
metadata=None,
|
||||
file_paths={},
|
||||
)
|
||||
return [vector_segment, record_segment, metadata_segment]
|
||||
|
||||
@override
|
||||
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
|
||||
# delete_collection deletes segments in distributed mode
|
||||
return []
|
||||
|
||||
@trace_method(
|
||||
"DistributedSegmentManager.get_endpoint",
|
||||
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
|
||||
)
|
||||
def get_endpoints(self, segment: Segment, n: int) -> List[str]:
|
||||
return self._segment_directory.get_segment_endpoints(segment, n)
|
||||
|
||||
@trace_method(
|
||||
"DistributedSegmentManager.hint_use_collection",
|
||||
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
|
||||
)
|
||||
@override
|
||||
def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,269 @@
|
||||
from threading import Lock
|
||||
from chromadb.segment import (
|
||||
SegmentImplementation,
|
||||
SegmentManager,
|
||||
MetadataReader,
|
||||
SegmentType,
|
||||
VectorReader,
|
||||
S,
|
||||
)
|
||||
import logging
|
||||
from chromadb.segment.impl.manager.cache.cache import (
|
||||
SegmentLRUCache,
|
||||
BasicCache,
|
||||
SegmentCache,
|
||||
)
|
||||
import os
|
||||
|
||||
from chromadb.config import System, get_class
|
||||
from chromadb.db.system import SysDB
|
||||
from overrides import override
|
||||
from chromadb.segment.impl.vector.local_persistent_hnsw import (
|
||||
PersistentLocalHnswSegment,
|
||||
)
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata
|
||||
from typing import Dict, Type, Sequence, Optional, cast
|
||||
from uuid import UUID, uuid4
|
||||
import platform
|
||||
|
||||
from chromadb.utils.lru_cache import LRUCache
|
||||
from chromadb.utils.directory import get_directory_size
|
||||
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import resource
|
||||
elif platform.system() == "Windows":
|
||||
import ctypes
|
||||
|
||||
SEGMENT_TYPE_IMPLS = {
|
||||
SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment",
|
||||
SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment",
|
||||
SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment",
|
||||
}
|
||||
|
||||
|
||||
class LocalSegmentManager(SegmentManager):
|
||||
_sysdb: SysDB
|
||||
_system: System
|
||||
_opentelemetry_client: OpenTelemetryClient
|
||||
_instances: Dict[UUID, SegmentImplementation]
|
||||
_vector_instances_file_handle_cache: LRUCache[
|
||||
UUID, PersistentLocalHnswSegment
|
||||
] # LRU cache to manage file handles across vector segment instances
|
||||
_vector_segment_type: SegmentType = SegmentType.HNSW_LOCAL_MEMORY
|
||||
_lock: Lock
|
||||
_max_file_handles: int
|
||||
|
||||
def __init__(self, system: System):
|
||||
super().__init__(system)
|
||||
self._sysdb = self.require(SysDB)
|
||||
self._system = system
|
||||
self._opentelemetry_client = system.require(OpenTelemetryClient)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._instances = {}
|
||||
self.segment_cache: Dict[SegmentScope, SegmentCache] = {
|
||||
SegmentScope.METADATA: BasicCache() # type: ignore[no-untyped-call]
|
||||
}
|
||||
if (
|
||||
system.settings.chroma_segment_cache_policy == "LRU"
|
||||
and system.settings.chroma_memory_limit_bytes > 0
|
||||
):
|
||||
self.segment_cache[SegmentScope.VECTOR] = SegmentLRUCache(
|
||||
capacity=system.settings.chroma_memory_limit_bytes,
|
||||
callback=lambda k, v: self.callback_cache_evict(v),
|
||||
size_func=lambda k: self._get_segment_disk_size(k),
|
||||
)
|
||||
else:
|
||||
self.segment_cache[SegmentScope.VECTOR] = BasicCache() # type: ignore[no-untyped-call]
|
||||
|
||||
self._lock = Lock()
|
||||
|
||||
# TODO: prototyping with distributed segment for now, but this should be a configurable option
|
||||
# we need to think about how to handle this configuration
|
||||
if self._system.settings.require("is_persistent"):
|
||||
self._vector_segment_type = SegmentType.HNSW_LOCAL_PERSISTED
|
||||
if platform.system() != "Windows":
|
||||
self._max_file_handles = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
|
||||
else:
|
||||
self._max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore
|
||||
segment_limit = (
|
||||
self._max_file_handles
|
||||
# This is integer division in Python 3, and not a comment.
|
||||
// PersistentLocalHnswSegment.get_file_handle_count()
|
||||
)
|
||||
self._vector_instances_file_handle_cache = LRUCache(
|
||||
segment_limit, callback=lambda _, v: v.close_persistent_index()
|
||||
)
|
||||
|
||||
@trace_method(
|
||||
"LocalSegmentManager.callback_cache_evict",
|
||||
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
|
||||
)
|
||||
def callback_cache_evict(self, segment: Segment) -> None:
|
||||
collection_id = segment["collection"]
|
||||
self.logger.info(f"LRU cache evict collection {collection_id}")
|
||||
instance = self._instance(segment)
|
||||
instance.stop()
|
||||
del self._instances[segment["id"]]
|
||||
|
||||
@override
|
||||
def start(self) -> None:
|
||||
for instance in self._instances.values():
|
||||
instance.start()
|
||||
super().start()
|
||||
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
for instance in self._instances.values():
|
||||
instance.stop()
|
||||
super().stop()
|
||||
|
||||
@override
|
||||
def reset_state(self) -> None:
|
||||
for instance in self._instances.values():
|
||||
instance.stop()
|
||||
instance.reset_state()
|
||||
self._instances = {}
|
||||
self.segment_cache[SegmentScope.VECTOR].reset()
|
||||
super().reset_state()
|
||||
|
||||
@trace_method(
|
||||
"LocalSegmentManager.prepare_segments_for_new_collection",
|
||||
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
|
||||
)
|
||||
@override
|
||||
def prepare_segments_for_new_collection(
|
||||
self, collection: Collection
|
||||
) -> Sequence[Segment]:
|
||||
vector_segment = _segment(
|
||||
self._vector_segment_type, SegmentScope.VECTOR, collection
|
||||
)
|
||||
metadata_segment = _segment(
|
||||
SegmentType.SQLITE, SegmentScope.METADATA, collection
|
||||
)
|
||||
return [vector_segment, metadata_segment]
|
||||
|
||||
@trace_method(
|
||||
"LocalSegmentManager.delete_segments",
|
||||
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
|
||||
)
|
||||
@override
|
||||
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
|
||||
segments = self._sysdb.get_segments(collection=collection_id)
|
||||
for segment in segments:
|
||||
if segment["id"] in self._instances:
|
||||
if segment["type"] == SegmentType.HNSW_LOCAL_PERSISTED.value:
|
||||
instance = self.get_segment(collection_id, VectorReader)
|
||||
instance.delete()
|
||||
elif segment["type"] == SegmentType.SQLITE.value:
|
||||
instance = self.get_segment(collection_id, MetadataReader) # type: ignore[assignment]
|
||||
instance.delete()
|
||||
del self._instances[segment["id"]]
|
||||
if segment["scope"] is SegmentScope.VECTOR:
|
||||
self.segment_cache[SegmentScope.VECTOR].pop(collection_id)
|
||||
if segment["scope"] is SegmentScope.METADATA:
|
||||
self.segment_cache[SegmentScope.METADATA].pop(collection_id)
|
||||
return [s["id"] for s in segments]
|
||||
|
||||
def _get_segment_disk_size(self, collection_id: UUID) -> int:
|
||||
segments = self._sysdb.get_segments(
|
||||
collection=collection_id, scope=SegmentScope.VECTOR
|
||||
)
|
||||
if len(segments) == 0:
|
||||
return 0
|
||||
# With local segment manager (single server chroma), a collection always have one segment.
|
||||
size = get_directory_size(
|
||||
os.path.join(
|
||||
self._system.settings.require("persist_directory"),
|
||||
str(segments[0]["id"]),
|
||||
)
|
||||
)
|
||||
return size
|
||||
|
||||
@trace_method(
|
||||
"LocalSegmentManager._get_segment_sysdb",
|
||||
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
|
||||
)
|
||||
def _get_segment_sysdb(self, collection_id: UUID, scope: SegmentScope) -> Segment:
|
||||
segments = self._sysdb.get_segments(collection=collection_id, scope=scope)
|
||||
known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()])
|
||||
# Get the first segment of a known type
|
||||
segment = next(filter(lambda s: s["type"] in known_types, segments))
|
||||
return segment
|
||||
|
||||
@trace_method(
|
||||
"LocalSegmentManager.get_segment",
|
||||
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
|
||||
)
|
||||
def get_segment(self, collection_id: UUID, type: Type[S]) -> S:
|
||||
if type == MetadataReader:
|
||||
scope = SegmentScope.METADATA
|
||||
elif type == VectorReader:
|
||||
scope = SegmentScope.VECTOR
|
||||
else:
|
||||
raise ValueError(f"Invalid segment type: {type}")
|
||||
|
||||
segment = self.segment_cache[scope].get(collection_id)
|
||||
if segment is None:
|
||||
segment = self._get_segment_sysdb(collection_id, scope)
|
||||
self.segment_cache[scope].set(collection_id, segment)
|
||||
|
||||
# Instances must be atomically created, so we use a lock to ensure that only one thread
|
||||
# creates the instance.
|
||||
with self._lock:
|
||||
instance = self._instance(segment)
|
||||
return cast(S, instance)
|
||||
|
||||
@trace_method(
|
||||
"LocalSegmentManager.hint_use_collection",
|
||||
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
|
||||
)
|
||||
@override
|
||||
def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None:
|
||||
# The local segment manager responds to hints by pre-loading both the metadata and vector
|
||||
# segments for the given collection.
|
||||
for type in [MetadataReader, VectorReader]:
|
||||
# Just use get_segment to load the segment into the cache
|
||||
instance = self.get_segment(collection_id, type)
|
||||
# If the segment is a vector segment, we need to keep segments in an LRU cache
|
||||
# to avoid hitting the OS file handle limit.
|
||||
if type == VectorReader and self._system.settings.require("is_persistent"):
|
||||
instance = cast(PersistentLocalHnswSegment, instance)
|
||||
instance.open_persistent_index()
|
||||
self._vector_instances_file_handle_cache.set(collection_id, instance)
|
||||
|
||||
def _cls(self, segment: Segment) -> Type[SegmentImplementation]:
|
||||
classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])]
|
||||
cls = get_class(classname, SegmentImplementation)
|
||||
return cls
|
||||
|
||||
def _instance(self, segment: Segment) -> SegmentImplementation:
|
||||
if segment["id"] not in self._instances:
|
||||
cls = self._cls(segment)
|
||||
instance = cls(self._system, segment)
|
||||
instance.start()
|
||||
self._instances[segment["id"]] = instance
|
||||
return self._instances[segment["id"]]
|
||||
|
||||
|
||||
def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> Segment:
|
||||
"""Create a metadata dict, propagating metadata correctly for the given segment type."""
|
||||
cls = get_class(SEGMENT_TYPE_IMPLS[type], SegmentImplementation)
|
||||
collection_metadata = collection.metadata
|
||||
metadata: Optional[Metadata] = None
|
||||
if collection_metadata:
|
||||
metadata = cls.propagate_collection_metadata(collection_metadata)
|
||||
|
||||
return Segment(
|
||||
id=uuid4(),
|
||||
type=type.value,
|
||||
scope=scope,
|
||||
collection=collection.id,
|
||||
metadata=metadata,
|
||||
file_paths={},
|
||||
)
|
||||
@@ -0,0 +1,723 @@
|
||||
from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict, List
|
||||
from chromadb.segment import MetadataReader
|
||||
from chromadb.ingest import Consumer
|
||||
from chromadb.config import System
|
||||
from chromadb.types import RequestVersionContext, Segment, InclusionExclusionOperator
|
||||
from chromadb.db.impl.sqlite import SqliteDB
|
||||
from overrides import override
|
||||
from chromadb.db.base import (
|
||||
Cursor,
|
||||
ParameterValue,
|
||||
get_sql,
|
||||
)
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.types import (
|
||||
Where,
|
||||
WhereDocument,
|
||||
MetadataEmbeddingRecord,
|
||||
LogRecord,
|
||||
SeqId,
|
||||
Operation,
|
||||
UpdateMetadata,
|
||||
LiteralValue,
|
||||
WhereOperator,
|
||||
)
|
||||
from uuid import UUID
|
||||
from pypika import Table, Tables
|
||||
from pypika.queries import QueryBuilder
|
||||
import pypika.functions as fn
|
||||
from pypika.terms import Criterion
|
||||
from itertools import groupby
|
||||
from functools import reduce
|
||||
import sqlite3
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SqliteMetadataSegment(MetadataReader):
|
||||
_consumer: Consumer
|
||||
_db: SqliteDB
|
||||
_id: UUID
|
||||
_opentelemetry_client: OpenTelemetryClient
|
||||
_collection_id: Optional[UUID]
|
||||
_subscription: Optional[UUID] = None
|
||||
|
||||
def __init__(self, system: System, segment: Segment):
|
||||
self._db = system.instance(SqliteDB)
|
||||
self._consumer = system.instance(Consumer)
|
||||
self._id = segment["id"]
|
||||
self._opentelemetry_client = system.require(OpenTelemetryClient)
|
||||
self._collection_id = segment["collection"]
|
||||
|
||||
@trace_method("SqliteMetadataSegment.start", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def start(self) -> None:
|
||||
if self._collection_id:
|
||||
seq_id = self.max_seqid()
|
||||
self._subscription = self._consumer.subscribe(
|
||||
collection_id=self._collection_id,
|
||||
consume_fn=self._write_metadata,
|
||||
start=seq_id,
|
||||
)
|
||||
|
||||
@trace_method("SqliteMetadataSegment.stop", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
if self._subscription:
|
||||
self._consumer.unsubscribe(self._subscription)
|
||||
|
||||
@trace_method("SqliteMetadataSegment.max_seqid", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def max_seqid(self) -> SeqId:
|
||||
t = Table("max_seq_id")
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.select(t.seq_id)
|
||||
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
with self._db.tx() as cur:
|
||||
result = cur.execute(sql, params).fetchone()
|
||||
|
||||
if result is None:
|
||||
return self._consumer.min_seqid()
|
||||
else:
|
||||
return cast(int, result[0])
|
||||
|
||||
@trace_method("SqliteMetadataSegment.count", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def count(self, request_version_context: RequestVersionContext) -> int:
|
||||
embeddings_t = Table("embeddings")
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.from_(embeddings_t)
|
||||
.where(
|
||||
embeddings_t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
|
||||
)
|
||||
.select(fn.Count(embeddings_t.id))
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
with self._db.tx() as cur:
|
||||
result = cur.execute(sql, params).fetchone()[0]
|
||||
return cast(int, result)
|
||||
|
||||
@trace_method("SqliteMetadataSegment.get_metadata", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def get_metadata(
|
||||
self,
|
||||
request_version_context: RequestVersionContext,
|
||||
where: Optional[Where] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
include_metadata: bool = True,
|
||||
) -> Sequence[MetadataEmbeddingRecord]:
|
||||
"""Query for embedding metadata."""
|
||||
embeddings_t, metadata_t, fulltext_t = Tables(
|
||||
"embeddings", "embedding_metadata", "embedding_fulltext_search"
|
||||
)
|
||||
|
||||
limit = limit or 2**63 - 1
|
||||
offset = offset or 0
|
||||
|
||||
if limit < 0:
|
||||
raise ValueError("Limit cannot be negative")
|
||||
|
||||
select_clause = [
|
||||
embeddings_t.id,
|
||||
embeddings_t.embedding_id,
|
||||
embeddings_t.seq_id,
|
||||
]
|
||||
if include_metadata:
|
||||
select_clause.extend(
|
||||
[
|
||||
metadata_t.key,
|
||||
metadata_t.string_value,
|
||||
metadata_t.int_value,
|
||||
metadata_t.float_value,
|
||||
metadata_t.bool_value,
|
||||
]
|
||||
)
|
||||
|
||||
q = (
|
||||
(
|
||||
self._db.querybuilder()
|
||||
.from_(embeddings_t)
|
||||
.left_join(metadata_t)
|
||||
.on(embeddings_t.id == metadata_t.id)
|
||||
)
|
||||
.select(*select_clause)
|
||||
.orderby(embeddings_t.id)
|
||||
)
|
||||
|
||||
# If there is a query that touches the metadata table, it uses
|
||||
# where and where_document filters, we treat this case seperately
|
||||
if where is not None or where_document is not None:
|
||||
metadata_q = (
|
||||
self._db.querybuilder()
|
||||
.from_(embeddings_t)
|
||||
.select(embeddings_t.id)
|
||||
.left_join(metadata_t)
|
||||
.on(embeddings_t.id == metadata_t.id)
|
||||
.orderby(embeddings_t.id)
|
||||
.where(
|
||||
embeddings_t.segment_id
|
||||
== ParameterValue(self._db.uuid_to_db(self._id))
|
||||
)
|
||||
.distinct() # These are embedding ids
|
||||
)
|
||||
|
||||
if where:
|
||||
metadata_q = metadata_q.where(
|
||||
self._where_map_criterion(
|
||||
metadata_q, where, metadata_t, embeddings_t
|
||||
)
|
||||
)
|
||||
if where_document:
|
||||
metadata_q = metadata_q.where(
|
||||
self._where_doc_criterion(
|
||||
metadata_q, where_document, metadata_t, fulltext_t, embeddings_t
|
||||
)
|
||||
)
|
||||
if ids is not None:
|
||||
metadata_q = metadata_q.where(
|
||||
embeddings_t.embedding_id.isin(ParameterValue(ids))
|
||||
)
|
||||
|
||||
metadata_q = metadata_q.limit(limit)
|
||||
metadata_q = metadata_q.offset(offset)
|
||||
|
||||
q = q.where(embeddings_t.id.isin(metadata_q))
|
||||
else:
|
||||
# In the case where we don't use the metadata table
|
||||
# We have to apply limit/offset to embeddings and then join
|
||||
# with metadata
|
||||
embeddings_q = (
|
||||
self._db.querybuilder()
|
||||
.from_(embeddings_t)
|
||||
.select(embeddings_t.id)
|
||||
.where(
|
||||
embeddings_t.segment_id
|
||||
== ParameterValue(self._db.uuid_to_db(self._id))
|
||||
)
|
||||
.orderby(embeddings_t.id)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
if ids is not None:
|
||||
embeddings_q = embeddings_q.where(
|
||||
embeddings_t.embedding_id.isin(ParameterValue(ids))
|
||||
)
|
||||
|
||||
q = q.where(embeddings_t.id.isin(embeddings_q))
|
||||
|
||||
with self._db.tx() as cur:
|
||||
# Execute the query with the limit and offset already applied
|
||||
return list(self._records(cur, q, include_metadata))
|
||||
|
||||
def _records(
|
||||
self, cur: Cursor, q: QueryBuilder, include_metadata: bool
|
||||
) -> Generator[MetadataEmbeddingRecord, None, None]:
|
||||
"""Given a cursor and a QueryBuilder, yield a generator of records. Assumes
|
||||
cursor returns rows in ID order."""
|
||||
|
||||
sql, params = get_sql(q)
|
||||
cur.execute(sql, params)
|
||||
|
||||
cur_iterator = iter(cur.fetchone, None)
|
||||
group_iterator = groupby(cur_iterator, lambda r: int(r[0]))
|
||||
|
||||
for _, group in group_iterator:
|
||||
yield self._record(list(group), include_metadata)
|
||||
|
||||
@trace_method("SqliteMetadataSegment._record", OpenTelemetryGranularity.ALL)
|
||||
def _record(
|
||||
self, rows: Sequence[Tuple[Any, ...]], include_metadata: bool
|
||||
) -> MetadataEmbeddingRecord:
|
||||
"""Given a list of DB rows with the same ID, construct a
|
||||
MetadataEmbeddingRecord"""
|
||||
_, embedding_id, seq_id = rows[0][:3]
|
||||
if not include_metadata:
|
||||
return MetadataEmbeddingRecord(id=embedding_id, metadata=None)
|
||||
|
||||
metadata = {}
|
||||
for row in rows:
|
||||
key, string_value, int_value, float_value, bool_value = row[3:]
|
||||
if string_value is not None:
|
||||
metadata[key] = string_value
|
||||
elif int_value is not None:
|
||||
metadata[key] = int_value
|
||||
elif float_value is not None:
|
||||
metadata[key] = float_value
|
||||
elif bool_value is not None:
|
||||
if bool_value == 1:
|
||||
metadata[key] = True
|
||||
else:
|
||||
metadata[key] = False
|
||||
|
||||
return MetadataEmbeddingRecord(
|
||||
id=embedding_id,
|
||||
metadata=metadata or None,
|
||||
)
|
||||
|
||||
@trace_method("SqliteMetadataSegment._insert_record", OpenTelemetryGranularity.ALL)
|
||||
def _insert_record(self, cur: Cursor, record: LogRecord, upsert: bool) -> None:
|
||||
"""Add or update a single EmbeddingRecord into the DB"""
|
||||
|
||||
t = Table("embeddings")
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.into(t)
|
||||
.columns(t.segment_id, t.embedding_id, t.seq_id)
|
||||
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
|
||||
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
|
||||
).insert(
|
||||
ParameterValue(self._db.uuid_to_db(self._id)),
|
||||
ParameterValue(record["record"]["id"]),
|
||||
ParameterValue(record["log_offset"]),
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
sql = sql + "RETURNING id"
|
||||
try:
|
||||
id = cur.execute(sql, params).fetchone()[0]
|
||||
except sqlite3.IntegrityError:
|
||||
# Can't use INSERT OR REPLACE here because it changes the primary key.
|
||||
if upsert:
|
||||
return self._update_record(cur, record)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Insert of existing embedding ID: {record['record']['id']}"
|
||||
)
|
||||
# We are trying to add for a record that already exists. Fail the call.
|
||||
# We don't throw an exception since this is in principal an async path
|
||||
return
|
||||
|
||||
if record["record"]["metadata"]:
|
||||
self._update_metadata(cur, id, record["record"]["metadata"])
|
||||
|
||||
@trace_method(
|
||||
"SqliteMetadataSegment._update_metadata", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
def _update_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> None:
|
||||
"""Update the metadata for a single EmbeddingRecord"""
|
||||
t = Table("embedding_metadata")
|
||||
to_delete = [k for k, v in metadata.items() if v is None]
|
||||
if to_delete:
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.where(t.id == ParameterValue(id))
|
||||
.where(t.key.isin(ParameterValue(to_delete)))
|
||||
.delete()
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
cur.execute(sql, params)
|
||||
|
||||
self._insert_metadata(cur, id, metadata)
|
||||
|
||||
@trace_method(
|
||||
"SqliteMetadataSegment._insert_metadata", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
def _insert_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> None:
|
||||
"""Insert or update each metadata row for a single embedding record"""
|
||||
t = Table("embedding_metadata")
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.into(t)
|
||||
.columns(
|
||||
t.id,
|
||||
t.key,
|
||||
t.string_value,
|
||||
t.int_value,
|
||||
t.float_value,
|
||||
t.bool_value,
|
||||
)
|
||||
)
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
q = q.insert(
|
||||
ParameterValue(id),
|
||||
ParameterValue(key),
|
||||
ParameterValue(value),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
# isinstance(True, int) evaluates to True, so we need to check for bools separately
|
||||
elif isinstance(value, bool):
|
||||
q = q.insert(
|
||||
ParameterValue(id),
|
||||
ParameterValue(key),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
ParameterValue(value),
|
||||
)
|
||||
elif isinstance(value, int):
|
||||
q = q.insert(
|
||||
ParameterValue(id),
|
||||
ParameterValue(key),
|
||||
None,
|
||||
ParameterValue(value),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
elif isinstance(value, float):
|
||||
q = q.insert(
|
||||
ParameterValue(id),
|
||||
ParameterValue(key),
|
||||
None,
|
||||
None,
|
||||
ParameterValue(value),
|
||||
None,
|
||||
)
|
||||
|
||||
sql, params = get_sql(q)
|
||||
sql = sql.replace("INSERT", "INSERT OR REPLACE")
|
||||
if sql:
|
||||
cur.execute(sql, params)
|
||||
|
||||
if "chroma:document" in metadata:
|
||||
t = Table("embedding_fulltext_search")
|
||||
|
||||
def insert_into_fulltext_search() -> None:
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.into(t)
|
||||
.columns(t.rowid, t.string_value)
|
||||
.insert(
|
||||
ParameterValue(id),
|
||||
ParameterValue(metadata["chroma:document"]),
|
||||
)
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
cur.execute(sql, params)
|
||||
|
||||
try:
|
||||
insert_into_fulltext_search()
|
||||
except sqlite3.IntegrityError:
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.where(t.rowid == ParameterValue(id))
|
||||
.delete()
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
cur.execute(sql, params)
|
||||
insert_into_fulltext_search()
|
||||
|
||||
@trace_method("SqliteMetadataSegment._delete_record", OpenTelemetryGranularity.ALL)
|
||||
def _delete_record(self, cur: Cursor, record: LogRecord) -> None:
|
||||
"""Delete a single EmbeddingRecord from the DB"""
|
||||
t = Table("embeddings")
|
||||
fts_t = Table("embedding_fulltext_search")
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
|
||||
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
|
||||
.delete()
|
||||
)
|
||||
q_fts = (
|
||||
self._db.querybuilder()
|
||||
.from_(fts_t)
|
||||
.delete()
|
||||
.where(
|
||||
fts_t.rowid.isin(
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.select(t.id)
|
||||
.where(
|
||||
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
|
||||
)
|
||||
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
|
||||
)
|
||||
)
|
||||
)
|
||||
cur.execute(*get_sql(q_fts))
|
||||
sql, params = get_sql(q)
|
||||
sql = sql + " RETURNING id"
|
||||
result = cur.execute(sql, params).fetchone()
|
||||
if result is None:
|
||||
logger.warning(
|
||||
f"Delete of nonexisting embedding ID: {record['record']['id']}"
|
||||
)
|
||||
else:
|
||||
id = result[0]
|
||||
|
||||
# Manually delete metadata; cannot use cascade because
|
||||
# that triggers on replace
|
||||
metadata_t = Table("embedding_metadata")
|
||||
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.from_(metadata_t)
|
||||
.where(metadata_t.id == ParameterValue(id))
|
||||
.delete()
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
cur.execute(sql, params)
|
||||
|
||||
@trace_method("SqliteMetadataSegment._update_record", OpenTelemetryGranularity.ALL)
|
||||
def _update_record(self, cur: Cursor, record: LogRecord) -> None:
|
||||
"""Update a single EmbeddingRecord in the DB"""
|
||||
t = Table("embeddings")
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.update(t)
|
||||
.set(t.seq_id, ParameterValue(record["log_offset"]))
|
||||
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
|
||||
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
sql = sql + " RETURNING id"
|
||||
result = cur.execute(sql, params).fetchone()
|
||||
if result is None:
|
||||
logger.warning(
|
||||
f"Update of nonexisting embedding ID: {record['record']['id']}"
|
||||
)
|
||||
else:
|
||||
id = result[0]
|
||||
if record["record"]["metadata"]:
|
||||
self._update_metadata(cur, id, record["record"]["metadata"])
|
||||
|
||||
@trace_method("SqliteMetadataSegment._write_metadata", OpenTelemetryGranularity.ALL)
|
||||
def _write_metadata(self, records: Sequence[LogRecord]) -> None:
|
||||
"""Write embedding metadata to the database. Care should be taken to ensure
|
||||
records are append-only (that is, that seq-ids should increase monotonically)"""
|
||||
with self._db.tx() as cur:
|
||||
for record in records:
|
||||
if record["record"]["operation"] == Operation.ADD:
|
||||
self._insert_record(cur, record, False)
|
||||
elif record["record"]["operation"] == Operation.UPSERT:
|
||||
self._insert_record(cur, record, True)
|
||||
elif record["record"]["operation"] == Operation.DELETE:
|
||||
self._delete_record(cur, record)
|
||||
elif record["record"]["operation"] == Operation.UPDATE:
|
||||
self._update_record(cur, record)
|
||||
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.into(Table("max_seq_id"))
|
||||
.columns("segment_id", "seq_id")
|
||||
.insert(
|
||||
ParameterValue(self._db.uuid_to_db(self._id)),
|
||||
ParameterValue(record["log_offset"]),
|
||||
)
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
sql = sql.replace("INSERT", "INSERT OR REPLACE")
|
||||
cur.execute(sql, params)
|
||||
|
||||
@trace_method(
|
||||
"SqliteMetadataSegment._where_map_criterion", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
def _where_map_criterion(
|
||||
self, q: QueryBuilder, where: Where, metadata_t: Table, embeddings_t: Table
|
||||
) -> Criterion:
|
||||
clause: List[Criterion] = []
|
||||
for k, v in where.items():
|
||||
if k == "$and":
|
||||
criteria = [
|
||||
self._where_map_criterion(q, w, metadata_t, embeddings_t)
|
||||
for w in cast(Sequence[Where], v)
|
||||
]
|
||||
clause.append(reduce(lambda x, y: x & y, criteria))
|
||||
elif k == "$or":
|
||||
criteria = [
|
||||
self._where_map_criterion(q, w, metadata_t, embeddings_t)
|
||||
for w in cast(Sequence[Where], v)
|
||||
]
|
||||
clause.append(reduce(lambda x, y: x | y, criteria))
|
||||
else:
|
||||
expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v)
|
||||
clause.append(_where_clause(k, expr, q, metadata_t, embeddings_t))
|
||||
return reduce(lambda x, y: x & y, clause)
|
||||
|
||||
@trace_method(
|
||||
"SqliteMetadataSegment._where_doc_criterion", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
def _where_doc_criterion(
|
||||
self,
|
||||
q: QueryBuilder,
|
||||
where: WhereDocument,
|
||||
metadata_t: Table,
|
||||
fulltext_t: Table,
|
||||
embeddings_t: Table,
|
||||
) -> Criterion:
|
||||
for k, v in where.items():
|
||||
if k == "$and":
|
||||
criteria = [
|
||||
self._where_doc_criterion(
|
||||
q, w, metadata_t, fulltext_t, embeddings_t
|
||||
)
|
||||
for w in cast(Sequence[WhereDocument], v)
|
||||
]
|
||||
return reduce(lambda x, y: x & y, criteria)
|
||||
elif k == "$or":
|
||||
criteria = [
|
||||
self._where_doc_criterion(
|
||||
q, w, metadata_t, fulltext_t, embeddings_t
|
||||
)
|
||||
for w in cast(Sequence[WhereDocument], v)
|
||||
]
|
||||
return reduce(lambda x, y: x | y, criteria)
|
||||
elif k in ("$contains", "$not_contains"):
|
||||
v = cast(str, v)
|
||||
search_term = f"%{v}%"
|
||||
|
||||
sq = (
|
||||
self._db.querybuilder()
|
||||
.from_(fulltext_t)
|
||||
.select(fulltext_t.rowid)
|
||||
.where(fulltext_t.string_value.like(ParameterValue(search_term)))
|
||||
)
|
||||
return (
|
||||
embeddings_t.id.isin(sq)
|
||||
if k == "$contains"
|
||||
else embeddings_t.id.notin(sq)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown where_doc operator {k}")
|
||||
raise ValueError("Empty where_doc")
|
||||
|
||||
@trace_method("SqliteMetadataSegment.delete", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def delete(self) -> None:
|
||||
t = Table("embeddings")
|
||||
t1 = Table("embedding_metadata")
|
||||
t2 = Table("embedding_fulltext_search")
|
||||
q0 = (
|
||||
self._db.querybuilder()
|
||||
.from_(t1)
|
||||
.delete()
|
||||
.where(
|
||||
t1.id.isin(
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.select(t.id)
|
||||
.where(
|
||||
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.delete()
|
||||
.where(
|
||||
t.id.isin(
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.select(t.id)
|
||||
.where(
|
||||
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
q_fts = (
|
||||
self._db.querybuilder()
|
||||
.from_(t2)
|
||||
.delete()
|
||||
.where(
|
||||
t2.rowid.isin(
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.select(t.id)
|
||||
.where(
|
||||
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
with self._db.tx() as cur:
|
||||
cur.execute(*get_sql(q_fts))
|
||||
cur.execute(*get_sql(q0))
|
||||
cur.execute(*get_sql(q))
|
||||
|
||||
|
||||
def _where_clause(
|
||||
key: str,
|
||||
expr: Union[
|
||||
LiteralValue,
|
||||
Dict[WhereOperator, LiteralValue],
|
||||
Dict[InclusionExclusionOperator, List[LiteralValue]],
|
||||
],
|
||||
metadata_q: QueryBuilder,
|
||||
metadata_t: Table,
|
||||
embeddings_t: Table,
|
||||
) -> Criterion:
|
||||
"""Given a field name, an expression, and a table, construct a Pypika Criterion"""
|
||||
|
||||
# Literal value case
|
||||
if isinstance(expr, (str, int, float, bool)):
|
||||
return _where_clause(
|
||||
key,
|
||||
{cast(WhereOperator, "$eq"): expr},
|
||||
metadata_q,
|
||||
metadata_t,
|
||||
embeddings_t,
|
||||
)
|
||||
|
||||
# Operator dict case
|
||||
operator, value = next(iter(expr.items()))
|
||||
return _value_criterion(key, value, operator, metadata_q, metadata_t, embeddings_t)
|
||||
|
||||
|
||||
def _value_criterion(
|
||||
key: str,
|
||||
value: Union[LiteralValue, List[LiteralValue]],
|
||||
op: Union[WhereOperator, InclusionExclusionOperator],
|
||||
metadata_q: QueryBuilder,
|
||||
metadata_t: Table,
|
||||
embeddings_t: Table,
|
||||
) -> Criterion:
|
||||
"""Creates the filter for a single operator"""
|
||||
|
||||
def is_numeric(obj: object) -> bool:
|
||||
return (not isinstance(obj, bool)) and isinstance(obj, (int, float))
|
||||
|
||||
sub_q = metadata_q.where(metadata_t.key == ParameterValue(key))
|
||||
p_val = ParameterValue(value)
|
||||
|
||||
if is_numeric(value) or (isinstance(value, list) and is_numeric(value[0])):
|
||||
int_col, float_col = metadata_t.int_value, metadata_t.float_value
|
||||
if op in ("$eq", "$ne"):
|
||||
expr = (int_col == p_val) | (float_col == p_val)
|
||||
elif op == "$gt":
|
||||
expr = (int_col > p_val) | (float_col > p_val)
|
||||
elif op == "$gte":
|
||||
expr = (int_col >= p_val) | (float_col >= p_val)
|
||||
elif op == "$lt":
|
||||
expr = (int_col < p_val) | (float_col < p_val)
|
||||
elif op == "$lte":
|
||||
expr = (int_col <= p_val) | (float_col <= p_val)
|
||||
else:
|
||||
expr = int_col.isin(p_val) | float_col.isin(p_val)
|
||||
else:
|
||||
if isinstance(value, bool) or (
|
||||
isinstance(value, list) and isinstance(value[0], bool)
|
||||
):
|
||||
col = metadata_t.bool_value
|
||||
else:
|
||||
col = metadata_t.string_value
|
||||
if op in ("$eq", "$ne"):
|
||||
expr = col == p_val
|
||||
else:
|
||||
expr = col.isin(p_val)
|
||||
|
||||
if op in ("$ne", "$nin"):
|
||||
return embeddings_t.id.notin(sub_q.where(expr))
|
||||
else:
|
||||
return embeddings_t.id.isin(sub_q.where(expr))
|
||||
@@ -0,0 +1,106 @@
|
||||
from typing import Dict, List, Set, cast
|
||||
from chromadb.types import LogRecord, Operation, Vector
|
||||
|
||||
|
||||
class Batch:
|
||||
"""Used to model the set of changes as an atomic operation"""
|
||||
|
||||
_ids_to_records: Dict[str, LogRecord]
|
||||
_deleted_ids: Set[str]
|
||||
_written_ids: Set[str]
|
||||
_upsert_add_ids: Set[str] # IDs that are being added in an upsert
|
||||
add_count: int
|
||||
update_count: int
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ids_to_records = {}
|
||||
self._deleted_ids = set()
|
||||
self._written_ids = set()
|
||||
self._upsert_add_ids = set()
|
||||
self.add_count = 0
|
||||
self.update_count = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the number of changes in this batch"""
|
||||
return len(self._written_ids) + len(self._deleted_ids)
|
||||
|
||||
def get_deleted_ids(self) -> List[str]:
|
||||
"""Get the list of deleted embeddings in this batch"""
|
||||
return list(self._deleted_ids)
|
||||
|
||||
def get_written_ids(self) -> List[str]:
|
||||
"""Get the list of written embeddings in this batch"""
|
||||
return list(self._written_ids)
|
||||
|
||||
def get_written_vectors(self, ids: List[str]) -> List[Vector]:
|
||||
"""Get the list of vectors to write in this batch"""
|
||||
return [
|
||||
cast(Vector, self._ids_to_records[id]["record"]["embedding"]) for id in ids
|
||||
]
|
||||
|
||||
def get_record(self, id: str) -> LogRecord:
|
||||
"""Get the record for a given ID"""
|
||||
return self._ids_to_records[id]
|
||||
|
||||
def is_deleted(self, id: str) -> bool:
|
||||
"""Check if a given ID is deleted"""
|
||||
return id in self._deleted_ids
|
||||
|
||||
@property
|
||||
def delete_count(self) -> int:
|
||||
return len(self._deleted_ids)
|
||||
|
||||
def apply(self, record: LogRecord, exists_already: bool = False) -> None:
|
||||
"""Apply an embedding record to this batch. Records passed to this method are assumed to be validated for correctness.
|
||||
For example, a delete or update presumes the ID exists in the index. An add presumes the ID does not exist in the index.
|
||||
The exists_already flag should be set to True if the ID does exist in the index, and False otherwise.
|
||||
"""
|
||||
id = record["record"]["id"]
|
||||
if record["record"]["operation"] == Operation.DELETE:
|
||||
# If the ID was previously written, remove it from the written set
|
||||
# And update the add/update/delete counts
|
||||
if id in self._written_ids:
|
||||
self._written_ids.remove(id)
|
||||
if self._ids_to_records[id]["record"]["operation"] == Operation.ADD:
|
||||
self.add_count -= 1
|
||||
elif (
|
||||
self._ids_to_records[id]["record"]["operation"] == Operation.UPDATE
|
||||
):
|
||||
self.update_count -= 1
|
||||
self._deleted_ids.add(id)
|
||||
elif (
|
||||
self._ids_to_records[id]["record"]["operation"] == Operation.UPSERT
|
||||
):
|
||||
if id in self._upsert_add_ids:
|
||||
self.add_count -= 1
|
||||
self._upsert_add_ids.remove(id)
|
||||
else:
|
||||
self.update_count -= 1
|
||||
self._deleted_ids.add(id)
|
||||
elif id not in self._deleted_ids:
|
||||
self._deleted_ids.add(id)
|
||||
|
||||
# Remove the record from the batch
|
||||
if id in self._ids_to_records:
|
||||
del self._ids_to_records[id]
|
||||
|
||||
else:
|
||||
self._ids_to_records[id] = record
|
||||
self._written_ids.add(id)
|
||||
|
||||
# If the ID was previously deleted, remove it from the deleted set
|
||||
# And update the delete count
|
||||
if id in self._deleted_ids:
|
||||
self._deleted_ids.remove(id)
|
||||
|
||||
# Update the add/update counts
|
||||
if record["record"]["operation"] == Operation.UPSERT:
|
||||
if not exists_already:
|
||||
self.add_count += 1
|
||||
self._upsert_add_ids.add(id)
|
||||
else:
|
||||
self.update_count += 1
|
||||
elif record["record"]["operation"] == Operation.ADD:
|
||||
self.add_count += 1
|
||||
elif record["record"]["operation"] == Operation.UPDATE:
|
||||
self.update_count += 1
|
||||
@@ -0,0 +1,151 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from chromadb.types import (
|
||||
LogRecord,
|
||||
VectorEmbeddingRecord,
|
||||
VectorQuery,
|
||||
VectorQueryResult,
|
||||
)
|
||||
|
||||
from chromadb.utils import distance_functions
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BruteForceIndex:
|
||||
"""A lightweight, numpy based brute force index that is used for batches that have not been indexed into hnsw yet. It is not
|
||||
thread safe and callers should ensure that only one thread is accessing it at a time.
|
||||
"""
|
||||
|
||||
id_to_index: Dict[str, int]
|
||||
index_to_id: Dict[int, str]
|
||||
id_to_seq_id: Dict[str, int]
|
||||
deleted_ids: Set[str]
|
||||
free_indices: List[int]
|
||||
size: int
|
||||
dimensionality: int
|
||||
distance_fn: Callable[[npt.NDArray[Any], npt.NDArray[Any]], float]
|
||||
vectors: npt.NDArray[Any]
|
||||
|
||||
def __init__(self, size: int, dimensionality: int, space: str = "l2"):
|
||||
if space == "l2":
|
||||
self.distance_fn = distance_functions.l2
|
||||
elif space == "ip":
|
||||
self.distance_fn = distance_functions.ip
|
||||
elif space == "cosine":
|
||||
self.distance_fn = distance_functions.cosine
|
||||
else:
|
||||
raise Exception(f"Unknown distance function: {space}")
|
||||
|
||||
self.id_to_index = {}
|
||||
self.index_to_id = {}
|
||||
self.id_to_seq_id = {}
|
||||
self.deleted_ids = set()
|
||||
self.free_indices = list(range(size))
|
||||
self.size = size
|
||||
self.dimensionality = dimensionality
|
||||
self.vectors = np.zeros((size, dimensionality))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.id_to_index)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.id_to_index = {}
|
||||
self.index_to_id = {}
|
||||
self.id_to_seq_id = {}
|
||||
self.deleted_ids.clear()
|
||||
self.free_indices = list(range(self.size))
|
||||
self.vectors.fill(0)
|
||||
|
||||
def upsert(self, records: List[LogRecord]) -> None:
|
||||
if len(records) + len(self) > self.size:
|
||||
raise Exception(
|
||||
"Index with capacity {} and {} current entries cannot add {} records".format(
|
||||
self.size, len(self), len(records)
|
||||
)
|
||||
)
|
||||
|
||||
for i, record in enumerate(records):
|
||||
id = record["record"]["id"]
|
||||
vector = record["record"]["embedding"]
|
||||
self.id_to_seq_id[id] = record["log_offset"]
|
||||
if id in self.deleted_ids:
|
||||
self.deleted_ids.remove(id)
|
||||
|
||||
# TODO: It may be faster to use multi-index selection on the vectors array
|
||||
if id in self.id_to_index:
|
||||
# Update
|
||||
index = self.id_to_index[id]
|
||||
self.vectors[index] = vector
|
||||
else:
|
||||
# Add
|
||||
next_index = self.free_indices.pop()
|
||||
self.id_to_index[id] = next_index
|
||||
self.index_to_id[next_index] = id
|
||||
self.vectors[next_index] = vector
|
||||
|
||||
def delete(self, records: List[LogRecord]) -> None:
|
||||
for record in records:
|
||||
id = record["record"]["id"]
|
||||
if id in self.id_to_index:
|
||||
index = self.id_to_index[id]
|
||||
self.deleted_ids.add(id)
|
||||
del self.id_to_index[id]
|
||||
del self.index_to_id[index]
|
||||
del self.id_to_seq_id[id]
|
||||
self.vectors[index].fill(np.nan)
|
||||
self.free_indices.append(index)
|
||||
else:
|
||||
logger.warning(f"Delete of nonexisting embedding ID: {id}")
|
||||
|
||||
def has_id(self, id: str) -> bool:
|
||||
"""Returns whether the index contains the given ID"""
|
||||
return id in self.id_to_index and id not in self.deleted_ids
|
||||
|
||||
def get_vectors(
|
||||
self, ids: Optional[Sequence[str]] = None
|
||||
) -> Sequence[VectorEmbeddingRecord]:
|
||||
target_ids = ids or self.id_to_index.keys()
|
||||
|
||||
return [
|
||||
VectorEmbeddingRecord(
|
||||
id=id,
|
||||
embedding=self.vectors[self.id_to_index[id]],
|
||||
)
|
||||
for id in target_ids
|
||||
]
|
||||
|
||||
def query(self, query: VectorQuery) -> Sequence[Sequence[VectorQueryResult]]:
|
||||
np_query = np.array(query["vectors"], dtype=np.float32)
|
||||
allowed_ids = (
|
||||
None if query["allowed_ids"] is None else set(query["allowed_ids"])
|
||||
)
|
||||
distances = np.apply_along_axis(
|
||||
lambda query: np.apply_along_axis(self.distance_fn, 1, self.vectors, query),
|
||||
1,
|
||||
np_query,
|
||||
)
|
||||
|
||||
indices = np.argsort(distances)
|
||||
# Filter out deleted labels
|
||||
filtered_results = []
|
||||
for i, index_list in enumerate(indices):
|
||||
curr_results = []
|
||||
for j in index_list:
|
||||
# If the index is in the index_to_id map, then it has been added
|
||||
if j in self.index_to_id:
|
||||
id = self.index_to_id[j]
|
||||
if id not in self.deleted_ids and (
|
||||
allowed_ids is None or id in allowed_ids
|
||||
):
|
||||
curr_results.append(
|
||||
VectorQueryResult(
|
||||
id=id,
|
||||
distance=distances[i][j].item(),
|
||||
embedding=self.vectors[j],
|
||||
)
|
||||
)
|
||||
filtered_results.append(curr_results)
|
||||
return filtered_results
|
||||
@@ -0,0 +1,88 @@
|
||||
import multiprocessing
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
from chromadb.types import Metadata
|
||||
|
||||
|
||||
Validator = Callable[[Union[str, int, float]], bool]
|
||||
|
||||
param_validators: Dict[str, Validator] = {
|
||||
"hnsw:space": lambda p: bool(re.match(r"^(l2|cosine|ip)$", str(p))),
|
||||
"hnsw:construction_ef": lambda p: isinstance(p, int),
|
||||
"hnsw:search_ef": lambda p: isinstance(p, int),
|
||||
"hnsw:M": lambda p: isinstance(p, int),
|
||||
"hnsw:num_threads": lambda p: isinstance(p, int),
|
||||
"hnsw:resize_factor": lambda p: isinstance(p, (int, float)),
|
||||
}
|
||||
|
||||
# Extra params used for persistent hnsw
|
||||
persistent_param_validators: Dict[str, Validator] = {
|
||||
"hnsw:batch_size": lambda p: isinstance(p, int) and p > 2,
|
||||
"hnsw:sync_threshold": lambda p: isinstance(p, int) and p > 2,
|
||||
}
|
||||
|
||||
|
||||
class Params:
|
||||
@staticmethod
|
||||
def _select(metadata: Metadata) -> Dict[str, Any]:
|
||||
segment_metadata = {}
|
||||
for param, value in metadata.items():
|
||||
if param.startswith("hnsw:"):
|
||||
segment_metadata[param] = value
|
||||
return segment_metadata
|
||||
|
||||
@staticmethod
|
||||
def _validate(metadata: Dict[str, Any], validators: Dict[str, Validator]) -> None:
|
||||
"""Validates the metadata"""
|
||||
# Validate it
|
||||
for param, value in metadata.items():
|
||||
if param not in validators:
|
||||
raise ValueError(f"Unknown HNSW parameter: {param}")
|
||||
if not validators[param](value):
|
||||
raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}")
|
||||
|
||||
|
||||
class HnswParams(Params):
|
||||
space: str
|
||||
construction_ef: int
|
||||
search_ef: int
|
||||
M: int
|
||||
num_threads: int
|
||||
resize_factor: float
|
||||
|
||||
def __init__(self, metadata: Metadata):
|
||||
metadata = metadata or {}
|
||||
self.space = str(metadata.get("hnsw:space", "l2"))
|
||||
self.construction_ef = int(metadata.get("hnsw:construction_ef", 100))
|
||||
self.search_ef = int(metadata.get("hnsw:search_ef", 100))
|
||||
self.M = int(metadata.get("hnsw:M", 16))
|
||||
self.num_threads = int(
|
||||
metadata.get("hnsw:num_threads", multiprocessing.cpu_count())
|
||||
)
|
||||
self.resize_factor = float(metadata.get("hnsw:resize_factor", 1.2))
|
||||
|
||||
@staticmethod
|
||||
def extract(metadata: Metadata) -> Metadata:
|
||||
"""Validate and return only the relevant hnsw params"""
|
||||
segment_metadata = HnswParams._select(metadata)
|
||||
HnswParams._validate(segment_metadata, param_validators)
|
||||
return segment_metadata
|
||||
|
||||
|
||||
class PersistentHnswParams(HnswParams):
|
||||
batch_size: int
|
||||
sync_threshold: int
|
||||
|
||||
def __init__(self, metadata: Metadata):
|
||||
super().__init__(metadata)
|
||||
self.batch_size = int(metadata.get("hnsw:batch_size", 100))
|
||||
self.sync_threshold = int(metadata.get("hnsw:sync_threshold", 1000))
|
||||
|
||||
@staticmethod
|
||||
def extract(metadata: Metadata) -> Metadata:
|
||||
"""Returns only the relevant hnsw params"""
|
||||
all_validators = {**param_validators, **persistent_param_validators}
|
||||
segment_metadata = PersistentHnswParams._select(metadata)
|
||||
PersistentHnswParams._validate(segment_metadata, all_validators)
|
||||
return segment_metadata
|
||||
@@ -0,0 +1,332 @@
|
||||
from overrides import override
|
||||
from typing import Optional, Sequence, Dict, Set, List, cast
|
||||
from uuid import UUID
|
||||
from chromadb.segment import VectorReader
|
||||
from chromadb.ingest import Consumer
|
||||
from chromadb.config import System, Settings
|
||||
from chromadb.segment.impl.vector.batch import Batch
|
||||
from chromadb.segment.impl.vector.hnsw_params import HnswParams
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.types import (
|
||||
LogRecord,
|
||||
RequestVersionContext,
|
||||
VectorEmbeddingRecord,
|
||||
VectorQuery,
|
||||
VectorQueryResult,
|
||||
SeqId,
|
||||
Segment,
|
||||
Metadata,
|
||||
Operation,
|
||||
Vector,
|
||||
)
|
||||
from chromadb.errors import InvalidDimensionException
|
||||
import hnswlib
|
||||
from chromadb.utils.read_write_lock import ReadWriteLock, ReadRWLock, WriteRWLock
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CAPACITY = 1000
|
||||
|
||||
|
||||
class LocalHnswSegment(VectorReader):
|
||||
_id: UUID
|
||||
_consumer: Consumer
|
||||
_collection: Optional[UUID]
|
||||
_subscription: Optional[UUID]
|
||||
_settings: Settings
|
||||
_params: HnswParams
|
||||
|
||||
_index: Optional[hnswlib.Index]
|
||||
_dimensionality: Optional[int]
|
||||
_total_elements_added: int
|
||||
_max_seq_id: SeqId
|
||||
|
||||
_lock: ReadWriteLock
|
||||
|
||||
_id_to_label: Dict[str, int]
|
||||
_label_to_id: Dict[int, str]
|
||||
# Note: As of the time of writing, this mapping is no longer needed.
|
||||
# We merely keep it around for easy compatibility with the old code and
|
||||
# debugging purposes.
|
||||
_id_to_seq_id: Dict[str, SeqId]
|
||||
|
||||
_opentelemtry_client: OpenTelemetryClient
|
||||
|
||||
def __init__(self, system: System, segment: Segment):
|
||||
self._consumer = system.instance(Consumer)
|
||||
self._id = segment["id"]
|
||||
self._collection = segment["collection"]
|
||||
self._subscription = None
|
||||
self._settings = system.settings
|
||||
self._params = HnswParams(segment["metadata"] or {})
|
||||
|
||||
self._index = None
|
||||
self._dimensionality = None
|
||||
self._total_elements_added = 0
|
||||
self._max_seq_id = self._consumer.min_seqid()
|
||||
|
||||
self._id_to_seq_id = {}
|
||||
self._id_to_label = {}
|
||||
self._label_to_id = {}
|
||||
|
||||
self._lock = ReadWriteLock()
|
||||
self._opentelemtry_client = system.require(OpenTelemetryClient)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]:
|
||||
# Extract relevant metadata
|
||||
segment_metadata = HnswParams.extract(metadata)
|
||||
return segment_metadata
|
||||
|
||||
@trace_method("LocalHnswSegment.start", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
if self._collection:
|
||||
seq_id = self.max_seqid()
|
||||
self._subscription = self._consumer.subscribe(
|
||||
self._collection, self._write_records, start=seq_id
|
||||
)
|
||||
|
||||
@trace_method("LocalHnswSegment.stop", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
if self._subscription:
|
||||
self._consumer.unsubscribe(self._subscription)
|
||||
|
||||
@trace_method("LocalHnswSegment.get_vectors", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def get_vectors(
|
||||
self,
|
||||
request_version_context: RequestVersionContext,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
) -> Sequence[VectorEmbeddingRecord]:
|
||||
if ids is None:
|
||||
labels = list(self._label_to_id.keys())
|
||||
else:
|
||||
labels = []
|
||||
for id in ids:
|
||||
if id in self._id_to_label:
|
||||
labels.append(self._id_to_label[id])
|
||||
|
||||
results = []
|
||||
if self._index is not None:
|
||||
vectors = cast(
|
||||
Sequence[Vector], np.array(self._index.get_items(labels))
|
||||
) # version 0.8 of hnswlib allows return_type="numpy"
|
||||
|
||||
for label, vector in zip(labels, vectors):
|
||||
id = self._label_to_id[label]
|
||||
results.append(VectorEmbeddingRecord(id=id, embedding=vector))
|
||||
|
||||
return results
|
||||
|
||||
@trace_method("LocalHnswSegment.query_vectors", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def query_vectors(
|
||||
self, query: VectorQuery
|
||||
) -> Sequence[Sequence[VectorQueryResult]]:
|
||||
if self._index is None:
|
||||
return [[] for _ in range(len(query["vectors"]))]
|
||||
|
||||
k = query["k"]
|
||||
size = len(self._id_to_label)
|
||||
|
||||
if k > size:
|
||||
logger.warning(
|
||||
f"Number of requested results {k} is greater than number of elements in index {size}, updating n_results = {size}"
|
||||
)
|
||||
k = size
|
||||
|
||||
labels: Set[int] = set()
|
||||
ids = query["allowed_ids"]
|
||||
if ids is not None:
|
||||
labels = {self._id_to_label[id] for id in ids if id in self._id_to_label}
|
||||
if len(labels) < k:
|
||||
k = len(labels)
|
||||
|
||||
def filter_function(label: int) -> bool:
|
||||
return label in labels
|
||||
|
||||
query_vectors = query["vectors"]
|
||||
|
||||
with ReadRWLock(self._lock):
|
||||
result_labels, distances = self._index.knn_query(
|
||||
np.array(query_vectors, dtype=np.float32),
|
||||
k=k,
|
||||
filter=filter_function if ids else None,
|
||||
)
|
||||
|
||||
# TODO: these casts are not correct, hnswlib returns np
|
||||
# distances = cast(List[List[float]], distances)
|
||||
# result_labels = cast(List[List[int]], result_labels)
|
||||
|
||||
all_results: List[List[VectorQueryResult]] = []
|
||||
for result_i in range(len(result_labels)):
|
||||
results: List[VectorQueryResult] = []
|
||||
for label, distance in zip(
|
||||
result_labels[result_i], distances[result_i]
|
||||
):
|
||||
id = self._label_to_id[label]
|
||||
if query["include_embeddings"]:
|
||||
embedding = np.array(
|
||||
self._index.get_items([label])[0]
|
||||
) # version 0.8 of hnswlib allows return_type="numpy"
|
||||
else:
|
||||
embedding = None
|
||||
results.append(
|
||||
VectorQueryResult(
|
||||
id=id,
|
||||
distance=distance.item(),
|
||||
embedding=embedding,
|
||||
)
|
||||
)
|
||||
all_results.append(results)
|
||||
|
||||
return all_results
|
||||
|
||||
@override
|
||||
def max_seqid(self) -> SeqId:
|
||||
return self._max_seq_id
|
||||
|
||||
@override
|
||||
def count(self, request_version_context: RequestVersionContext) -> int:
|
||||
return len(self._id_to_label)
|
||||
|
||||
@trace_method("LocalHnswSegment._init_index", OpenTelemetryGranularity.ALL)
|
||||
def _init_index(self, dimensionality: int) -> None:
|
||||
# more comments available at the source: https://github.com/nmslib/hnswlib
|
||||
|
||||
index = hnswlib.Index(
|
||||
space=self._params.space, dim=dimensionality
|
||||
) # possible options are l2, cosine or ip
|
||||
index.init_index(
|
||||
max_elements=DEFAULT_CAPACITY,
|
||||
ef_construction=self._params.construction_ef,
|
||||
M=self._params.M,
|
||||
)
|
||||
index.set_ef(self._params.search_ef)
|
||||
index.set_num_threads(self._params.num_threads)
|
||||
|
||||
self._index = index
|
||||
self._dimensionality = dimensionality
|
||||
|
||||
@trace_method("LocalHnswSegment._ensure_index", OpenTelemetryGranularity.ALL)
|
||||
def _ensure_index(self, n: int, dim: int) -> None:
|
||||
"""Create or resize the index as necessary to accomodate N new records"""
|
||||
if not self._index:
|
||||
self._dimensionality = dim
|
||||
self._init_index(dim)
|
||||
else:
|
||||
if dim != self._dimensionality:
|
||||
raise InvalidDimensionException(
|
||||
f"Dimensionality of ({dim}) does not match index"
|
||||
+ f"dimensionality ({self._dimensionality})"
|
||||
)
|
||||
|
||||
index = cast(hnswlib.Index, self._index)
|
||||
|
||||
if (self._total_elements_added + n) > index.get_max_elements():
|
||||
new_size = int(
|
||||
(self._total_elements_added + n) * self._params.resize_factor
|
||||
)
|
||||
index.resize_index(max(new_size, DEFAULT_CAPACITY))
|
||||
|
||||
@trace_method("LocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL)
|
||||
def _apply_batch(self, batch: Batch) -> None:
|
||||
"""Apply a batch of changes, as atomically as possible."""
|
||||
deleted_ids = batch.get_deleted_ids()
|
||||
written_ids = batch.get_written_ids()
|
||||
vectors_to_write = batch.get_written_vectors(written_ids)
|
||||
labels_to_write = [0] * len(vectors_to_write)
|
||||
|
||||
if len(deleted_ids) > 0:
|
||||
index = cast(hnswlib.Index, self._index)
|
||||
for i in range(len(deleted_ids)):
|
||||
id = deleted_ids[i]
|
||||
# Never added this id to hnsw, so we can safely ignore it for deletions
|
||||
if id not in self._id_to_label:
|
||||
continue
|
||||
label = self._id_to_label[id]
|
||||
|
||||
index.mark_deleted(label)
|
||||
del self._id_to_label[id]
|
||||
del self._label_to_id[label]
|
||||
del self._id_to_seq_id[id]
|
||||
|
||||
if len(written_ids) > 0:
|
||||
self._ensure_index(batch.add_count, len(vectors_to_write[0]))
|
||||
|
||||
next_label = self._total_elements_added + 1
|
||||
for i in range(len(written_ids)):
|
||||
if written_ids[i] not in self._id_to_label:
|
||||
labels_to_write[i] = next_label
|
||||
next_label += 1
|
||||
else:
|
||||
labels_to_write[i] = self._id_to_label[written_ids[i]]
|
||||
|
||||
index = cast(hnswlib.Index, self._index)
|
||||
|
||||
# First, update the index
|
||||
index.add_items(vectors_to_write, labels_to_write)
|
||||
|
||||
# If that succeeds, update the mappings
|
||||
for i, id in enumerate(written_ids):
|
||||
self._id_to_seq_id[id] = batch.get_record(id)["log_offset"]
|
||||
self._id_to_label[id] = labels_to_write[i]
|
||||
self._label_to_id[labels_to_write[i]] = id
|
||||
|
||||
# If that succeeds, update the total count
|
||||
self._total_elements_added += batch.add_count
|
||||
|
||||
@trace_method("LocalHnswSegment._write_records", OpenTelemetryGranularity.ALL)
|
||||
def _write_records(self, records: Sequence[LogRecord]) -> None:
|
||||
"""Add a batch of embeddings to the index"""
|
||||
if not self._running:
|
||||
raise RuntimeError("Cannot add embeddings to stopped component")
|
||||
|
||||
# Avoid all sorts of potential problems by ensuring single-threaded access
|
||||
with WriteRWLock(self._lock):
|
||||
batch = Batch()
|
||||
|
||||
for record in records:
|
||||
self._max_seq_id = max(self._max_seq_id, record["log_offset"])
|
||||
id = record["record"]["id"]
|
||||
op = record["record"]["operation"]
|
||||
label = self._id_to_label.get(id, None)
|
||||
|
||||
if op == Operation.DELETE:
|
||||
if label:
|
||||
batch.apply(record)
|
||||
else:
|
||||
logger.warning(f"Delete of nonexisting embedding ID: {id}")
|
||||
|
||||
elif op == Operation.UPDATE:
|
||||
if record["record"]["embedding"] is not None:
|
||||
if label is not None:
|
||||
batch.apply(record)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Update of nonexisting embedding ID: {record['record']['id']}"
|
||||
)
|
||||
elif op == Operation.ADD:
|
||||
if not label:
|
||||
batch.apply(record, False)
|
||||
else:
|
||||
logger.warning(f"Add of existing embedding ID: {id}")
|
||||
elif op == Operation.UPSERT:
|
||||
batch.apply(record, label is not None)
|
||||
|
||||
self._apply_batch(batch)
|
||||
|
||||
@override
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,543 @@
|
||||
import os
|
||||
import shutil
|
||||
from overrides import override
|
||||
import pickle
|
||||
from typing import Dict, List, Optional, Sequence, Set, cast
|
||||
from chromadb.config import System
|
||||
from chromadb.db.base import ParameterValue, get_sql
|
||||
from chromadb.db.impl.sqlite import SqliteDB
|
||||
from chromadb.segment.impl.vector.batch import Batch
|
||||
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
|
||||
from chromadb.segment.impl.vector.local_hnsw import (
|
||||
DEFAULT_CAPACITY,
|
||||
LocalHnswSegment,
|
||||
)
|
||||
from chromadb.segment.impl.vector.brute_force_index import BruteForceIndex
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryClient,
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
from chromadb.types import (
|
||||
LogRecord,
|
||||
Metadata,
|
||||
Operation,
|
||||
RequestVersionContext,
|
||||
Segment,
|
||||
SeqId,
|
||||
Vector,
|
||||
VectorEmbeddingRecord,
|
||||
VectorQuery,
|
||||
VectorQueryResult,
|
||||
)
|
||||
import hnswlib
|
||||
import logging
|
||||
from pypika import Table
|
||||
import numpy as np
|
||||
|
||||
from chromadb.utils.read_write_lock import ReadRWLock, WriteRWLock
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PersistentData:
|
||||
"""Stores the data and metadata needed for a PersistentLocalHnswSegment"""
|
||||
|
||||
dimensionality: Optional[int]
|
||||
total_elements_added: int
|
||||
|
||||
max_seq_id: SeqId
|
||||
"This is a legacy field. It is no longer mutated, but kept to allow automatic migration of the `max_seq_id` from the pickled file to the `max_seq_id` table in SQLite."
|
||||
|
||||
id_to_label: Dict[str, int]
|
||||
label_to_id: Dict[int, str]
|
||||
id_to_seq_id: Dict[str, SeqId]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dimensionality: Optional[int],
|
||||
total_elements_added: int,
|
||||
id_to_label: Dict[str, int],
|
||||
label_to_id: Dict[int, str],
|
||||
id_to_seq_id: Dict[str, SeqId],
|
||||
):
|
||||
self.dimensionality = dimensionality
|
||||
self.total_elements_added = total_elements_added
|
||||
self.id_to_label = id_to_label
|
||||
self.label_to_id = label_to_id
|
||||
self.id_to_seq_id = id_to_seq_id
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(filename: str) -> "PersistentData":
|
||||
"""Load persistent data from a file"""
|
||||
with open(filename, "rb") as f:
|
||||
ret = cast(PersistentData, pickle.load(f))
|
||||
return ret
|
||||
|
||||
|
||||
class PersistentLocalHnswSegment(LocalHnswSegment):
|
||||
METADATA_FILE: str = "index_metadata.pickle"
|
||||
# How many records to add to index at once, we do this because crossing the python/c++ boundary is expensive (for add())
|
||||
# When records are not added to the c++ index, they are buffered in memory and served
|
||||
# via brute force search.
|
||||
_batch_size: int
|
||||
_brute_force_index: Optional[BruteForceIndex]
|
||||
_index_initialized: bool = False
|
||||
_curr_batch: Batch
|
||||
# How many records to add to index before syncing to disk
|
||||
_sync_threshold: int
|
||||
_persist_data: PersistentData
|
||||
_persist_directory: str
|
||||
_allow_reset: bool
|
||||
|
||||
_db: SqliteDB
|
||||
_opentelemtry_client: OpenTelemetryClient
|
||||
|
||||
_num_log_records_since_last_batch: int = 0
|
||||
_num_log_records_since_last_persist: int = 0
|
||||
|
||||
def __init__(self, system: System, segment: Segment):
|
||||
super().__init__(system, segment)
|
||||
|
||||
self._db = system.instance(SqliteDB)
|
||||
self._opentelemtry_client = system.require(OpenTelemetryClient)
|
||||
|
||||
self._params = PersistentHnswParams(segment["metadata"] or {})
|
||||
self._batch_size = self._params.batch_size
|
||||
self._sync_threshold = self._params.sync_threshold
|
||||
self._allow_reset = system.settings.allow_reset
|
||||
self._persist_directory = system.settings.require("persist_directory")
|
||||
self._curr_batch = Batch()
|
||||
self._brute_force_index = None
|
||||
if not os.path.exists(self._get_storage_folder()):
|
||||
os.makedirs(self._get_storage_folder(), exist_ok=True)
|
||||
# Load persist data if it exists already, otherwise create it
|
||||
if self._index_exists():
|
||||
self._persist_data = PersistentData.load_from_file(
|
||||
self._get_metadata_file()
|
||||
)
|
||||
self._dimensionality = self._persist_data.dimensionality
|
||||
self._total_elements_added = self._persist_data.total_elements_added
|
||||
self._id_to_label = self._persist_data.id_to_label
|
||||
self._label_to_id = self._persist_data.label_to_id
|
||||
self._id_to_seq_id = self._persist_data.id_to_seq_id
|
||||
# If the index was written to, we need to re-initialize it
|
||||
if len(self._id_to_label) > 0:
|
||||
self._dimensionality = cast(int, self._dimensionality)
|
||||
self._init_index(self._dimensionality)
|
||||
else:
|
||||
self._persist_data = PersistentData(
|
||||
self._dimensionality,
|
||||
self._total_elements_added,
|
||||
self._id_to_label,
|
||||
self._label_to_id,
|
||||
self._id_to_seq_id,
|
||||
)
|
||||
|
||||
# Hydrate the max_seq_id
|
||||
with self._db.tx() as cur:
|
||||
t = Table("max_seq_id")
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.from_(t)
|
||||
.select(t.seq_id)
|
||||
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
|
||||
.limit(1)
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
cur.execute(sql, params)
|
||||
result = cur.fetchone()
|
||||
|
||||
if result:
|
||||
self._max_seq_id = result[0]
|
||||
elif self._index_exists():
|
||||
# Migrate the max_seq_id from the legacy field in the pickled file to the SQLite database
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.into(Table("max_seq_id"))
|
||||
.columns("segment_id", "seq_id")
|
||||
.insert(
|
||||
ParameterValue(self._db.uuid_to_db(self._id)),
|
||||
ParameterValue(self._persist_data.max_seq_id),
|
||||
)
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
cur.execute(sql, params)
|
||||
|
||||
self._max_seq_id = self._persist_data.max_seq_id
|
||||
else:
|
||||
self._max_seq_id = self._consumer.min_seqid()
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]:
|
||||
# Extract relevant metadata
|
||||
segment_metadata = PersistentHnswParams.extract(metadata)
|
||||
return segment_metadata
|
||||
|
||||
def _index_exists(self) -> bool:
|
||||
"""Check if the index exists via the metadata file"""
|
||||
return os.path.exists(self._get_metadata_file())
|
||||
|
||||
def _get_metadata_file(self) -> str:
|
||||
"""Get the metadata file path"""
|
||||
return os.path.join(self._get_storage_folder(), self.METADATA_FILE)
|
||||
|
||||
def _get_storage_folder(self) -> str:
|
||||
"""Get the storage folder path"""
|
||||
folder = os.path.join(self._persist_directory, str(self._id))
|
||||
return folder
|
||||
|
||||
@trace_method(
|
||||
"PersistentLocalHnswSegment._init_index", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def _init_index(self, dimensionality: int) -> None:
|
||||
index = hnswlib.Index(space=self._params.space, dim=dimensionality)
|
||||
self._brute_force_index = BruteForceIndex(
|
||||
size=self._batch_size,
|
||||
dimensionality=dimensionality,
|
||||
space=self._params.space,
|
||||
)
|
||||
|
||||
# Check if index exists and load it if it does
|
||||
if self._index_exists():
|
||||
index.load_index(
|
||||
self._get_storage_folder(),
|
||||
is_persistent_index=True,
|
||||
max_elements=int(
|
||||
max(
|
||||
self.count(
|
||||
request_version_context=RequestVersionContext(
|
||||
collection_version=0, log_position=0
|
||||
)
|
||||
)
|
||||
* self._params.resize_factor,
|
||||
DEFAULT_CAPACITY,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
index.init_index(
|
||||
max_elements=DEFAULT_CAPACITY,
|
||||
ef_construction=self._params.construction_ef,
|
||||
M=self._params.M,
|
||||
is_persistent_index=True,
|
||||
persistence_location=self._get_storage_folder(),
|
||||
)
|
||||
|
||||
index.set_ef(self._params.search_ef)
|
||||
index.set_num_threads(self._params.num_threads)
|
||||
|
||||
self._index = index
|
||||
self._dimensionality = dimensionality
|
||||
self._index_initialized = True
|
||||
|
||||
@trace_method("PersistentLocalHnswSegment._persist", OpenTelemetryGranularity.ALL)
|
||||
def _persist(self) -> None:
|
||||
"""Persist the index and data to disk"""
|
||||
index = cast(hnswlib.Index, self._index)
|
||||
|
||||
# Persist the index
|
||||
index.persist_dirty()
|
||||
|
||||
# Persist the metadata
|
||||
self._persist_data.dimensionality = self._dimensionality
|
||||
self._persist_data.total_elements_added = self._total_elements_added
|
||||
|
||||
# TODO: This should really be stored in sqlite, the index itself, or a better
|
||||
# storage format
|
||||
self._persist_data.id_to_label = self._id_to_label
|
||||
self._persist_data.label_to_id = self._label_to_id
|
||||
self._persist_data.id_to_seq_id = self._id_to_seq_id
|
||||
|
||||
with open(self._get_metadata_file(), "wb") as metadata_file:
|
||||
pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
with self._db.tx() as cur:
|
||||
q = (
|
||||
self._db.querybuilder()
|
||||
.into(Table("max_seq_id"))
|
||||
.columns("segment_id", "seq_id")
|
||||
.insert(
|
||||
ParameterValue(self._db.uuid_to_db(self._id)),
|
||||
ParameterValue(self._max_seq_id),
|
||||
)
|
||||
)
|
||||
sql, params = get_sql(q)
|
||||
sql = sql.replace("INSERT", "INSERT OR REPLACE")
|
||||
cur.execute(sql, params)
|
||||
|
||||
self._num_log_records_since_last_persist = 0
|
||||
|
||||
@trace_method(
|
||||
"PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def _apply_batch(self, batch: Batch) -> None:
|
||||
super()._apply_batch(batch)
|
||||
if self._num_log_records_since_last_persist >= self._sync_threshold:
|
||||
self._persist()
|
||||
|
||||
self._num_log_records_since_last_batch = 0
|
||||
|
||||
@trace_method(
|
||||
"PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def _write_records(self, records: Sequence[LogRecord]) -> None:
|
||||
"""Add a batch of embeddings to the index"""
|
||||
if not self._running:
|
||||
raise RuntimeError("Cannot add embeddings to stopped component")
|
||||
with WriteRWLock(self._lock):
|
||||
for record in records:
|
||||
self._num_log_records_since_last_batch += 1
|
||||
self._num_log_records_since_last_persist += 1
|
||||
|
||||
if record["record"]["embedding"] is not None:
|
||||
self._ensure_index(len(records), len(record["record"]["embedding"]))
|
||||
if not self._index_initialized:
|
||||
# If the index is not initialized here, it means that we have
|
||||
# not yet added any records to the index. So we can just
|
||||
# ignore the record since it was a delete.
|
||||
continue
|
||||
self._brute_force_index = cast(BruteForceIndex, self._brute_force_index)
|
||||
|
||||
self._max_seq_id = max(self._max_seq_id, record["log_offset"])
|
||||
id = record["record"]["id"]
|
||||
op = record["record"]["operation"]
|
||||
|
||||
exists_in_bf_index = self._brute_force_index.has_id(id)
|
||||
exists_in_persisted_index = self._id_to_label.get(id, None) is not None
|
||||
exists_in_index = exists_in_bf_index or exists_in_persisted_index
|
||||
|
||||
id_is_pending_delete = self._curr_batch.is_deleted(id)
|
||||
|
||||
if op == Operation.DELETE:
|
||||
if exists_in_index:
|
||||
self._curr_batch.apply(record)
|
||||
if exists_in_bf_index:
|
||||
self._brute_force_index.delete([record])
|
||||
else:
|
||||
logger.warning(f"Delete of nonexisting embedding ID: {id}")
|
||||
|
||||
elif op == Operation.UPDATE:
|
||||
if record["record"]["embedding"] is not None:
|
||||
if exists_in_index:
|
||||
self._curr_batch.apply(record)
|
||||
self._brute_force_index.upsert([record])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Update of nonexisting embedding ID: {record['record']['id']}"
|
||||
)
|
||||
elif op == Operation.ADD:
|
||||
if record["record"]["embedding"] is not None:
|
||||
if exists_in_index and not id_is_pending_delete:
|
||||
logger.warning(f"Add of existing embedding ID: {id}")
|
||||
else:
|
||||
self._curr_batch.apply(record, not exists_in_index)
|
||||
self._brute_force_index.upsert([record])
|
||||
elif op == Operation.UPSERT:
|
||||
if record["record"]["embedding"] is not None:
|
||||
self._curr_batch.apply(record, exists_in_index)
|
||||
self._brute_force_index.upsert([record])
|
||||
|
||||
if self._num_log_records_since_last_batch >= self._batch_size:
|
||||
self._apply_batch(self._curr_batch)
|
||||
self._curr_batch = Batch()
|
||||
self._brute_force_index.clear()
|
||||
|
||||
@override
|
||||
def count(self, request_version_context: RequestVersionContext) -> int:
|
||||
return (
|
||||
len(self._id_to_label)
|
||||
+ self._curr_batch.add_count
|
||||
- self._curr_batch.delete_count
|
||||
)
|
||||
|
||||
@trace_method(
|
||||
"PersistentLocalHnswSegment.get_vectors", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def get_vectors(
|
||||
self,
|
||||
request_version_context: RequestVersionContext,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
) -> Sequence[VectorEmbeddingRecord]:
|
||||
"""Get the embeddings from the HNSW index and layered brute force
|
||||
batch index."""
|
||||
|
||||
ids_hnsw: Set[str] = set()
|
||||
ids_bf: Set[str] = set()
|
||||
|
||||
if self._index is not None:
|
||||
ids_hnsw = set(self._id_to_label.keys())
|
||||
if self._brute_force_index is not None:
|
||||
ids_bf = set(self._curr_batch.get_written_ids())
|
||||
|
||||
target_ids = ids or list(ids_hnsw.union(ids_bf))
|
||||
self._brute_force_index = cast(BruteForceIndex, self._brute_force_index)
|
||||
hnsw_labels = []
|
||||
|
||||
results: List[Optional[VectorEmbeddingRecord]] = []
|
||||
id_to_index: Dict[str, int] = {}
|
||||
for i, id in enumerate(target_ids):
|
||||
if id in ids_bf:
|
||||
results.append(self._brute_force_index.get_vectors([id])[0])
|
||||
elif id in ids_hnsw and id not in self._curr_batch._deleted_ids:
|
||||
hnsw_labels.append(self._id_to_label[id])
|
||||
# Placeholder for hnsw results to be filled in down below so we
|
||||
# can batch the hnsw get() call
|
||||
results.append(None)
|
||||
id_to_index[id] = i
|
||||
|
||||
if len(hnsw_labels) > 0 and self._index is not None:
|
||||
vectors = cast(
|
||||
Sequence[Vector], np.array(self._index.get_items(hnsw_labels))
|
||||
) # version 0.8 of hnswlib allows return_type="numpy"
|
||||
|
||||
for label, vector in zip(hnsw_labels, vectors):
|
||||
id = self._label_to_id[label]
|
||||
results[id_to_index[id]] = VectorEmbeddingRecord(
|
||||
id=id, embedding=vector
|
||||
)
|
||||
|
||||
return results # type: ignore ## Python can't cast List with Optional to List with VectorEmbeddingRecord
|
||||
|
||||
@trace_method(
|
||||
"PersistentLocalHnswSegment.query_vectors", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def query_vectors(
|
||||
self, query: VectorQuery
|
||||
) -> Sequence[Sequence[VectorQueryResult]]:
|
||||
if self._index is None and self._brute_force_index is None:
|
||||
return [[] for _ in range(len(query["vectors"]))]
|
||||
|
||||
k = query["k"]
|
||||
if k > self.count(query["request_version_context"]):
|
||||
count = self.count(query["request_version_context"])
|
||||
logger.warning(
|
||||
f"Number of requested results {k} is greater than number of elements in index {count}, updating n_results = {count}"
|
||||
)
|
||||
k = count
|
||||
|
||||
# Overquery by updated and deleted elements layered on the index because they may
|
||||
# hide the real nearest neighbors in the hnsw index
|
||||
hnsw_k = k + self._curr_batch.update_count + self._curr_batch.delete_count
|
||||
# self._id_to_label contains the ids of the elements in the hnsw index
|
||||
# so its length is the number of elements in the hnsw index
|
||||
if hnsw_k > len(self._id_to_label):
|
||||
hnsw_k = len(self._id_to_label)
|
||||
hnsw_query = VectorQuery(
|
||||
vectors=query["vectors"],
|
||||
k=hnsw_k,
|
||||
allowed_ids=query["allowed_ids"],
|
||||
include_embeddings=query["include_embeddings"],
|
||||
options=query["options"],
|
||||
request_version_context=query["request_version_context"],
|
||||
)
|
||||
|
||||
# For each query vector, we want to take the top k results from the
|
||||
# combined results of the brute force and hnsw index
|
||||
results: List[List[VectorQueryResult]] = []
|
||||
self._brute_force_index = cast(BruteForceIndex, self._brute_force_index)
|
||||
with ReadRWLock(self._lock):
|
||||
bf_results = self._brute_force_index.query(query)
|
||||
hnsw_results = super().query_vectors(hnsw_query)
|
||||
for i in range(len(query["vectors"])):
|
||||
# Merge results into a single list of size k
|
||||
bf_pointer: int = 0
|
||||
hnsw_pointer: int = 0
|
||||
curr_bf_result: Sequence[VectorQueryResult] = bf_results[i]
|
||||
curr_hnsw_result: Sequence[VectorQueryResult] = hnsw_results[i]
|
||||
|
||||
# Filter deleted results that haven't yet been removed from the persisted index
|
||||
curr_hnsw_result = [
|
||||
x
|
||||
for x in curr_hnsw_result
|
||||
if not self._curr_batch.is_deleted(x["id"])
|
||||
]
|
||||
|
||||
curr_results: List[VectorQueryResult] = []
|
||||
# In the case where filters cause the number of results to be less than k,
|
||||
# we set k to be the number of results
|
||||
total_results = len(curr_bf_result) + len(curr_hnsw_result)
|
||||
if total_results == 0:
|
||||
results.append([])
|
||||
else:
|
||||
while len(curr_results) < min(k, total_results):
|
||||
if bf_pointer < len(curr_bf_result) and hnsw_pointer < len(
|
||||
curr_hnsw_result
|
||||
):
|
||||
bf_dist = curr_bf_result[bf_pointer]["distance"]
|
||||
hnsw_dist = curr_hnsw_result[hnsw_pointer]["distance"]
|
||||
if bf_dist <= hnsw_dist:
|
||||
curr_results.append(curr_bf_result[bf_pointer])
|
||||
bf_pointer += 1
|
||||
else:
|
||||
id = curr_hnsw_result[hnsw_pointer]["id"]
|
||||
# Only add the hnsw result if it is not in the brute force index
|
||||
if not self._brute_force_index.has_id(id):
|
||||
curr_results.append(curr_hnsw_result[hnsw_pointer])
|
||||
hnsw_pointer += 1
|
||||
else:
|
||||
break
|
||||
remaining = min(k, total_results) - len(curr_results)
|
||||
if remaining > 0 and hnsw_pointer < len(curr_hnsw_result):
|
||||
for i in range(
|
||||
hnsw_pointer,
|
||||
min(len(curr_hnsw_result), hnsw_pointer + remaining),
|
||||
):
|
||||
id = curr_hnsw_result[i]["id"]
|
||||
if not self._brute_force_index.has_id(id):
|
||||
curr_results.append(curr_hnsw_result[i])
|
||||
elif remaining > 0 and bf_pointer < len(curr_bf_result):
|
||||
curr_results.extend(
|
||||
curr_bf_result[bf_pointer : bf_pointer + remaining]
|
||||
)
|
||||
results.append(curr_results)
|
||||
return results
|
||||
|
||||
@trace_method(
|
||||
"PersistentLocalHnswSegment.reset_state", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def reset_state(self) -> None:
|
||||
if self._allow_reset:
|
||||
data_path = self._get_storage_folder()
|
||||
if os.path.exists(data_path):
|
||||
self.close_persistent_index()
|
||||
shutil.rmtree(data_path, ignore_errors=True)
|
||||
|
||||
@trace_method("PersistentLocalHnswSegment.delete", OpenTelemetryGranularity.ALL)
|
||||
@override
|
||||
def delete(self) -> None:
|
||||
data_path = self._get_storage_folder()
|
||||
if os.path.exists(data_path):
|
||||
self.close_persistent_index()
|
||||
shutil.rmtree(data_path, ignore_errors=False)
|
||||
|
||||
@staticmethod
|
||||
def get_file_handle_count() -> int:
|
||||
"""Return how many file handles are used by the index"""
|
||||
hnswlib_count = hnswlib.Index.file_handle_count
|
||||
hnswlib_count = cast(int, hnswlib_count)
|
||||
# One extra for the metadata file
|
||||
return hnswlib_count + 1 # type: ignore
|
||||
|
||||
def open_persistent_index(self) -> None:
|
||||
"""Open the persistent index"""
|
||||
if self._index is not None:
|
||||
self._index.open_file_handles()
|
||||
|
||||
@override
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
self.close_persistent_index()
|
||||
|
||||
def close_persistent_index(self) -> None:
|
||||
"""Close the persistent index"""
|
||||
if self._index is not None:
|
||||
self._index.close_file_handles()
|
||||
Reference in New Issue
Block a user