增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -23,6 +23,7 @@ from chromadb.utils.embedding_functions.google_embedding_function import (
GooglePalmEmbeddingFunction,
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
GoogleGenaiEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
@@ -68,6 +69,10 @@ from chromadb.utils.embedding_functions.mistral_embedding_function import (
from chromadb.utils.embedding_functions.morph_embedding_function import (
MorphEmbeddingFunction,
)
from chromadb.utils.embedding_functions.nomic_embedding_function import (
NomicEmbeddingFunction,
NomicQueryConfig,
)
from chromadb.utils.embedding_functions.huggingface_sparse_embedding_function import (
HuggingFaceSparseEmbeddingFunction,
)
@@ -98,11 +103,13 @@ _all_classes: Set[str] = {
"GooglePalmEmbeddingFunction",
"GoogleGenerativeAiEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"GoogleGenaiEmbeddingFunction",
"OllamaEmbeddingFunction",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"MistralEmbeddingFunction",
"MorphEmbeddingFunction",
"NomicEmbeddingFunction",
"VoyageAIEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OpenCLIPEmbeddingFunction",
@@ -137,11 +144,13 @@ known_embedding_functions: Dict[str, Type[EmbeddingFunction]] = { # type: ignor
"google_palm": GooglePalmEmbeddingFunction,
"google_generative_ai": GoogleGenerativeAiEmbeddingFunction,
"google_vertex": GoogleVertexEmbeddingFunction,
"google_genai": GoogleGenaiEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"mistral": MistralEmbeddingFunction,
"morph": MorphEmbeddingFunction,
"nomic": NomicEmbeddingFunction,
"voyageai": VoyageAIEmbeddingFunction,
"onnx_mini_lm_l6_v2": ONNXMiniLM_L6_V2,
"open_clip": OpenCLIPEmbeddingFunction,
@@ -259,12 +268,15 @@ __all__ = [
"GooglePalmEmbeddingFunction",
"GoogleGenerativeAiEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"GoogleGenaiEmbeddingFunction",
"OllamaEmbeddingFunction",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"JinaQueryConfig",
"MistralEmbeddingFunction",
"MorphEmbeddingFunction",
"NomicEmbeddingFunction",
"NomicQueryConfig",
"VoyageAIEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OpenCLIPEmbeddingFunction",

View File

@@ -2,6 +2,7 @@ import os
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import Dict, Any, Optional, List
from chromadb.api.types import Space
import warnings
@@ -35,12 +36,16 @@ class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
if os.getenv("BASETEN_API_KEY") is not None:
self.api_key_env_var = "BASETEN_API_KEY"
else:
self.api_key_env_var = api_key_env_var
# Prioritize api_key argument, then environment variable
resolved_api_key = api_key or os.getenv(api_key_env_var)
resolved_api_key = api_key or os.getenv(self.api_key_env_var)
if not resolved_api_key:
raise ValueError(
f"API key not provided and {api_key_env_var} environment variable is not set."
f"API key not provided and {self.api_key_env_var} environment variable is not set."
)
self.api_key = resolved_api_key
if not api_base:
@@ -96,3 +101,16 @@ class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
api_base=api_base,
api_key_env_var=api_key_env_var,
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "baseten")

View File

@@ -231,4 +231,4 @@ class Bm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "bm25")
validate_config_schema(config, "bm25")

View File

@@ -23,12 +23,32 @@ DEFAULT_TOKEN_MAX_LENGTH = 40
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)
class _HashedToken:
__slots__ = ("hash", "label")
def __init__(self, hash: int, label: Optional[str]):
self.hash = hash
self.label = label
def __hash__(self) -> int:
return self.hash
def __eq__(self, other: object) -> bool:
if not isinstance(other, _HashedToken):
return NotImplemented
return self.hash == other.hash
def __lt__(self, other: "_HashedToken") -> bool:
return self.hash < other.hash
class ChromaBm25Config(TypedDict, total=False):
k: float
b: float
avg_doc_length: float
token_max_length: int
stopwords: List[str]
include_tokens: bool
class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
@@ -39,6 +59,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
stopwords: Optional[Iterable[str]] = None,
include_tokens: bool = False,
) -> None:
"""Initialize the BM25 sparse embedding function."""
@@ -46,38 +67,51 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
self.b = float(b)
self.avg_doc_length = float(avg_doc_length)
self.token_max_length = int(token_max_length)
self.include_tokens = bool(include_tokens)
if stopwords is not None:
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
stopword_list: Iterable[str] = self.stopwords
self._stopword_list: Iterable[str] = self.stopwords
else:
self.stopwords = None
stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
self._stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
stemmer = get_english_stemmer()
self._tokenizer = Bm25Tokenizer(stemmer, stopword_list, self.token_max_length)
self._hasher = Murmur3AbsHasher()
def _encode(self, text: str) -> SparseVector:
tokens = self._tokenizer.tokenize(text)
stemmer = get_english_stemmer()
tokenizer = Bm25Tokenizer(stemmer, self._stopword_list, self.token_max_length)
tokens = tokenizer.tokenize(text)
if not tokens:
return SparseVector(indices=[], values=[])
doc_len = float(len(tokens))
counts = Counter(self._hasher.hash(token) for token in tokens)
counts = Counter(
_HashedToken(
self._hasher.hash(token), token if self.include_tokens else None
)
for token in tokens
)
indices = sorted(counts.keys())
sorted_keys = sorted(counts.keys())
indices: List[int] = []
values: List[float] = []
for idx in indices:
tf = float(counts[idx])
labels: Optional[List[str]] = [] if self.include_tokens else None
for key in sorted_keys:
tf = float(counts[key])
denominator = tf + self.k * (
1 - self.b + (self.b * doc_len) / self.avg_doc_length
)
score = tf * (self.k + 1) / denominator
values.append(score)
return SparseVector(indices=indices, values=values)
indices.append(key.hash)
values.append(score)
if labels is not None and key.label is not None:
labels.append(key.label)
return SparseVector(indices=indices, values=values, labels=labels)
def __call__(self, input: Documents) -> SparseVectors:
sparse_vectors: SparseVectors = []
@@ -99,7 +133,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
@staticmethod
def build_from_config(
config: Dict[str, Any]
config: Dict[str, Any],
) -> "SparseEmbeddingFunction[Documents]":
return ChromaBm25EmbeddingFunction(
k=config.get("k", DEFAULT_K),
@@ -107,6 +141,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
stopwords=config.get("stopwords"),
include_tokens=config.get("include_tokens", False),
)
def get_config(self) -> Dict[str, Any]:
@@ -115,6 +150,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
"b": self.b,
"avg_doc_length": self.avg_doc_length,
"token_max_length": self.token_max_length,
"include_tokens": self.include_tokens,
}
if self.stopwords is not None:
@@ -125,7 +161,14 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords"}
mutable_keys = {
"k",
"b",
"avg_doc_length",
"token_max_length",
"stopwords",
"include_tokens",
}
for key in new_config:
if key not in mutable_keys:
raise ValueError(f"Updating '{key}' is not supported for {NAME}")

View File

@@ -1,8 +1,9 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from typing import List, Dict, Any, Union
from typing import List, Dict, Any, Union, Optional
import os
import numpy as np
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.embedding_functions.utils import _get_shared_system_client
from enum import Enum
@@ -32,7 +33,7 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
model: ChromaCloudQwenEmbeddingModel,
task: str,
task: Optional[str],
instructions: ChromaCloudQwenEmbeddingInstructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
api_key_env_var: str = "CHROMA_API_KEY",
):
@@ -41,7 +42,8 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
Args:
model (ChromaCloudQwenEmbeddingModel): The specific Qwen model to use for embeddings.
task (str): The task for which embeddings are being generated.
task (str, optional): The task for which embeddings are being generated. If None or empty,
empty instructions will be used for both documents and queries.
instructions (ChromaCloudQwenEmbeddingInstructions, optional): A dictionary containing
custom instructions to use for the specified Qwen model. Defaults to CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS.
api_key_env_var (str, optional): Environment variable name that contains your API key.
@@ -55,9 +57,18 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
)
self.api_key_env_var = api_key_env_var
# First, try to get API key from environment variable
self.api_key = os.getenv(api_key_env_var)
# If not found in env var, try to get it from existing client instances
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
SharedSystemClient = _get_shared_system_client()
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
# Raise error if still no API key found
if not self.api_key:
raise ValueError(
f"API key not found in environment variable {api_key_env_var} "
f"or in any existing client instances"
)
self.model = model
self.task = task
@@ -102,10 +113,14 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
if not input:
return []
payload: Dict[str, Union[str, Documents]] = {
"instructions": self.instructions[self.task][
instruction = ""
if self.task and self.task in self.instructions:
instruction = self.instructions[self.task][
ChromaCloudQwenEmbeddingTarget.DOCUMENTS
],
]
payload: Dict[str, Union[str, Documents]] = {
"instructions": instruction,
"texts": input,
}
@@ -120,10 +135,14 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
if not input:
return []
payload: Dict[str, Union[str, Documents]] = {
"instructions": self.instructions[self.task][
instruction = ""
if self.task and self.task in self.instructions:
instruction = self.instructions[self.task][
ChromaCloudQwenEmbeddingTarget.QUERY
],
]
payload: Dict[str, Union[str, Documents]] = {
"instructions": instruction,
"texts": input,
}

View File

@@ -1,15 +1,16 @@
from chromadb.api.types import (
SparseEmbeddingFunction,
SparseVector,
SparseVectors,
Documents,
)
from typing import Dict, Any
from typing import Dict, Any, List, Optional
from enum import Enum
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
from chromadb.base_types import SparseVector
import os
from typing import Union
from chromadb.utils.embedding_functions.utils import _get_shared_system_client
class ChromaCloudSpladeEmbeddingModel(Enum):
@@ -21,6 +22,7 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
self,
api_key_env_var: str = "CHROMA_API_KEY",
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
include_tokens: bool = False,
):
"""
Initialize the ChromaCloudSpladeEmbeddingFunction.
@@ -36,12 +38,20 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
self.api_key_env_var = api_key_env_var
# First, try to get API key from environment variable
self.api_key = os.getenv(self.api_key_env_var)
# If not found in env var, try to get it from existing client instances
if not self.api_key:
SharedSystemClient = _get_shared_system_client()
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
# Raise error if still no API key found
if not self.api_key:
raise ValueError(
f"API key not found in environment variable {self.api_key_env_var}"
f"API key not found in environment variable {self.api_key_env_var} "
f"or in any existing client instances"
)
self.model = model
self.include_tokens = bool(include_tokens)
self._api_url = "https://embed.trychroma.com/embed_sparse"
self._session = httpx.Client()
self._session.headers.update(
@@ -80,6 +90,7 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
"texts": list(input),
"task": "",
"target": "",
"fetch_tokens": "true" if self.include_tokens is True else "false",
}
try:
@@ -113,13 +124,17 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
if isinstance(emb, dict):
indices = emb.get("indices", [])
values = emb.get("values", [])
raw_labels = emb.get("labels") if self.include_tokens else None
labels: Optional[List[str]] = raw_labels if raw_labels else None
else:
# Already a SparseVector, extract its data
assert isinstance(emb, SparseVector)
indices = emb.indices
values = emb.values
labels = emb.labels if self.include_tokens else None
normalized_vectors.append(
normalize_sparse_vector(indices=indices, values=values)
normalize_sparse_vector(indices=indices, values=values, labels=labels)
)
return normalized_vectors
@@ -141,18 +156,25 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
return ChromaCloudSpladeEmbeddingFunction(
api_key_env_var=api_key_env_var,
model=ChromaCloudSpladeEmbeddingModel(model),
include_tokens=config.get("include_tokens", False),
)
def get_config(self) -> Dict[str, Any]:
return {"api_key_env_var": self.api_key_env_var, "model": self.model.value}
return {
"api_key_env_var": self.api_key_env_var,
"model": self.model.value,
"include_tokens": self.include_tokens,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model" in new_config:
raise ValueError(
"model cannot be changed after the embedding function has been initialized"
)
immutable_keys = {"include_tokens", "model"}
for key in immutable_keys:
if key in new_config and new_config[key] != old_config.get(key):
raise ValueError(
f"Updating '{key}' is not supported for chroma-cloud-splade"
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:

View File

@@ -53,12 +53,19 @@ class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
)
self.model_name = model_name
self.account_id = account_id
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("CLOUDFLARE_API_KEY") is not None:
self.api_key_env_var = "CLOUDFLARE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
self.gateway_id = gateway_id
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
if self.gateway_id:
self._api_url = f"{GATEWAY_BASE_URL}/{self.account_id}/{self.gateway_id}/workers-ai/{self.model_name}"

View File

@@ -43,10 +43,16 @@ class CohereEmbeddingFunction(EmbeddingFunction[Embeddable]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("COHERE_API_KEY") is not None:
self.api_key_env_var = "COHERE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name

View File

@@ -7,6 +7,150 @@ from chromadb.utils.embedding_functions.schemas import validate_config_schema
import warnings
class GoogleGenaiEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
model_name: str,
vertexai: Optional[bool] = None,
project: Optional[str] = None,
location: Optional[str] = None,
api_key_env_var: str = "GOOGLE_API_KEY",
):
"""
Initialize the GoogleGenaiEmbeddingFunction.
Args:
model_name (str): The name of the model to use for text embeddings.
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google GenAI API.
Defaults to "GOOGLE_API_KEY".
"""
try:
import google.genai as genai
except ImportError:
raise ValueError(
"The google-genai python package is not installed. Please install it with `pip install google-genai`"
)
self.model_name = model_name
self.api_key_env_var = api_key_env_var
self.vertexai = vertexai
self.project = project
self.location = location
self.api_key = os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.client = genai.Client(
api_key=self.api_key, vertexai=vertexai, project=project, location=location
)
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.
Args:
input: Documents or images to generate embeddings for.
Returns:
Embeddings for the documents.
"""
if not input:
raise ValueError("Input documents cannot be empty")
if not isinstance(input, (list, tuple)):
raise ValueError("Input must be a list or tuple of documents")
if not all(isinstance(doc, str) for doc in input):
raise ValueError("All input documents must be strings")
try:
response = self.client.models.embed_content(
model=self.model_name, contents=input
)
except Exception as e:
raise ValueError(f"Failed to generate embeddings: {str(e)}") from e
# Validate response structure
if not hasattr(response, "embeddings") or not response.embeddings:
raise ValueError("No embeddings returned from the API")
embeddings_list = []
for ce in response.embeddings:
if not hasattr(ce, "values"):
raise ValueError("Malformed embedding response: missing 'values'")
embeddings_list.append(np.array(ce.values, dtype=np.float32))
return cast(Embeddings, embeddings_list)
@staticmethod
def name() -> str:
return "google_genai"
def default_space(self) -> Space:
return "cosine"
def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]
@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
model_name = config.get("model_name")
vertexai = config.get("vertexai")
project = config.get("project")
location = config.get("location")
if model_name is None:
raise ValueError("The model name is required.")
return GoogleGenaiEmbeddingFunction(
model_name=model_name,
vertexai=vertexai,
project=project,
location=location,
)
def get_config(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"vertexai": self.vertexai,
"project": self.project,
"location": self.location,
}
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)
if "vertexai" in new_config:
raise ValueError(
"The vertexai cannot be changed after the embedding function has been initialized."
)
if "project" in new_config:
raise ValueError(
"The project cannot be changed after the embedding function has been initialized."
)
if "location" in new_config:
raise ValueError(
"The location cannot be changed after the embedding function has been initialized."
)
@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.
Args:
config: Configuration to validate
Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "google_genai")
class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""
@@ -38,10 +182,16 @@ class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("GOOGLE_API_KEY") is not None:
self.api_key_env_var = "GOOGLE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
@@ -154,10 +304,16 @@ class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("GOOGLE_API_KEY") is not None:
self.api_key_env_var = "GOOGLE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
self.task_type = task_type
@@ -289,10 +445,16 @@ class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("GOOGLE_API_KEY") is not None:
self.api_key_env_var = "GOOGLE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
self.project_id = project_id

View File

@@ -40,10 +40,16 @@ class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("HUGGINGFACE_API_KEY") is not None:
self.api_key_env_var = "HUGGINGFACE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
@@ -160,6 +166,9 @@ class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]):
self.url = url
self.api_key_env_var = api_key_env_var
if os.getenv("HUGGINGFACE_API_KEY") is not None:
self.api_key_env_var = "HUGGINGFACE_API_KEY"
if self.api_key_env_var is not None:
self.api_key = api_key or os.getenv(self.api_key_env_var)
else:

View File

@@ -81,10 +81,16 @@ class JinaEmbeddingFunction(EmbeddingFunction[Embeddable]):
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("JINA_API_KEY") is not None:
self.api_key_env_var = "JINA_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name

View File

@@ -57,10 +57,16 @@ class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]):
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("OPENAI_API_KEY") is not None:
self.api_key_env_var = "OPENAI_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
self.organization_id = organization_id

View File

@@ -45,10 +45,16 @@ class RoboflowEmbeddingFunction(EmbeddingFunction[Embeddable]):
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("ROBOFLOW_API_KEY") is not None:
self.api_key_env_var = "ROBOFLOW_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.api_url = api_url

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import re
from functools import lru_cache
from typing import Iterable, List, Protocol, cast
@@ -213,10 +212,8 @@ class _SnowballStemmerAdapter:
return cast(str, self._stemmer.stemWord(token))
@lru_cache(maxsize=1)
def get_english_stemmer() -> SnowballStemmer:
"""Return a cached Snowball stemmer for English."""
"""Return a Snowball stemmer for English."""
return _SnowballStemmerAdapter()

View File

@@ -48,15 +48,16 @@ class TogetherAIEmbeddingFunction(EmbeddingFunction[Documents]):
)
self.model_name = model_name
self.api_key = api_key
self.api_key_env_var = api_key_env_var
if not self.api_key:
self.api_key = os.getenv(self.api_key_env_var)
if os.getenv("TOGETHER_API_KEY") is not None:
self.api_key_env_var = "TOGETHER_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(
f"API key not found in environment variable {self.api_key_env_var}"
f"The {self.api_key_env_var} environment variable is not set."
)
self._session = httpx.Client()

View File

@@ -47,10 +47,16 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
DeprecationWarning,
)
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if os.getenv("VOYAGE_API_KEY") is not None:
self.api_key_env_var = "VOYAGE_API_KEY"
else:
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
raise ValueError(
f"The {self.api_key_env_var} environment variable is not set."
)
self.model_name = model_name
self.input_type = input_type

View File

@@ -1,8 +1,12 @@
from typing import List, Tuple
from typing import List, Optional, Tuple
from chromadb.base_types import SparseVector
def normalize_sparse_vector(indices: List[int], values: List[float]) -> SparseVector:
def normalize_sparse_vector(
indices: List[int],
values: List[float],
labels: Optional[List[str]] = None
) -> SparseVector:
"""Normalize and create a SparseVector by sorting indices and values together.
This function takes raw indices and values (which may be unsorted or have duplicates)
@@ -11,6 +15,7 @@ def normalize_sparse_vector(indices: List[int], values: List[float]) -> SparseVe
Args:
indices: List of dimension indices (may be unsorted)
values: List of values corresponding to each index
labels: Optional list of string labels corresponding to each index
Returns:
SparseVector with indices sorted in ascending order
@@ -20,17 +25,25 @@ def normalize_sparse_vector(indices: List[int], values: List[float]) -> SparseVe
ValueError: If there are duplicate indices (after sorting)
ValueError: If indices are negative
ValueError: If values are not numeric
ValueError: If labels is provided and has different length than indices
"""
if not indices:
return SparseVector(indices=[], values=[])
return SparseVector(indices=[], values=[], labels=None)
# Sort indices and values together by index
sorted_pairs = sorted(zip(indices, values), key=lambda x: x[0])
sorted_indices, sorted_values = zip(*sorted_pairs)
# Create SparseVector which will validate:
# - indices are sorted
# - no duplicate indices
# - indices are non-negative
# - values are numeric
return SparseVector(indices=list(sorted_indices), values=list(sorted_values))
# Sort indices, values, and labels together by index
if labels is not None:
sorted_triples = sorted(zip(indices, values, labels), key=lambda x: x[0])
sorted_indices, sorted_values, sorted_labels = zip(*sorted_triples)
return SparseVector(
indices=list(sorted_indices),
values=list(sorted_values),
labels=list(sorted_labels)
)
else:
sorted_pairs = sorted(zip(indices, values), key=lambda x: x[0])
sorted_indices, sorted_values = zip(*sorted_pairs)
return SparseVector(
indices=list(sorted_indices),
values=list(sorted_values),
labels=None
)