增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -1,5 +1,5 @@
import pytest
from typing import List, Any, Callable
from typing import List, Any, Callable, Dict
from jsonschema import ValidationError
from unittest.mock import MagicMock, create_autospec
from chromadb.utils.embedding_functions.schemas import (
@@ -7,7 +7,10 @@ from chromadb.utils.embedding_functions.schemas import (
load_schema,
get_available_schemas,
)
from chromadb.utils.embedding_functions import known_embedding_functions
from chromadb.utils.embedding_functions import (
known_embedding_functions,
sparse_known_embedding_functions,
)
from chromadb.api.types import Documents, Embeddings
from pytest import MonkeyPatch
@@ -143,3 +146,306 @@ class TestEmbeddingFunctionSchemas:
if schema.get("additionalProperties", True) is False:
with pytest.raises(ValidationError):
validate_config_schema(test_config, schema_name)
def _create_valid_config_from_schema(
self, schema: Dict[str, Any]
) -> Dict[str, Any]:
"""Create a valid config from a schema by filling in required fields"""
config: Dict[str, Any] = {}
if "required" in schema and "properties" in schema:
for field in schema["required"]:
if field in schema["properties"]:
field_schema = schema["properties"][field]
config[field] = self._get_value_from_field_schema(field_schema)
return config
def _get_value_from_field_schema(self, field_schema: Dict[str, Any]) -> Any:
"""Get a valid value from a field schema"""
# Handle enums - use first enum value
if "enum" in field_schema:
return field_schema["enum"][0]
# Handle type (could be a list or single value)
field_type = field_schema.get("type")
if field_type is None:
return "dummy" # Fallback if no type specified
if isinstance(field_type, list):
# If null is in the type list, prefer non-null type
non_null_types = [t for t in field_type if t != "null"]
field_type = non_null_types[0] if non_null_types else field_type[0]
if field_type == "object":
# Handle nested objects
nested_config = {}
if "properties" in field_schema:
nested_required = field_schema.get("required", [])
for prop in nested_required:
if prop in field_schema["properties"]:
nested_config[prop] = self._get_value_from_field_schema(
field_schema["properties"][prop]
)
return nested_config if nested_config else {}
if field_type == "array":
# Return empty array for arrays
return []
# Use the existing dummy value method for primitive types
return self._get_dummy_value(field_type)
def _has_custom_validation(self, ef_class: Any) -> bool:
"""Check if validate_config actually validates (not just base implementation)"""
try:
# Try with an obviously invalid config - if it doesn't raise, it's base implementation
invalid_config = {"__invalid_test_config__": True}
try:
ef_class.validate_config(invalid_config)
# If we get here without exception, it's using base implementation
return False
except (ValidationError, ValueError, FileNotFoundError):
# If it raises any validation-related error, it's actually validating
return True
except Exception:
# Any other exception means it's trying to validate (e.g., schema not found)
return True
def _setup_env_vars_for_ef(
self, ef_name: str, mock_common_deps: MonkeyPatch
) -> None:
"""Set up environment variables needed for embedding function instantiation"""
# Map of embedding function names to their default API key environment variable names
api_key_env_vars = {
"cohere": "CHROMA_COHERE_API_KEY",
"openai": "CHROMA_OPENAI_API_KEY",
"huggingface": "CHROMA_HUGGINGFACE_API_KEY",
"huggingface_server": "CHROMA_HUGGINGFACE_API_KEY",
"google_palm": "CHROMA_GOOGLE_PALM_API_KEY",
"google_generative_ai": "CHROMA_GOOGLE_GENAI_API_KEY",
"google_vertex": "CHROMA_GOOGLE_VERTEX_API_KEY",
"jina": "CHROMA_JINA_API_KEY",
"mistral": "MISTRAL_API_KEY",
"morph": "MORPH_API_KEY",
"voyageai": "CHROMA_VOYAGE_API_KEY",
"cloudflare_workers_ai": "CHROMA_CLOUDFLARE_API_KEY",
"together_ai": "CHROMA_TOGETHER_AI_API_KEY",
"baseten": "CHROMA_BASETEN_API_KEY",
"roboflow": "CHROMA_ROBOFLOW_API_KEY",
"amazon_bedrock": "AWS_ACCESS_KEY_ID", # AWS uses different env vars
"chroma-cloud-qwen": "CHROMA_API_KEY",
# Sparse embedding functions
"chroma-cloud-splade": "CHROMA_API_KEY",
}
# Set API key environment variable if needed
if ef_name in api_key_env_vars:
mock_common_deps.setenv(api_key_env_vars[ef_name], "test-api-key")
# Special cases that need additional environment variables
if ef_name == "amazon_bedrock":
mock_common_deps.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
mock_common_deps.setenv("AWS_REGION", "us-east-1")
def _create_ef_instance(
self, ef_name: str, ef_class: Any, mock_common_deps: MonkeyPatch
) -> Any:
"""Create an embedding function instance, handling special cases"""
# Set up environment variables first
self._setup_env_vars_for_ef(ef_name, mock_common_deps)
# Mock missing modules that are imported inside __init__ methods
import sys
# Create mock modules
mock_pil = MagicMock()
mock_pil_image = MagicMock()
mock_google_genai = MagicMock()
mock_vertexai = MagicMock()
mock_vertexai_lm = MagicMock()
mock_boto3 = MagicMock()
mock_jina = MagicMock()
mock_mistralai = MagicMock()
# Mock boto3.Session for amazon_bedrock
mock_boto3_session = MagicMock()
mock_session_instance = MagicMock()
mock_session_instance.region_name = "us-east-1"
mock_session_instance.profile_name = None
mock_session_instance.client.return_value = MagicMock()
mock_boto3_session.return_value = mock_session_instance
mock_boto3.Session = mock_boto3_session
# Mock vertexai.init and TextEmbeddingModel
mock_text_embedding_model = MagicMock()
mock_text_embedding_model.from_pretrained.return_value = MagicMock()
mock_vertexai_lm.TextEmbeddingModel = mock_text_embedding_model
mock_vertexai.language_models = mock_vertexai_lm
mock_vertexai.init = MagicMock()
# Mock google.generativeai - need to set up google module first
mock_google = MagicMock()
mock_google_genai.configure = MagicMock() # For palm.configure()
mock_google_genai.GenerativeModel = MagicMock(return_value=MagicMock())
mock_google.generativeai = mock_google_genai
# Mock jina Client
mock_jina.Client = MagicMock()
# Mock mistralai
mock_mistral_client = MagicMock()
mock_mistral_client.return_value.embeddings.create.return_value.data = [
MagicMock(embedding=[0.1, 0.2, 0.3])
]
mock_mistralai.Mistral = mock_mistral_client
# Add missing modules to sys.modules using monkeypatch
modules_to_mock = {
"PIL": mock_pil,
"PIL.Image": mock_pil_image,
"google": mock_google,
"google.generativeai": mock_google_genai,
"vertexai": mock_vertexai,
"vertexai.language_models": mock_vertexai_lm,
"boto3": mock_boto3,
"jina": mock_jina,
"mistralai": mock_mistralai,
}
for module_name, mock_module in modules_to_mock.items():
mock_common_deps.setitem(sys.modules, module_name, mock_module)
# Special cases that need additional arguments
if ef_name == "cloudflare_workers_ai":
return ef_class(
model_name="test-model",
account_id="test-account-id",
)
elif ef_name == "baseten":
# Baseten needs api_key explicitly passed even with env var
return ef_class(
api_key="test-api-key",
api_base="https://test.api.baseten.co",
)
elif ef_name == "amazon_bedrock":
# Amazon Bedrock needs a boto3 session - create a mock session
# boto3 is already mocked in sys.modules above
mock_session = mock_boto3.Session(region_name="us-east-1")
return ef_class(
session=mock_session,
model_name="amazon.titan-embed-text-v1",
)
elif ef_name == "huggingface_server":
return ef_class(url="http://localhost:8080")
elif ef_name == "google_vertex":
return ef_class(project_id="test-project", region="us-central1")
elif ef_name == "mistral":
return ef_class(model="mistral-embed")
elif ef_name == "roboflow":
return ef_class() # No model_name needed
elif ef_name == "chroma-cloud-qwen":
from chromadb.utils.embedding_functions.chroma_cloud_qwen_embedding_function import (
ChromaCloudQwenEmbeddingModel,
)
return ef_class(
model=ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
task="nl_to_code",
)
else:
# Try with no args first
try:
return ef_class()
except Exception:
# If that fails, try with common minimal args
return ef_class(model_name="test-model")
@pytest.mark.parametrize("ef_name", get_embedding_function_names())
def test_validate_config_with_schema(
self,
ef_name: str,
mock_embeddings: Callable[[Documents], Embeddings],
mock_common_deps: MonkeyPatch,
) -> None:
"""Test that validate_config works correctly with actual configs from embedding functions"""
ef_class = known_embedding_functions[ef_name]
# Skip if the embedding function doesn't have a validate_config method
if not hasattr(ef_class, "validate_config"):
pytest.skip(f"{ef_name} does not have validate_config method")
# Check if it's callable (static methods are callable on the class)
if not callable(getattr(ef_class, "validate_config", None)):
pytest.skip(f"{ef_name} validate_config is not callable")
# Skip if using base implementation (doesn't actually validate)
if not self._has_custom_validation(ef_class):
pytest.skip(
f"{ef_name} uses base validate_config implementation (no validation)"
)
# Create a real instance to get the actual config
# We'll mock __call__ to avoid needing to actually generate embeddings
try:
ef_instance = self._create_ef_instance(ef_name, ef_class, mock_common_deps)
except Exception as e:
pytest.skip(
f"{ef_name} requires arguments that we cannot provide without external deps: {e}"
)
# Mock only __call__ to avoid needing to actually generate embeddings
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
# Get the actual config from the embedding function (this uses the real get_config method)
config = ef_instance.get_config()
# Filter out None values - optional fields with None shouldn't be included in validation
# This matches common JSON schema practice where optional fields are omitted rather than null
config = {k: v for k, v in config.items() if v is not None}
# Validate the actual config using the embedding function's validate_config method
ef_class.validate_config(config)
def test_validate_config_sparse_embedding_functions(
self,
mock_embeddings: Callable[[Documents], Embeddings],
mock_common_deps: MonkeyPatch,
) -> None:
"""Test validate_config for sparse embedding functions with actual configs"""
for ef_name, ef_class in sparse_known_embedding_functions.items():
# Skip if the embedding function doesn't have a validate_config method
if not hasattr(ef_class, "validate_config"):
continue
# Check if it's callable (static methods are callable on the class)
if not callable(getattr(ef_class, "validate_config", None)):
continue
# Skip if using base implementation (doesn't actually validate)
if not self._has_custom_validation(ef_class):
continue
# Create a real instance to get the actual config
# We'll mock __call__ to avoid needing to actually generate embeddings
try:
ef_instance = self._create_ef_instance(
ef_name, ef_class, mock_common_deps
)
except Exception:
continue # Skip if we can't create instance
# Mock only __call__ to avoid needing to actually generate embeddings
mock_call = MagicMock(return_value=mock_embeddings(["test"]))
mock_common_deps.setattr(ef_instance, "__call__", mock_call)
# Get the actual config from the embedding function (this uses the real get_config method)
config = ef_instance.get_config()
# Filter out None values - optional fields with None shouldn't be included in validation
# This matches common JSON schema practice where optional fields are omitted rather than null
config = {k: v for k, v in config.items() if v is not None}
# Validate the actual config using the embedding function's validate_config method
ef_class.validate_config(config)