chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
# This file is used by test_create_http_client.py to test the initialization
|
||||
# of an HttpClient class with auth settings.
|
||||
#
|
||||
# See https://github.com/chroma-core/chroma/issues/1554
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
import sys
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
chromadb.HttpClient(
|
||||
host="localhost",
|
||||
port=8000,
|
||||
settings=Settings(
|
||||
chroma_client_auth_provider="chromadb.auth.basic_authn.BasicAuthClientProvider",
|
||||
chroma_client_auth_credentials="admin:testDb@home2",
|
||||
),
|
||||
)
|
||||
except ValueError:
|
||||
# We don't expect to be able to connect to Chroma. We just want to make sure
|
||||
# there isn't an ImportError.
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,334 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from chromadb import CloudClient
|
||||
from chromadb.errors import ChromaAuthError, NotFoundError
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.types import Tenant, Database
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
def test_valid_key() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity, patch(
|
||||
"chromadb.api.client.AdminClient.get_tenant"
|
||||
) as mock_get_tenant, patch(
|
||||
"chromadb.api.client.AdminClient.get_database"
|
||||
) as mock_get_database, patch(
|
||||
"chromadb.api.fastapi.FastAPI.heartbeat"
|
||||
) as mock_heartbeat:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="default_tenant", databases=["testdb"]
|
||||
)
|
||||
mock_get_tenant.return_value = Tenant(name="default_tenant")
|
||||
mock_get_database.return_value = Database(
|
||||
id=uuid4(), name="testdb", tenant="default_tenant"
|
||||
)
|
||||
mock_heartbeat.return_value = 1234567890
|
||||
|
||||
client = CloudClient(database="testdb", api_key="valid_token")
|
||||
|
||||
assert client.get_user_identity().user_id == "test_user"
|
||||
assert client.get_user_identity().tenant == "default_tenant"
|
||||
assert client.get_user_identity().databases == ["testdb"]
|
||||
|
||||
settings = client.get_settings()
|
||||
assert settings.chroma_client_auth_credentials == "valid_token"
|
||||
assert (
|
||||
settings.chroma_client_auth_provider
|
||||
== "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
|
||||
assert client.heartbeat() == 1234567890
|
||||
|
||||
|
||||
def test_invalid_key() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.side_effect = ChromaAuthError("Authentication failed")
|
||||
|
||||
with pytest.raises(ChromaAuthError):
|
||||
CloudClient(database="testdb", api_key="invalid_token")
|
||||
|
||||
|
||||
# Scoped API key to 1 database tests
|
||||
def test_scoped_api_key_to_single_db_with_api_key_only() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity, patch(
|
||||
"chromadb.api.client.AdminClient.get_tenant"
|
||||
) as mock_get_tenant, patch(
|
||||
"chromadb.api.client.AdminClient.get_database"
|
||||
) as mock_get_database:
|
||||
# mock single db scoped api key
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="123-456-789", databases=["right-db"]
|
||||
)
|
||||
mock_get_tenant.return_value = Tenant(name="123-456-789")
|
||||
mock_get_database.return_value = Database(
|
||||
id=uuid4(), name="right-db", tenant="123-456-789"
|
||||
)
|
||||
|
||||
client = CloudClient(api_key="valid_token")
|
||||
|
||||
# should resolve to single db
|
||||
assert client.database == "right-db"
|
||||
assert client.tenant == "123-456-789"
|
||||
|
||||
|
||||
def test_scoped_api_key_to_single_db_with_correct_tenant() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity, patch(
|
||||
"chromadb.api.client.AdminClient.get_tenant"
|
||||
) as mock_get_tenant, patch(
|
||||
"chromadb.api.client.AdminClient.get_database"
|
||||
) as mock_get_database:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="123-456-789", databases=["right-db"]
|
||||
)
|
||||
mock_get_tenant.return_value = Tenant(name="123-456-789")
|
||||
mock_get_database.return_value = Database(
|
||||
id=uuid4(), name="right-db", tenant="123-456-789"
|
||||
)
|
||||
|
||||
client = CloudClient(tenant="123-456-789", api_key="valid_token")
|
||||
|
||||
assert client.tenant == "123-456-789"
|
||||
assert client.database == "right-db"
|
||||
|
||||
|
||||
def test_scoped_api_key_to_single_db_with_correct_db() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity, patch(
|
||||
"chromadb.api.client.AdminClient.get_tenant"
|
||||
) as mock_get_tenant, patch(
|
||||
"chromadb.api.client.AdminClient.get_database"
|
||||
) as mock_get_database:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="123-456-789", databases=["right-db"]
|
||||
)
|
||||
mock_get_tenant.return_value = Tenant(name="123-456-789")
|
||||
mock_get_database.return_value = Database(
|
||||
id=uuid4(), name="right-db", tenant="123-456-789"
|
||||
)
|
||||
|
||||
client = CloudClient(database="right-db", api_key="valid_token")
|
||||
|
||||
assert client.tenant == "123-456-789"
|
||||
assert client.database == "right-db"
|
||||
|
||||
|
||||
def test_scoped_api_key_to_single_db_with_correct_tenant_and_db() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity, patch(
|
||||
"chromadb.api.client.AdminClient.get_tenant"
|
||||
) as mock_get_tenant, patch(
|
||||
"chromadb.api.client.AdminClient.get_database"
|
||||
) as mock_get_database:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="123-456-789", databases=["right-db"]
|
||||
)
|
||||
mock_get_tenant.return_value = Tenant(name="123-456-789")
|
||||
mock_get_database.return_value = Database(
|
||||
id=uuid4(), name="right-db", tenant="123-456-789"
|
||||
)
|
||||
|
||||
client = CloudClient(
|
||||
tenant="123-456-789", database="right-db", api_key="valid_token"
|
||||
)
|
||||
|
||||
assert client.tenant == "123-456-789"
|
||||
assert client.database == "right-db"
|
||||
|
||||
|
||||
def test_scoped_api_key_to_single_db_with_wrong_tenant() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="123-456-789", databases=["right-db"]
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ChromaAuthError,
|
||||
match="Tenant wrong-tenant does not match 123-456-789 from the server. Are you sure the tenant is correct?",
|
||||
):
|
||||
CloudClient(tenant="wrong-tenant", api_key="valid_token")
|
||||
|
||||
|
||||
def test_scoped_api_key_to_single_db_with_wrong_database() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="123-456-789", databases=["right-db"]
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ChromaAuthError,
|
||||
match="Database wrong-db does not match right-db from the server. Are you sure the database is correct?",
|
||||
):
|
||||
CloudClient(database="wrong-db", api_key="valid_token")
|
||||
|
||||
|
||||
def test_scoped_api_key_to_single_db_with_wrong_api_key() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.side_effect = ChromaAuthError("Permission denied.")
|
||||
|
||||
with pytest.raises(ChromaAuthError, match="Permission denied."):
|
||||
CloudClient(database="right-db", api_key="wrong-api-key")
|
||||
|
||||
|
||||
# Scoped API key to multiple databases tests
|
||||
def test_scoped_api_key_to_multiple_dbs_with_wrong_tenant() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user",
|
||||
tenant="123-456-789",
|
||||
databases=["right-db", "another-db"],
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ChromaAuthError,
|
||||
match="Tenant wrong-tenant does not match 123-456-789 from the server. Are you sure the tenant is correct?",
|
||||
):
|
||||
CloudClient(
|
||||
tenant="wrong-tenant", database="right-db", api_key="valid_token"
|
||||
)
|
||||
|
||||
|
||||
def test_scoped_api_key_to_multiple_dbs_with_correct_tenant_and_db() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity, patch(
|
||||
"chromadb.api.client.AdminClient.get_tenant"
|
||||
) as mock_get_tenant, patch(
|
||||
"chromadb.api.client.AdminClient.get_database"
|
||||
) as mock_get_database:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user",
|
||||
tenant="123-456-789",
|
||||
databases=["right-db", "another-db"],
|
||||
)
|
||||
mock_get_tenant.return_value = Tenant(name="123-456-789")
|
||||
mock_get_database.return_value = Database(
|
||||
id=uuid4(), name="right-db", tenant="123-456-789"
|
||||
)
|
||||
|
||||
client = CloudClient(
|
||||
tenant="123-456-789", database="right-db", api_key="valid_token"
|
||||
)
|
||||
|
||||
assert client.tenant == "123-456-789"
|
||||
assert client.database == "right-db"
|
||||
|
||||
|
||||
def test_scoped_api_key_to_multiple_dbs_with_nonexistent_database() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity, patch(
|
||||
"chromadb.api.client.AdminClient.get_tenant"
|
||||
) as mock_get_tenant, patch(
|
||||
"chromadb.api.client.AdminClient.get_database"
|
||||
) as mock_get_database:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user",
|
||||
tenant="123-456-789",
|
||||
databases=["right-db", "another-db"],
|
||||
)
|
||||
mock_get_tenant.return_value = Tenant(name="123-456-789")
|
||||
mock_get_database.side_effect = NotFoundError(
|
||||
"Database [wrong-db] not found. Are you sure it exists?"
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
NotFoundError,
|
||||
match="Database \\[wrong-db\\] not found. Are you sure it exists?",
|
||||
):
|
||||
CloudClient(database="wrong-db", api_key="valid_token")
|
||||
|
||||
|
||||
def test_scoped_api_key_to_multiple_dbs_with_api_key_only() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user",
|
||||
tenant="123-456-789",
|
||||
databases=["right-db", "another-db"],
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ChromaAuthError,
|
||||
match="Could not determine a database name from the current authentication method. Please provide a database name.",
|
||||
):
|
||||
CloudClient(api_key="valid_token")
|
||||
|
||||
|
||||
# Unscoped API key tests
|
||||
def test_api_key_with_unscoped_tenant() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="*", databases=["right-db"]
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ChromaAuthError,
|
||||
match="Could not determine a tenant from the current authentication method. Please provide a tenant.",
|
||||
):
|
||||
CloudClient(api_key="valid_token")
|
||||
|
||||
|
||||
def test_api_key_with_unscoped_db() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="123-456-789", databases=["*"]
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ChromaAuthError,
|
||||
match="Could not determine a database name from the current authentication method. Please provide a database name.",
|
||||
):
|
||||
CloudClient(api_key="valid_token")
|
||||
|
||||
|
||||
def test_api_key_with_no_db_access() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant="123-456-789", databases=[]
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ChromaAuthError,
|
||||
match="Could not determine a database name from the current authentication method. Please provide a database name.",
|
||||
):
|
||||
CloudClient(api_key="valid_token")
|
||||
|
||||
|
||||
def test_api_key_with_no_tenant_access() -> None:
|
||||
with patch(
|
||||
"chromadb.api.fastapi.FastAPI.get_user_identity"
|
||||
) as mock_get_user_identity:
|
||||
mock_get_user_identity.return_value = UserIdentity(
|
||||
user_id="test_user", tenant=None, databases=["right-db"]
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ChromaAuthError,
|
||||
match="Could not determine a tenant from the current authentication method. Please provide a tenant.",
|
||||
):
|
||||
CloudClient(api_key="valid_token")
|
||||
@@ -0,0 +1,15 @@
|
||||
import subprocess
|
||||
|
||||
# Needs to be a module, not a file, so that local imports work.
|
||||
TEST_MODULE = "chromadb.test.client.create_http_client_with_basic_auth"
|
||||
|
||||
|
||||
def test_main() -> None:
|
||||
# This is the only way to test what we want to test: pytest does a bunch of
|
||||
# importing and other module stuff in the background, so we need a clean
|
||||
# python process to make sure we're not circular-importing.
|
||||
#
|
||||
# See https://github.com/chroma-core/chroma/issues/1554
|
||||
|
||||
res = subprocess.run(["python", "-m", TEST_MODULE])
|
||||
assert res.returncode == 0
|
||||
@@ -0,0 +1,184 @@
|
||||
import pytest
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
|
||||
from chromadb.test.conftest import ClientFactories
|
||||
from chromadb.errors import InvalidArgumentError
|
||||
from chromadb.api.types import GetResult
|
||||
from typing import Dict, Any
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_database_tenant_collections(client_factories: ClientFactories) -> None:
|
||||
client = client_factories.create_client_from_system()
|
||||
client.reset()
|
||||
# Create a new database in the default tenant
|
||||
admin_client = client_factories.create_admin_client_from_system()
|
||||
admin_client.create_database("test_db")
|
||||
|
||||
# Create collections in this new database
|
||||
client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
|
||||
client.create_collection("collection", metadata={"database": "test_db"})
|
||||
|
||||
# Create collections in the default database
|
||||
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
|
||||
client.create_collection("collection", metadata={"database": DEFAULT_DATABASE})
|
||||
|
||||
# List collections in the default database
|
||||
collections = client.list_collections()
|
||||
assert len(collections) == 1
|
||||
assert collections[0].name == "collection"
|
||||
collection = client.get_collection(collections[0].name)
|
||||
assert collection.metadata == {"database": DEFAULT_DATABASE}
|
||||
|
||||
# List collections in the new database
|
||||
client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
|
||||
collections = client.list_collections()
|
||||
assert len(collections) == 1
|
||||
assert collections[0].metadata == {"database": "test_db"}
|
||||
|
||||
# Update the metadata in both databases to different values
|
||||
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
|
||||
client.list_collections()[0].modify(metadata={"database": "default2"})
|
||||
|
||||
client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
|
||||
client.list_collections()[0].modify(metadata={"database": "test_db2"})
|
||||
|
||||
# Validate that the metadata was updated
|
||||
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
|
||||
collections = client.list_collections()
|
||||
assert len(collections) == 1
|
||||
assert collections[0].metadata == {"database": "default2"}
|
||||
|
||||
client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
|
||||
collections = client.list_collections()
|
||||
assert len(collections) == 1
|
||||
assert collections[0].metadata == {"database": "test_db2"}
|
||||
|
||||
# Delete the collections and make sure databases are isolated
|
||||
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
|
||||
client.delete_collection("collection")
|
||||
|
||||
collections = client.list_collections()
|
||||
assert len(collections) == 0
|
||||
|
||||
client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
|
||||
collections = client.list_collections()
|
||||
assert len(collections) == 1
|
||||
|
||||
client.delete_collection("collection")
|
||||
collections = client.list_collections()
|
||||
assert len(collections) == 0
|
||||
|
||||
|
||||
def test_database_collections_add(client_factories: ClientFactories) -> None:
|
||||
client = client_factories.create_client_from_system()
|
||||
client.reset()
|
||||
|
||||
# Create a new database in the default tenant
|
||||
admin_client = client_factories.create_admin_client_from_system()
|
||||
admin_client.create_database("test_db")
|
||||
|
||||
# Create collections in this new database
|
||||
client.set_database(database="test_db")
|
||||
coll_new = client.create_collection("collection_new")
|
||||
|
||||
# Create collections in the default database
|
||||
client.set_database(database=DEFAULT_DATABASE)
|
||||
coll_default = client.create_collection("collection_default")
|
||||
|
||||
records_new = {
|
||||
"ids": ["a", "b", "c"],
|
||||
"embeddings": [[1.0, 2.0, 3.0] for _ in range(3)],
|
||||
"documents": ["a", "b", "c"],
|
||||
}
|
||||
|
||||
records_default = {
|
||||
"ids": ["c", "d", "e"],
|
||||
"embeddings": [[4.0, 5.0, 6.0] for _ in range(3)],
|
||||
"documents": ["c", "d", "e"],
|
||||
}
|
||||
|
||||
# Add to the new coll
|
||||
coll_new.add(**records_new) # type: ignore
|
||||
|
||||
# Add to the default coll
|
||||
coll_default.add(**records_default) # type: ignore
|
||||
|
||||
# Make sure the collections are isolated
|
||||
res = coll_new.get(include=["embeddings", "documents"]) # type: ignore
|
||||
assert res["ids"] == records_new["ids"]
|
||||
check_embeddings(res=res, records=records_new)
|
||||
assert res["documents"] == records_new["documents"]
|
||||
|
||||
res = coll_default.get(include=["embeddings", "documents"]) # type: ignore
|
||||
assert res["ids"] == records_default["ids"]
|
||||
check_embeddings(res=res, records=records_default)
|
||||
assert res["documents"] == records_default["documents"]
|
||||
|
||||
|
||||
def test_tenant_collections_add(client_factories: ClientFactories) -> None:
|
||||
client = client_factories.create_client_from_system()
|
||||
client.reset()
|
||||
|
||||
# Create two databases with same name in different tenants
|
||||
admin_client = client_factories.create_admin_client_from_system()
|
||||
admin_client.create_tenant("test_tenant1")
|
||||
admin_client.create_tenant("test_tenant2")
|
||||
admin_client.create_database("test_db", tenant="test_tenant1")
|
||||
admin_client.create_database("test_db", tenant="test_tenant2")
|
||||
|
||||
# Create collections in each database with same name
|
||||
client.set_tenant(tenant="test_tenant1", database="test_db")
|
||||
coll_tenant1 = client.create_collection("collection")
|
||||
client.set_tenant(tenant="test_tenant2", database="test_db")
|
||||
coll_tenant2 = client.create_collection("collection")
|
||||
|
||||
records_tenant1 = {
|
||||
"ids": ["a", "b", "c"],
|
||||
"embeddings": [[1.0, 2.0, 3.0] for _ in range(3)],
|
||||
"documents": ["a", "b", "c"],
|
||||
}
|
||||
|
||||
records_tenant2 = {
|
||||
"ids": ["c", "d", "e"],
|
||||
"embeddings": [[4.0, 5.0, 6.0] for _ in range(3)],
|
||||
"documents": ["c", "d", "e"],
|
||||
}
|
||||
|
||||
# Add to the tenant1 coll
|
||||
coll_tenant1.add(**records_tenant1) # type: ignore
|
||||
|
||||
# Add to the tenant2 coll
|
||||
coll_tenant2.add(**records_tenant2) # type: ignore
|
||||
|
||||
# Make sure the collections are isolated
|
||||
res = coll_tenant1.get(include=["embeddings", "documents"]) # type: ignore
|
||||
assert res["ids"] == records_tenant1["ids"]
|
||||
check_embeddings(res=res, records=records_tenant1)
|
||||
assert res["documents"] == records_tenant1["documents"]
|
||||
|
||||
res = coll_tenant2.get(include=["embeddings", "documents"]) # type: ignore
|
||||
assert res["ids"] == records_tenant2["ids"]
|
||||
check_embeddings(res=res, records=records_tenant2)
|
||||
assert res["documents"] == records_tenant2["documents"]
|
||||
|
||||
|
||||
def test_min_len_name(client_factories: ClientFactories) -> None:
|
||||
client = client_factories.create_client_from_system()
|
||||
client.reset()
|
||||
|
||||
# Create a new database in the default tenant with a name of length 1
|
||||
# and expect an error
|
||||
admin_client = client_factories.create_admin_client_from_system()
|
||||
with pytest.raises((Exception, InvalidArgumentError)):
|
||||
admin_client.create_database("a")
|
||||
|
||||
# Create a tenant with a name of length 1 and expect an error
|
||||
with pytest.raises((Exception, InvalidArgumentError)):
|
||||
admin_client.create_tenant("a")
|
||||
|
||||
|
||||
def check_embeddings(res: GetResult, records: Dict[str, Any]) -> None:
|
||||
if res["embeddings"] is not None:
|
||||
assert np.array_equal(res["embeddings"], records["embeddings"])
|
||||
else:
|
||||
assert records["embeddings"] is None
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import Dict
|
||||
from fastapi import HTTPException
|
||||
from overrides import override
|
||||
from chromadb.auth import (
|
||||
AuthzAction,
|
||||
AuthzResource,
|
||||
ServerAuthenticationProvider,
|
||||
ServerAuthorizationProvider,
|
||||
UserIdentity,
|
||||
)
|
||||
from chromadb.config import System
|
||||
|
||||
|
||||
class ExampleAuthenticationProvider(ServerAuthenticationProvider):
|
||||
"""In practice the tenant would likely be resolved from some other opaque value (e.g. key/token). Here, it's just passed directly as a header for simplicity."""
|
||||
|
||||
@override
|
||||
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
|
||||
return UserIdentity(
|
||||
user_id="test",
|
||||
tenant=headers.get("x-tenant", None),
|
||||
)
|
||||
|
||||
|
||||
class ExampleAuthorizationProvider(ServerAuthorizationProvider):
|
||||
"""A simple authz provider that asserts the user's tenant matches the resource's tenant."""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
|
||||
@override
|
||||
def authorize_or_raise(
|
||||
self, user: UserIdentity, action: AuthzAction, resource: AuthzResource
|
||||
) -> None:
|
||||
if user.tenant is None:
|
||||
return
|
||||
|
||||
if action == AuthzAction.RESET:
|
||||
return
|
||||
|
||||
if user.tenant != resource.tenant:
|
||||
raise HTTPException(status_code=403, detail="Unauthorized")
|
||||
@@ -0,0 +1,49 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from chromadb.config import DEFAULT_TENANT
|
||||
from chromadb.test.conftest import ClientFactories
|
||||
|
||||
|
||||
def test_multiple_clients_concurrently(client_factories: ClientFactories) -> None:
|
||||
"""Tests running multiple clients, each against their own database, concurrently."""
|
||||
client = client_factories.create_client()
|
||||
client.reset()
|
||||
admin_client = client_factories.create_admin_client_from_system()
|
||||
admin_client.create_database("test_db")
|
||||
|
||||
CLIENT_COUNT = 50
|
||||
COLLECTION_COUNT = 10
|
||||
|
||||
# Each database will create the same collections by name, with differing metadata
|
||||
databases = [f"db{i}" for i in range(CLIENT_COUNT)]
|
||||
for database in databases:
|
||||
admin_client.create_database(database)
|
||||
|
||||
collections = [f"collection{i}" for i in range(COLLECTION_COUNT)]
|
||||
|
||||
# Create N clients, each on a seperate thread, each with their own database
|
||||
def run_target(n: int) -> None:
|
||||
thread_client = client_factories.create_client(
|
||||
tenant=DEFAULT_TENANT,
|
||||
database=databases[n],
|
||||
settings=client._system.settings,
|
||||
)
|
||||
for collection in collections:
|
||||
thread_client.create_collection(
|
||||
collection, metadata={"database": databases[n]}
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=CLIENT_COUNT) as executor:
|
||||
executor.map(run_target, range(CLIENT_COUNT))
|
||||
executor.shutdown(wait=True)
|
||||
# Create a final client, which will be used to verify the collections were created
|
||||
client = client_factories.create_client(settings=client._system.settings)
|
||||
|
||||
# Verify that the collections were created
|
||||
for database in databases:
|
||||
client.set_database(database)
|
||||
seen_collections = client.list_collections()
|
||||
assert len(seen_collections) == COLLECTION_COUNT
|
||||
for collection in seen_collections:
|
||||
assert collection.name in collections
|
||||
assert collection.metadata == {"database": database}
|
||||
Reference in New Issue
Block a user