增加环绕侦察场景适配
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -23,6 +23,7 @@ from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
GoogleGenaiEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
@@ -68,6 +69,10 @@ from chromadb.utils.embedding_functions.mistral_embedding_function import (
|
||||
from chromadb.utils.embedding_functions.morph_embedding_function import (
|
||||
MorphEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.nomic_embedding_function import (
|
||||
NomicEmbeddingFunction,
|
||||
NomicQueryConfig,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_sparse_embedding_function import (
|
||||
HuggingFaceSparseEmbeddingFunction,
|
||||
)
|
||||
@@ -98,11 +103,13 @@ _all_classes: Set[str] = {
|
||||
"GooglePalmEmbeddingFunction",
|
||||
"GoogleGenerativeAiEmbeddingFunction",
|
||||
"GoogleVertexEmbeddingFunction",
|
||||
"GoogleGenaiEmbeddingFunction",
|
||||
"OllamaEmbeddingFunction",
|
||||
"InstructorEmbeddingFunction",
|
||||
"JinaEmbeddingFunction",
|
||||
"MistralEmbeddingFunction",
|
||||
"MorphEmbeddingFunction",
|
||||
"NomicEmbeddingFunction",
|
||||
"VoyageAIEmbeddingFunction",
|
||||
"ONNXMiniLM_L6_V2",
|
||||
"OpenCLIPEmbeddingFunction",
|
||||
@@ -137,11 +144,13 @@ known_embedding_functions: Dict[str, Type[EmbeddingFunction]] = { # type: ignor
|
||||
"google_palm": GooglePalmEmbeddingFunction,
|
||||
"google_generative_ai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google_vertex": GoogleVertexEmbeddingFunction,
|
||||
"google_genai": GoogleGenaiEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"mistral": MistralEmbeddingFunction,
|
||||
"morph": MorphEmbeddingFunction,
|
||||
"nomic": NomicEmbeddingFunction,
|
||||
"voyageai": VoyageAIEmbeddingFunction,
|
||||
"onnx_mini_lm_l6_v2": ONNXMiniLM_L6_V2,
|
||||
"open_clip": OpenCLIPEmbeddingFunction,
|
||||
@@ -259,12 +268,15 @@ __all__ = [
|
||||
"GooglePalmEmbeddingFunction",
|
||||
"GoogleGenerativeAiEmbeddingFunction",
|
||||
"GoogleVertexEmbeddingFunction",
|
||||
"GoogleGenaiEmbeddingFunction",
|
||||
"OllamaEmbeddingFunction",
|
||||
"InstructorEmbeddingFunction",
|
||||
"JinaEmbeddingFunction",
|
||||
"JinaQueryConfig",
|
||||
"MistralEmbeddingFunction",
|
||||
"MorphEmbeddingFunction",
|
||||
"NomicEmbeddingFunction",
|
||||
"NomicQueryConfig",
|
||||
"VoyageAIEmbeddingFunction",
|
||||
"ONNXMiniLM_L6_V2",
|
||||
"OpenCLIPEmbeddingFunction",
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,6 +2,7 @@ import os
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from typing import Dict, Any, Optional, List
|
||||
from chromadb.api.types import Space
|
||||
import warnings
|
||||
@@ -35,12 +36,16 @@ class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
if os.getenv("BASETEN_API_KEY") is not None:
|
||||
self.api_key_env_var = "BASETEN_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
# Prioritize api_key argument, then environment variable
|
||||
resolved_api_key = api_key or os.getenv(api_key_env_var)
|
||||
resolved_api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not resolved_api_key:
|
||||
raise ValueError(
|
||||
f"API key not provided and {api_key_env_var} environment variable is not set."
|
||||
f"API key not provided and {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
self.api_key = resolved_api_key
|
||||
if not api_base:
|
||||
@@ -96,3 +101,16 @@ class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
|
||||
api_base=api_base,
|
||||
api_key_env_var=api_key_env_var,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "baseten")
|
||||
|
||||
@@ -231,4 +231,4 @@ class Bm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "bm25")
|
||||
validate_config_schema(config, "bm25")
|
||||
@@ -23,12 +23,32 @@ DEFAULT_TOKEN_MAX_LENGTH = 40
|
||||
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)
|
||||
|
||||
|
||||
class _HashedToken:
|
||||
__slots__ = ("hash", "label")
|
||||
|
||||
def __init__(self, hash: int, label: Optional[str]):
|
||||
self.hash = hash
|
||||
self.label = label
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.hash
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, _HashedToken):
|
||||
return NotImplemented
|
||||
return self.hash == other.hash
|
||||
|
||||
def __lt__(self, other: "_HashedToken") -> bool:
|
||||
return self.hash < other.hash
|
||||
|
||||
|
||||
class ChromaBm25Config(TypedDict, total=False):
|
||||
k: float
|
||||
b: float
|
||||
avg_doc_length: float
|
||||
token_max_length: int
|
||||
stopwords: List[str]
|
||||
include_tokens: bool
|
||||
|
||||
|
||||
class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
@@ -39,6 +59,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
|
||||
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
|
||||
stopwords: Optional[Iterable[str]] = None,
|
||||
include_tokens: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the BM25 sparse embedding function."""
|
||||
|
||||
@@ -46,38 +67,51 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
self.b = float(b)
|
||||
self.avg_doc_length = float(avg_doc_length)
|
||||
self.token_max_length = int(token_max_length)
|
||||
self.include_tokens = bool(include_tokens)
|
||||
|
||||
if stopwords is not None:
|
||||
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
|
||||
stopword_list: Iterable[str] = self.stopwords
|
||||
self._stopword_list: Iterable[str] = self.stopwords
|
||||
else:
|
||||
self.stopwords = None
|
||||
stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
|
||||
self._stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
|
||||
|
||||
stemmer = get_english_stemmer()
|
||||
self._tokenizer = Bm25Tokenizer(stemmer, stopword_list, self.token_max_length)
|
||||
self._hasher = Murmur3AbsHasher()
|
||||
|
||||
def _encode(self, text: str) -> SparseVector:
|
||||
tokens = self._tokenizer.tokenize(text)
|
||||
stemmer = get_english_stemmer()
|
||||
tokenizer = Bm25Tokenizer(stemmer, self._stopword_list, self.token_max_length)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
|
||||
if not tokens:
|
||||
return SparseVector(indices=[], values=[])
|
||||
|
||||
doc_len = float(len(tokens))
|
||||
counts = Counter(self._hasher.hash(token) for token in tokens)
|
||||
counts = Counter(
|
||||
_HashedToken(
|
||||
self._hasher.hash(token), token if self.include_tokens else None
|
||||
)
|
||||
for token in tokens
|
||||
)
|
||||
|
||||
indices = sorted(counts.keys())
|
||||
sorted_keys = sorted(counts.keys())
|
||||
indices: List[int] = []
|
||||
values: List[float] = []
|
||||
for idx in indices:
|
||||
tf = float(counts[idx])
|
||||
labels: Optional[List[str]] = [] if self.include_tokens else None
|
||||
|
||||
for key in sorted_keys:
|
||||
tf = float(counts[key])
|
||||
denominator = tf + self.k * (
|
||||
1 - self.b + (self.b * doc_len) / self.avg_doc_length
|
||||
)
|
||||
score = tf * (self.k + 1) / denominator
|
||||
values.append(score)
|
||||
|
||||
return SparseVector(indices=indices, values=values)
|
||||
indices.append(key.hash)
|
||||
values.append(score)
|
||||
if labels is not None and key.label is not None:
|
||||
labels.append(key.label)
|
||||
|
||||
return SparseVector(indices=indices, values=values, labels=labels)
|
||||
|
||||
def __call__(self, input: Documents) -> SparseVectors:
|
||||
sparse_vectors: SparseVectors = []
|
||||
@@ -99,7 +133,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(
|
||||
config: Dict[str, Any]
|
||||
config: Dict[str, Any],
|
||||
) -> "SparseEmbeddingFunction[Documents]":
|
||||
return ChromaBm25EmbeddingFunction(
|
||||
k=config.get("k", DEFAULT_K),
|
||||
@@ -107,6 +141,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
|
||||
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
|
||||
stopwords=config.get("stopwords"),
|
||||
include_tokens=config.get("include_tokens", False),
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
@@ -115,6 +150,7 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
"b": self.b,
|
||||
"avg_doc_length": self.avg_doc_length,
|
||||
"token_max_length": self.token_max_length,
|
||||
"include_tokens": self.include_tokens,
|
||||
}
|
||||
|
||||
if self.stopwords is not None:
|
||||
@@ -125,7 +161,14 @@ class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords"}
|
||||
mutable_keys = {
|
||||
"k",
|
||||
"b",
|
||||
"avg_doc_length",
|
||||
"token_max_length",
|
||||
"stopwords",
|
||||
"include_tokens",
|
||||
}
|
||||
for key in new_config:
|
||||
if key not in mutable_keys:
|
||||
raise ValueError(f"Updating '{key}' is not supported for {NAME}")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
|
||||
from typing import List, Dict, Any, Union
|
||||
from typing import List, Dict, Any, Union, Optional
|
||||
import os
|
||||
import numpy as np
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.utils.embedding_functions.utils import _get_shared_system_client
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -32,7 +33,7 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
model: ChromaCloudQwenEmbeddingModel,
|
||||
task: str,
|
||||
task: Optional[str],
|
||||
instructions: ChromaCloudQwenEmbeddingInstructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
|
||||
api_key_env_var: str = "CHROMA_API_KEY",
|
||||
):
|
||||
@@ -41,7 +42,8 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
|
||||
Args:
|
||||
model (ChromaCloudQwenEmbeddingModel): The specific Qwen model to use for embeddings.
|
||||
task (str): The task for which embeddings are being generated.
|
||||
task (str, optional): The task for which embeddings are being generated. If None or empty,
|
||||
empty instructions will be used for both documents and queries.
|
||||
instructions (ChromaCloudQwenEmbeddingInstructions, optional): A dictionary containing
|
||||
custom instructions to use for the specified Qwen model. Defaults to CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS.
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key.
|
||||
@@ -55,9 +57,18 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
# First, try to get API key from environment variable
|
||||
self.api_key = os.getenv(api_key_env_var)
|
||||
# If not found in env var, try to get it from existing client instances
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
SharedSystemClient = _get_shared_system_client()
|
||||
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
|
||||
# Raise error if still no API key found
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
f"API key not found in environment variable {api_key_env_var} "
|
||||
f"or in any existing client instances"
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.task = task
|
||||
@@ -102,10 +113,14 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
if not input:
|
||||
return []
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": self.instructions[self.task][
|
||||
instruction = ""
|
||||
if self.task and self.task in self.instructions:
|
||||
instruction = self.instructions[self.task][
|
||||
ChromaCloudQwenEmbeddingTarget.DOCUMENTS
|
||||
],
|
||||
]
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": instruction,
|
||||
"texts": input,
|
||||
}
|
||||
|
||||
@@ -120,10 +135,14 @@ class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
if not input:
|
||||
return []
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": self.instructions[self.task][
|
||||
instruction = ""
|
||||
if self.task and self.task in self.instructions:
|
||||
instruction = self.instructions[self.task][
|
||||
ChromaCloudQwenEmbeddingTarget.QUERY
|
||||
],
|
||||
]
|
||||
|
||||
payload: Dict[str, Union[str, Documents]] = {
|
||||
"instructions": instruction,
|
||||
"texts": input,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from chromadb.api.types import (
|
||||
SparseEmbeddingFunction,
|
||||
SparseVector,
|
||||
SparseVectors,
|
||||
Documents,
|
||||
)
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, List, Optional
|
||||
from enum import Enum
|
||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
|
||||
from chromadb.base_types import SparseVector
|
||||
import os
|
||||
from typing import Union
|
||||
from chromadb.utils.embedding_functions.utils import _get_shared_system_client
|
||||
|
||||
|
||||
class ChromaCloudSpladeEmbeddingModel(Enum):
|
||||
@@ -21,6 +22,7 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
self,
|
||||
api_key_env_var: str = "CHROMA_API_KEY",
|
||||
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
|
||||
include_tokens: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the ChromaCloudSpladeEmbeddingFunction.
|
||||
@@ -36,12 +38,20 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
"The httpx python package is not installed. Please install it with `pip install httpx`"
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
# First, try to get API key from environment variable
|
||||
self.api_key = os.getenv(self.api_key_env_var)
|
||||
# If not found in env var, try to get it from existing client instances
|
||||
if not self.api_key:
|
||||
SharedSystemClient = _get_shared_system_client()
|
||||
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
|
||||
# Raise error if still no API key found
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
f"API key not found in environment variable {self.api_key_env_var}"
|
||||
f"API key not found in environment variable {self.api_key_env_var} "
|
||||
f"or in any existing client instances"
|
||||
)
|
||||
self.model = model
|
||||
self.include_tokens = bool(include_tokens)
|
||||
self._api_url = "https://embed.trychroma.com/embed_sparse"
|
||||
self._session = httpx.Client()
|
||||
self._session.headers.update(
|
||||
@@ -80,6 +90,7 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
"texts": list(input),
|
||||
"task": "",
|
||||
"target": "",
|
||||
"fetch_tokens": "true" if self.include_tokens is True else "false",
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -113,13 +124,17 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
if isinstance(emb, dict):
|
||||
indices = emb.get("indices", [])
|
||||
values = emb.get("values", [])
|
||||
raw_labels = emb.get("labels") if self.include_tokens else None
|
||||
labels: Optional[List[str]] = raw_labels if raw_labels else None
|
||||
else:
|
||||
# Already a SparseVector, extract its data
|
||||
assert isinstance(emb, SparseVector)
|
||||
indices = emb.indices
|
||||
values = emb.values
|
||||
labels = emb.labels if self.include_tokens else None
|
||||
|
||||
normalized_vectors.append(
|
||||
normalize_sparse_vector(indices=indices, values=values)
|
||||
normalize_sparse_vector(indices=indices, values=values, labels=labels)
|
||||
)
|
||||
|
||||
return normalized_vectors
|
||||
@@ -141,18 +156,25 @@ class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
|
||||
return ChromaCloudSpladeEmbeddingFunction(
|
||||
api_key_env_var=api_key_env_var,
|
||||
model=ChromaCloudSpladeEmbeddingModel(model),
|
||||
include_tokens=config.get("include_tokens", False),
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"api_key_env_var": self.api_key_env_var, "model": self.model.value}
|
||||
return {
|
||||
"api_key_env_var": self.api_key_env_var,
|
||||
"model": self.model.value,
|
||||
"include_tokens": self.include_tokens,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model" in new_config:
|
||||
raise ValueError(
|
||||
"model cannot be changed after the embedding function has been initialized"
|
||||
)
|
||||
immutable_keys = {"include_tokens", "model"}
|
||||
for key in immutable_keys:
|
||||
if key in new_config and new_config[key] != old_config.get(key):
|
||||
raise ValueError(
|
||||
f"Updating '{key}' is not supported for chroma-cloud-splade"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
|
||||
@@ -53,12 +53,19 @@ class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.account_id = account_id
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
|
||||
if os.getenv("CLOUDFLARE_API_KEY") is not None:
|
||||
self.api_key_env_var = "CLOUDFLARE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
self.gateway_id = gateway_id
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
if self.gateway_id:
|
||||
self._api_url = f"{GATEWAY_BASE_URL}/{self.account_id}/{self.gateway_id}/workers-ai/{self.model_name}"
|
||||
|
||||
@@ -43,10 +43,16 @@ class CohereEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("COHERE_API_KEY") is not None:
|
||||
self.api_key_env_var = "COHERE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
@@ -7,6 +7,150 @@ from chromadb.utils.embedding_functions.schemas import validate_config_schema
|
||||
import warnings
|
||||
|
||||
|
||||
class GoogleGenaiEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
vertexai: Optional[bool] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
api_key_env_var: str = "GOOGLE_API_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize the GoogleGenaiEmbeddingFunction.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to use for text embeddings.
|
||||
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google GenAI API.
|
||||
Defaults to "GOOGLE_API_KEY".
|
||||
"""
|
||||
try:
|
||||
import google.genai as genai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The google-genai python package is not installed. Please install it with `pip install google-genai`"
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.vertexai = vertexai
|
||||
self.project = project
|
||||
self.location = location
|
||||
self.api_key = os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.client = genai.Client(
|
||||
api_key=self.api_key, vertexai=vertexai, project=project, location=location
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""
|
||||
Generate embeddings for the given documents.
|
||||
|
||||
Args:
|
||||
input: Documents or images to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
Embeddings for the documents.
|
||||
"""
|
||||
if not input:
|
||||
raise ValueError("Input documents cannot be empty")
|
||||
if not isinstance(input, (list, tuple)):
|
||||
raise ValueError("Input must be a list or tuple of documents")
|
||||
if not all(isinstance(doc, str) for doc in input):
|
||||
raise ValueError("All input documents must be strings")
|
||||
|
||||
try:
|
||||
response = self.client.models.embed_content(
|
||||
model=self.model_name, contents=input
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate embeddings: {str(e)}") from e
|
||||
|
||||
# Validate response structure
|
||||
if not hasattr(response, "embeddings") or not response.embeddings:
|
||||
raise ValueError("No embeddings returned from the API")
|
||||
|
||||
embeddings_list = []
|
||||
for ce in response.embeddings:
|
||||
if not hasattr(ce, "values"):
|
||||
raise ValueError("Malformed embedding response: missing 'values'")
|
||||
embeddings_list.append(np.array(ce.values, dtype=np.float32))
|
||||
|
||||
return cast(Embeddings, embeddings_list)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "google_genai"
|
||||
|
||||
def default_space(self) -> Space:
|
||||
return "cosine"
|
||||
|
||||
def supported_spaces(self) -> List[Space]:
|
||||
return ["cosine", "l2", "ip"]
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
|
||||
model_name = config.get("model_name")
|
||||
vertexai = config.get("vertexai")
|
||||
project = config.get("project")
|
||||
location = config.get("location")
|
||||
|
||||
if model_name is None:
|
||||
raise ValueError("The model name is required.")
|
||||
|
||||
return GoogleGenaiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
vertexai=vertexai,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"vertexai": self.vertexai,
|
||||
"project": self.project,
|
||||
"location": self.location,
|
||||
}
|
||||
|
||||
def validate_config_update(
|
||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
||||
) -> None:
|
||||
if "model_name" in new_config:
|
||||
raise ValueError(
|
||||
"The model name cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "vertexai" in new_config:
|
||||
raise ValueError(
|
||||
"The vertexai cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "project" in new_config:
|
||||
raise ValueError(
|
||||
"The project cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
if "location" in new_config:
|
||||
raise ValueError(
|
||||
"The location cannot be changed after the embedding function has been initialized."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the configuration using the JSON schema.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValidationError: If the configuration does not match the schema
|
||||
"""
|
||||
validate_config_schema(config, "google_genai")
|
||||
|
||||
|
||||
class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""
|
||||
|
||||
@@ -38,10 +182,16 @@ class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("GOOGLE_API_KEY") is not None:
|
||||
self.api_key_env_var = "GOOGLE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
@@ -154,10 +304,16 @@ class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("GOOGLE_API_KEY") is not None:
|
||||
self.api_key_env_var = "GOOGLE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.task_type = task_type
|
||||
@@ -289,10 +445,16 @@ class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("GOOGLE_API_KEY") is not None:
|
||||
self.api_key_env_var = "GOOGLE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.project_id = project_id
|
||||
|
||||
@@ -40,10 +40,16 @@ class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("HUGGINGFACE_API_KEY") is not None:
|
||||
self.api_key_env_var = "HUGGINGFACE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
@@ -160,6 +166,9 @@ class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]):
|
||||
self.url = url
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
if os.getenv("HUGGINGFACE_API_KEY") is not None:
|
||||
self.api_key_env_var = "HUGGINGFACE_API_KEY"
|
||||
|
||||
if self.api_key_env_var is not None:
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
else:
|
||||
|
||||
@@ -81,10 +81,16 @@ class JinaEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("JINA_API_KEY") is not None:
|
||||
self.api_key_env_var = "JINA_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
@@ -57,10 +57,16 @@ class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("OPENAI_API_KEY") is not None:
|
||||
self.api_key_env_var = "OPENAI_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.organization_id = organization_id
|
||||
|
||||
@@ -45,10 +45,16 @@ class RoboflowEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"Please use environment variables via api_key_env_var for persistent storage.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("ROBOFLOW_API_KEY") is not None:
|
||||
self.api_key_env_var = "ROBOFLOW_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.api_url = api_url
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Protocol, cast
|
||||
|
||||
|
||||
@@ -213,10 +212,8 @@ class _SnowballStemmerAdapter:
|
||||
return cast(str, self._stemmer.stemWord(token))
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_english_stemmer() -> SnowballStemmer:
|
||||
"""Return a cached Snowball stemmer for English."""
|
||||
|
||||
"""Return a Snowball stemmer for English."""
|
||||
return _SnowballStemmerAdapter()
|
||||
|
||||
|
||||
|
||||
@@ -48,15 +48,16 @@ class TogetherAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv(self.api_key_env_var)
|
||||
if os.getenv("TOGETHER_API_KEY") is not None:
|
||||
self.api_key_env_var = "TOGETHER_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
f"API key not found in environment variable {self.api_key_env_var}"
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self._session = httpx.Client()
|
||||
|
||||
@@ -47,10 +47,16 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.api_key_env_var = api_key_env_var
|
||||
self.api_key = api_key or os.getenv(api_key_env_var)
|
||||
if os.getenv("VOYAGE_API_KEY") is not None:
|
||||
self.api_key_env_var = "VOYAGE_API_KEY"
|
||||
else:
|
||||
self.api_key_env_var = api_key_env_var
|
||||
|
||||
self.api_key = api_key or os.getenv(self.api_key_env_var)
|
||||
if not self.api_key:
|
||||
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
|
||||
raise ValueError(
|
||||
f"The {self.api_key_env_var} environment variable is not set."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.input_type = input_type
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
from chromadb.base_types import SparseVector
|
||||
|
||||
|
||||
def normalize_sparse_vector(indices: List[int], values: List[float]) -> SparseVector:
|
||||
def normalize_sparse_vector(
|
||||
indices: List[int],
|
||||
values: List[float],
|
||||
labels: Optional[List[str]] = None
|
||||
) -> SparseVector:
|
||||
"""Normalize and create a SparseVector by sorting indices and values together.
|
||||
|
||||
This function takes raw indices and values (which may be unsorted or have duplicates)
|
||||
@@ -11,6 +15,7 @@ def normalize_sparse_vector(indices: List[int], values: List[float]) -> SparseVe
|
||||
Args:
|
||||
indices: List of dimension indices (may be unsorted)
|
||||
values: List of values corresponding to each index
|
||||
labels: Optional list of string labels corresponding to each index
|
||||
|
||||
Returns:
|
||||
SparseVector with indices sorted in ascending order
|
||||
@@ -20,17 +25,25 @@ def normalize_sparse_vector(indices: List[int], values: List[float]) -> SparseVe
|
||||
ValueError: If there are duplicate indices (after sorting)
|
||||
ValueError: If indices are negative
|
||||
ValueError: If values are not numeric
|
||||
ValueError: If labels is provided and has different length than indices
|
||||
"""
|
||||
if not indices:
|
||||
return SparseVector(indices=[], values=[])
|
||||
return SparseVector(indices=[], values=[], labels=None)
|
||||
|
||||
# Sort indices and values together by index
|
||||
sorted_pairs = sorted(zip(indices, values), key=lambda x: x[0])
|
||||
sorted_indices, sorted_values = zip(*sorted_pairs)
|
||||
|
||||
# Create SparseVector which will validate:
|
||||
# - indices are sorted
|
||||
# - no duplicate indices
|
||||
# - indices are non-negative
|
||||
# - values are numeric
|
||||
return SparseVector(indices=list(sorted_indices), values=list(sorted_values))
|
||||
# Sort indices, values, and labels together by index
|
||||
if labels is not None:
|
||||
sorted_triples = sorted(zip(indices, values, labels), key=lambda x: x[0])
|
||||
sorted_indices, sorted_values, sorted_labels = zip(*sorted_triples)
|
||||
return SparseVector(
|
||||
indices=list(sorted_indices),
|
||||
values=list(sorted_values),
|
||||
labels=list(sorted_labels)
|
||||
)
|
||||
else:
|
||||
sorted_pairs = sorted(zip(indices, values), key=lambda x: x[0])
|
||||
sorted_indices, sorted_values = zip(*sorted_pairs)
|
||||
return SparseVector(
|
||||
indices=list(sorted_indices),
|
||||
values=list(sorted_values),
|
||||
labels=None
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user