chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,125 @@
from typing import Optional, Sequence, TypeVar
from abc import abstractmethod
from chromadb.types import (
Collection,
MetadataEmbeddingRecord,
Operation,
RequestVersionContext,
VectorEmbeddingRecord,
Where,
WhereDocument,
VectorQuery,
VectorQueryResult,
Segment,
SeqId,
Metadata,
)
from chromadb.config import Component, System
from uuid import UUID
from enum import Enum
class SegmentType(Enum):
SQLITE = "urn:chroma:segment/metadata/sqlite"
HNSW_LOCAL_MEMORY = "urn:chroma:segment/vector/hnsw-local-memory"
HNSW_LOCAL_PERSISTED = "urn:chroma:segment/vector/hnsw-local-persisted"
HNSW_DISTRIBUTED = "urn:chroma:segment/vector/hnsw-distributed"
BLOCKFILE_RECORD = "urn:chroma:segment/record/blockfile"
BLOCKFILE_METADATA = "urn:chroma:segment/metadata/blockfile"
class SegmentImplementation(Component):
@abstractmethod
def __init__(self, sytstem: System, segment: Segment):
pass
@abstractmethod
def count(self, request_version_context: RequestVersionContext) -> int:
"""Get the number of embeddings in this segment"""
pass
@abstractmethod
def max_seqid(self) -> SeqId:
"""Get the maximum SeqID currently indexed by this segment"""
pass
@staticmethod
def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]:
"""Given an arbitrary metadata map (e.g, from a collection), validate it and
return metadata (if any) that is applicable and should be applied to the
segment. Validation errors will be reported to the user."""
return None
@abstractmethod
def delete(self) -> None:
"""Delete the segment and all its data"""
...
S = TypeVar("S", bound=SegmentImplementation)
class MetadataReader(SegmentImplementation):
"""Embedding Metadata segment interface"""
@abstractmethod
def get_metadata(
self,
request_version_context: RequestVersionContext,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
include_metadata: bool = True,
) -> Sequence[MetadataEmbeddingRecord]:
"""Query for embedding metadata."""
pass
class VectorReader(SegmentImplementation):
"""Embedding Vector segment interface"""
@abstractmethod
def get_vectors(
self,
request_version_context: RequestVersionContext,
ids: Optional[Sequence[str]] = None,
) -> Sequence[VectorEmbeddingRecord]:
"""Get embeddings from the segment. If no IDs are provided, all embeddings are
returned."""
pass
@abstractmethod
def query_vectors(
self, query: VectorQuery
) -> Sequence[Sequence[VectorQueryResult]]:
"""Given a vector query, return the top-k nearest neighbors for vector in the
query."""
pass
class SegmentManager(Component):
"""Interface for a pluggable strategy for creating, retrieving and instantiating
segments as required"""
@abstractmethod
def prepare_segments_for_new_collection(
self, collection: Collection
) -> Sequence[Segment]:
"""Return the segments required for a new collection. Returns only segment data,
does not persist to the SysDB"""
pass
@abstractmethod
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
"""Delete any local state for all the segments associated with a collection, and
returns a sequence of their IDs. Does not update the SysDB."""
pass
@abstractmethod
def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None:
"""Signal to the segment manager that a collection is about to be used, so that
it can preload segments as needed. This is only a hint, and implementations are
free to ignore it."""
pass

View File

@@ -0,0 +1,80 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, List
from overrides import EnforceOverrides, overrides
from chromadb.config import Component, System
from chromadb.types import Segment
class SegmentDirectory(Component):
"""A segment directory is a data interface that manages the location of segments. Concretely, this
means that for distributed chroma, it provides the grpc endpoint for a segment."""
@abstractmethod
def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
"""Return the segment residences for a given segment ID. Will return at most n residences.
Should only return less than n residences if there are less than n residences available.
"""
@abstractmethod
def register_updated_segment_callback(
self, callback: Callable[[Segment], None]
) -> None:
"""Register a callback that will be called when a segment is updated"""
pass
@dataclass
class Member:
id: str
ip: str
node: str
Memberlist = List[Member]
class MemberlistProvider(Component, EnforceOverrides):
"""Returns the latest memberlist and provdes a callback for when it changes. This
callback may be called from a different thread than the one that called. Callers should ensure
that they are thread-safe."""
callbacks: List[Callable[[Memberlist], Any]]
def __init__(self, system: System):
self.callbacks = []
super().__init__(system)
@abstractmethod
def get_memberlist(self) -> Memberlist:
"""Returns the latest memberlist"""
pass
@abstractmethod
def set_memberlist_name(self, memberlist: str) -> None:
"""Sets the memberlist that this provider will watch"""
pass
@overrides
def stop(self) -> None:
"""Stops watching the memberlist"""
self.callbacks = []
def register_updated_memberlist_callback(
self, callback: Callable[[Memberlist], Any]
) -> None:
"""Registers a callback that will be called when the memberlist changes. May be called many times
with the same memberlist, so callers should be idempotent. May be called from a different thread.
"""
self.callbacks.append(callback)
def unregister_updated_memberlist_callback(
self, callback: Callable[[Memberlist], Any]
) -> bool:
"""Unregisters a callback that was previously registered. Returns True if the callback was
successfully unregistered, False if it was not ever registered."""
if callback in self.callbacks:
self.callbacks.remove(callback)
return True
return False

View File

@@ -0,0 +1,337 @@
import threading
import time
from typing import Any, Callable, Dict, List, Optional, cast
from kubernetes import client, config, watch
from kubernetes.client.rest import ApiException
from overrides import EnforceOverrides, override
from chromadb.config import RoutingMode, System
from chromadb.segment.distributed import (
Member,
Memberlist,
MemberlistProvider,
SegmentDirectory,
)
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
add_attributes_to_current_span,
trace_method,
)
from chromadb.types import Segment
from chromadb.utils.rendezvous_hash import assign, murmur3hasher
# These could go in config but given that they will rarely change, they are here for now to avoid
# polluting the config file further.
WATCH_TIMEOUT_SECONDS = 60
KUBERNETES_NAMESPACE = "chroma"
KUBERNETES_GROUP = "chroma.cluster"
HEADLESS_SERVICE = "svc.cluster.local"
class MockMemberlistProvider(MemberlistProvider, EnforceOverrides):
"""A mock memberlist provider for testing"""
_memberlist: Memberlist
def __init__(self, system: System):
super().__init__(system)
self._memberlist = [
Member(id="a", ip="10.0.0.1", node="node1"),
Member(id="b", ip="10.0.0.2", node="node2"),
Member(id="c", ip="10.0.0.3", node="node3"),
]
@override
def get_memberlist(self) -> Memberlist:
return self._memberlist
@override
def set_memberlist_name(self, memberlist: str) -> None:
pass # The mock provider does not need to set the memberlist name
def update_memberlist(self, memberlist: Memberlist) -> None:
"""Updates the memberlist and calls all registered callbacks. This mocks an update from a k8s CR"""
self._memberlist = memberlist
for callback in self.callbacks:
callback(memberlist)
class CustomResourceMemberlistProvider(MemberlistProvider, EnforceOverrides):
"""A memberlist provider that uses a k8s custom resource to store the memberlist"""
_kubernetes_api: client.CustomObjectsApi
_memberlist_name: Optional[str]
_curr_memberlist: Optional[Memberlist]
_curr_memberlist_mutex: threading.Lock
_watch_thread: Optional[threading.Thread]
_kill_watch_thread: threading.Event
_done_waiting_for_reset: threading.Event
def __init__(self, system: System):
super().__init__(system)
config.load_config()
self._kubernetes_api = client.CustomObjectsApi()
self._watch_thread = None
self._memberlist_name = None
self._curr_memberlist = None
self._curr_memberlist_mutex = threading.Lock()
self._kill_watch_thread = threading.Event()
self._done_waiting_for_reset = threading.Event()
@override
def start(self) -> None:
if self._memberlist_name is None:
raise ValueError("Memberlist name must be set before starting")
self.get_memberlist()
self._done_waiting_for_reset.clear()
self._watch_worker_memberlist()
return super().start()
@override
def stop(self) -> None:
self._curr_memberlist = None
self._memberlist_name = None
# Stop the watch thread
self._kill_watch_thread.set()
if self._watch_thread is not None:
self._watch_thread.join()
self._watch_thread = None
self._kill_watch_thread.clear()
self._done_waiting_for_reset.clear()
return super().stop()
@override
def reset_state(self) -> None:
# Reset the memberlist in kubernetes, and wait for it to
# get propagated back again
# Note that the component must be running in order to reset the state
if not self._system.settings.require("allow_reset"):
raise ValueError(
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
)
if self._memberlist_name:
self._done_waiting_for_reset.clear()
self._kubernetes_api.patch_namespaced_custom_object(
group=KUBERNETES_GROUP,
version="v1",
namespace=KUBERNETES_NAMESPACE,
plural="memberlists",
name=self._memberlist_name,
body={
"kind": "MemberList",
"spec": {"members": []},
},
)
self._done_waiting_for_reset.wait(5.0)
# TODO: For some reason the above can flake and the memberlist won't be populated
# Given that this is a test harness, just sleep for an additional 500ms for now
# We should understand why this flaps
time.sleep(0.5)
@override
def get_memberlist(self) -> Memberlist:
if self._curr_memberlist is None:
self._curr_memberlist = self._fetch_memberlist()
return self._curr_memberlist
@override
def set_memberlist_name(self, memberlist: str) -> None:
self._memberlist_name = memberlist
def _fetch_memberlist(self) -> Memberlist:
api_response = self._kubernetes_api.get_namespaced_custom_object(
group=KUBERNETES_GROUP,
version="v1",
namespace=KUBERNETES_NAMESPACE,
plural="memberlists",
name=f"{self._memberlist_name}",
)
api_response = cast(Dict[str, Any], api_response)
if "spec" not in api_response:
return []
response_spec = cast(Dict[str, Any], api_response["spec"])
return self._parse_response_memberlist(response_spec)
def _watch_worker_memberlist(self) -> None:
# TODO: We may want to make this watch function a library function that can be used by other
# components that need to watch k8s custom resources.
def run_watch() -> None:
w = watch.Watch()
def do_watch() -> None:
for event in w.stream(
self._kubernetes_api.list_namespaced_custom_object,
group=KUBERNETES_GROUP,
version="v1",
namespace=KUBERNETES_NAMESPACE,
plural="memberlists",
field_selector=f"metadata.name={self._memberlist_name}",
timeout_seconds=WATCH_TIMEOUT_SECONDS,
):
event = cast(Dict[str, Any], event)
response_spec = event["object"]["spec"]
response_spec = cast(Dict[str, Any], response_spec)
with self._curr_memberlist_mutex:
self._curr_memberlist = self._parse_response_memberlist(
response_spec
)
self._notify(self._curr_memberlist)
if (
self._system.settings.require("allow_reset")
and not self._done_waiting_for_reset.is_set()
and len(self._curr_memberlist) > 0
):
self._done_waiting_for_reset.set()
# Watch the custom resource for changes
# Watch with a timeout and retry so we can gracefully stop this if needed
while not self._kill_watch_thread.is_set():
try:
do_watch()
except ApiException as e:
# If status code is 410, the watch has expired and we need to start a new one.
if e.status == 410:
pass
return
if self._watch_thread is None:
thread = threading.Thread(target=run_watch, daemon=True)
thread.start()
self._watch_thread = thread
else:
raise Exception("A watch thread is already running.")
def _parse_response_memberlist(
self, api_response_spec: Dict[str, Any]
) -> Memberlist:
if "members" not in api_response_spec:
return []
parsed = []
for m in api_response_spec["members"]:
id = m["member_id"]
ip = m["member_ip"] if "member_ip" in m else ""
node = m["member_node_name"] if "member_node_name" in m else ""
parsed.append(Member(id=id, ip=ip, node=node))
return parsed
def _notify(self, memberlist: Memberlist) -> None:
for callback in self.callbacks:
callback(memberlist)
class RendezvousHashSegmentDirectory(SegmentDirectory, EnforceOverrides):
_memberlist_provider: MemberlistProvider
_curr_memberlist_mutex: threading.Lock
_curr_memberlist: Optional[Memberlist]
_routing_mode: RoutingMode
def __init__(self, system: System):
super().__init__(system)
self._memberlist_provider = self.require(MemberlistProvider)
memberlist_name = system.settings.require("worker_memberlist_name")
self._memberlist_provider.set_memberlist_name(memberlist_name)
self._routing_mode = system.settings.require(
"chroma_segment_directory_routing_mode"
)
self._curr_memberlist = None
self._curr_memberlist_mutex = threading.Lock()
@override
def start(self) -> None:
self._curr_memberlist = self._memberlist_provider.get_memberlist()
self._memberlist_provider.register_updated_memberlist_callback(
self._update_memberlist
)
return super().start()
@override
def stop(self) -> None:
self._memberlist_provider.unregister_updated_memberlist_callback(
self._update_memberlist
)
return super().stop()
@override
def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
if self._curr_memberlist is None or len(self._curr_memberlist) == 0:
raise ValueError("Memberlist is not initialized")
# assign() will throw an error if n is greater than the number of members
# clamp n to the number of members to align with the contract of this method
# which is to return at most n endpoints
n = min(n, len(self._curr_memberlist))
# Check if all members in the memberlist have a node set,
# if so, route using the node
# NOTE(@hammadb) 1/8/2024: This is to handle the migration between routing
# using the member id and routing using the node name
# We want to route using the node name over the member id
# because the node may have a disk cache that we want a
# stable identifier for over deploys.
can_use_node_routing = (
all([m.node != "" and len(m.node) != 0 for m in self._curr_memberlist])
and self._routing_mode == RoutingMode.NODE
)
if can_use_node_routing:
# If we are using node routing and the segments
assignments = assign(
segment["collection"].hex,
[m.node for m in self._curr_memberlist],
murmur3hasher,
n,
)
else:
# Query to the same collection should end up on the same endpoint
assignments = assign(
segment["collection"].hex,
[m.id for m in self._curr_memberlist],
murmur3hasher,
n,
)
assignments_set = set(assignments)
out_endpoints = []
for member in self._curr_memberlist:
is_chosen_with_node_routing = (
can_use_node_routing and member.node in assignments_set
)
is_chosen_with_id_routing = (
not can_use_node_routing and member.id in assignments_set
)
if is_chosen_with_node_routing or is_chosen_with_id_routing:
# If the memberlist has an ip, use it, otherwise use the member id with the headless service
# this is for backwards compatibility with the old memberlist which only had ids
if member.ip is not None and member.ip != "":
endpoint = f"{member.ip}:50051"
out_endpoints.append(endpoint)
else:
service_name = self.extract_service_name(member.id)
endpoint = f"{member.id}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051"
out_endpoints.append(endpoint)
return out_endpoints
@override
def register_updated_segment_callback(
self, callback: Callable[[Segment], None]
) -> None:
raise NotImplementedError()
@trace_method(
"RendezvousHashSegmentDirectory._update_memberlist",
OpenTelemetryGranularity.ALL,
)
def _update_memberlist(self, memberlist: Memberlist) -> None:
with self._curr_memberlist_mutex:
add_attributes_to_current_span(
{"new_memberlist": [m.id for m in memberlist]}
)
self._curr_memberlist = memberlist
def extract_service_name(self, pod_name: str) -> Optional[str]:
# Split the pod name by the hyphen
parts = pod_name.split("-")
# The service name is expected to be the prefix before the last hyphen
if len(parts) > 1:
return "-".join(parts[:-1])
return None

View File

@@ -0,0 +1,119 @@
import threading
import uuid
from typing import Any, Callable
from chromadb.types import Segment
from overrides import override
from typing import Dict, Optional
from abc import ABC, abstractmethod
class SegmentCache(ABC):
@abstractmethod
def get(self, key: uuid.UUID) -> Optional[Segment]:
pass
@abstractmethod
def pop(self, key: uuid.UUID) -> Optional[Segment]:
pass
@abstractmethod
def set(self, key: uuid.UUID, value: Segment) -> None:
pass
@abstractmethod
def reset(self) -> None:
pass
class BasicCache(SegmentCache):
def __init__(self):
self.cache: Dict[uuid.UUID, Segment] = {}
self.lock = threading.RLock()
@override
def get(self, key: uuid.UUID) -> Optional[Segment]:
with self.lock:
return self.cache.get(key)
@override
def pop(self, key: uuid.UUID) -> Optional[Segment]:
with self.lock:
return self.cache.pop(key, None)
@override
def set(self, key: uuid.UUID, value: Segment) -> None:
with self.lock:
self.cache[key] = value
@override
def reset(self) -> None:
with self.lock:
self.cache = {}
class SegmentLRUCache(BasicCache):
"""A simple LRU cache implementation that handles objects with dynamic sizes.
The size of each object is determined by a user-provided size function."""
def __init__(
self,
capacity: int,
size_func: Callable[[uuid.UUID], int],
callback: Optional[Callable[[uuid.UUID, Segment], Any]] = None,
):
self.capacity = capacity
self.size_func = size_func
self.cache: Dict[uuid.UUID, Segment] = {}
self.history = []
self.callback = callback
self.lock = threading.RLock()
def _upsert_key(self, key: uuid.UUID):
if key in self.history:
self.history.remove(key)
self.history.append(key)
else:
self.history.append(key)
@override
def get(self, key: uuid.UUID) -> Optional[Segment]:
with self.lock:
self._upsert_key(key)
if key in self.cache:
return self.cache[key]
else:
return None
@override
def pop(self, key: uuid.UUID) -> Optional[Segment]:
with self.lock:
if key in self.history:
self.history.remove(key)
return self.cache.pop(key, None)
@override
def set(self, key: uuid.UUID, value: Segment) -> None:
with self.lock:
if key in self.cache:
return
item_size = self.size_func(key)
key_sizes = {key: self.size_func(key) for key in self.cache}
total_size = sum(key_sizes.values())
index = 0
# Evict items if capacity is exceeded
while total_size + item_size > self.capacity and len(self.history) > index:
key_delete = self.history[index]
if key_delete in self.cache:
self.callback(key_delete, self.cache[key_delete])
del self.cache[key_delete]
total_size -= key_sizes[key_delete]
index += 1
self.cache[key] = value
self._upsert_key(key)
@override
def reset(self):
with self.lock:
self.cache = {}
self.history = []

View File

@@ -0,0 +1,97 @@
from threading import Lock
from typing import Dict, List, Sequence
from uuid import UUID, uuid4
from overrides import override
from chromadb.config import System
from chromadb.db.system import SysDB
from chromadb.segment import (
SegmentImplementation,
SegmentManager,
SegmentType,
)
from chromadb.segment.distributed import SegmentDirectory
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import (
Collection,
Operation,
Segment,
SegmentScope,
)
class DistributedSegmentManager(SegmentManager):
_sysdb: SysDB
_system: System
_instances: Dict[UUID, SegmentImplementation]
_segment_directory: SegmentDirectory
_lock: Lock
def __init__(self, system: System):
super().__init__(system)
self._sysdb = self.require(SysDB)
self._segment_directory = self.require(SegmentDirectory)
self._system = system
self._instances = {}
self._lock = Lock()
@trace_method(
"DistributedSegmentManager.prepare_segments_for_new_collection",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
@override
def prepare_segments_for_new_collection(
self, collection: Collection
) -> Sequence[Segment]:
vector_segment = Segment(
id=uuid4(),
type=SegmentType.HNSW_DISTRIBUTED.value,
scope=SegmentScope.VECTOR,
collection=collection.id,
metadata=PersistentHnswParams.extract(collection.metadata)
if collection.metadata
else None,
file_paths={},
)
metadata_segment = Segment(
id=uuid4(),
type=SegmentType.BLOCKFILE_METADATA.value,
scope=SegmentScope.METADATA,
collection=collection.id,
metadata=None,
file_paths={},
)
record_segment = Segment(
id=uuid4(),
type=SegmentType.BLOCKFILE_RECORD.value,
scope=SegmentScope.RECORD,
collection=collection.id,
metadata=None,
file_paths={},
)
return [vector_segment, record_segment, metadata_segment]
@override
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
# delete_collection deletes segments in distributed mode
return []
@trace_method(
"DistributedSegmentManager.get_endpoint",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def get_endpoints(self, segment: Segment, n: int) -> List[str]:
return self._segment_directory.get_segment_endpoints(segment, n)
@trace_method(
"DistributedSegmentManager.hint_use_collection",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
@override
def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None:
pass

View File

@@ -0,0 +1,269 @@
from threading import Lock
from chromadb.segment import (
SegmentImplementation,
SegmentManager,
MetadataReader,
SegmentType,
VectorReader,
S,
)
import logging
from chromadb.segment.impl.manager.cache.cache import (
SegmentLRUCache,
BasicCache,
SegmentCache,
)
import os
from chromadb.config import System, get_class
from chromadb.db.system import SysDB
from overrides import override
from chromadb.segment.impl.vector.local_persistent_hnsw import (
PersistentLocalHnswSegment,
)
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata
from typing import Dict, Type, Sequence, Optional, cast
from uuid import UUID, uuid4
import platform
from chromadb.utils.lru_cache import LRUCache
from chromadb.utils.directory import get_directory_size
if platform.system() != "Windows":
import resource
elif platform.system() == "Windows":
import ctypes
SEGMENT_TYPE_IMPLS = {
SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment",
SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment",
SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment",
}
class LocalSegmentManager(SegmentManager):
_sysdb: SysDB
_system: System
_opentelemetry_client: OpenTelemetryClient
_instances: Dict[UUID, SegmentImplementation]
_vector_instances_file_handle_cache: LRUCache[
UUID, PersistentLocalHnswSegment
] # LRU cache to manage file handles across vector segment instances
_vector_segment_type: SegmentType = SegmentType.HNSW_LOCAL_MEMORY
_lock: Lock
_max_file_handles: int
def __init__(self, system: System):
super().__init__(system)
self._sysdb = self.require(SysDB)
self._system = system
self._opentelemetry_client = system.require(OpenTelemetryClient)
self.logger = logging.getLogger(__name__)
self._instances = {}
self.segment_cache: Dict[SegmentScope, SegmentCache] = {
SegmentScope.METADATA: BasicCache() # type: ignore[no-untyped-call]
}
if (
system.settings.chroma_segment_cache_policy == "LRU"
and system.settings.chroma_memory_limit_bytes > 0
):
self.segment_cache[SegmentScope.VECTOR] = SegmentLRUCache(
capacity=system.settings.chroma_memory_limit_bytes,
callback=lambda k, v: self.callback_cache_evict(v),
size_func=lambda k: self._get_segment_disk_size(k),
)
else:
self.segment_cache[SegmentScope.VECTOR] = BasicCache() # type: ignore[no-untyped-call]
self._lock = Lock()
# TODO: prototyping with distributed segment for now, but this should be a configurable option
# we need to think about how to handle this configuration
if self._system.settings.require("is_persistent"):
self._vector_segment_type = SegmentType.HNSW_LOCAL_PERSISTED
if platform.system() != "Windows":
self._max_file_handles = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
else:
self._max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore
segment_limit = (
self._max_file_handles
# This is integer division in Python 3, and not a comment.
// PersistentLocalHnswSegment.get_file_handle_count()
)
self._vector_instances_file_handle_cache = LRUCache(
segment_limit, callback=lambda _, v: v.close_persistent_index()
)
@trace_method(
"LocalSegmentManager.callback_cache_evict",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def callback_cache_evict(self, segment: Segment) -> None:
collection_id = segment["collection"]
self.logger.info(f"LRU cache evict collection {collection_id}")
instance = self._instance(segment)
instance.stop()
del self._instances[segment["id"]]
@override
def start(self) -> None:
for instance in self._instances.values():
instance.start()
super().start()
@override
def stop(self) -> None:
for instance in self._instances.values():
instance.stop()
super().stop()
@override
def reset_state(self) -> None:
for instance in self._instances.values():
instance.stop()
instance.reset_state()
self._instances = {}
self.segment_cache[SegmentScope.VECTOR].reset()
super().reset_state()
@trace_method(
"LocalSegmentManager.prepare_segments_for_new_collection",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
@override
def prepare_segments_for_new_collection(
self, collection: Collection
) -> Sequence[Segment]:
vector_segment = _segment(
self._vector_segment_type, SegmentScope.VECTOR, collection
)
metadata_segment = _segment(
SegmentType.SQLITE, SegmentScope.METADATA, collection
)
return [vector_segment, metadata_segment]
@trace_method(
"LocalSegmentManager.delete_segments",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
@override
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
segments = self._sysdb.get_segments(collection=collection_id)
for segment in segments:
if segment["id"] in self._instances:
if segment["type"] == SegmentType.HNSW_LOCAL_PERSISTED.value:
instance = self.get_segment(collection_id, VectorReader)
instance.delete()
elif segment["type"] == SegmentType.SQLITE.value:
instance = self.get_segment(collection_id, MetadataReader) # type: ignore[assignment]
instance.delete()
del self._instances[segment["id"]]
if segment["scope"] is SegmentScope.VECTOR:
self.segment_cache[SegmentScope.VECTOR].pop(collection_id)
if segment["scope"] is SegmentScope.METADATA:
self.segment_cache[SegmentScope.METADATA].pop(collection_id)
return [s["id"] for s in segments]
def _get_segment_disk_size(self, collection_id: UUID) -> int:
segments = self._sysdb.get_segments(
collection=collection_id, scope=SegmentScope.VECTOR
)
if len(segments) == 0:
return 0
# With local segment manager (single server chroma), a collection always have one segment.
size = get_directory_size(
os.path.join(
self._system.settings.require("persist_directory"),
str(segments[0]["id"]),
)
)
return size
@trace_method(
"LocalSegmentManager._get_segment_sysdb",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def _get_segment_sysdb(self, collection_id: UUID, scope: SegmentScope) -> Segment:
segments = self._sysdb.get_segments(collection=collection_id, scope=scope)
known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()])
# Get the first segment of a known type
segment = next(filter(lambda s: s["type"] in known_types, segments))
return segment
@trace_method(
"LocalSegmentManager.get_segment",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def get_segment(self, collection_id: UUID, type: Type[S]) -> S:
if type == MetadataReader:
scope = SegmentScope.METADATA
elif type == VectorReader:
scope = SegmentScope.VECTOR
else:
raise ValueError(f"Invalid segment type: {type}")
segment = self.segment_cache[scope].get(collection_id)
if segment is None:
segment = self._get_segment_sysdb(collection_id, scope)
self.segment_cache[scope].set(collection_id, segment)
# Instances must be atomically created, so we use a lock to ensure that only one thread
# creates the instance.
with self._lock:
instance = self._instance(segment)
return cast(S, instance)
@trace_method(
"LocalSegmentManager.hint_use_collection",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
@override
def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None:
# The local segment manager responds to hints by pre-loading both the metadata and vector
# segments for the given collection.
for type in [MetadataReader, VectorReader]:
# Just use get_segment to load the segment into the cache
instance = self.get_segment(collection_id, type)
# If the segment is a vector segment, we need to keep segments in an LRU cache
# to avoid hitting the OS file handle limit.
if type == VectorReader and self._system.settings.require("is_persistent"):
instance = cast(PersistentLocalHnswSegment, instance)
instance.open_persistent_index()
self._vector_instances_file_handle_cache.set(collection_id, instance)
def _cls(self, segment: Segment) -> Type[SegmentImplementation]:
classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])]
cls = get_class(classname, SegmentImplementation)
return cls
def _instance(self, segment: Segment) -> SegmentImplementation:
if segment["id"] not in self._instances:
cls = self._cls(segment)
instance = cls(self._system, segment)
instance.start()
self._instances[segment["id"]] = instance
return self._instances[segment["id"]]
def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> Segment:
"""Create a metadata dict, propagating metadata correctly for the given segment type."""
cls = get_class(SEGMENT_TYPE_IMPLS[type], SegmentImplementation)
collection_metadata = collection.metadata
metadata: Optional[Metadata] = None
if collection_metadata:
metadata = cls.propagate_collection_metadata(collection_metadata)
return Segment(
id=uuid4(),
type=type.value,
scope=scope,
collection=collection.id,
metadata=metadata,
file_paths={},
)

View File

@@ -0,0 +1,723 @@
from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict, List
from chromadb.segment import MetadataReader
from chromadb.ingest import Consumer
from chromadb.config import System
from chromadb.types import RequestVersionContext, Segment, InclusionExclusionOperator
from chromadb.db.impl.sqlite import SqliteDB
from overrides import override
from chromadb.db.base import (
Cursor,
ParameterValue,
get_sql,
)
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import (
Where,
WhereDocument,
MetadataEmbeddingRecord,
LogRecord,
SeqId,
Operation,
UpdateMetadata,
LiteralValue,
WhereOperator,
)
from uuid import UUID
from pypika import Table, Tables
from pypika.queries import QueryBuilder
import pypika.functions as fn
from pypika.terms import Criterion
from itertools import groupby
from functools import reduce
import sqlite3
import logging
logger = logging.getLogger(__name__)
class SqliteMetadataSegment(MetadataReader):
_consumer: Consumer
_db: SqliteDB
_id: UUID
_opentelemetry_client: OpenTelemetryClient
_collection_id: Optional[UUID]
_subscription: Optional[UUID] = None
def __init__(self, system: System, segment: Segment):
self._db = system.instance(SqliteDB)
self._consumer = system.instance(Consumer)
self._id = segment["id"]
self._opentelemetry_client = system.require(OpenTelemetryClient)
self._collection_id = segment["collection"]
@trace_method("SqliteMetadataSegment.start", OpenTelemetryGranularity.ALL)
@override
def start(self) -> None:
if self._collection_id:
seq_id = self.max_seqid()
self._subscription = self._consumer.subscribe(
collection_id=self._collection_id,
consume_fn=self._write_metadata,
start=seq_id,
)
@trace_method("SqliteMetadataSegment.stop", OpenTelemetryGranularity.ALL)
@override
def stop(self) -> None:
if self._subscription:
self._consumer.unsubscribe(self._subscription)
@trace_method("SqliteMetadataSegment.max_seqid", OpenTelemetryGranularity.ALL)
@override
def max_seqid(self) -> SeqId:
t = Table("max_seq_id")
q = (
self._db.querybuilder()
.from_(t)
.select(t.seq_id)
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
)
sql, params = get_sql(q)
with self._db.tx() as cur:
result = cur.execute(sql, params).fetchone()
if result is None:
return self._consumer.min_seqid()
else:
return cast(int, result[0])
@trace_method("SqliteMetadataSegment.count", OpenTelemetryGranularity.ALL)
@override
def count(self, request_version_context: RequestVersionContext) -> int:
embeddings_t = Table("embeddings")
q = (
self._db.querybuilder()
.from_(embeddings_t)
.where(
embeddings_t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
.select(fn.Count(embeddings_t.id))
)
sql, params = get_sql(q)
with self._db.tx() as cur:
result = cur.execute(sql, params).fetchone()[0]
return cast(int, result)
@trace_method("SqliteMetadataSegment.get_metadata", OpenTelemetryGranularity.ALL)
@override
def get_metadata(
self,
request_version_context: RequestVersionContext,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
include_metadata: bool = True,
) -> Sequence[MetadataEmbeddingRecord]:
"""Query for embedding metadata."""
embeddings_t, metadata_t, fulltext_t = Tables(
"embeddings", "embedding_metadata", "embedding_fulltext_search"
)
limit = limit or 2**63 - 1
offset = offset or 0
if limit < 0:
raise ValueError("Limit cannot be negative")
select_clause = [
embeddings_t.id,
embeddings_t.embedding_id,
embeddings_t.seq_id,
]
if include_metadata:
select_clause.extend(
[
metadata_t.key,
metadata_t.string_value,
metadata_t.int_value,
metadata_t.float_value,
metadata_t.bool_value,
]
)
q = (
(
self._db.querybuilder()
.from_(embeddings_t)
.left_join(metadata_t)
.on(embeddings_t.id == metadata_t.id)
)
.select(*select_clause)
.orderby(embeddings_t.id)
)
# If there is a query that touches the metadata table, it uses
# where and where_document filters, we treat this case seperately
if where is not None or where_document is not None:
metadata_q = (
self._db.querybuilder()
.from_(embeddings_t)
.select(embeddings_t.id)
.left_join(metadata_t)
.on(embeddings_t.id == metadata_t.id)
.orderby(embeddings_t.id)
.where(
embeddings_t.segment_id
== ParameterValue(self._db.uuid_to_db(self._id))
)
.distinct() # These are embedding ids
)
if where:
metadata_q = metadata_q.where(
self._where_map_criterion(
metadata_q, where, metadata_t, embeddings_t
)
)
if where_document:
metadata_q = metadata_q.where(
self._where_doc_criterion(
metadata_q, where_document, metadata_t, fulltext_t, embeddings_t
)
)
if ids is not None:
metadata_q = metadata_q.where(
embeddings_t.embedding_id.isin(ParameterValue(ids))
)
metadata_q = metadata_q.limit(limit)
metadata_q = metadata_q.offset(offset)
q = q.where(embeddings_t.id.isin(metadata_q))
else:
# In the case where we don't use the metadata table
# We have to apply limit/offset to embeddings and then join
# with metadata
embeddings_q = (
self._db.querybuilder()
.from_(embeddings_t)
.select(embeddings_t.id)
.where(
embeddings_t.segment_id
== ParameterValue(self._db.uuid_to_db(self._id))
)
.orderby(embeddings_t.id)
.limit(limit)
.offset(offset)
)
if ids is not None:
embeddings_q = embeddings_q.where(
embeddings_t.embedding_id.isin(ParameterValue(ids))
)
q = q.where(embeddings_t.id.isin(embeddings_q))
with self._db.tx() as cur:
# Execute the query with the limit and offset already applied
return list(self._records(cur, q, include_metadata))
def _records(
self, cur: Cursor, q: QueryBuilder, include_metadata: bool
) -> Generator[MetadataEmbeddingRecord, None, None]:
"""Given a cursor and a QueryBuilder, yield a generator of records. Assumes
cursor returns rows in ID order."""
sql, params = get_sql(q)
cur.execute(sql, params)
cur_iterator = iter(cur.fetchone, None)
group_iterator = groupby(cur_iterator, lambda r: int(r[0]))
for _, group in group_iterator:
yield self._record(list(group), include_metadata)
@trace_method("SqliteMetadataSegment._record", OpenTelemetryGranularity.ALL)
def _record(
self, rows: Sequence[Tuple[Any, ...]], include_metadata: bool
) -> MetadataEmbeddingRecord:
"""Given a list of DB rows with the same ID, construct a
MetadataEmbeddingRecord"""
_, embedding_id, seq_id = rows[0][:3]
if not include_metadata:
return MetadataEmbeddingRecord(id=embedding_id, metadata=None)
metadata = {}
for row in rows:
key, string_value, int_value, float_value, bool_value = row[3:]
if string_value is not None:
metadata[key] = string_value
elif int_value is not None:
metadata[key] = int_value
elif float_value is not None:
metadata[key] = float_value
elif bool_value is not None:
if bool_value == 1:
metadata[key] = True
else:
metadata[key] = False
return MetadataEmbeddingRecord(
id=embedding_id,
metadata=metadata or None,
)
@trace_method("SqliteMetadataSegment._insert_record", OpenTelemetryGranularity.ALL)
def _insert_record(self, cur: Cursor, record: LogRecord, upsert: bool) -> None:
"""Add or update a single EmbeddingRecord into the DB"""
t = Table("embeddings")
q = (
self._db.querybuilder()
.into(t)
.columns(t.segment_id, t.embedding_id, t.seq_id)
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
).insert(
ParameterValue(self._db.uuid_to_db(self._id)),
ParameterValue(record["record"]["id"]),
ParameterValue(record["log_offset"]),
)
sql, params = get_sql(q)
sql = sql + "RETURNING id"
try:
id = cur.execute(sql, params).fetchone()[0]
except sqlite3.IntegrityError:
# Can't use INSERT OR REPLACE here because it changes the primary key.
if upsert:
return self._update_record(cur, record)
else:
logger.warning(
f"Insert of existing embedding ID: {record['record']['id']}"
)
# We are trying to add for a record that already exists. Fail the call.
# We don't throw an exception since this is in principal an async path
return
if record["record"]["metadata"]:
self._update_metadata(cur, id, record["record"]["metadata"])
@trace_method(
"SqliteMetadataSegment._update_metadata", OpenTelemetryGranularity.ALL
)
def _update_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> None:
"""Update the metadata for a single EmbeddingRecord"""
t = Table("embedding_metadata")
to_delete = [k for k, v in metadata.items() if v is None]
if to_delete:
q = (
self._db.querybuilder()
.from_(t)
.where(t.id == ParameterValue(id))
.where(t.key.isin(ParameterValue(to_delete)))
.delete()
)
sql, params = get_sql(q)
cur.execute(sql, params)
self._insert_metadata(cur, id, metadata)
@trace_method(
"SqliteMetadataSegment._insert_metadata", OpenTelemetryGranularity.ALL
)
def _insert_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> None:
"""Insert or update each metadata row for a single embedding record"""
t = Table("embedding_metadata")
q = (
self._db.querybuilder()
.into(t)
.columns(
t.id,
t.key,
t.string_value,
t.int_value,
t.float_value,
t.bool_value,
)
)
for key, value in metadata.items():
if isinstance(value, str):
q = q.insert(
ParameterValue(id),
ParameterValue(key),
ParameterValue(value),
None,
None,
None,
)
# isinstance(True, int) evaluates to True, so we need to check for bools separately
elif isinstance(value, bool):
q = q.insert(
ParameterValue(id),
ParameterValue(key),
None,
None,
None,
ParameterValue(value),
)
elif isinstance(value, int):
q = q.insert(
ParameterValue(id),
ParameterValue(key),
None,
ParameterValue(value),
None,
None,
)
elif isinstance(value, float):
q = q.insert(
ParameterValue(id),
ParameterValue(key),
None,
None,
ParameterValue(value),
None,
)
sql, params = get_sql(q)
sql = sql.replace("INSERT", "INSERT OR REPLACE")
if sql:
cur.execute(sql, params)
if "chroma:document" in metadata:
t = Table("embedding_fulltext_search")
def insert_into_fulltext_search() -> None:
q = (
self._db.querybuilder()
.into(t)
.columns(t.rowid, t.string_value)
.insert(
ParameterValue(id),
ParameterValue(metadata["chroma:document"]),
)
)
sql, params = get_sql(q)
cur.execute(sql, params)
try:
insert_into_fulltext_search()
except sqlite3.IntegrityError:
q = (
self._db.querybuilder()
.from_(t)
.where(t.rowid == ParameterValue(id))
.delete()
)
sql, params = get_sql(q)
cur.execute(sql, params)
insert_into_fulltext_search()
@trace_method("SqliteMetadataSegment._delete_record", OpenTelemetryGranularity.ALL)
def _delete_record(self, cur: Cursor, record: LogRecord) -> None:
"""Delete a single EmbeddingRecord from the DB"""
t = Table("embeddings")
fts_t = Table("embedding_fulltext_search")
q = (
self._db.querybuilder()
.from_(t)
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
.delete()
)
q_fts = (
self._db.querybuilder()
.from_(fts_t)
.delete()
.where(
fts_t.rowid.isin(
self._db.querybuilder()
.from_(t)
.select(t.id)
.where(
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
)
)
)
cur.execute(*get_sql(q_fts))
sql, params = get_sql(q)
sql = sql + " RETURNING id"
result = cur.execute(sql, params).fetchone()
if result is None:
logger.warning(
f"Delete of nonexisting embedding ID: {record['record']['id']}"
)
else:
id = result[0]
# Manually delete metadata; cannot use cascade because
# that triggers on replace
metadata_t = Table("embedding_metadata")
q = (
self._db.querybuilder()
.from_(metadata_t)
.where(metadata_t.id == ParameterValue(id))
.delete()
)
sql, params = get_sql(q)
cur.execute(sql, params)
@trace_method("SqliteMetadataSegment._update_record", OpenTelemetryGranularity.ALL)
def _update_record(self, cur: Cursor, record: LogRecord) -> None:
"""Update a single EmbeddingRecord in the DB"""
t = Table("embeddings")
q = (
self._db.querybuilder()
.update(t)
.set(t.seq_id, ParameterValue(record["log_offset"]))
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
)
sql, params = get_sql(q)
sql = sql + " RETURNING id"
result = cur.execute(sql, params).fetchone()
if result is None:
logger.warning(
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
else:
id = result[0]
if record["record"]["metadata"]:
self._update_metadata(cur, id, record["record"]["metadata"])
@trace_method("SqliteMetadataSegment._write_metadata", OpenTelemetryGranularity.ALL)
def _write_metadata(self, records: Sequence[LogRecord]) -> None:
"""Write embedding metadata to the database. Care should be taken to ensure
records are append-only (that is, that seq-ids should increase monotonically)"""
with self._db.tx() as cur:
for record in records:
if record["record"]["operation"] == Operation.ADD:
self._insert_record(cur, record, False)
elif record["record"]["operation"] == Operation.UPSERT:
self._insert_record(cur, record, True)
elif record["record"]["operation"] == Operation.DELETE:
self._delete_record(cur, record)
elif record["record"]["operation"] == Operation.UPDATE:
self._update_record(cur, record)
q = (
self._db.querybuilder()
.into(Table("max_seq_id"))
.columns("segment_id", "seq_id")
.insert(
ParameterValue(self._db.uuid_to_db(self._id)),
ParameterValue(record["log_offset"]),
)
)
sql, params = get_sql(q)
sql = sql.replace("INSERT", "INSERT OR REPLACE")
cur.execute(sql, params)
@trace_method(
"SqliteMetadataSegment._where_map_criterion", OpenTelemetryGranularity.ALL
)
def _where_map_criterion(
self, q: QueryBuilder, where: Where, metadata_t: Table, embeddings_t: Table
) -> Criterion:
clause: List[Criterion] = []
for k, v in where.items():
if k == "$and":
criteria = [
self._where_map_criterion(q, w, metadata_t, embeddings_t)
for w in cast(Sequence[Where], v)
]
clause.append(reduce(lambda x, y: x & y, criteria))
elif k == "$or":
criteria = [
self._where_map_criterion(q, w, metadata_t, embeddings_t)
for w in cast(Sequence[Where], v)
]
clause.append(reduce(lambda x, y: x | y, criteria))
else:
expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v)
clause.append(_where_clause(k, expr, q, metadata_t, embeddings_t))
return reduce(lambda x, y: x & y, clause)
@trace_method(
"SqliteMetadataSegment._where_doc_criterion", OpenTelemetryGranularity.ALL
)
def _where_doc_criterion(
self,
q: QueryBuilder,
where: WhereDocument,
metadata_t: Table,
fulltext_t: Table,
embeddings_t: Table,
) -> Criterion:
for k, v in where.items():
if k == "$and":
criteria = [
self._where_doc_criterion(
q, w, metadata_t, fulltext_t, embeddings_t
)
for w in cast(Sequence[WhereDocument], v)
]
return reduce(lambda x, y: x & y, criteria)
elif k == "$or":
criteria = [
self._where_doc_criterion(
q, w, metadata_t, fulltext_t, embeddings_t
)
for w in cast(Sequence[WhereDocument], v)
]
return reduce(lambda x, y: x | y, criteria)
elif k in ("$contains", "$not_contains"):
v = cast(str, v)
search_term = f"%{v}%"
sq = (
self._db.querybuilder()
.from_(fulltext_t)
.select(fulltext_t.rowid)
.where(fulltext_t.string_value.like(ParameterValue(search_term)))
)
return (
embeddings_t.id.isin(sq)
if k == "$contains"
else embeddings_t.id.notin(sq)
)
else:
raise ValueError(f"Unknown where_doc operator {k}")
raise ValueError("Empty where_doc")
@trace_method("SqliteMetadataSegment.delete", OpenTelemetryGranularity.ALL)
@override
def delete(self) -> None:
t = Table("embeddings")
t1 = Table("embedding_metadata")
t2 = Table("embedding_fulltext_search")
q0 = (
self._db.querybuilder()
.from_(t1)
.delete()
.where(
t1.id.isin(
self._db.querybuilder()
.from_(t)
.select(t.id)
.where(
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
)
)
)
q = (
self._db.querybuilder()
.from_(t)
.delete()
.where(
t.id.isin(
self._db.querybuilder()
.from_(t)
.select(t.id)
.where(
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
)
)
)
q_fts = (
self._db.querybuilder()
.from_(t2)
.delete()
.where(
t2.rowid.isin(
self._db.querybuilder()
.from_(t)
.select(t.id)
.where(
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
)
)
)
with self._db.tx() as cur:
cur.execute(*get_sql(q_fts))
cur.execute(*get_sql(q0))
cur.execute(*get_sql(q))
def _where_clause(
key: str,
expr: Union[
LiteralValue,
Dict[WhereOperator, LiteralValue],
Dict[InclusionExclusionOperator, List[LiteralValue]],
],
metadata_q: QueryBuilder,
metadata_t: Table,
embeddings_t: Table,
) -> Criterion:
"""Given a field name, an expression, and a table, construct a Pypika Criterion"""
# Literal value case
if isinstance(expr, (str, int, float, bool)):
return _where_clause(
key,
{cast(WhereOperator, "$eq"): expr},
metadata_q,
metadata_t,
embeddings_t,
)
# Operator dict case
operator, value = next(iter(expr.items()))
return _value_criterion(key, value, operator, metadata_q, metadata_t, embeddings_t)
def _value_criterion(
key: str,
value: Union[LiteralValue, List[LiteralValue]],
op: Union[WhereOperator, InclusionExclusionOperator],
metadata_q: QueryBuilder,
metadata_t: Table,
embeddings_t: Table,
) -> Criterion:
"""Creates the filter for a single operator"""
def is_numeric(obj: object) -> bool:
return (not isinstance(obj, bool)) and isinstance(obj, (int, float))
sub_q = metadata_q.where(metadata_t.key == ParameterValue(key))
p_val = ParameterValue(value)
if is_numeric(value) or (isinstance(value, list) and is_numeric(value[0])):
int_col, float_col = metadata_t.int_value, metadata_t.float_value
if op in ("$eq", "$ne"):
expr = (int_col == p_val) | (float_col == p_val)
elif op == "$gt":
expr = (int_col > p_val) | (float_col > p_val)
elif op == "$gte":
expr = (int_col >= p_val) | (float_col >= p_val)
elif op == "$lt":
expr = (int_col < p_val) | (float_col < p_val)
elif op == "$lte":
expr = (int_col <= p_val) | (float_col <= p_val)
else:
expr = int_col.isin(p_val) | float_col.isin(p_val)
else:
if isinstance(value, bool) or (
isinstance(value, list) and isinstance(value[0], bool)
):
col = metadata_t.bool_value
else:
col = metadata_t.string_value
if op in ("$eq", "$ne"):
expr = col == p_val
else:
expr = col.isin(p_val)
if op in ("$ne", "$nin"):
return embeddings_t.id.notin(sub_q.where(expr))
else:
return embeddings_t.id.isin(sub_q.where(expr))

View File

@@ -0,0 +1,106 @@
from typing import Dict, List, Set, cast
from chromadb.types import LogRecord, Operation, Vector
class Batch:
"""Used to model the set of changes as an atomic operation"""
_ids_to_records: Dict[str, LogRecord]
_deleted_ids: Set[str]
_written_ids: Set[str]
_upsert_add_ids: Set[str] # IDs that are being added in an upsert
add_count: int
update_count: int
def __init__(self) -> None:
self._ids_to_records = {}
self._deleted_ids = set()
self._written_ids = set()
self._upsert_add_ids = set()
self.add_count = 0
self.update_count = 0
def __len__(self) -> int:
"""Get the number of changes in this batch"""
return len(self._written_ids) + len(self._deleted_ids)
def get_deleted_ids(self) -> List[str]:
"""Get the list of deleted embeddings in this batch"""
return list(self._deleted_ids)
def get_written_ids(self) -> List[str]:
"""Get the list of written embeddings in this batch"""
return list(self._written_ids)
def get_written_vectors(self, ids: List[str]) -> List[Vector]:
"""Get the list of vectors to write in this batch"""
return [
cast(Vector, self._ids_to_records[id]["record"]["embedding"]) for id in ids
]
def get_record(self, id: str) -> LogRecord:
"""Get the record for a given ID"""
return self._ids_to_records[id]
def is_deleted(self, id: str) -> bool:
"""Check if a given ID is deleted"""
return id in self._deleted_ids
@property
def delete_count(self) -> int:
return len(self._deleted_ids)
def apply(self, record: LogRecord, exists_already: bool = False) -> None:
"""Apply an embedding record to this batch. Records passed to this method are assumed to be validated for correctness.
For example, a delete or update presumes the ID exists in the index. An add presumes the ID does not exist in the index.
The exists_already flag should be set to True if the ID does exist in the index, and False otherwise.
"""
id = record["record"]["id"]
if record["record"]["operation"] == Operation.DELETE:
# If the ID was previously written, remove it from the written set
# And update the add/update/delete counts
if id in self._written_ids:
self._written_ids.remove(id)
if self._ids_to_records[id]["record"]["operation"] == Operation.ADD:
self.add_count -= 1
elif (
self._ids_to_records[id]["record"]["operation"] == Operation.UPDATE
):
self.update_count -= 1
self._deleted_ids.add(id)
elif (
self._ids_to_records[id]["record"]["operation"] == Operation.UPSERT
):
if id in self._upsert_add_ids:
self.add_count -= 1
self._upsert_add_ids.remove(id)
else:
self.update_count -= 1
self._deleted_ids.add(id)
elif id not in self._deleted_ids:
self._deleted_ids.add(id)
# Remove the record from the batch
if id in self._ids_to_records:
del self._ids_to_records[id]
else:
self._ids_to_records[id] = record
self._written_ids.add(id)
# If the ID was previously deleted, remove it from the deleted set
# And update the delete count
if id in self._deleted_ids:
self._deleted_ids.remove(id)
# Update the add/update counts
if record["record"]["operation"] == Operation.UPSERT:
if not exists_already:
self.add_count += 1
self._upsert_add_ids.add(id)
else:
self.update_count += 1
elif record["record"]["operation"] == Operation.ADD:
self.add_count += 1
elif record["record"]["operation"] == Operation.UPDATE:
self.update_count += 1

View File

@@ -0,0 +1,151 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Set
import numpy as np
import numpy.typing as npt
from chromadb.types import (
LogRecord,
VectorEmbeddingRecord,
VectorQuery,
VectorQueryResult,
)
from chromadb.utils import distance_functions
import logging
logger = logging.getLogger(__name__)
class BruteForceIndex:
"""A lightweight, numpy based brute force index that is used for batches that have not been indexed into hnsw yet. It is not
thread safe and callers should ensure that only one thread is accessing it at a time.
"""
id_to_index: Dict[str, int]
index_to_id: Dict[int, str]
id_to_seq_id: Dict[str, int]
deleted_ids: Set[str]
free_indices: List[int]
size: int
dimensionality: int
distance_fn: Callable[[npt.NDArray[Any], npt.NDArray[Any]], float]
vectors: npt.NDArray[Any]
def __init__(self, size: int, dimensionality: int, space: str = "l2"):
if space == "l2":
self.distance_fn = distance_functions.l2
elif space == "ip":
self.distance_fn = distance_functions.ip
elif space == "cosine":
self.distance_fn = distance_functions.cosine
else:
raise Exception(f"Unknown distance function: {space}")
self.id_to_index = {}
self.index_to_id = {}
self.id_to_seq_id = {}
self.deleted_ids = set()
self.free_indices = list(range(size))
self.size = size
self.dimensionality = dimensionality
self.vectors = np.zeros((size, dimensionality))
def __len__(self) -> int:
return len(self.id_to_index)
def clear(self) -> None:
self.id_to_index = {}
self.index_to_id = {}
self.id_to_seq_id = {}
self.deleted_ids.clear()
self.free_indices = list(range(self.size))
self.vectors.fill(0)
def upsert(self, records: List[LogRecord]) -> None:
if len(records) + len(self) > self.size:
raise Exception(
"Index with capacity {} and {} current entries cannot add {} records".format(
self.size, len(self), len(records)
)
)
for i, record in enumerate(records):
id = record["record"]["id"]
vector = record["record"]["embedding"]
self.id_to_seq_id[id] = record["log_offset"]
if id in self.deleted_ids:
self.deleted_ids.remove(id)
# TODO: It may be faster to use multi-index selection on the vectors array
if id in self.id_to_index:
# Update
index = self.id_to_index[id]
self.vectors[index] = vector
else:
# Add
next_index = self.free_indices.pop()
self.id_to_index[id] = next_index
self.index_to_id[next_index] = id
self.vectors[next_index] = vector
def delete(self, records: List[LogRecord]) -> None:
for record in records:
id = record["record"]["id"]
if id in self.id_to_index:
index = self.id_to_index[id]
self.deleted_ids.add(id)
del self.id_to_index[id]
del self.index_to_id[index]
del self.id_to_seq_id[id]
self.vectors[index].fill(np.nan)
self.free_indices.append(index)
else:
logger.warning(f"Delete of nonexisting embedding ID: {id}")
def has_id(self, id: str) -> bool:
"""Returns whether the index contains the given ID"""
return id in self.id_to_index and id not in self.deleted_ids
def get_vectors(
self, ids: Optional[Sequence[str]] = None
) -> Sequence[VectorEmbeddingRecord]:
target_ids = ids or self.id_to_index.keys()
return [
VectorEmbeddingRecord(
id=id,
embedding=self.vectors[self.id_to_index[id]],
)
for id in target_ids
]
def query(self, query: VectorQuery) -> Sequence[Sequence[VectorQueryResult]]:
np_query = np.array(query["vectors"], dtype=np.float32)
allowed_ids = (
None if query["allowed_ids"] is None else set(query["allowed_ids"])
)
distances = np.apply_along_axis(
lambda query: np.apply_along_axis(self.distance_fn, 1, self.vectors, query),
1,
np_query,
)
indices = np.argsort(distances)
# Filter out deleted labels
filtered_results = []
for i, index_list in enumerate(indices):
curr_results = []
for j in index_list:
# If the index is in the index_to_id map, then it has been added
if j in self.index_to_id:
id = self.index_to_id[j]
if id not in self.deleted_ids and (
allowed_ids is None or id in allowed_ids
):
curr_results.append(
VectorQueryResult(
id=id,
distance=distances[i][j].item(),
embedding=self.vectors[j],
)
)
filtered_results.append(curr_results)
return filtered_results

View File

@@ -0,0 +1,88 @@
import multiprocessing
import re
from typing import Any, Callable, Dict, Union
from chromadb.types import Metadata
Validator = Callable[[Union[str, int, float]], bool]
param_validators: Dict[str, Validator] = {
"hnsw:space": lambda p: bool(re.match(r"^(l2|cosine|ip)$", str(p))),
"hnsw:construction_ef": lambda p: isinstance(p, int),
"hnsw:search_ef": lambda p: isinstance(p, int),
"hnsw:M": lambda p: isinstance(p, int),
"hnsw:num_threads": lambda p: isinstance(p, int),
"hnsw:resize_factor": lambda p: isinstance(p, (int, float)),
}
# Extra params used for persistent hnsw
persistent_param_validators: Dict[str, Validator] = {
"hnsw:batch_size": lambda p: isinstance(p, int) and p > 2,
"hnsw:sync_threshold": lambda p: isinstance(p, int) and p > 2,
}
class Params:
@staticmethod
def _select(metadata: Metadata) -> Dict[str, Any]:
segment_metadata = {}
for param, value in metadata.items():
if param.startswith("hnsw:"):
segment_metadata[param] = value
return segment_metadata
@staticmethod
def _validate(metadata: Dict[str, Any], validators: Dict[str, Validator]) -> None:
"""Validates the metadata"""
# Validate it
for param, value in metadata.items():
if param not in validators:
raise ValueError(f"Unknown HNSW parameter: {param}")
if not validators[param](value):
raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}")
class HnswParams(Params):
space: str
construction_ef: int
search_ef: int
M: int
num_threads: int
resize_factor: float
def __init__(self, metadata: Metadata):
metadata = metadata or {}
self.space = str(metadata.get("hnsw:space", "l2"))
self.construction_ef = int(metadata.get("hnsw:construction_ef", 100))
self.search_ef = int(metadata.get("hnsw:search_ef", 100))
self.M = int(metadata.get("hnsw:M", 16))
self.num_threads = int(
metadata.get("hnsw:num_threads", multiprocessing.cpu_count())
)
self.resize_factor = float(metadata.get("hnsw:resize_factor", 1.2))
@staticmethod
def extract(metadata: Metadata) -> Metadata:
"""Validate and return only the relevant hnsw params"""
segment_metadata = HnswParams._select(metadata)
HnswParams._validate(segment_metadata, param_validators)
return segment_metadata
class PersistentHnswParams(HnswParams):
batch_size: int
sync_threshold: int
def __init__(self, metadata: Metadata):
super().__init__(metadata)
self.batch_size = int(metadata.get("hnsw:batch_size", 100))
self.sync_threshold = int(metadata.get("hnsw:sync_threshold", 1000))
@staticmethod
def extract(metadata: Metadata) -> Metadata:
"""Returns only the relevant hnsw params"""
all_validators = {**param_validators, **persistent_param_validators}
segment_metadata = PersistentHnswParams._select(metadata)
PersistentHnswParams._validate(segment_metadata, all_validators)
return segment_metadata

View File

@@ -0,0 +1,332 @@
from overrides import override
from typing import Optional, Sequence, Dict, Set, List, cast
from uuid import UUID
from chromadb.segment import VectorReader
from chromadb.ingest import Consumer
from chromadb.config import System, Settings
from chromadb.segment.impl.vector.batch import Batch
from chromadb.segment.impl.vector.hnsw_params import HnswParams
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import (
LogRecord,
RequestVersionContext,
VectorEmbeddingRecord,
VectorQuery,
VectorQueryResult,
SeqId,
Segment,
Metadata,
Operation,
Vector,
)
from chromadb.errors import InvalidDimensionException
import hnswlib
from chromadb.utils.read_write_lock import ReadWriteLock, ReadRWLock, WriteRWLock
import logging
import numpy as np
logger = logging.getLogger(__name__)
DEFAULT_CAPACITY = 1000
class LocalHnswSegment(VectorReader):
_id: UUID
_consumer: Consumer
_collection: Optional[UUID]
_subscription: Optional[UUID]
_settings: Settings
_params: HnswParams
_index: Optional[hnswlib.Index]
_dimensionality: Optional[int]
_total_elements_added: int
_max_seq_id: SeqId
_lock: ReadWriteLock
_id_to_label: Dict[str, int]
_label_to_id: Dict[int, str]
# Note: As of the time of writing, this mapping is no longer needed.
# We merely keep it around for easy compatibility with the old code and
# debugging purposes.
_id_to_seq_id: Dict[str, SeqId]
_opentelemtry_client: OpenTelemetryClient
def __init__(self, system: System, segment: Segment):
self._consumer = system.instance(Consumer)
self._id = segment["id"]
self._collection = segment["collection"]
self._subscription = None
self._settings = system.settings
self._params = HnswParams(segment["metadata"] or {})
self._index = None
self._dimensionality = None
self._total_elements_added = 0
self._max_seq_id = self._consumer.min_seqid()
self._id_to_seq_id = {}
self._id_to_label = {}
self._label_to_id = {}
self._lock = ReadWriteLock()
self._opentelemtry_client = system.require(OpenTelemetryClient)
@staticmethod
@override
def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]:
# Extract relevant metadata
segment_metadata = HnswParams.extract(metadata)
return segment_metadata
@trace_method("LocalHnswSegment.start", OpenTelemetryGranularity.ALL)
@override
def start(self) -> None:
super().start()
if self._collection:
seq_id = self.max_seqid()
self._subscription = self._consumer.subscribe(
self._collection, self._write_records, start=seq_id
)
@trace_method("LocalHnswSegment.stop", OpenTelemetryGranularity.ALL)
@override
def stop(self) -> None:
super().stop()
if self._subscription:
self._consumer.unsubscribe(self._subscription)
@trace_method("LocalHnswSegment.get_vectors", OpenTelemetryGranularity.ALL)
@override
def get_vectors(
self,
request_version_context: RequestVersionContext,
ids: Optional[Sequence[str]] = None,
) -> Sequence[VectorEmbeddingRecord]:
if ids is None:
labels = list(self._label_to_id.keys())
else:
labels = []
for id in ids:
if id in self._id_to_label:
labels.append(self._id_to_label[id])
results = []
if self._index is not None:
vectors = cast(
Sequence[Vector], np.array(self._index.get_items(labels))
) # version 0.8 of hnswlib allows return_type="numpy"
for label, vector in zip(labels, vectors):
id = self._label_to_id[label]
results.append(VectorEmbeddingRecord(id=id, embedding=vector))
return results
@trace_method("LocalHnswSegment.query_vectors", OpenTelemetryGranularity.ALL)
@override
def query_vectors(
self, query: VectorQuery
) -> Sequence[Sequence[VectorQueryResult]]:
if self._index is None:
return [[] for _ in range(len(query["vectors"]))]
k = query["k"]
size = len(self._id_to_label)
if k > size:
logger.warning(
f"Number of requested results {k} is greater than number of elements in index {size}, updating n_results = {size}"
)
k = size
labels: Set[int] = set()
ids = query["allowed_ids"]
if ids is not None:
labels = {self._id_to_label[id] for id in ids if id in self._id_to_label}
if len(labels) < k:
k = len(labels)
def filter_function(label: int) -> bool:
return label in labels
query_vectors = query["vectors"]
with ReadRWLock(self._lock):
result_labels, distances = self._index.knn_query(
np.array(query_vectors, dtype=np.float32),
k=k,
filter=filter_function if ids else None,
)
# TODO: these casts are not correct, hnswlib returns np
# distances = cast(List[List[float]], distances)
# result_labels = cast(List[List[int]], result_labels)
all_results: List[List[VectorQueryResult]] = []
for result_i in range(len(result_labels)):
results: List[VectorQueryResult] = []
for label, distance in zip(
result_labels[result_i], distances[result_i]
):
id = self._label_to_id[label]
if query["include_embeddings"]:
embedding = np.array(
self._index.get_items([label])[0]
) # version 0.8 of hnswlib allows return_type="numpy"
else:
embedding = None
results.append(
VectorQueryResult(
id=id,
distance=distance.item(),
embedding=embedding,
)
)
all_results.append(results)
return all_results
@override
def max_seqid(self) -> SeqId:
return self._max_seq_id
@override
def count(self, request_version_context: RequestVersionContext) -> int:
return len(self._id_to_label)
@trace_method("LocalHnswSegment._init_index", OpenTelemetryGranularity.ALL)
def _init_index(self, dimensionality: int) -> None:
# more comments available at the source: https://github.com/nmslib/hnswlib
index = hnswlib.Index(
space=self._params.space, dim=dimensionality
) # possible options are l2, cosine or ip
index.init_index(
max_elements=DEFAULT_CAPACITY,
ef_construction=self._params.construction_ef,
M=self._params.M,
)
index.set_ef(self._params.search_ef)
index.set_num_threads(self._params.num_threads)
self._index = index
self._dimensionality = dimensionality
@trace_method("LocalHnswSegment._ensure_index", OpenTelemetryGranularity.ALL)
def _ensure_index(self, n: int, dim: int) -> None:
"""Create or resize the index as necessary to accomodate N new records"""
if not self._index:
self._dimensionality = dim
self._init_index(dim)
else:
if dim != self._dimensionality:
raise InvalidDimensionException(
f"Dimensionality of ({dim}) does not match index"
+ f"dimensionality ({self._dimensionality})"
)
index = cast(hnswlib.Index, self._index)
if (self._total_elements_added + n) > index.get_max_elements():
new_size = int(
(self._total_elements_added + n) * self._params.resize_factor
)
index.resize_index(max(new_size, DEFAULT_CAPACITY))
@trace_method("LocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL)
def _apply_batch(self, batch: Batch) -> None:
"""Apply a batch of changes, as atomically as possible."""
deleted_ids = batch.get_deleted_ids()
written_ids = batch.get_written_ids()
vectors_to_write = batch.get_written_vectors(written_ids)
labels_to_write = [0] * len(vectors_to_write)
if len(deleted_ids) > 0:
index = cast(hnswlib.Index, self._index)
for i in range(len(deleted_ids)):
id = deleted_ids[i]
# Never added this id to hnsw, so we can safely ignore it for deletions
if id not in self._id_to_label:
continue
label = self._id_to_label[id]
index.mark_deleted(label)
del self._id_to_label[id]
del self._label_to_id[label]
del self._id_to_seq_id[id]
if len(written_ids) > 0:
self._ensure_index(batch.add_count, len(vectors_to_write[0]))
next_label = self._total_elements_added + 1
for i in range(len(written_ids)):
if written_ids[i] not in self._id_to_label:
labels_to_write[i] = next_label
next_label += 1
else:
labels_to_write[i] = self._id_to_label[written_ids[i]]
index = cast(hnswlib.Index, self._index)
# First, update the index
index.add_items(vectors_to_write, labels_to_write)
# If that succeeds, update the mappings
for i, id in enumerate(written_ids):
self._id_to_seq_id[id] = batch.get_record(id)["log_offset"]
self._id_to_label[id] = labels_to_write[i]
self._label_to_id[labels_to_write[i]] = id
# If that succeeds, update the total count
self._total_elements_added += batch.add_count
@trace_method("LocalHnswSegment._write_records", OpenTelemetryGranularity.ALL)
def _write_records(self, records: Sequence[LogRecord]) -> None:
"""Add a batch of embeddings to the index"""
if not self._running:
raise RuntimeError("Cannot add embeddings to stopped component")
# Avoid all sorts of potential problems by ensuring single-threaded access
with WriteRWLock(self._lock):
batch = Batch()
for record in records:
self._max_seq_id = max(self._max_seq_id, record["log_offset"])
id = record["record"]["id"]
op = record["record"]["operation"]
label = self._id_to_label.get(id, None)
if op == Operation.DELETE:
if label:
batch.apply(record)
else:
logger.warning(f"Delete of nonexisting embedding ID: {id}")
elif op == Operation.UPDATE:
if record["record"]["embedding"] is not None:
if label is not None:
batch.apply(record)
else:
logger.warning(
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
elif op == Operation.ADD:
if not label:
batch.apply(record, False)
else:
logger.warning(f"Add of existing embedding ID: {id}")
elif op == Operation.UPSERT:
batch.apply(record, label is not None)
self._apply_batch(batch)
@override
def delete(self) -> None:
raise NotImplementedError()

View File

@@ -0,0 +1,543 @@
import os
import shutil
from overrides import override
import pickle
from typing import Dict, List, Optional, Sequence, Set, cast
from chromadb.config import System
from chromadb.db.base import ParameterValue, get_sql
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.segment.impl.vector.batch import Batch
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
from chromadb.segment.impl.vector.local_hnsw import (
DEFAULT_CAPACITY,
LocalHnswSegment,
)
from chromadb.segment.impl.vector.brute_force_index import BruteForceIndex
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import (
LogRecord,
Metadata,
Operation,
RequestVersionContext,
Segment,
SeqId,
Vector,
VectorEmbeddingRecord,
VectorQuery,
VectorQueryResult,
)
import hnswlib
import logging
from pypika import Table
import numpy as np
from chromadb.utils.read_write_lock import ReadRWLock, WriteRWLock
logger = logging.getLogger(__name__)
class PersistentData:
"""Stores the data and metadata needed for a PersistentLocalHnswSegment"""
dimensionality: Optional[int]
total_elements_added: int
max_seq_id: SeqId
"This is a legacy field. It is no longer mutated, but kept to allow automatic migration of the `max_seq_id` from the pickled file to the `max_seq_id` table in SQLite."
id_to_label: Dict[str, int]
label_to_id: Dict[int, str]
id_to_seq_id: Dict[str, SeqId]
def __init__(
self,
dimensionality: Optional[int],
total_elements_added: int,
id_to_label: Dict[str, int],
label_to_id: Dict[int, str],
id_to_seq_id: Dict[str, SeqId],
):
self.dimensionality = dimensionality
self.total_elements_added = total_elements_added
self.id_to_label = id_to_label
self.label_to_id = label_to_id
self.id_to_seq_id = id_to_seq_id
@staticmethod
def load_from_file(filename: str) -> "PersistentData":
"""Load persistent data from a file"""
with open(filename, "rb") as f:
ret = cast(PersistentData, pickle.load(f))
return ret
class PersistentLocalHnswSegment(LocalHnswSegment):
METADATA_FILE: str = "index_metadata.pickle"
# How many records to add to index at once, we do this because crossing the python/c++ boundary is expensive (for add())
# When records are not added to the c++ index, they are buffered in memory and served
# via brute force search.
_batch_size: int
_brute_force_index: Optional[BruteForceIndex]
_index_initialized: bool = False
_curr_batch: Batch
# How many records to add to index before syncing to disk
_sync_threshold: int
_persist_data: PersistentData
_persist_directory: str
_allow_reset: bool
_db: SqliteDB
_opentelemtry_client: OpenTelemetryClient
_num_log_records_since_last_batch: int = 0
_num_log_records_since_last_persist: int = 0
def __init__(self, system: System, segment: Segment):
super().__init__(system, segment)
self._db = system.instance(SqliteDB)
self._opentelemtry_client = system.require(OpenTelemetryClient)
self._params = PersistentHnswParams(segment["metadata"] or {})
self._batch_size = self._params.batch_size
self._sync_threshold = self._params.sync_threshold
self._allow_reset = system.settings.allow_reset
self._persist_directory = system.settings.require("persist_directory")
self._curr_batch = Batch()
self._brute_force_index = None
if not os.path.exists(self._get_storage_folder()):
os.makedirs(self._get_storage_folder(), exist_ok=True)
# Load persist data if it exists already, otherwise create it
if self._index_exists():
self._persist_data = PersistentData.load_from_file(
self._get_metadata_file()
)
self._dimensionality = self._persist_data.dimensionality
self._total_elements_added = self._persist_data.total_elements_added
self._id_to_label = self._persist_data.id_to_label
self._label_to_id = self._persist_data.label_to_id
self._id_to_seq_id = self._persist_data.id_to_seq_id
# If the index was written to, we need to re-initialize it
if len(self._id_to_label) > 0:
self._dimensionality = cast(int, self._dimensionality)
self._init_index(self._dimensionality)
else:
self._persist_data = PersistentData(
self._dimensionality,
self._total_elements_added,
self._id_to_label,
self._label_to_id,
self._id_to_seq_id,
)
# Hydrate the max_seq_id
with self._db.tx() as cur:
t = Table("max_seq_id")
q = (
self._db.querybuilder()
.from_(t)
.select(t.seq_id)
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
.limit(1)
)
sql, params = get_sql(q)
cur.execute(sql, params)
result = cur.fetchone()
if result:
self._max_seq_id = result[0]
elif self._index_exists():
# Migrate the max_seq_id from the legacy field in the pickled file to the SQLite database
q = (
self._db.querybuilder()
.into(Table("max_seq_id"))
.columns("segment_id", "seq_id")
.insert(
ParameterValue(self._db.uuid_to_db(self._id)),
ParameterValue(self._persist_data.max_seq_id),
)
)
sql, params = get_sql(q)
cur.execute(sql, params)
self._max_seq_id = self._persist_data.max_seq_id
else:
self._max_seq_id = self._consumer.min_seqid()
@staticmethod
@override
def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]:
# Extract relevant metadata
segment_metadata = PersistentHnswParams.extract(metadata)
return segment_metadata
def _index_exists(self) -> bool:
"""Check if the index exists via the metadata file"""
return os.path.exists(self._get_metadata_file())
def _get_metadata_file(self) -> str:
"""Get the metadata file path"""
return os.path.join(self._get_storage_folder(), self.METADATA_FILE)
def _get_storage_folder(self) -> str:
"""Get the storage folder path"""
folder = os.path.join(self._persist_directory, str(self._id))
return folder
@trace_method(
"PersistentLocalHnswSegment._init_index", OpenTelemetryGranularity.ALL
)
@override
def _init_index(self, dimensionality: int) -> None:
index = hnswlib.Index(space=self._params.space, dim=dimensionality)
self._brute_force_index = BruteForceIndex(
size=self._batch_size,
dimensionality=dimensionality,
space=self._params.space,
)
# Check if index exists and load it if it does
if self._index_exists():
index.load_index(
self._get_storage_folder(),
is_persistent_index=True,
max_elements=int(
max(
self.count(
request_version_context=RequestVersionContext(
collection_version=0, log_position=0
)
)
* self._params.resize_factor,
DEFAULT_CAPACITY,
)
),
)
else:
index.init_index(
max_elements=DEFAULT_CAPACITY,
ef_construction=self._params.construction_ef,
M=self._params.M,
is_persistent_index=True,
persistence_location=self._get_storage_folder(),
)
index.set_ef(self._params.search_ef)
index.set_num_threads(self._params.num_threads)
self._index = index
self._dimensionality = dimensionality
self._index_initialized = True
@trace_method("PersistentLocalHnswSegment._persist", OpenTelemetryGranularity.ALL)
def _persist(self) -> None:
"""Persist the index and data to disk"""
index = cast(hnswlib.Index, self._index)
# Persist the index
index.persist_dirty()
# Persist the metadata
self._persist_data.dimensionality = self._dimensionality
self._persist_data.total_elements_added = self._total_elements_added
# TODO: This should really be stored in sqlite, the index itself, or a better
# storage format
self._persist_data.id_to_label = self._id_to_label
self._persist_data.label_to_id = self._label_to_id
self._persist_data.id_to_seq_id = self._id_to_seq_id
with open(self._get_metadata_file(), "wb") as metadata_file:
pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL)
with self._db.tx() as cur:
q = (
self._db.querybuilder()
.into(Table("max_seq_id"))
.columns("segment_id", "seq_id")
.insert(
ParameterValue(self._db.uuid_to_db(self._id)),
ParameterValue(self._max_seq_id),
)
)
sql, params = get_sql(q)
sql = sql.replace("INSERT", "INSERT OR REPLACE")
cur.execute(sql, params)
self._num_log_records_since_last_persist = 0
@trace_method(
"PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL
)
@override
def _apply_batch(self, batch: Batch) -> None:
super()._apply_batch(batch)
if self._num_log_records_since_last_persist >= self._sync_threshold:
self._persist()
self._num_log_records_since_last_batch = 0
@trace_method(
"PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL
)
@override
def _write_records(self, records: Sequence[LogRecord]) -> None:
"""Add a batch of embeddings to the index"""
if not self._running:
raise RuntimeError("Cannot add embeddings to stopped component")
with WriteRWLock(self._lock):
for record in records:
self._num_log_records_since_last_batch += 1
self._num_log_records_since_last_persist += 1
if record["record"]["embedding"] is not None:
self._ensure_index(len(records), len(record["record"]["embedding"]))
if not self._index_initialized:
# If the index is not initialized here, it means that we have
# not yet added any records to the index. So we can just
# ignore the record since it was a delete.
continue
self._brute_force_index = cast(BruteForceIndex, self._brute_force_index)
self._max_seq_id = max(self._max_seq_id, record["log_offset"])
id = record["record"]["id"]
op = record["record"]["operation"]
exists_in_bf_index = self._brute_force_index.has_id(id)
exists_in_persisted_index = self._id_to_label.get(id, None) is not None
exists_in_index = exists_in_bf_index or exists_in_persisted_index
id_is_pending_delete = self._curr_batch.is_deleted(id)
if op == Operation.DELETE:
if exists_in_index:
self._curr_batch.apply(record)
if exists_in_bf_index:
self._brute_force_index.delete([record])
else:
logger.warning(f"Delete of nonexisting embedding ID: {id}")
elif op == Operation.UPDATE:
if record["record"]["embedding"] is not None:
if exists_in_index:
self._curr_batch.apply(record)
self._brute_force_index.upsert([record])
else:
logger.warning(
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
elif op == Operation.ADD:
if record["record"]["embedding"] is not None:
if exists_in_index and not id_is_pending_delete:
logger.warning(f"Add of existing embedding ID: {id}")
else:
self._curr_batch.apply(record, not exists_in_index)
self._brute_force_index.upsert([record])
elif op == Operation.UPSERT:
if record["record"]["embedding"] is not None:
self._curr_batch.apply(record, exists_in_index)
self._brute_force_index.upsert([record])
if self._num_log_records_since_last_batch >= self._batch_size:
self._apply_batch(self._curr_batch)
self._curr_batch = Batch()
self._brute_force_index.clear()
@override
def count(self, request_version_context: RequestVersionContext) -> int:
return (
len(self._id_to_label)
+ self._curr_batch.add_count
- self._curr_batch.delete_count
)
@trace_method(
"PersistentLocalHnswSegment.get_vectors", OpenTelemetryGranularity.ALL
)
@override
def get_vectors(
self,
request_version_context: RequestVersionContext,
ids: Optional[Sequence[str]] = None,
) -> Sequence[VectorEmbeddingRecord]:
"""Get the embeddings from the HNSW index and layered brute force
batch index."""
ids_hnsw: Set[str] = set()
ids_bf: Set[str] = set()
if self._index is not None:
ids_hnsw = set(self._id_to_label.keys())
if self._brute_force_index is not None:
ids_bf = set(self._curr_batch.get_written_ids())
target_ids = ids or list(ids_hnsw.union(ids_bf))
self._brute_force_index = cast(BruteForceIndex, self._brute_force_index)
hnsw_labels = []
results: List[Optional[VectorEmbeddingRecord]] = []
id_to_index: Dict[str, int] = {}
for i, id in enumerate(target_ids):
if id in ids_bf:
results.append(self._brute_force_index.get_vectors([id])[0])
elif id in ids_hnsw and id not in self._curr_batch._deleted_ids:
hnsw_labels.append(self._id_to_label[id])
# Placeholder for hnsw results to be filled in down below so we
# can batch the hnsw get() call
results.append(None)
id_to_index[id] = i
if len(hnsw_labels) > 0 and self._index is not None:
vectors = cast(
Sequence[Vector], np.array(self._index.get_items(hnsw_labels))
) # version 0.8 of hnswlib allows return_type="numpy"
for label, vector in zip(hnsw_labels, vectors):
id = self._label_to_id[label]
results[id_to_index[id]] = VectorEmbeddingRecord(
id=id, embedding=vector
)
return results # type: ignore ## Python can't cast List with Optional to List with VectorEmbeddingRecord
@trace_method(
"PersistentLocalHnswSegment.query_vectors", OpenTelemetryGranularity.ALL
)
@override
def query_vectors(
self, query: VectorQuery
) -> Sequence[Sequence[VectorQueryResult]]:
if self._index is None and self._brute_force_index is None:
return [[] for _ in range(len(query["vectors"]))]
k = query["k"]
if k > self.count(query["request_version_context"]):
count = self.count(query["request_version_context"])
logger.warning(
f"Number of requested results {k} is greater than number of elements in index {count}, updating n_results = {count}"
)
k = count
# Overquery by updated and deleted elements layered on the index because they may
# hide the real nearest neighbors in the hnsw index
hnsw_k = k + self._curr_batch.update_count + self._curr_batch.delete_count
# self._id_to_label contains the ids of the elements in the hnsw index
# so its length is the number of elements in the hnsw index
if hnsw_k > len(self._id_to_label):
hnsw_k = len(self._id_to_label)
hnsw_query = VectorQuery(
vectors=query["vectors"],
k=hnsw_k,
allowed_ids=query["allowed_ids"],
include_embeddings=query["include_embeddings"],
options=query["options"],
request_version_context=query["request_version_context"],
)
# For each query vector, we want to take the top k results from the
# combined results of the brute force and hnsw index
results: List[List[VectorQueryResult]] = []
self._brute_force_index = cast(BruteForceIndex, self._brute_force_index)
with ReadRWLock(self._lock):
bf_results = self._brute_force_index.query(query)
hnsw_results = super().query_vectors(hnsw_query)
for i in range(len(query["vectors"])):
# Merge results into a single list of size k
bf_pointer: int = 0
hnsw_pointer: int = 0
curr_bf_result: Sequence[VectorQueryResult] = bf_results[i]
curr_hnsw_result: Sequence[VectorQueryResult] = hnsw_results[i]
# Filter deleted results that haven't yet been removed from the persisted index
curr_hnsw_result = [
x
for x in curr_hnsw_result
if not self._curr_batch.is_deleted(x["id"])
]
curr_results: List[VectorQueryResult] = []
# In the case where filters cause the number of results to be less than k,
# we set k to be the number of results
total_results = len(curr_bf_result) + len(curr_hnsw_result)
if total_results == 0:
results.append([])
else:
while len(curr_results) < min(k, total_results):
if bf_pointer < len(curr_bf_result) and hnsw_pointer < len(
curr_hnsw_result
):
bf_dist = curr_bf_result[bf_pointer]["distance"]
hnsw_dist = curr_hnsw_result[hnsw_pointer]["distance"]
if bf_dist <= hnsw_dist:
curr_results.append(curr_bf_result[bf_pointer])
bf_pointer += 1
else:
id = curr_hnsw_result[hnsw_pointer]["id"]
# Only add the hnsw result if it is not in the brute force index
if not self._brute_force_index.has_id(id):
curr_results.append(curr_hnsw_result[hnsw_pointer])
hnsw_pointer += 1
else:
break
remaining = min(k, total_results) - len(curr_results)
if remaining > 0 and hnsw_pointer < len(curr_hnsw_result):
for i in range(
hnsw_pointer,
min(len(curr_hnsw_result), hnsw_pointer + remaining),
):
id = curr_hnsw_result[i]["id"]
if not self._brute_force_index.has_id(id):
curr_results.append(curr_hnsw_result[i])
elif remaining > 0 and bf_pointer < len(curr_bf_result):
curr_results.extend(
curr_bf_result[bf_pointer : bf_pointer + remaining]
)
results.append(curr_results)
return results
@trace_method(
"PersistentLocalHnswSegment.reset_state", OpenTelemetryGranularity.ALL
)
@override
def reset_state(self) -> None:
if self._allow_reset:
data_path = self._get_storage_folder()
if os.path.exists(data_path):
self.close_persistent_index()
shutil.rmtree(data_path, ignore_errors=True)
@trace_method("PersistentLocalHnswSegment.delete", OpenTelemetryGranularity.ALL)
@override
def delete(self) -> None:
data_path = self._get_storage_folder()
if os.path.exists(data_path):
self.close_persistent_index()
shutil.rmtree(data_path, ignore_errors=False)
@staticmethod
def get_file_handle_count() -> int:
"""Return how many file handles are used by the index"""
hnswlib_count = hnswlib.Index.file_handle_count
hnswlib_count = cast(int, hnswlib_count)
# One extra for the metadata file
return hnswlib_count + 1 # type: ignore
def open_persistent_index(self) -> None:
"""Open the persistent index"""
if self._index is not None:
self._index.open_file_handles()
@override
def stop(self) -> None:
super().stop()
self.close_persistent_index()
def close_persistent_index(self) -> None:
"""Close the persistent index"""
if self._index is not None:
self._index.close_file_handles()