chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,237 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Dict,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from chromadb.config import (
|
||||
Component,
|
||||
System,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
ClientAuthHeaders = Dict[str, SecretStr]
|
||||
|
||||
|
||||
class ClientAuthProvider(Component):
|
||||
"""
|
||||
ClientAuthProvider is responsible for providing authentication headers for
|
||||
client requests. Client implementations (in our case, just the FastAPI
|
||||
client) must inject these headers into their requests.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
|
||||
@abstractmethod
|
||||
def authenticate(self) -> ClientAuthHeaders:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserIdentity:
|
||||
"""
|
||||
UserIdentity represents the identity of a user. In general, not all fields
|
||||
will be populated, and the fields that are populated will depend on the
|
||||
authentication provider.
|
||||
|
||||
The idea is that the AuthenticationProvider is responsible for populating
|
||||
_all_ information known about the user, and the AuthorizationProvider is
|
||||
responsible for making decisions based on that information.
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
tenant: Optional[str] = None
|
||||
databases: Optional[List[str]] = None
|
||||
# This can be used for any additional auth context which needs to be
|
||||
# propagated from the authentication provider to the authorization
|
||||
# provider.
|
||||
attributes: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ServerAuthenticationProvider(Component):
|
||||
"""
|
||||
ServerAuthenticationProvider is responsible for authenticating requests. If
|
||||
a ServerAuthenticationProvider is configured, it will be called by the
|
||||
server to authenticate requests. If no ServerAuthenticationProvider is
|
||||
configured, all requests will be authenticated.
|
||||
|
||||
The ServerAuthenticationProvider should return a UserIdentity object if the
|
||||
request is authenticated for use by the ServerAuthorizationProvider.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._ignore_auth_paths: Dict[
|
||||
str, List[str]
|
||||
] = system.settings.chroma_server_auth_ignore_paths
|
||||
self.overwrite_singleton_tenant_database_access_from_auth = (
|
||||
system.settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
|
||||
pass
|
||||
|
||||
def ignore_operation(self, verb: str, path: str) -> bool:
|
||||
if (
|
||||
path in self._ignore_auth_paths.keys()
|
||||
and verb.upper() in self._ignore_auth_paths[path]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def read_creds_or_creds_file(self) -> List[str]:
|
||||
_creds_file = None
|
||||
_creds = None
|
||||
|
||||
if self._system.settings.chroma_server_authn_credentials_file:
|
||||
_creds_file = str(
|
||||
self._system.settings["chroma_server_authn_credentials_file"]
|
||||
)
|
||||
if self._system.settings.chroma_server_authn_credentials:
|
||||
_creds = str(self._system.settings["chroma_server_authn_credentials"])
|
||||
if not _creds_file and not _creds:
|
||||
raise ValueError(
|
||||
"No credentials file or credentials found in "
|
||||
"[chroma_server_authn_credentials]."
|
||||
)
|
||||
if _creds_file and _creds:
|
||||
raise ValueError(
|
||||
"Both credentials file and credentials found."
|
||||
"Please provide only one."
|
||||
)
|
||||
if _creds:
|
||||
return [c for c in _creds.split("\n") if c]
|
||||
elif _creds_file:
|
||||
with open(_creds_file, "r") as f:
|
||||
return f.readlines()
|
||||
raise ValueError("Should never happen")
|
||||
|
||||
def singleton_tenant_database_if_applicable(
|
||||
self, user: Optional[UserIdentity]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
is False, this function always returns (None, None).
|
||||
|
||||
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
is True, follows the following logic:
|
||||
- If the user only has access to a single tenant, this function will
|
||||
return that tenant as its first return value.
|
||||
- If the user only has access to a single database, this function will
|
||||
return that database as its second return value. If the user has
|
||||
access to multiple tenants and/or databases, including "*", this
|
||||
function will return None for the corresponding value(s).
|
||||
- If the user has access to multiple tenants and/or databases this
|
||||
function will return None for the corresponding value(s).
|
||||
"""
|
||||
if not self.overwrite_singleton_tenant_database_access_from_auth or not user:
|
||||
return None, None
|
||||
tenant = None
|
||||
database = None
|
||||
if user.tenant and user.tenant != "*":
|
||||
tenant = user.tenant
|
||||
if user.databases and len(user.databases) == 1 and user.databases[0] != "*":
|
||||
database = user.databases[0]
|
||||
return tenant, database
|
||||
|
||||
|
||||
class AuthzAction(str, Enum):
|
||||
"""
|
||||
The set of actions that can be authorized by the authorization provider.
|
||||
"""
|
||||
|
||||
RESET = "system:reset"
|
||||
CREATE_TENANT = "tenant:create_tenant"
|
||||
GET_TENANT = "tenant:get_tenant"
|
||||
CREATE_DATABASE = "db:create_database"
|
||||
GET_DATABASE = "db:get_database"
|
||||
DELETE_DATABASE = "db:delete_database"
|
||||
LIST_DATABASES = "db:list_databases"
|
||||
LIST_COLLECTIONS = "db:list_collections"
|
||||
COUNT_COLLECTIONS = "db:count_collections"
|
||||
CREATE_COLLECTION = "db:create_collection"
|
||||
GET_OR_CREATE_COLLECTION = "db:get_or_create_collection"
|
||||
GET_COLLECTION = "collection:get_collection"
|
||||
DELETE_COLLECTION = "collection:delete_collection"
|
||||
UPDATE_COLLECTION = "collection:update_collection"
|
||||
ADD = "collection:add"
|
||||
DELETE = "collection:delete"
|
||||
GET = "collection:get"
|
||||
QUERY = "collection:query"
|
||||
COUNT = "collection:count"
|
||||
UPDATE = "collection:update"
|
||||
UPSERT = "collection:upsert"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthzResource:
|
||||
"""
|
||||
The resource being accessed in an authorization request.
|
||||
"""
|
||||
|
||||
tenant: Optional[str]
|
||||
database: Optional[str]
|
||||
collection: Optional[str]
|
||||
|
||||
|
||||
class ServerAuthorizationProvider(Component):
|
||||
"""
|
||||
ServerAuthorizationProvider is responsible for authorizing requests. If a
|
||||
ServerAuthorizationProvider is configured, it will be called by the server
|
||||
to authorize requests. If no ServerAuthorizationProvider is configured, all
|
||||
requests will be authorized.
|
||||
|
||||
ServerAuthorizationProvider should raise an exception if the request is not
|
||||
authorized.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
|
||||
@abstractmethod
|
||||
def authorize_or_raise(
|
||||
self, user: UserIdentity, action: AuthzAction, resource: AuthzResource
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def read_config_or_config_file(self) -> List[str]:
|
||||
_config_file = None
|
||||
_config = None
|
||||
if self._system.settings.chroma_server_authz_config_file:
|
||||
_config_file = self._system.settings["chroma_server_authz_config_file"]
|
||||
if self._system.settings.chroma_server_authz_config:
|
||||
_config = str(self._system.settings["chroma_server_authz_config"])
|
||||
if not _config_file and not _config:
|
||||
raise ValueError(
|
||||
"No authz configuration file or authz configuration found."
|
||||
)
|
||||
if _config_file and _config:
|
||||
raise ValueError(
|
||||
"Both authz configuration file and authz configuration found."
|
||||
"Please provide only one."
|
||||
)
|
||||
if _config:
|
||||
return [c for c in _config.split("\n") if c]
|
||||
elif _config_file:
|
||||
with open(_config_file, "r") as f:
|
||||
return f.readlines()
|
||||
raise ValueError("Should never happen")
|
||||
Binary file not shown.
@@ -0,0 +1,146 @@
|
||||
import base64
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import bcrypt
|
||||
import logging
|
||||
|
||||
from overrides import override
|
||||
from pydantic import SecretStr
|
||||
|
||||
from chromadb.auth import (
|
||||
UserIdentity,
|
||||
ServerAuthenticationProvider,
|
||||
ClientAuthProvider,
|
||||
ClientAuthHeaders,
|
||||
AuthError,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from chromadb.errors import ChromaAuthError
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["BasicAuthenticationServerProvider", "BasicAuthClientProvider"]
|
||||
|
||||
AUTHORIZATION_HEADER = "Authorization"
|
||||
|
||||
|
||||
class BasicAuthClientProvider(ClientAuthProvider):
|
||||
"""
|
||||
Client auth provider for basic auth. The credentials are passed as a
|
||||
base64-encoded string in the Authorization header prepended with "Basic ".
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
system.settings.require("chroma_client_auth_credentials")
|
||||
self._creds = SecretStr(str(system.settings.chroma_client_auth_credentials))
|
||||
|
||||
@override
|
||||
def authenticate(self) -> ClientAuthHeaders:
|
||||
encoded = base64.b64encode(
|
||||
f"{self._creds.get_secret_value()}".encode("utf-8")
|
||||
).decode("utf-8")
|
||||
return {
|
||||
AUTHORIZATION_HEADER: SecretStr(f"Basic {encoded}"),
|
||||
}
|
||||
|
||||
|
||||
class BasicAuthenticationServerProvider(ServerAuthenticationProvider):
|
||||
"""
|
||||
Server auth provider for basic auth. The credentials are read from
|
||||
`chroma_server_authn_credentials_file` and each line must be in the format
|
||||
<username>:<bcrypt passwd>.
|
||||
|
||||
Expects tokens to be passed as a base64-encoded string in the Authorization
|
||||
header prepended with "Basic".
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
|
||||
self._creds: Dict[str, SecretStr] = {}
|
||||
creds = self.read_creds_or_creds_file()
|
||||
|
||||
for line in creds:
|
||||
if not line.strip():
|
||||
continue
|
||||
_raw_creds = [v for v in line.strip().split(":")]
|
||||
if (
|
||||
_raw_creds
|
||||
and _raw_creds[0]
|
||||
and len(_raw_creds) != 2
|
||||
or not all(_raw_creds)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid htpasswd credentials found: {_raw_creds}. "
|
||||
"Lines must be exactly <username>:<bcrypt passwd>."
|
||||
)
|
||||
username = _raw_creds[0]
|
||||
password = _raw_creds[1]
|
||||
if username in self._creds:
|
||||
raise ValueError(
|
||||
"Duplicate username found in "
|
||||
"[chroma_server_authn_credentials]. "
|
||||
"Usernames must be unique."
|
||||
)
|
||||
self._creds[username] = SecretStr(password)
|
||||
|
||||
@trace_method(
|
||||
"BasicAuthenticationServerProvider.authenticate", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
|
||||
try:
|
||||
if AUTHORIZATION_HEADER.lower() not in headers.keys():
|
||||
raise AuthError(AUTHORIZATION_HEADER + " header not found")
|
||||
_auth_header = headers[AUTHORIZATION_HEADER.lower()]
|
||||
_auth_header = re.sub(r"^Basic ", "", _auth_header)
|
||||
_auth_header = _auth_header.strip()
|
||||
|
||||
base64_decoded = base64.b64decode(_auth_header).decode("utf-8")
|
||||
if ":" not in base64_decoded:
|
||||
raise AuthError("Invalid Authorization header format")
|
||||
username, password = base64_decoded.split(":", 1)
|
||||
username = str(username) # convert to string to prevent header injection
|
||||
password = str(password) # convert to string to prevent header injection
|
||||
if username not in self._creds:
|
||||
raise AuthError("Invalid username or password")
|
||||
|
||||
_pwd_check = bcrypt.checkpw(
|
||||
password.encode("utf-8"),
|
||||
self._creds[username].get_secret_value().encode("utf-8"),
|
||||
)
|
||||
if not _pwd_check:
|
||||
raise AuthError("Invalid username or password")
|
||||
return UserIdentity(user_id=username)
|
||||
except AuthError as e:
|
||||
logger.error(
|
||||
f"BasicAuthenticationServerProvider.authenticate failed: {repr(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
# Get the last call stack
|
||||
last_call_stack = tb[-1]
|
||||
line_number = last_call_stack.lineno
|
||||
filename = last_call_stack.filename
|
||||
logger.error(
|
||||
"BasicAuthenticationServerProvider.authenticate failed: "
|
||||
f"Failed to authenticate {type(e).__name__} at {filename}:{line_number}"
|
||||
)
|
||||
time.sleep(
|
||||
random.uniform(0.001, 0.005)
|
||||
) # add some jitter to avoid timing attacks
|
||||
raise ChromaAuthError()
|
||||
@@ -0,0 +1,75 @@
|
||||
import logging
|
||||
from typing import Dict, Set
|
||||
from overrides import override
|
||||
import yaml
|
||||
from chromadb.auth import (
|
||||
AuthzAction,
|
||||
AuthzResource,
|
||||
UserIdentity,
|
||||
ServerAuthorizationProvider,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from fastapi import HTTPException
|
||||
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleRBACAuthorizationProvider(ServerAuthorizationProvider):
|
||||
"""
|
||||
A simple Role-Based Access Control (RBAC) authorization provider. This
|
||||
provider reads a configuration file that maps users to roles, and roles to
|
||||
actions. The provider then checks if the user has the action they are
|
||||
attempting to perform.
|
||||
|
||||
For an example of an RBAC configuration file, see
|
||||
examples/basic_functionality/authz/authz.yaml.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
self._config = yaml.safe_load("\n".join(self.read_config_or_config_file()))
|
||||
|
||||
# We favor preprocessing here to avoid having to parse the config file
|
||||
# on every request. This AuthorizationProvider does not support
|
||||
# per-resource authorization so we just map the user ID to the
|
||||
# permissions they have. We're not worried about the size of this dict
|
||||
# since users are all specified in the file -- anyone with a gigantic
|
||||
# number of users can roll their own AuthorizationProvider.
|
||||
self._permissions: Dict[str, Set[str]] = {}
|
||||
for user in self._config["users"]:
|
||||
_actions = self._config["roles_mapping"][user["role"]]["actions"]
|
||||
self._permissions[user["id"]] = set(_actions)
|
||||
logger.info(
|
||||
"Authorization Provider SimpleRBACAuthorizationProvider " "initialized"
|
||||
)
|
||||
|
||||
@trace_method(
|
||||
"SimpleRBACAuthorizationProvider.authorize",
|
||||
OpenTelemetryGranularity.ALL,
|
||||
)
|
||||
@override
|
||||
def authorize_or_raise(
|
||||
self, user: UserIdentity, action: AuthzAction, resource: AuthzResource
|
||||
) -> None:
|
||||
policy_decision = False
|
||||
if (
|
||||
user.user_id in self._permissions
|
||||
and action in self._permissions[user.user_id]
|
||||
):
|
||||
policy_decision = True
|
||||
|
||||
logger.debug(
|
||||
f"Authorization decision: Access "
|
||||
f"{'granted' if policy_decision else 'denied'} for "
|
||||
f"user [{user.user_id}] attempting to "
|
||||
f"[{action}] [{resource}]"
|
||||
)
|
||||
if not policy_decision:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
@@ -0,0 +1,235 @@
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import cast, Dict, List, Optional, TypedDict, TypeVar
|
||||
|
||||
|
||||
from overrides import override
|
||||
from pydantic import SecretStr
|
||||
import yaml
|
||||
|
||||
from chromadb.auth import (
|
||||
ServerAuthenticationProvider,
|
||||
ClientAuthProvider,
|
||||
ClientAuthHeaders,
|
||||
UserIdentity,
|
||||
AuthError,
|
||||
)
|
||||
from chromadb.config import System
|
||||
from chromadb.errors import ChromaAuthError
|
||||
from chromadb.telemetry.opentelemetry import (
|
||||
OpenTelemetryGranularity,
|
||||
trace_method,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"TokenAuthenticationServerProvider",
|
||||
"TokenAuthClientProvider",
|
||||
"TokenTransportHeader",
|
||||
]
|
||||
|
||||
|
||||
class TokenTransportHeader(str, Enum):
|
||||
"""
|
||||
Accceptable token transport headers.
|
||||
"""
|
||||
|
||||
# I don't love having this enum here -- it's weird to have an enum
|
||||
# for just two values and it's weird to have users pass X_CHROMA_TOKEN
|
||||
# to configure "x-chroma-token". But I also like having a single source
|
||||
# of truth, so 🤷🏻♂️
|
||||
AUTHORIZATION = "Authorization"
|
||||
X_CHROMA_TOKEN = "X-Chroma-Token"
|
||||
|
||||
|
||||
valid_token_chars = set(string.digits + string.ascii_letters + string.punctuation)
|
||||
|
||||
|
||||
def _check_token(token: str) -> None:
|
||||
token_str = str(token)
|
||||
if not all(c in valid_token_chars for c in token_str):
|
||||
raise ValueError(
|
||||
"Invalid token. Must contain only ASCII letters, digits, and punctuation."
|
||||
)
|
||||
|
||||
|
||||
allowed_token_headers = [
|
||||
TokenTransportHeader.AUTHORIZATION.value,
|
||||
TokenTransportHeader.X_CHROMA_TOKEN.value,
|
||||
]
|
||||
|
||||
|
||||
def _check_allowed_token_headers(token_header: str) -> None:
|
||||
if token_header not in allowed_token_headers:
|
||||
raise ValueError(
|
||||
f"Invalid token transport header: {token_header}. "
|
||||
f"Must be one of {allowed_token_headers}"
|
||||
)
|
||||
|
||||
|
||||
class TokenAuthClientProvider(ClientAuthProvider):
|
||||
"""
|
||||
Client auth provider for token-based auth. Header key will be either
|
||||
"Authorization" or "X-Chroma-Token" depending on
|
||||
`chroma_auth_token_transport_header`. If the header is "Authorization",
|
||||
the token is passed as a bearer token.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
|
||||
system.settings.require("chroma_client_auth_credentials")
|
||||
self._token = SecretStr(str(system.settings.chroma_client_auth_credentials))
|
||||
_check_token(self._token.get_secret_value())
|
||||
|
||||
if system.settings.chroma_auth_token_transport_header:
|
||||
_check_allowed_token_headers(
|
||||
system.settings.chroma_auth_token_transport_header
|
||||
)
|
||||
self._token_transport_header = TokenTransportHeader(
|
||||
system.settings.chroma_auth_token_transport_header
|
||||
)
|
||||
else:
|
||||
self._token_transport_header = TokenTransportHeader.AUTHORIZATION
|
||||
|
||||
@override
|
||||
def authenticate(self) -> ClientAuthHeaders:
|
||||
val = self._token.get_secret_value()
|
||||
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION:
|
||||
val = f"Bearer {val}"
|
||||
return {
|
||||
self._token_transport_header.value: SecretStr(val),
|
||||
}
|
||||
|
||||
|
||||
class User(TypedDict):
|
||||
"""
|
||||
A simple User class for use in this module only. If you need a generic
|
||||
way to represent a User, please use UserIdentity as this class keeps
|
||||
track of sensitive tokens.
|
||||
"""
|
||||
|
||||
id: str
|
||||
role: str
|
||||
tenant: Optional[str]
|
||||
databases: Optional[List[str]]
|
||||
tokens: List[str]
|
||||
|
||||
|
||||
class TokenAuthenticationServerProvider(ServerAuthenticationProvider):
|
||||
"""
|
||||
Server authentication provider for token-based auth. The provider will
|
||||
- On initialization, read the users from the file specified in
|
||||
`chroma_server_authn_credentials_file`. This file must be a well-formed
|
||||
YAML file with a top-level array called `users`. Each user must have
|
||||
an `id` field and a `tokens` (string array) field.
|
||||
- On each request, check the token in the header specified by
|
||||
`chroma_auth_token_transport_header`. If the configured header is
|
||||
"Authorization", the token is expected to be a bearer token.
|
||||
- If the token is valid, the server will return the user identity
|
||||
associated with the token.
|
||||
"""
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
super().__init__(system)
|
||||
self._settings = system.settings
|
||||
if system.settings.chroma_auth_token_transport_header:
|
||||
_check_allowed_token_headers(
|
||||
system.settings.chroma_auth_token_transport_header
|
||||
)
|
||||
self._token_transport_header = TokenTransportHeader(
|
||||
system.settings.chroma_auth_token_transport_header
|
||||
)
|
||||
else:
|
||||
self._token_transport_header = TokenTransportHeader.AUTHORIZATION
|
||||
|
||||
self._token_user_mapping: Dict[str, User] = {}
|
||||
creds = self.read_creds_or_creds_file()
|
||||
|
||||
# If we only get one cred, assume it's just a valid token.
|
||||
if len(creds) == 1:
|
||||
self._token_user_mapping[creds[0]] = User(
|
||||
id="anonymous",
|
||||
tenant="*",
|
||||
databases=["*"],
|
||||
role="anonymous",
|
||||
tokens=[creds[0]],
|
||||
)
|
||||
return
|
||||
|
||||
self._users = cast(List[User], yaml.safe_load("\n".join(creds))["users"])
|
||||
for user in self._users:
|
||||
if "tokens" not in user:
|
||||
raise ValueError("User missing tokens")
|
||||
if "tenant" not in user:
|
||||
user["tenant"] = "*"
|
||||
if "databases" not in user:
|
||||
user["databases"] = ["*"]
|
||||
for token in user["tokens"]:
|
||||
_check_token(token)
|
||||
if (
|
||||
token in self._token_user_mapping
|
||||
and self._token_user_mapping[token] != user
|
||||
):
|
||||
raise ValueError(
|
||||
f"Token {token} already in use: wanted to use it for "
|
||||
f"user {user['id']} but it's already in use by "
|
||||
f"user {self._token_user_mapping[token]}"
|
||||
)
|
||||
self._token_user_mapping[token] = user
|
||||
|
||||
@trace_method(
|
||||
"TokenAuthenticationServerProvider.authenticate", OpenTelemetryGranularity.ALL
|
||||
)
|
||||
@override
|
||||
def authenticate_or_raise(self, headers: Dict[str, str]) -> UserIdentity:
|
||||
try:
|
||||
if self._token_transport_header.value.lower() not in headers.keys():
|
||||
raise AuthError(
|
||||
f"Authorization header '{self._token_transport_header.value}' not found"
|
||||
)
|
||||
token = headers[self._token_transport_header.value.lower()]
|
||||
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION:
|
||||
if not token.startswith("Bearer "):
|
||||
raise AuthError("Bearer not found in Authorization header")
|
||||
token = re.sub(r"^Bearer ", "", token)
|
||||
|
||||
token = token.strip()
|
||||
_check_token(token)
|
||||
|
||||
if token not in self._token_user_mapping:
|
||||
raise AuthError("Invalid credentials: Token not found}")
|
||||
|
||||
user_identity = UserIdentity(
|
||||
user_id=self._token_user_mapping[token]["id"],
|
||||
tenant=self._token_user_mapping[token]["tenant"],
|
||||
databases=self._token_user_mapping[token]["databases"],
|
||||
)
|
||||
return user_identity
|
||||
except AuthError as e:
|
||||
logger.debug(
|
||||
f"TokenAuthenticationServerProvider.authenticate failed: {repr(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
# Get the last call stack
|
||||
last_call_stack = tb[-1]
|
||||
line_number = last_call_stack.lineno
|
||||
filename = last_call_stack.filename
|
||||
logger.debug(
|
||||
"TokenAuthenticationServerProvider.authenticate failed: "
|
||||
f"Failed to authenticate {type(e).__name__} at {filename}:{line_number}"
|
||||
)
|
||||
time.sleep(
|
||||
random.uniform(0.001, 0.005)
|
||||
) # add some jitter to avoid timing attacks
|
||||
raise ChromaAuthError()
|
||||
Binary file not shown.
@@ -0,0 +1,86 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from chromadb.auth import UserIdentity
|
||||
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
|
||||
from chromadb.errors import ChromaAuthError
|
||||
|
||||
|
||||
def _singleton_tenant_database_if_applicable(
|
||||
user_identity: UserIdentity,
|
||||
overwrite_singleton_tenant_database_access_from_auth: bool,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
is False, this function always returns (None, None).
|
||||
|
||||
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
|
||||
is True, follows the following logic:
|
||||
- If the user only has access to a single tenant, this function will
|
||||
return that tenant as its first return value.
|
||||
- If the user only has access to a single database, this function will
|
||||
return that database as its second return value. If the user has
|
||||
access to multiple tenants and/or databases, including "*", this
|
||||
function will return None for the corresponding value(s).
|
||||
- If the user has access to multiple tenants and/or databases this
|
||||
function will return None for the corresponding value(s).
|
||||
"""
|
||||
if not overwrite_singleton_tenant_database_access_from_auth:
|
||||
return None, None
|
||||
tenant = None
|
||||
database = None
|
||||
user_tenant = user_identity.tenant
|
||||
user_databases = user_identity.databases
|
||||
if user_tenant and user_tenant != "*":
|
||||
tenant = user_tenant
|
||||
if user_databases:
|
||||
user_databases_set = set(user_databases)
|
||||
if len(user_databases_set) == 1 and "*" not in user_databases_set:
|
||||
database = list(user_databases_set)[0]
|
||||
return tenant, database
|
||||
|
||||
|
||||
def maybe_set_tenant_and_database(
|
||||
user_identity: UserIdentity,
|
||||
overwrite_singleton_tenant_database_access_from_auth: bool,
|
||||
user_provided_tenant: Optional[str] = None,
|
||||
user_provided_database: Optional[str] = None,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
(
|
||||
new_tenant,
|
||||
new_database,
|
||||
) = _singleton_tenant_database_if_applicable(
|
||||
user_identity=user_identity,
|
||||
overwrite_singleton_tenant_database_access_from_auth=overwrite_singleton_tenant_database_access_from_auth,
|
||||
)
|
||||
|
||||
# The only error case is if the user provides a tenant and database that
|
||||
# don't match what we resolved from auth. This can incorrectly happen when
|
||||
# there is no auth provider set, but overwrite_singleton_tenant_database_access_from_auth
|
||||
# is set to True. In this case, we'll resolve tenant/database to the default
|
||||
# values, which might not match the provided values. Thus, it's important
|
||||
# to ensure that the flag is set to True only when there is an auth provider.
|
||||
if (
|
||||
user_provided_tenant
|
||||
and user_provided_tenant != DEFAULT_TENANT
|
||||
and new_tenant
|
||||
and new_tenant != user_provided_tenant
|
||||
):
|
||||
raise ChromaAuthError(f"Tenant {user_provided_tenant} does not match {new_tenant} from the server. Are you sure the tenant is correct?")
|
||||
if (
|
||||
user_provided_database
|
||||
and user_provided_database != DEFAULT_DATABASE
|
||||
and new_database
|
||||
and new_database != user_provided_database
|
||||
):
|
||||
raise ChromaAuthError(f"Database {user_provided_database} does not match {new_database} from the server. Are you sure the database is correct?")
|
||||
|
||||
if (
|
||||
not user_provided_tenant or user_provided_tenant == DEFAULT_TENANT
|
||||
) and new_tenant:
|
||||
user_provided_tenant = new_tenant
|
||||
if (
|
||||
not user_provided_database or user_provided_database == DEFAULT_DATABASE
|
||||
) and new_database:
|
||||
user_provided_database = new_database
|
||||
|
||||
return user_provided_tenant, user_provided_database
|
||||
Binary file not shown.
Reference in New Issue
Block a user