增加环绕侦察场景适配
This commit is contained in:
@@ -41,13 +41,17 @@ class Python37DeprecationWarning(DeprecationWarning): # pragma: NO COVER
|
||||
pass
|
||||
|
||||
|
||||
# Checks if the current runtime is Python 3.7.
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 7: # pragma: NO COVER
|
||||
message = (
|
||||
"After January 1, 2024, new releases of this library will drop support "
|
||||
"for Python 3.7."
|
||||
)
|
||||
warnings.warn(message, Python37DeprecationWarning)
|
||||
# Raise warnings for deprecated versions
|
||||
eol_message = """
|
||||
You are using a Python version {} past its end of life. Google will update
|
||||
google-auth with critical bug fixes on a best-effort basis, but not
|
||||
with any other fixes or features. Please upgrade your Python version,
|
||||
and then update google-auth.
|
||||
"""
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 8: # pragma: NO COVER
|
||||
warnings.warn(eol_message.format("3.8"), FutureWarning)
|
||||
elif sys.version_info.major == 3 and sys.version_info.minor == 9: # pragma: NO COVER
|
||||
warnings.warn(eol_message.format("3.9"), FutureWarning)
|
||||
|
||||
# Set default logging handler to avoid "No handler found" warnings.
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
@@ -83,7 +83,7 @@ def get_application_default_credentials_path():
|
||||
|
||||
|
||||
def _run_subprocess_ignore_stderr(command):
|
||||
""" Return subprocess.check_output with the given command and ignores stderr."""
|
||||
"""Return subprocess.check_output with the given command and ignores stderr."""
|
||||
with open(os.devnull, "w") as devnull:
|
||||
output = subprocess.check_output(command, stderr=devnull)
|
||||
return output
|
||||
|
||||
@@ -16,17 +16,23 @@
|
||||
|
||||
Implements application default credentials and project ID detection.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Sequence, TYPE_CHECKING
|
||||
import warnings
|
||||
|
||||
from google.auth import environment_vars
|
||||
from google.auth import exceptions
|
||||
import google.auth.transport._http_client
|
||||
|
||||
if TYPE_CHECKING: # pragma: NO COVER
|
||||
from google.auth.credentials import Credentials # noqa: F401
|
||||
from google.auth.transport import Request # noqa: F401
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Valid types accepted for file-based credentials.
|
||||
@@ -538,8 +544,10 @@ def _get_impersonated_service_account_credentials(filename, info, scopes):
|
||||
from google.auth import impersonated_credentials
|
||||
|
||||
try:
|
||||
credentials = impersonated_credentials.Credentials.from_impersonated_service_account_info(
|
||||
info, scopes=scopes
|
||||
credentials = (
|
||||
impersonated_credentials.Credentials.from_impersonated_service_account_info(
|
||||
info, scopes=scopes
|
||||
)
|
||||
)
|
||||
except ValueError as caught_exc:
|
||||
msg = "Failed to load impersonated service account credentials from {}".format(
|
||||
@@ -554,8 +562,8 @@ def _get_gdch_service_account_credentials(filename, info):
|
||||
from google.oauth2 import gdch_credentials
|
||||
|
||||
try:
|
||||
credentials = gdch_credentials.ServiceAccountCredentials.from_service_account_info(
|
||||
info
|
||||
credentials = (
|
||||
gdch_credentials.ServiceAccountCredentials.from_service_account_info(info)
|
||||
)
|
||||
except ValueError as caught_exc:
|
||||
msg = "Failed to load GDCH service account credentials from {}".format(filename)
|
||||
@@ -586,7 +594,12 @@ def _apply_quota_project_id(credentials, quota_project_id):
|
||||
return credentials
|
||||
|
||||
|
||||
def default(scopes=None, request=None, quota_project_id=None, default_scopes=None):
|
||||
def default(
|
||||
scopes: Optional[Sequence[str]] = None,
|
||||
request: Optional["google.auth.transport.Request"] = None,
|
||||
quota_project_id: Optional[str] = None,
|
||||
default_scopes: Optional[Sequence[str]] = None,
|
||||
) -> tuple["google.auth.credentials.Credentials", Optional[str]]:
|
||||
"""Gets the default credentials for the current environment.
|
||||
|
||||
`Application Default Credentials`_ provides an easy way to obtain
|
||||
|
||||
@@ -334,7 +334,8 @@ def is_python_3():
|
||||
Returns:
|
||||
bool: True if the Python interpreter is Python 3 and False otherwise.
|
||||
"""
|
||||
return sys.version_info > (3, 0)
|
||||
|
||||
return sys.version_info > (3, 0) # pragma: NO COVER
|
||||
|
||||
|
||||
def _hash_sensitive_info(data: Union[dict, list]) -> Union[dict, list, str]:
|
||||
|
||||
@@ -127,7 +127,7 @@ _CLASS_CONVERSION_MAP = {
|
||||
oauth2client.contrib.gce.AppAssertionCredentials: _convert_gce_app_assertion_credentials,
|
||||
}
|
||||
|
||||
if _HAS_APPENGINE:
|
||||
if _HAS_APPENGINE: # pragma: no cover
|
||||
_CLASS_CONVERSION_MAP[
|
||||
oauth2client.contrib.appengine.AppAssertionCredentials
|
||||
] = _convert_appengine_app_assertion_credentials
|
||||
|
||||
@@ -61,8 +61,8 @@ class RefreshThreadManager:
|
||||
|
||||
def clear_error(self):
|
||||
"""
|
||||
Removes any errors that were stored from previous background refreshes.
|
||||
"""
|
||||
Removes any errors that were stored from previous background refreshes.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._worker:
|
||||
self._worker._error_info = None
|
||||
|
||||
@@ -56,7 +56,7 @@ def from_dict(data, require=None, use_rsa_signer=True):
|
||||
if use_rsa_signer:
|
||||
signer = crypt.RSASigner.from_service_account_info(data)
|
||||
else:
|
||||
signer = crypt.ES256Signer.from_service_account_info(data)
|
||||
signer = crypt.EsSigner.from_service_account_info(data)
|
||||
|
||||
return signer
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class Request(transport.Request):
|
||||
# Custom aiohttp Session Example:
|
||||
session = session=aiohttp.ClientSession(auto_decompress=False)
|
||||
request = google.auth.aio.transport.aiohttp.Request(session=session)
|
||||
auth_sesion = google.auth.aio.transport.sessions.AsyncAuthorizedSession(auth_request=request)
|
||||
auth_session = google.auth.aio.transport.sessions.AsyncAuthorizedSession(auth_request=request)
|
||||
|
||||
Args:
|
||||
session (aiohttp.ClientSession): An instance :class:`aiohttp.ClientSession` used
|
||||
|
||||
@@ -159,7 +159,7 @@ class AsyncAuthorizedSession:
|
||||
at ``max_allowed_time``. It might take longer, for example, if
|
||||
an underlying request takes a lot of time, but the request
|
||||
itself does not timeout, e.g. if a large file is being
|
||||
transmitted. The timout error will be raised after such
|
||||
transmitted. The timeout error will be raised after such
|
||||
request completes.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -348,10 +348,10 @@ def _generate_authentication_header_map(
|
||||
class AwsSecurityCredentials:
|
||||
"""A class that models AWS security credentials with an optional session token.
|
||||
|
||||
Attributes:
|
||||
access_key_id (str): The AWS security credentials access key id.
|
||||
secret_access_key (str): The AWS security credentials secret access key.
|
||||
session_token (Optional[str]): The optional AWS security credentials session token. This should be set when using temporary credentials.
|
||||
Attributes:
|
||||
access_key_id (str): The AWS security credentials access key id.
|
||||
secret_access_key (str): The AWS security credentials secret access key.
|
||||
session_token (Optional[str]): The optional AWS security credentials session token. This should be set when using temporary credentials.
|
||||
"""
|
||||
|
||||
access_key_id: str
|
||||
@@ -420,7 +420,6 @@ class _DefaultAwsSecurityCredentialsSupplier(AwsSecurityCredentialsSupplier):
|
||||
|
||||
@_helpers.copy_docstring(AwsSecurityCredentialsSupplier)
|
||||
def get_aws_security_credentials(self, context, request):
|
||||
|
||||
# Check environment variables for permanent credentials first.
|
||||
# https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html
|
||||
env_aws_access_key_id = os.environ.get(environment_vars.AWS_ACCESS_KEY_ID)
|
||||
@@ -688,8 +687,8 @@ class Credentials(external_account.Credentials):
|
||||
)
|
||||
else:
|
||||
environment_id = credential_source.get("environment_id") or ""
|
||||
self._aws_security_credentials_supplier = _DefaultAwsSecurityCredentialsSupplier(
|
||||
credential_source
|
||||
self._aws_security_credentials_supplier = (
|
||||
_DefaultAwsSecurityCredentialsSupplier(credential_source)
|
||||
)
|
||||
self._cred_verification_url = credential_source.get(
|
||||
"regional_cred_verification_url"
|
||||
@@ -759,8 +758,10 @@ class Credentials(external_account.Credentials):
|
||||
|
||||
# Retrieve the AWS security credentials needed to generate the signed
|
||||
# request.
|
||||
aws_security_credentials = self._aws_security_credentials_supplier.get_aws_security_credentials(
|
||||
self._supplier_context, request
|
||||
aws_security_credentials = (
|
||||
self._aws_security_credentials_supplier.get_aws_security_credentials(
|
||||
self._supplier_context, request
|
||||
)
|
||||
)
|
||||
# Generate the signed request to AWS STS GetCallerIdentity API.
|
||||
# Use the required regional endpoint. Otherwise, the request will fail.
|
||||
|
||||
@@ -24,15 +24,23 @@ import logging
|
||||
import os
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from google.auth import _helpers
|
||||
from google.auth import environment_vars
|
||||
from google.auth import exceptions
|
||||
from google.auth import metrics
|
||||
from google.auth import transport
|
||||
from google.auth._exponential_backoff import ExponentialBackoff
|
||||
from google.auth.compute_engine import _mtls
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_GCE_DEFAULT_MDS_IP = "169.254.169.254"
|
||||
_GCE_DEFAULT_HOST = "metadata.google.internal"
|
||||
_GCE_DEFAULT_MDS_HOSTS = [_GCE_DEFAULT_HOST, _GCE_DEFAULT_MDS_IP]
|
||||
|
||||
# Environment variable GCE_METADATA_HOST is originally named
|
||||
# GCE_METADATA_ROOT. For compatibility reasons, here it checks
|
||||
# the new variable first; if not set, the system falls back
|
||||
@@ -40,15 +48,48 @@ _LOGGER = logging.getLogger(__name__)
|
||||
_GCE_METADATA_HOST = os.getenv(environment_vars.GCE_METADATA_HOST, None)
|
||||
if not _GCE_METADATA_HOST:
|
||||
_GCE_METADATA_HOST = os.getenv(
|
||||
environment_vars.GCE_METADATA_ROOT, "metadata.google.internal"
|
||||
environment_vars.GCE_METADATA_ROOT, _GCE_DEFAULT_HOST
|
||||
)
|
||||
_METADATA_ROOT = "http://{}/computeMetadata/v1/".format(_GCE_METADATA_HOST)
|
||||
|
||||
# This is used to ping the metadata server, it avoids the cost of a DNS
|
||||
# lookup.
|
||||
_METADATA_IP_ROOT = "http://{}".format(
|
||||
os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254")
|
||||
)
|
||||
|
||||
def _validate_gce_mds_configured_environment():
|
||||
"""Validates the GCE metadata server environment configuration for mTLS.
|
||||
|
||||
mTLS is only supported when connecting to the default metadata server hosts.
|
||||
If we are in strict mode (which requires mTLS), ensure that the metadata host
|
||||
has not been overridden to a custom value (which means mTLS will fail).
|
||||
|
||||
Raises:
|
||||
google.auth.exceptions.MutualTLSChannelError: if the environment
|
||||
configuration is invalid for mTLS.
|
||||
"""
|
||||
mode = _mtls._parse_mds_mode()
|
||||
if mode == _mtls.MdsMtlsMode.STRICT:
|
||||
# mTLS is only supported when connecting to the default metadata host.
|
||||
# Raise an exception if we are in strict mode (which requires mTLS)
|
||||
# but the metadata host has been overridden to a custom MDS. (which means mTLS will fail)
|
||||
if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS:
|
||||
raise exceptions.MutualTLSChannelError(
|
||||
"Mutual TLS is required, but the metadata host has been overridden. "
|
||||
"mTLS is only supported when connecting to the default metadata host."
|
||||
)
|
||||
|
||||
|
||||
def _get_metadata_root(use_mtls: bool):
|
||||
"""Returns the metadata server root URL."""
|
||||
|
||||
scheme = "https" if use_mtls else "http"
|
||||
return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST)
|
||||
|
||||
|
||||
def _get_metadata_ip_root(use_mtls: bool):
|
||||
"""Returns the metadata server IP root URL."""
|
||||
scheme = "https" if use_mtls else "http"
|
||||
return "{}://{}".format(
|
||||
scheme, os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP)
|
||||
)
|
||||
|
||||
|
||||
_METADATA_FLAVOR_HEADER = "metadata-flavor"
|
||||
_METADATA_FLAVOR_VALUE = "Google"
|
||||
_METADATA_HEADERS = {_METADATA_FLAVOR_HEADER: _METADATA_FLAVOR_VALUE}
|
||||
@@ -102,6 +143,33 @@ def detect_gce_residency_linux():
|
||||
return content.startswith(_GOOGLE)
|
||||
|
||||
|
||||
def _prepare_request_for_mds(request, use_mtls=False) -> None:
|
||||
"""Prepares a request for the metadata server.
|
||||
|
||||
This will check if mTLS should be used and mount the mTLS adapter if needed.
|
||||
|
||||
Args:
|
||||
request (google.auth.transport.Request): A callable used to make
|
||||
HTTP requests.
|
||||
use_mtls (bool): Whether to use mTLS for the request.
|
||||
|
||||
Returns:
|
||||
google.auth.transport.Request: A request object to use.
|
||||
If mTLS is enabled, the request will have the mTLS adapter mounted.
|
||||
Otherwise, the original request will be returned unchanged.
|
||||
"""
|
||||
# Only modify the request if mTLS is enabled.
|
||||
if use_mtls:
|
||||
# Ensure the request has a session to mount the adapter to.
|
||||
if not request.session:
|
||||
request.session = requests.Session()
|
||||
|
||||
adapter = _mtls.MdsMtlsAdapter()
|
||||
# Mount the adapter for all default GCE metadata hosts.
|
||||
for host in _GCE_DEFAULT_MDS_HOSTS:
|
||||
request.session.mount(f"https://{host}/", adapter)
|
||||
|
||||
|
||||
def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
|
||||
"""Checks to see if the metadata server is available.
|
||||
|
||||
@@ -115,6 +183,8 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
|
||||
Returns:
|
||||
bool: True if the metadata server is reachable, False otherwise.
|
||||
"""
|
||||
use_mtls = _mtls.should_use_mds_mtls()
|
||||
_prepare_request_for_mds(request, use_mtls=use_mtls)
|
||||
# NOTE: The explicit ``timeout`` is a workaround. The underlying
|
||||
# issue is that resolving an unknown host on some networks will take
|
||||
# 20-30 seconds; making this timeout short fixes the issue, but
|
||||
@@ -129,7 +199,10 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
|
||||
for attempt in backoff:
|
||||
try:
|
||||
response = request(
|
||||
url=_METADATA_IP_ROOT, method="GET", headers=headers, timeout=timeout
|
||||
url=_get_metadata_ip_root(use_mtls),
|
||||
method="GET",
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
metadata_flavor = response.headers.get(_METADATA_FLAVOR_HEADER)
|
||||
@@ -153,7 +226,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
|
||||
def get(
|
||||
request,
|
||||
path,
|
||||
root=_METADATA_ROOT,
|
||||
root=None,
|
||||
params=None,
|
||||
recursive=False,
|
||||
retry_count=5,
|
||||
@@ -168,7 +241,8 @@ def get(
|
||||
HTTP requests.
|
||||
path (str): The resource to retrieve. For example,
|
||||
``'instance/service-accounts/default'``.
|
||||
root (str): The full path to the metadata server root.
|
||||
root (Optional[str]): The full path to the metadata server root. If not
|
||||
provided, the default root will be used.
|
||||
params (Optional[Mapping[str, str]]): A mapping of query parameter
|
||||
keys to values.
|
||||
recursive (bool): Whether to do a recursive query of metadata. See
|
||||
@@ -189,7 +263,24 @@ def get(
|
||||
Raises:
|
||||
google.auth.exceptions.TransportError: if an error occurred while
|
||||
retrieving metadata.
|
||||
google.auth.exceptions.MutualTLSChannelError: if using mtls and the environment
|
||||
configuration is invalid for mTLS (for example, the metadata host
|
||||
has been overridden in strict mTLS mode).
|
||||
|
||||
"""
|
||||
use_mtls = _mtls.should_use_mds_mtls()
|
||||
# Prepare the request object for mTLS if needed.
|
||||
# This will create a new request object with the mTLS session.
|
||||
_prepare_request_for_mds(request, use_mtls=use_mtls)
|
||||
|
||||
if root is None:
|
||||
root = _get_metadata_root(use_mtls)
|
||||
|
||||
# mTLS is only supported when connecting to the default metadata host.
|
||||
# If we are in strict mode (which requires mTLS), ensure that the metadata host
|
||||
# has not been overridden to a non-default host value (which means mTLS will fail).
|
||||
_validate_gce_mds_configured_environment()
|
||||
|
||||
base_url = urljoin(root, path)
|
||||
query_params = {} if params is None else params
|
||||
|
||||
@@ -203,7 +294,7 @@ def get(
|
||||
url = _helpers.update_query(base_url, query_params)
|
||||
|
||||
backoff = ExponentialBackoff(total_attempts=retry_count)
|
||||
failure_reason = None
|
||||
last_exception = None
|
||||
for attempt in backoff:
|
||||
try:
|
||||
response = request(
|
||||
@@ -217,13 +308,10 @@ def get(
|
||||
retry_count,
|
||||
response.status,
|
||||
)
|
||||
failure_reason = (
|
||||
response.data.decode("utf-8")
|
||||
if hasattr(response.data, "decode")
|
||||
else response.data
|
||||
)
|
||||
last_exception = None
|
||||
continue
|
||||
else:
|
||||
last_exception = None
|
||||
break
|
||||
|
||||
except exceptions.TransportError as e:
|
||||
@@ -234,14 +322,27 @@ def get(
|
||||
retry_count,
|
||||
e,
|
||||
)
|
||||
failure_reason = e
|
||||
last_exception = e
|
||||
else:
|
||||
raise exceptions.TransportError(
|
||||
"Failed to retrieve {} from the Google Compute Engine "
|
||||
"metadata service. Compute Engine Metadata server unavailable due to {}".format(
|
||||
url, failure_reason
|
||||
if last_exception:
|
||||
raise exceptions.TransportError(
|
||||
"Failed to retrieve {} from the Google Compute Engine "
|
||||
"metadata service. Compute Engine Metadata server unavailable. "
|
||||
"Last exception: {}".format(url, last_exception)
|
||||
) from last_exception
|
||||
else:
|
||||
error_details = (
|
||||
response.data.decode("utf-8")
|
||||
if hasattr(response.data, "decode")
|
||||
else response.data
|
||||
)
|
||||
raise exceptions.TransportError(
|
||||
"Failed to retrieve {} from the Google Compute Engine "
|
||||
"metadata service. Compute Engine Metadata server unavailable. "
|
||||
"Response status: {}\nResponse details:\n{}".format(
|
||||
url, response.status, error_details
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
content = _helpers.from_bytes(response.data)
|
||||
|
||||
@@ -360,12 +461,19 @@ def get_service_account_token(request, service_account="default", scopes=None):
|
||||
google.auth.exceptions.TransportError: if an error occurred while
|
||||
retrieving metadata.
|
||||
"""
|
||||
from google.auth import _agent_identity_utils
|
||||
|
||||
params = {}
|
||||
if scopes:
|
||||
if not isinstance(scopes, str):
|
||||
scopes = ",".join(scopes)
|
||||
params = {"scopes": scopes}
|
||||
else:
|
||||
params = None
|
||||
params["scopes"] = scopes
|
||||
|
||||
cert = _agent_identity_utils.get_and_parse_agent_identity_certificate()
|
||||
if cert:
|
||||
if _agent_identity_utils.should_request_bound_token(cert):
|
||||
fingerprint = _agent_identity_utils.calculate_certificate_fingerprint(cert)
|
||||
params["bindCertificateFingerprint"] = fingerprint
|
||||
|
||||
metrics_header = {
|
||||
metrics.API_CLIENT_HEADER: metrics.token_request_access_token_mds()
|
||||
|
||||
@@ -123,7 +123,7 @@ class Credentials(
|
||||
def _metric_header_for_usage(self):
|
||||
return metrics.CRED_TYPE_SA_MDS
|
||||
|
||||
def _refresh_token(self, request):
|
||||
def _perform_refresh_token(self, request):
|
||||
"""Refresh the access token and scopes.
|
||||
|
||||
Args:
|
||||
@@ -135,9 +135,9 @@ class Credentials(
|
||||
service can't be reached if if the instance has not
|
||||
credentials.
|
||||
"""
|
||||
scopes = self._scopes if self._scopes is not None else self._default_scopes
|
||||
try:
|
||||
self._retrieve_info(request)
|
||||
scopes = self._scopes if self._scopes is not None else self._default_scopes
|
||||
# Always fetch token with default service account email.
|
||||
self.token, self.expiry = _metadata.get_service_account_token(
|
||||
request, service_account="default", scopes=scopes
|
||||
@@ -399,7 +399,6 @@ class IDTokenCredentials(
|
||||
|
||||
@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
|
||||
def with_quota_project(self, quota_project_id):
|
||||
|
||||
# since the signer is already instantiated,
|
||||
# the request is not needed
|
||||
if self._use_metadata_identity_endpoint:
|
||||
@@ -423,7 +422,6 @@ class IDTokenCredentials(
|
||||
|
||||
@_helpers.copy_docstring(credentials.CredentialsWithTokenUri)
|
||||
def with_token_uri(self, token_uri):
|
||||
|
||||
# since the signer is already instantiated,
|
||||
# the request is not needed
|
||||
if self._use_metadata_identity_endpoint:
|
||||
|
||||
@@ -292,7 +292,7 @@ class CredentialsWithTrustBoundary(Credentials):
|
||||
"""Abstract base for credentials supporting ``with_trust_boundary`` factory"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _refresh_token(self, request):
|
||||
def _perform_refresh_token(self, request):
|
||||
"""Refreshes the access token.
|
||||
|
||||
Args:
|
||||
@@ -303,7 +303,7 @@ class CredentialsWithTrustBoundary(Credentials):
|
||||
google.auth.exceptions.RefreshError: If the credentials could
|
||||
not be refreshed.
|
||||
"""
|
||||
raise NotImplementedError("_refresh_token must be implemented")
|
||||
raise NotImplementedError("_perform_refresh_token must be implemented")
|
||||
|
||||
def with_trust_boundary(self, trust_boundary):
|
||||
"""Returns a copy of these credentials with a modified trust boundary.
|
||||
@@ -362,7 +362,7 @@ class CredentialsWithTrustBoundary(Credentials):
|
||||
This method calls the subclass's token refresh logic and then
|
||||
refreshes the trust boundary if applicable.
|
||||
"""
|
||||
self._refresh_token(request)
|
||||
self._perform_refresh_token(request)
|
||||
self._refresh_trust_boundary(request)
|
||||
|
||||
def _refresh_trust_boundary(self, request):
|
||||
|
||||
@@ -40,13 +40,19 @@ version is at least 1.4.0.
|
||||
from google.auth.crypt import base
|
||||
from google.auth.crypt import rsa
|
||||
|
||||
# google.auth.crypt.es depends on the crytpography module which may not be
|
||||
# successfully imported depending on the system.
|
||||
try:
|
||||
from google.auth.crypt import es
|
||||
from google.auth.crypt import es256
|
||||
except ImportError: # pragma: NO COVER
|
||||
es = None # type: ignore
|
||||
es256 = None # type: ignore
|
||||
|
||||
if es256 is not None: # pragma: NO COVER
|
||||
if es is not None and es256 is not None: # pragma: NO COVER
|
||||
__all__ = [
|
||||
"EsSigner",
|
||||
"EsVerifier",
|
||||
"ES256Signer",
|
||||
"ES256Verifier",
|
||||
"RSASigner",
|
||||
@@ -54,6 +60,11 @@ if es256 is not None: # pragma: NO COVER
|
||||
"Signer",
|
||||
"Verifier",
|
||||
]
|
||||
|
||||
EsSigner = es.EsSigner
|
||||
EsVerifier = es.EsVerifier
|
||||
ES256Signer = es256.ES256Signer
|
||||
ES256Verifier = es256.ES256Verifier
|
||||
else: # pragma: NO COVER
|
||||
__all__ = ["RSASigner", "RSAVerifier", "Signer", "Verifier"]
|
||||
|
||||
@@ -65,10 +76,6 @@ Verifier = base.Verifier
|
||||
RSASigner = rsa.RSASigner
|
||||
RSAVerifier = rsa.RSAVerifier
|
||||
|
||||
if es256 is not None: # pragma: NO COVER
|
||||
ES256Signer = es256.ES256Signer
|
||||
ES256Verifier = es256.ES256Verifier
|
||||
|
||||
|
||||
def verify_signature(message, signature, certs, verifier_cls=rsa.RSAVerifier):
|
||||
"""Verify an RSA or ECDSA cryptographic signature.
|
||||
|
||||
@@ -15,93 +15,22 @@
|
||||
"""ECDSA (ES256) verifier and signer that use the ``cryptography`` library.
|
||||
"""
|
||||
|
||||
from cryptography import utils # type: ignore
|
||||
import cryptography.exceptions
|
||||
from cryptography.hazmat import backends
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
|
||||
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature
|
||||
import cryptography.x509
|
||||
|
||||
from google.auth import _helpers
|
||||
from google.auth.crypt import base
|
||||
from google.auth.crypt.es import EsSigner
|
||||
from google.auth.crypt.es import EsVerifier
|
||||
|
||||
|
||||
_CERTIFICATE_MARKER = b"-----BEGIN CERTIFICATE-----"
|
||||
_BACKEND = backends.default_backend()
|
||||
_PADDING = padding.PKCS1v15()
|
||||
|
||||
|
||||
class ES256Verifier(base.Verifier):
|
||||
class ES256Verifier(EsVerifier):
|
||||
"""Verifies ECDSA cryptographic signatures using public keys.
|
||||
|
||||
Args:
|
||||
public_key (
|
||||
cryptography.hazmat.primitives.asymmetric.ec.ECDSAPublicKey):
|
||||
The public key used to verify signatures.
|
||||
public_key (cryptography.hazmat.primitives.asymmetric.ec.ECDSAPublicKey): The public key used to verify
|
||||
signatures.
|
||||
"""
|
||||
|
||||
def __init__(self, public_key):
|
||||
self._pubkey = public_key
|
||||
|
||||
@_helpers.copy_docstring(base.Verifier)
|
||||
def verify(self, message, signature):
|
||||
# First convert (r||s) raw signature to ASN1 encoded signature.
|
||||
sig_bytes = _helpers.to_bytes(signature)
|
||||
if len(sig_bytes) != 64:
|
||||
return False
|
||||
r = (
|
||||
int.from_bytes(sig_bytes[:32], byteorder="big")
|
||||
if _helpers.is_python_3()
|
||||
else utils.int_from_bytes(sig_bytes[:32], byteorder="big")
|
||||
)
|
||||
s = (
|
||||
int.from_bytes(sig_bytes[32:], byteorder="big")
|
||||
if _helpers.is_python_3()
|
||||
else utils.int_from_bytes(sig_bytes[32:], byteorder="big")
|
||||
)
|
||||
asn1_sig = encode_dss_signature(r, s)
|
||||
|
||||
message = _helpers.to_bytes(message)
|
||||
try:
|
||||
self._pubkey.verify(asn1_sig, message, ec.ECDSA(hashes.SHA256()))
|
||||
return True
|
||||
except (ValueError, cryptography.exceptions.InvalidSignature):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, public_key):
|
||||
"""Construct an Verifier instance from a public key or public
|
||||
certificate string.
|
||||
|
||||
Args:
|
||||
public_key (Union[str, bytes]): The public key in PEM format or the
|
||||
x509 public key certificate.
|
||||
|
||||
Returns:
|
||||
Verifier: The constructed verifier.
|
||||
|
||||
Raises:
|
||||
ValueError: If the public key can't be parsed.
|
||||
"""
|
||||
public_key_data = _helpers.to_bytes(public_key)
|
||||
|
||||
if _CERTIFICATE_MARKER in public_key_data:
|
||||
cert = cryptography.x509.load_pem_x509_certificate(
|
||||
public_key_data, _BACKEND
|
||||
)
|
||||
pubkey = cert.public_key()
|
||||
|
||||
else:
|
||||
pubkey = serialization.load_pem_public_key(public_key_data, _BACKEND)
|
||||
|
||||
return cls(pubkey)
|
||||
pass
|
||||
|
||||
|
||||
class ES256Signer(base.Signer, base.FromServiceAccountMixin):
|
||||
class ES256Signer(EsSigner):
|
||||
"""Signs messages with an ECDSA private key.
|
||||
|
||||
Args:
|
||||
@@ -113,63 +42,4 @@ class ES256Signer(base.Signer, base.FromServiceAccountMixin):
|
||||
public key or certificate.
|
||||
"""
|
||||
|
||||
def __init__(self, private_key, key_id=None):
|
||||
self._key = private_key
|
||||
self._key_id = key_id
|
||||
|
||||
@property # type: ignore
|
||||
@_helpers.copy_docstring(base.Signer)
|
||||
def key_id(self):
|
||||
return self._key_id
|
||||
|
||||
@_helpers.copy_docstring(base.Signer)
|
||||
def sign(self, message):
|
||||
message = _helpers.to_bytes(message)
|
||||
asn1_signature = self._key.sign(message, ec.ECDSA(hashes.SHA256()))
|
||||
|
||||
# Convert ASN1 encoded signature to (r||s) raw signature.
|
||||
(r, s) = decode_dss_signature(asn1_signature)
|
||||
return (
|
||||
(r.to_bytes(32, byteorder="big") + s.to_bytes(32, byteorder="big"))
|
||||
if _helpers.is_python_3()
|
||||
else (utils.int_to_bytes(r, 32) + utils.int_to_bytes(s, 32))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, key, key_id=None):
|
||||
"""Construct a RSASigner from a private key in PEM format.
|
||||
|
||||
Args:
|
||||
key (Union[bytes, str]): Private key in PEM format.
|
||||
key_id (str): An optional key id used to identify the private key.
|
||||
|
||||
Returns:
|
||||
google.auth.crypt._cryptography_rsa.RSASigner: The
|
||||
constructed signer.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``key`` is not ``bytes`` or ``str`` (unicode).
|
||||
UnicodeDecodeError: If ``key`` is ``bytes`` but cannot be decoded
|
||||
into a UTF-8 ``str``.
|
||||
ValueError: If ``cryptography`` "Could not deserialize key data."
|
||||
"""
|
||||
key = _helpers.to_bytes(key)
|
||||
private_key = serialization.load_pem_private_key(
|
||||
key, password=None, backend=_BACKEND
|
||||
)
|
||||
return cls(private_key, key_id=key_id)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Pickle helper that serializes the _key attribute."""
|
||||
state = self.__dict__.copy()
|
||||
state["_key"] = self._key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Pickle helper that deserializes the _key attribute."""
|
||||
state["_key"] = serialization.load_pem_private_key(state["_key"], None)
|
||||
self.__dict__.update(state)
|
||||
pass
|
||||
|
||||
@@ -60,6 +60,12 @@ GCE_METADATA_IP = "GCE_METADATA_IP"
|
||||
"""Environment variable providing an alternate ip:port to be used for ip-only
|
||||
GCE metadata requests."""
|
||||
|
||||
GCE_METADATA_MTLS_MODE = "GCE_METADATA_MTLS_MODE"
|
||||
"""Environment variable controlling the mTLS behavior for GCE metadata requests.
|
||||
|
||||
Can be one of "strict", "none", or "default".
|
||||
"""
|
||||
|
||||
GOOGLE_API_USE_CLIENT_CERTIFICATE = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
|
||||
"""Environment variable controlling whether to use client certificate or not.
|
||||
|
||||
@@ -86,3 +92,12 @@ AWS_DEFAULT_REGION = "AWS_DEFAULT_REGION"
|
||||
GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED = "GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED"
|
||||
"""Environment variable controlling whether to enable trust boundary feature.
|
||||
The default value is false. Users have to explicitly set this value to true."""
|
||||
|
||||
GOOGLE_API_CERTIFICATE_CONFIG = "GOOGLE_API_CERTIFICATE_CONFIG"
|
||||
"""Environment variable defining the location of Google API certificate config
|
||||
file."""
|
||||
|
||||
GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES = (
|
||||
"GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES"
|
||||
)
|
||||
"""Environment variable to prevent agent token sharing for GCP services."""
|
||||
|
||||
@@ -98,7 +98,8 @@ class Credentials(
|
||||
is used.
|
||||
When the credential configuration is accepted from an
|
||||
untrusted source, you should validate it before using.
|
||||
Refer https://cloud.google.com/docs/authentication/external/externally-sourced-credentials for more details."""
|
||||
Refer https://cloud.google.com/docs/authentication/external/externally-sourced-credentials for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -419,7 +420,10 @@ class Credentials(
|
||||
source credentials and the impersonated credentials. For non-impersonated
|
||||
credentials, it will refresh the access token and the trust boundary.
|
||||
"""
|
||||
self._refresh_token(request)
|
||||
self._perform_refresh_token(request)
|
||||
self._handle_trust_boundary(request)
|
||||
|
||||
def _handle_trust_boundary(self, request):
|
||||
# If we are impersonating, the trust boundary is handled by the
|
||||
# impersonated credentials object. We need to get it from there.
|
||||
if self._service_account_impersonation_url:
|
||||
@@ -428,7 +432,7 @@ class Credentials(
|
||||
# Otherwise, refresh the trust boundary for the external account.
|
||||
self._refresh_trust_boundary(request)
|
||||
|
||||
def _refresh_token(self, request):
|
||||
def _perform_refresh_token(self, request, cert_fingerprint=None):
|
||||
scopes = self._scopes if self._scopes is not None else self._default_scopes
|
||||
|
||||
# Inject client certificate into request.
|
||||
@@ -446,11 +450,15 @@ class Credentials(
|
||||
self.expiry = self._impersonated_credentials.expiry
|
||||
else:
|
||||
now = _helpers.utcnow()
|
||||
additional_options = None
|
||||
additional_options = {}
|
||||
# Do not pass workforce_pool_user_project when client authentication
|
||||
# is used. The client ID is sufficient for determining the user project.
|
||||
if self._workforce_pool_user_project and not self._client_id:
|
||||
additional_options = {"userProject": self._workforce_pool_user_project}
|
||||
additional_options["userProject"] = self._workforce_pool_user_project
|
||||
|
||||
if cert_fingerprint:
|
||||
additional_options["bindCertFingerprint"] = cert_fingerprint
|
||||
|
||||
additional_headers = {
|
||||
metrics.API_CLIENT_HEADER: metrics.byoid_metrics_header(
|
||||
self._metrics_options
|
||||
@@ -464,7 +472,7 @@ class Credentials(
|
||||
audience=self._audience,
|
||||
scopes=scopes,
|
||||
requested_token_type=_STS_REQUESTED_TOKEN_TYPE,
|
||||
additional_options=additional_options,
|
||||
additional_options=additional_options if additional_options else None,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
self.token = response_data.get("access_token")
|
||||
|
||||
@@ -70,7 +70,8 @@ class Credentials(
|
||||
is used.
|
||||
When the credential configuration is accepted from an
|
||||
untrusted source, you should validate it before using.
|
||||
Refer https://cloud.google.com/docs/authentication/external/externally-sourced-credentials for more details."""
|
||||
Refer https://cloud.google.com/docs/authentication/external/externally-sourced-credentials for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -123,7 +124,7 @@ class Credentials(
|
||||
self.token = token
|
||||
self.expiry = expiry
|
||||
self._audience = audience
|
||||
self._refresh_token_val = refresh_token
|
||||
self._refresh_token = refresh_token
|
||||
self._token_url = token_url
|
||||
self._token_info_url = token_info_url
|
||||
self._client_id = client_id
|
||||
@@ -170,7 +171,7 @@ class Credentials(
|
||||
def constructor_args(self):
|
||||
return {
|
||||
"audience": self._audience,
|
||||
"refresh_token": self._refresh_token_val,
|
||||
"refresh_token": self._refresh_token,
|
||||
"token_url": self._token_url,
|
||||
"token_info_url": self._token_info_url,
|
||||
"client_id": self._client_id,
|
||||
@@ -214,7 +215,7 @@ class Credentials(
|
||||
@property
|
||||
def refresh_token(self):
|
||||
"""Optional[str]: The OAuth 2.0 refresh token."""
|
||||
return self._refresh_token_val
|
||||
return self._refresh_token
|
||||
|
||||
@property
|
||||
def token_url(self):
|
||||
@@ -240,7 +241,7 @@ class Credentials(
|
||||
def can_refresh(self):
|
||||
return all(
|
||||
(
|
||||
self._refresh_token_val,
|
||||
self._refresh_token,
|
||||
self._token_url,
|
||||
self._client_id,
|
||||
self._client_secret,
|
||||
@@ -278,7 +279,7 @@ class Credentials(
|
||||
strip = strip if strip else []
|
||||
return json.dumps({k: v for (k, v) in self.info.items() if k not in strip})
|
||||
|
||||
def _refresh_token(self, request):
|
||||
def _perform_refresh_token(self, request):
|
||||
"""Refreshes the access token.
|
||||
|
||||
Args:
|
||||
@@ -297,7 +298,7 @@ class Credentials(
|
||||
)
|
||||
|
||||
now = _helpers.utcnow()
|
||||
response_data = self._sts_client.refresh_token(request, self._refresh_token_val)
|
||||
response_data = self._sts_client.refresh_token(request, self._refresh_token)
|
||||
|
||||
self.token = response_data.get("access_token")
|
||||
|
||||
@@ -305,7 +306,7 @@ class Credentials(
|
||||
self.expiry = now + lifetime
|
||||
|
||||
if "refresh_token" in response_data:
|
||||
self._refresh_token_val = response_data["refresh_token"]
|
||||
self._refresh_token = response_data["refresh_token"]
|
||||
|
||||
def _build_trust_boundary_lookup_url(self):
|
||||
"""Builds and returns the URL for the trust boundary lookup API."""
|
||||
@@ -321,6 +322,30 @@ class Credentials(
|
||||
universe_domain=self._universe_domain, pool_id=pool_id
|
||||
)
|
||||
|
||||
def revoke(self, request):
|
||||
"""Revokes the refresh token.
|
||||
|
||||
Args:
|
||||
request (google.auth.transport.Request): The object used to make
|
||||
HTTP requests.
|
||||
|
||||
Raises:
|
||||
google.auth.exceptions.OAuthError: If the token could not be
|
||||
revoked.
|
||||
"""
|
||||
if not self._revoke_url or not self._refresh_token:
|
||||
raise exceptions.OAuthError(
|
||||
"The credentials do not contain the necessary fields to "
|
||||
"revoke the refresh token. You must specify revoke_url and "
|
||||
"refresh_token."
|
||||
)
|
||||
|
||||
self._sts_client.revoke_token(
|
||||
request, self._refresh_token, "refresh_token", self._revoke_url
|
||||
)
|
||||
self.token = None
|
||||
self._refresh_token = None
|
||||
|
||||
@_helpers.copy_docstring(credentials.Credentials)
|
||||
def get_cred_info(self):
|
||||
if self._cred_file_path:
|
||||
|
||||
@@ -83,9 +83,9 @@ class SubjectTokenSupplier(metaclass=abc.ABCMeta):
|
||||
|
||||
class _TokenContent(NamedTuple):
|
||||
"""Models the token content response from file and url internal suppliers.
|
||||
Attributes:
|
||||
content (str): The string content of the file or URL response.
|
||||
location (str): The location the content was retrieved from. This will either be a file location or a URL.
|
||||
Attributes:
|
||||
content (str): The string content of the file or URL response.
|
||||
location (str): The location the content was retrieved from. This will either be a file location or a URL.
|
||||
"""
|
||||
|
||||
content: str
|
||||
@@ -93,7 +93,7 @@ class _TokenContent(NamedTuple):
|
||||
|
||||
|
||||
class _FileSupplier(SubjectTokenSupplier):
|
||||
""" Internal implementation of subject token supplier which supports reading a subject token from a file."""
|
||||
"""Internal implementation of subject token supplier which supports reading a subject token from a file."""
|
||||
|
||||
def __init__(self, path, format_type, subject_token_field_name):
|
||||
self._path = path
|
||||
@@ -114,7 +114,7 @@ class _FileSupplier(SubjectTokenSupplier):
|
||||
|
||||
|
||||
class _UrlSupplier(SubjectTokenSupplier):
|
||||
""" Internal implementation of subject token supplier which supports retrieving a subject token by calling a URL endpoint."""
|
||||
"""Internal implementation of subject token supplier which supports retrieving a subject token by calling a URL endpoint."""
|
||||
|
||||
def __init__(self, url, format_type, subject_token_field_name, headers):
|
||||
self._url = url
|
||||
@@ -261,7 +261,8 @@ class Credentials(external_account.Credentials):
|
||||
is used.
|
||||
When the credential configuration is accepted from an
|
||||
untrusted source, you should validate it before using.
|
||||
Refer https://cloud.google.com/docs/authentication/external/externally-sourced-credentials for more details."""
|
||||
Refer https://cloud.google.com/docs/authentication/external/externally-sourced-credentials for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -550,3 +551,25 @@ class Credentials(external_account.Credentials):
|
||||
credentials.
|
||||
"""
|
||||
return super(Credentials, cls).from_file(filename, **kwargs)
|
||||
|
||||
def refresh(self, request):
|
||||
"""Refreshes the access token.
|
||||
|
||||
Args:
|
||||
request (google.auth.transport.Request): The object used to make
|
||||
HTTP requests.
|
||||
"""
|
||||
from google.auth import _agent_identity_utils
|
||||
|
||||
cert_fingerprint = None
|
||||
# Check if the credential is X.509 based.
|
||||
if self._credential_source_certificate is not None:
|
||||
cert_bytes = self._get_cert_bytes()
|
||||
cert = _agent_identity_utils.parse_certificate(cert_bytes)
|
||||
if _agent_identity_utils.should_request_bound_token(cert):
|
||||
cert_fingerprint = (
|
||||
_agent_identity_utils.calculate_certificate_fingerprint(cert)
|
||||
)
|
||||
|
||||
self._perform_refresh_token(request, cert_fingerprint=cert_fingerprint)
|
||||
self._handle_trust_boundary(request)
|
||||
|
||||
@@ -272,7 +272,7 @@ class Credentials(
|
||||
def _metric_header_for_usage(self):
|
||||
return metrics.CRED_TYPE_SA_IMPERSONATE
|
||||
|
||||
def _refresh_token(self, request):
|
||||
def _perform_refresh_token(self, request):
|
||||
"""Updates credentials with a new access_token representing
|
||||
the impersonated account.
|
||||
|
||||
@@ -286,7 +286,7 @@ class Credentials(
|
||||
self._source_credentials.token_state == credentials.TokenState.STALE
|
||||
or self._source_credentials.token_state == credentials.TokenState.INVALID
|
||||
):
|
||||
self._source_credentials._refresh_token(request)
|
||||
self._source_credentials.refresh(request)
|
||||
|
||||
body = {
|
||||
"delegates": self._delegates,
|
||||
@@ -640,7 +640,14 @@ class IDTokenCredentials(credentials.CredentialsWithQuotaProject):
|
||||
"Error getting ID token: {}".format(response.json())
|
||||
)
|
||||
|
||||
id_token = response.json()["token"]
|
||||
try:
|
||||
id_token = response.json()["token"]
|
||||
except (KeyError, ValueError) as caught_exc:
|
||||
new_exc = exceptions.RefreshError(
|
||||
"No ID token in response.", response.json()
|
||||
)
|
||||
raise new_exc from caught_exc
|
||||
|
||||
self.token = id_token
|
||||
self.expiry = datetime.utcfromtimestamp(
|
||||
jwt.decode(id_token, verify=False)["exp"]
|
||||
|
||||
@@ -50,8 +50,7 @@ import datetime
|
||||
import json
|
||||
import urllib
|
||||
|
||||
import cachetools
|
||||
|
||||
from google.auth import _cache
|
||||
from google.auth import _helpers
|
||||
from google.auth import _service_account_info
|
||||
from google.auth import crypt
|
||||
@@ -59,17 +58,18 @@ from google.auth import exceptions
|
||||
import google.auth.credentials
|
||||
|
||||
try:
|
||||
from google.auth.crypt import es256
|
||||
from google.auth.crypt import es
|
||||
except ImportError: # pragma: NO COVER
|
||||
es256 = None # type: ignore
|
||||
es = None # type: ignore
|
||||
|
||||
_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
|
||||
_DEFAULT_MAX_CACHE_SIZE = 10
|
||||
_ALGORITHM_TO_VERIFIER_CLASS = {"RS256": crypt.RSAVerifier}
|
||||
_CRYPTOGRAPHY_BASED_ALGORITHMS = frozenset(["ES256"])
|
||||
_CRYPTOGRAPHY_BASED_ALGORITHMS = frozenset(["ES256", "ES384"])
|
||||
|
||||
if es256 is not None: # pragma: NO COVER
|
||||
_ALGORITHM_TO_VERIFIER_CLASS["ES256"] = es256.ES256Verifier # type: ignore
|
||||
if es is not None: # pragma: NO COVER
|
||||
_ALGORITHM_TO_VERIFIER_CLASS["ES256"] = es.EsVerifier # type: ignore
|
||||
_ALGORITHM_TO_VERIFIER_CLASS["ES384"] = es.EsVerifier # type: ignore
|
||||
|
||||
|
||||
def encode(signer, payload, header=None, key_id=None):
|
||||
@@ -95,8 +95,8 @@ def encode(signer, payload, header=None, key_id=None):
|
||||
header.update({"typ": "JWT"})
|
||||
|
||||
if "alg" not in header:
|
||||
if es256 is not None and isinstance(signer, es256.ES256Signer):
|
||||
header.update({"alg": "ES256"})
|
||||
if es is not None and isinstance(signer, es.EsSigner):
|
||||
header.update({"alg": signer.algorithm})
|
||||
else:
|
||||
header.update({"alg": "RS256"})
|
||||
|
||||
@@ -585,7 +585,7 @@ class Credentials(
|
||||
|
||||
@property # type: ignore
|
||||
def additional_claims(self):
|
||||
""" Additional claims the JWT object was created with."""
|
||||
"""Additional claims the JWT object was created with."""
|
||||
return self._additional_claims
|
||||
|
||||
|
||||
@@ -629,7 +629,7 @@ class OnDemandCredentials(
|
||||
token_lifetime (int): The amount of time in seconds for
|
||||
which the token is valid. Defaults to 1 hour.
|
||||
max_cache_size (int): The maximum number of JWT tokens to keep in
|
||||
cache. Tokens are cached using :class:`cachetools.LRUCache`.
|
||||
cache. Tokens are cached using :class:`google.auth._cache.LRUCache`.
|
||||
quota_project_id (Optional[str]): The project ID used for quota
|
||||
and billing.
|
||||
|
||||
@@ -645,7 +645,7 @@ class OnDemandCredentials(
|
||||
additional_claims = {}
|
||||
|
||||
self._additional_claims = additional_claims
|
||||
self._cache = cachetools.LRUCache(maxsize=max_cache_size)
|
||||
self._cache = _cache.LRUCache(maxsize=max_cache_size)
|
||||
|
||||
@classmethod
|
||||
def _from_signer_and_info(cls, signer, info, **kwargs):
|
||||
@@ -759,7 +759,6 @@ class OnDemandCredentials(
|
||||
|
||||
@_helpers.copy_docstring(google.auth.credentials.CredentialsWithQuotaProject)
|
||||
def with_quota_project(self, quota_project_id):
|
||||
|
||||
return self.__class__(
|
||||
self._signer,
|
||||
issuer=self._issuer,
|
||||
|
||||
@@ -48,6 +48,7 @@ def python_and_auth_lib_version():
|
||||
|
||||
# Token request metric header values
|
||||
|
||||
|
||||
# x-goog-api-client header value for access token request via metadata server.
|
||||
# Example: "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/mds"
|
||||
def token_request_access_token_mds():
|
||||
@@ -108,6 +109,7 @@ def token_request_user():
|
||||
|
||||
# Miscellenous metrics
|
||||
|
||||
|
||||
# x-goog-api-client header value for metadata server ping.
|
||||
# Example: "gl-python/3.7 auth/1.1 auth-request-type/mds"
|
||||
def mds_ping():
|
||||
|
||||
@@ -37,6 +37,7 @@ except ImportError: # pragma: NO COVER
|
||||
from collections import Mapping # type: ignore
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
@@ -65,7 +66,8 @@ class Credentials(external_account.Credentials):
|
||||
is used.
|
||||
When the credential configuration is accepted from an
|
||||
untrusted source, you should validate it before using.
|
||||
Refer https://cloud.google.com/docs/authentication/external/externally-sourced-credentials for more details."""
|
||||
Refer https://cloud.google.com/docs/authentication/external/externally-sourced-credentials for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -128,17 +130,17 @@ class Credentials(external_account.Credentials):
|
||||
raise exceptions.MalformedError(
|
||||
"Missing credential_source. An 'executable' must be provided."
|
||||
)
|
||||
self._credential_source_executable_command = self._credential_source_executable.get(
|
||||
"command"
|
||||
self._credential_source_executable_command = (
|
||||
self._credential_source_executable.get("command")
|
||||
)
|
||||
self._credential_source_executable_timeout_millis = self._credential_source_executable.get(
|
||||
"timeout_millis"
|
||||
self._credential_source_executable_timeout_millis = (
|
||||
self._credential_source_executable.get("timeout_millis")
|
||||
)
|
||||
self._credential_source_executable_interactive_timeout_millis = self._credential_source_executable.get(
|
||||
"interactive_timeout_millis"
|
||||
self._credential_source_executable_interactive_timeout_millis = (
|
||||
self._credential_source_executable.get("interactive_timeout_millis")
|
||||
)
|
||||
self._credential_source_executable_output_file = self._credential_source_executable.get(
|
||||
"output_file"
|
||||
self._credential_source_executable_output_file = (
|
||||
self._credential_source_executable.get("output_file")
|
||||
)
|
||||
|
||||
# Dummy value. This variable is only used via injection, not exposed to ctor
|
||||
@@ -199,11 +201,6 @@ class Credentials(external_account.Credentials):
|
||||
else:
|
||||
return subject_token
|
||||
|
||||
if not _helpers.is_python_3():
|
||||
raise exceptions.RefreshError(
|
||||
"Pluggable auth is only supported for python 3.7+"
|
||||
)
|
||||
|
||||
# Inject env vars.
|
||||
env = os.environ.copy()
|
||||
self._inject_env_variables(env)
|
||||
@@ -220,7 +217,7 @@ class Credentials(external_account.Credentials):
|
||||
exe_stderr = sys.stdout if self.interactive else subprocess.STDOUT
|
||||
|
||||
result = subprocess.run(
|
||||
self._credential_source_executable_command.split(),
|
||||
shlex.split(self._credential_source_executable_command),
|
||||
timeout=exe_timeout,
|
||||
stdin=exe_stdin,
|
||||
stdout=exe_stdout,
|
||||
@@ -261,11 +258,6 @@ class Credentials(external_account.Credentials):
|
||||
)
|
||||
self._validate_running_mode()
|
||||
|
||||
if not _helpers.is_python_3():
|
||||
raise exceptions.RefreshError(
|
||||
"Pluggable auth is only supported for python 3.7+"
|
||||
)
|
||||
|
||||
# Inject variables
|
||||
env = os.environ.copy()
|
||||
self._inject_env_variables(env)
|
||||
@@ -273,7 +265,7 @@ class Credentials(external_account.Credentials):
|
||||
|
||||
# Run executable
|
||||
result = subprocess.run(
|
||||
self._credential_source_executable_command.split(),
|
||||
shlex.split(self._credential_source_executable_command),
|
||||
timeout=self._credential_source_executable_interactive_timeout_millis
|
||||
/ 1000,
|
||||
stdout=subprocess.PIPE,
|
||||
|
||||
@@ -276,7 +276,6 @@ class AuthorizedSession(aiohttp.ClientSession):
|
||||
auto_decompress=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
"""Implementation of Authorized Session aiohttp request.
|
||||
|
||||
Args:
|
||||
@@ -302,7 +301,7 @@ class AuthorizedSession(aiohttp.ClientSession):
|
||||
at ``max_allowed_time``. It might take longer, for example, if
|
||||
an underlying request takes a lot of time, but the request
|
||||
itself does not timeout, e.g. if a large file is being
|
||||
transmitted. The timout error will be raised after such
|
||||
transmitted. The timeout error will be raised after such
|
||||
request completes.
|
||||
"""
|
||||
# Headers come in as bytes which isn't expected behavior, the resumable
|
||||
@@ -358,7 +357,6 @@ class AuthorizedSession(aiohttp.ClientSession):
|
||||
response.status in self._refresh_status_codes
|
||||
and _credential_refresh_attempt < self._max_refresh_attempts
|
||||
):
|
||||
|
||||
requests._LOGGER.info(
|
||||
"Refreshing credentials due to a %s response. Attempt %s/%s.",
|
||||
response.status,
|
||||
|
||||
@@ -100,7 +100,6 @@ class Request(transport.Request):
|
||||
connection = http_client.HTTPConnection(parts.netloc, timeout=timeout)
|
||||
|
||||
try:
|
||||
|
||||
_helpers.request_log(_LOGGER, method, url, body, headers)
|
||||
connection.request(method, path, body=body, headers=headers, **kwargs)
|
||||
response = connection.getresponse()
|
||||
|
||||
@@ -20,11 +20,12 @@ from os import environ, getenv, path
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from google.auth import _agent_identity_utils
|
||||
from google.auth import environment_vars
|
||||
from google.auth import exceptions
|
||||
|
||||
CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json"
|
||||
CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json"
|
||||
_CERTIFICATE_CONFIGURATION_ENV = "GOOGLE_API_CERTIFICATE_CONFIG"
|
||||
_CERT_PROVIDER_COMMAND = "cert_provider_command"
|
||||
_CERT_REGEX = re.compile(
|
||||
b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL
|
||||
@@ -47,6 +48,20 @@ _PASSPHRASE_REGEX = re.compile(
|
||||
b"-----BEGIN PASSPHRASE-----(.+)-----END PASSPHRASE-----", re.DOTALL
|
||||
)
|
||||
|
||||
# Temporary patch to accomodate incorrect cert config in Cloud Run prod environment.
|
||||
_WELL_KNOWN_CLOUD_RUN_CERT_PATH = (
|
||||
"/var/run/secrets/workload-spiffe-credentials/certificates.pem"
|
||||
)
|
||||
_WELL_KNOWN_CLOUD_RUN_KEY_PATH = (
|
||||
"/var/run/secrets/workload-spiffe-credentials/private_key.pem"
|
||||
)
|
||||
_INCORRECT_CLOUD_RUN_CERT_PATH = (
|
||||
"/var/lib/volumes/certificate/workload-certificates/certificates.pem"
|
||||
)
|
||||
_INCORRECT_CLOUD_RUN_KEY_PATH = (
|
||||
"/var/lib/volumes/certificate/workload-certificates/private_key.pem"
|
||||
)
|
||||
|
||||
|
||||
def _check_config_path(config_path):
|
||||
"""Checks for config file path. If it exists, returns the absolute path with user expansion;
|
||||
@@ -132,7 +147,7 @@ def _get_cert_config_path(certificate_config_path=None):
|
||||
"""
|
||||
|
||||
if certificate_config_path is None:
|
||||
env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None)
|
||||
env_path = environ.get(environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, None)
|
||||
if env_path is not None and env_path != "":
|
||||
certificate_config_path = env_path
|
||||
else:
|
||||
@@ -183,6 +198,25 @@ def _get_workload_cert_and_key_paths(config_path):
|
||||
)
|
||||
key_path = workload["key_path"]
|
||||
|
||||
# == BEGIN Temporary Cloud Run PATCH ==
|
||||
# See https://github.com/googleapis/google-auth-library-python/issues/1881
|
||||
if (cert_path == _INCORRECT_CLOUD_RUN_CERT_PATH) and (
|
||||
key_path == _INCORRECT_CLOUD_RUN_KEY_PATH
|
||||
):
|
||||
if not path.exists(cert_path) and not path.exists(key_path):
|
||||
_LOGGER.debug(
|
||||
"Applying Cloud Run certificate path patch. "
|
||||
"Configured paths not found: %s, %s. "
|
||||
"Using well-known paths: %s, %s",
|
||||
cert_path,
|
||||
key_path,
|
||||
_WELL_KNOWN_CLOUD_RUN_CERT_PATH,
|
||||
_WELL_KNOWN_CLOUD_RUN_KEY_PATH,
|
||||
)
|
||||
cert_path = _WELL_KNOWN_CLOUD_RUN_CERT_PATH
|
||||
key_path = _WELL_KNOWN_CLOUD_RUN_KEY_PATH
|
||||
# == END Temporary Cloud Run PATCH ==
|
||||
|
||||
return cert_path, key_path
|
||||
|
||||
|
||||
@@ -279,7 +313,7 @@ def _run_cert_provider_command(command, expect_encrypted_key=False):
|
||||
def get_client_ssl_credentials(
|
||||
generate_encrypted_key=False,
|
||||
context_aware_metadata_path=CONTEXT_AWARE_METADATA_PATH,
|
||||
certificate_config_path=CERTIFICATE_CONFIGURATION_DEFAULT_PATH,
|
||||
certificate_config_path=None,
|
||||
):
|
||||
"""Returns the client side certificate, private key and passphrase.
|
||||
|
||||
@@ -306,13 +340,10 @@ def get_client_ssl_credentials(
|
||||
the cert, key and passphrase.
|
||||
"""
|
||||
|
||||
# 1. Check for certificate config json.
|
||||
cert_config_path = _check_config_path(certificate_config_path)
|
||||
if cert_config_path:
|
||||
# Attempt to retrieve X.509 Workload cert and key.
|
||||
cert, key = _get_workload_cert_and_key(cert_config_path)
|
||||
if cert and key:
|
||||
return True, cert, key, None
|
||||
# 1. Attempt to retrieve X.509 Workload cert and key.
|
||||
cert, key = _get_workload_cert_and_key(certificate_config_path)
|
||||
if cert and key:
|
||||
return True, cert, key, None
|
||||
|
||||
# 2. Check for context aware metadata json
|
||||
metadata_path = _check_config_path(context_aware_metadata_path)
|
||||
@@ -444,3 +475,29 @@ def check_use_client_cert():
|
||||
) as e:
|
||||
_LOGGER.debug("error decoding certificate: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def check_parameters_for_unauthorized_response(cached_cert):
|
||||
"""Returns the cached and current cert fingerprint for reconfiguring mTLS.
|
||||
|
||||
Args:
|
||||
cached_cert(bytes): The cached client certificate.
|
||||
|
||||
Returns:
|
||||
bytes: The client callback cert bytes.
|
||||
bytes: The client callback key bytes.
|
||||
str: The base64-encoded SHA256 cached fingerprint.
|
||||
str: The base64-encoded SHA256 current cert fingerprint.
|
||||
"""
|
||||
call_cert_bytes, call_key_bytes = _agent_identity_utils.call_client_cert_callback()
|
||||
cert_obj = _agent_identity_utils.parse_certificate(call_cert_bytes)
|
||||
current_cert_fingerprint = _agent_identity_utils.calculate_certificate_fingerprint(
|
||||
cert_obj
|
||||
)
|
||||
if cached_cert:
|
||||
cached_fingerprint = _agent_identity_utils.get_cached_cert_fingerprint(
|
||||
cached_cert
|
||||
)
|
||||
else:
|
||||
cached_fingerprint = current_cert_fingerprint
|
||||
return call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint
|
||||
|
||||
@@ -146,7 +146,7 @@ def secure_authorized_channel(
|
||||
regular_ssl_credentials = grpc.ssl_channel_credentials()
|
||||
|
||||
channel = google.auth.transport.grpc.secure_authorized_channel(
|
||||
credentials, regular_endpoint, request,
|
||||
credentials, request, regular_endpoint,
|
||||
ssl_credentials=regular_ssl_credentials)
|
||||
|
||||
Option 2: create a mutual TLS channel by calling a callback which returns
|
||||
@@ -162,7 +162,7 @@ def secure_authorized_channel(
|
||||
|
||||
try:
|
||||
channel = google.auth.transport.grpc.secure_authorized_channel(
|
||||
credentials, mtls_endpoint, request,
|
||||
credentials, request, mtls_endpoint,
|
||||
client_cert_callback=my_client_cert_callback)
|
||||
except MyClientCertFailureException:
|
||||
# handle the exception
|
||||
@@ -186,7 +186,7 @@ def secure_authorized_channel(
|
||||
else:
|
||||
endpoint_to_use = regular_endpoint
|
||||
channel = google.auth.transport.grpc.secure_authorized_channel(
|
||||
credentials, endpoint_to_use, request,
|
||||
credentials, request, endpoint_to_use,
|
||||
ssl_credentials=default_ssl_credentials)
|
||||
|
||||
Option 4: not setting ssl_credentials and client_cert_callback. For devices
|
||||
@@ -200,14 +200,14 @@ def secure_authorized_channel(
|
||||
certificate and key::
|
||||
|
||||
channel = google.auth.transport.grpc.secure_authorized_channel(
|
||||
credentials, regular_endpoint, request)
|
||||
credentials, request, regular_endpoint)
|
||||
|
||||
The following code uses mtls_endpoint, if the created channle is regular,
|
||||
and API mtls_endpoint is confgured to require client SSL credentials, API
|
||||
calls using this channel will be rejected::
|
||||
|
||||
channel = google.auth.transport.grpc.secure_authorized_channel(
|
||||
credentials, mtls_endpoint, request)
|
||||
credentials, request, mtls_endpoint)
|
||||
|
||||
Args:
|
||||
credentials (google.auth.credentials.Credentials): The credentials to
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
"""Utilites for mutual TLS."""
|
||||
|
||||
from os import getenv
|
||||
|
||||
from google.auth import exceptions
|
||||
from google.auth.transport import _mtls_helper
|
||||
|
||||
@@ -36,6 +38,12 @@ def has_default_client_cert_source():
|
||||
is not None
|
||||
):
|
||||
return True
|
||||
cert_config_path = getenv("GOOGLE_API_CERTIFICATE_CONFIG")
|
||||
if (
|
||||
cert_config_path
|
||||
and _mtls_helper._check_config_path(cert_config_path) is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,11 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import functools
|
||||
import http.client as http_client
|
||||
import logging
|
||||
import numbers
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import requests
|
||||
@@ -36,6 +38,7 @@ from requests.packages.urllib3.util.ssl_ import ( # type: ignore
|
||||
from google.auth import _helpers
|
||||
from google.auth import exceptions
|
||||
from google.auth import transport
|
||||
from google.auth.transport import _mtls_helper
|
||||
import google.auth.transport._mtls_helper
|
||||
from google.oauth2 import service_account
|
||||
|
||||
@@ -135,7 +138,7 @@ class Request(transport.Request):
|
||||
.. automethod:: __call__
|
||||
"""
|
||||
|
||||
def __init__(self, session=None):
|
||||
def __init__(self, session: Optional[requests.Session] = None) -> None:
|
||||
if not session:
|
||||
session = requests.Session()
|
||||
|
||||
@@ -463,6 +466,7 @@ class AuthorizedSession(requests.Session):
|
||||
|
||||
if self._is_mtls:
|
||||
mtls_adapter = _MutualTlsAdapter(cert, key)
|
||||
self._cached_cert = cert
|
||||
self.mount("https://", mtls_adapter)
|
||||
except (
|
||||
exceptions.ClientCertError,
|
||||
@@ -500,8 +504,12 @@ class AuthorizedSession(requests.Session):
|
||||
at ``max_allowed_time``. It might take longer, for example, if
|
||||
an underlying request takes a lot of time, but the request
|
||||
itself does not timeout, e.g. if a large file is being
|
||||
transmitted. The timout error will be raised after such
|
||||
transmitted. The timeout error will be raised after such
|
||||
request completes.
|
||||
Raises:
|
||||
google.auth.exceptions.MutualTLSChannelError: If mutual TLS
|
||||
channel creation fails for any reason.
|
||||
ValueError: If the client certificate is invalid.
|
||||
"""
|
||||
# pylint: disable=arguments-differ
|
||||
# Requests has a ton of arguments to request, but only two
|
||||
@@ -551,7 +559,36 @@ class AuthorizedSession(requests.Session):
|
||||
response.status_code in self._refresh_status_codes
|
||||
and _credential_refresh_attempt < self._max_refresh_attempts
|
||||
):
|
||||
|
||||
# Handle unauthorized permission error(401 status code)
|
||||
if response.status_code == http_client.UNAUTHORIZED:
|
||||
if self.is_mtls:
|
||||
(
|
||||
call_cert_bytes,
|
||||
call_key_bytes,
|
||||
cached_fingerprint,
|
||||
current_cert_fingerprint,
|
||||
) = _mtls_helper.check_parameters_for_unauthorized_response(
|
||||
self._cached_cert
|
||||
)
|
||||
if cached_fingerprint != current_cert_fingerprint:
|
||||
try:
|
||||
_LOGGER.info(
|
||||
"Client certificate has changed, reconfiguring mTLS "
|
||||
"channel."
|
||||
)
|
||||
self.configure_mtls_channel(
|
||||
lambda: (call_cert_bytes, call_key_bytes)
|
||||
)
|
||||
except Exception as e:
|
||||
_LOGGER.error("Failed to reconfigure mTLS channel: %s", e)
|
||||
raise exceptions.MutualTLSChannelError(
|
||||
"Failed to reconfigure mTLS channel"
|
||||
) from e
|
||||
else:
|
||||
_LOGGER.info(
|
||||
"Skipping reconfiguration of mTLS channel because the client"
|
||||
" certificate has not changed."
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Refreshing credentials due to a %s response. Attempt %s/%s.",
|
||||
response.status_code,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import http.client as http_client
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
@@ -52,6 +53,7 @@ except ImportError as caught_exc: # pragma: NO COVER
|
||||
from google.auth import _helpers
|
||||
from google.auth import exceptions
|
||||
from google.auth import transport
|
||||
from google.auth.transport import _mtls_helper
|
||||
from google.oauth2 import service_account
|
||||
|
||||
if version.parse(urllib3.__version__) >= version.parse("2.0.0"): # pragma: NO COVER
|
||||
@@ -104,9 +106,8 @@ class Request(transport.Request):
|
||||
credentials.refresh(request)
|
||||
|
||||
Args:
|
||||
http (urllib3.request.RequestMethods): An instance of any urllib3
|
||||
class that implements :class:`~urllib3.request.RequestMethods`,
|
||||
usually :class:`urllib3.PoolManager`.
|
||||
http (urllib3.PoolManager): An instance of a urllib3 class that implements
|
||||
the request interface (e.g. :class:`urllib3.PoolManager`).
|
||||
|
||||
.. automethod:: __call__
|
||||
"""
|
||||
@@ -207,7 +208,7 @@ class AuthorizedHttp(RequestMethods): # type: ignore
|
||||
response = authed_http.request(
|
||||
'GET', 'https://www.googleapis.com/storage/v1/b')
|
||||
|
||||
This class implements :class:`urllib3.request.RequestMethods` and can be
|
||||
This class implements the urllib3 request interface and can be
|
||||
used just like any other :class:`urllib3.PoolManager`.
|
||||
|
||||
The underlying :meth:`urlopen` implementation handles adding the
|
||||
@@ -299,6 +300,7 @@ class AuthorizedHttp(RequestMethods): # type: ignore
|
||||
# Request instance used by internal methods (for example,
|
||||
# credentials.refresh).
|
||||
self._request = Request(self.http)
|
||||
self._is_mtls = False
|
||||
|
||||
# https://google.aip.dev/auth/4111
|
||||
# Attempt to use self-signed JWTs when a service account is used.
|
||||
@@ -335,7 +337,10 @@ class AuthorizedHttp(RequestMethods): # type: ignore
|
||||
"""
|
||||
use_client_cert = transport._mtls_helper.check_use_client_cert()
|
||||
if not use_client_cert:
|
||||
self._is_mtls = False
|
||||
return False
|
||||
else:
|
||||
self._is_mtls = True
|
||||
try:
|
||||
import OpenSSL
|
||||
except ImportError as caught_exc:
|
||||
@@ -349,6 +354,7 @@ class AuthorizedHttp(RequestMethods): # type: ignore
|
||||
|
||||
if found_cert_key:
|
||||
self.http = _make_mutual_tls_http(cert, key)
|
||||
self._cached_cert = cert
|
||||
else:
|
||||
self.http = _make_default_http()
|
||||
except (
|
||||
@@ -381,6 +387,11 @@ class AuthorizedHttp(RequestMethods): # type: ignore
|
||||
if headers is None:
|
||||
headers = self.headers
|
||||
|
||||
use_mtls = False
|
||||
if self._is_mtls:
|
||||
MTLS_URL_PREFIXES = ["mtls.googleapis.com", "mtls.sandbox.googleapis.com"]
|
||||
use_mtls = any([prefix in url for prefix in MTLS_URL_PREFIXES])
|
||||
|
||||
# Make a copy of the headers. They will be modified by the credentials
|
||||
# and we want to pass the original headers if we recurse.
|
||||
request_headers = headers.copy()
|
||||
@@ -402,6 +413,39 @@ class AuthorizedHttp(RequestMethods): # type: ignore
|
||||
response.status in self._refresh_status_codes
|
||||
and _credential_refresh_attempt < self._max_refresh_attempts
|
||||
):
|
||||
if response.status == http_client.UNAUTHORIZED:
|
||||
if use_mtls:
|
||||
(
|
||||
call_cert_bytes,
|
||||
call_key_bytes,
|
||||
cached_fingerprint,
|
||||
current_cert_fingerprint,
|
||||
) = _mtls_helper.check_parameters_for_unauthorized_response(
|
||||
self._cached_cert
|
||||
)
|
||||
if cached_fingerprint != current_cert_fingerprint:
|
||||
try:
|
||||
_LOGGER.info(
|
||||
"Client certificate has changed, reconfiguring mTLS "
|
||||
"channel."
|
||||
)
|
||||
self.configure_mtls_channel(
|
||||
client_cert_callback=lambda: (
|
||||
call_cert_bytes,
|
||||
call_key_bytes,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_LOGGER.error("Failed to reconfigure mTLS channel: %s", e)
|
||||
raise exceptions.MutualTLSChannelError(
|
||||
"Failed to reconfigure mTLS channel"
|
||||
) from e
|
||||
|
||||
else:
|
||||
_LOGGER.info(
|
||||
"Skipping reconfiguration of mTLS channel because the "
|
||||
"client certificate has not changed."
|
||||
)
|
||||
|
||||
_LOGGER.info(
|
||||
"Refreshing credentials due to a %s response. Attempt %s/%s.",
|
||||
|
||||
@@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "2.43.0"
|
||||
__version__ = "2.47.0"
|
||||
|
||||
Reference in New Issue
Block a user