Skip to content

chore: refactor shared code between services #2395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{% extends '_base.py.j2' %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{% extends '_base.py.j2' %}

{% block content %}


import os

from google.auth.exceptions import MutualTLSChannelError # type: ignore

def _read_environment_variables():
"""Returns the environment variables used by the client.

Returns:
Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE,
GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables.

Raises:
ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not
any of ["true", "false"].
google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT
is not any of ["auto", "never", "always"].
"""
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false").lower()
use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower()
universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN")
if use_client_cert not in ("true", "false"):
raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`")
if use_mtls_endpoint not in ("auto", "never", "always"):
raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`")
return use_client_cert == "true", use_mtls_endpoint, universe_domain_env

{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ class {{ service.async_client_name }}:
_client: {{ service.client_name }}

# Copy defaults from the synchronous client for use here.
# Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead.
DEFAULT_ENDPOINT = {{ service.client_name }}.DEFAULT_ENDPOINT
DEFAULT_MTLS_ENDPOINT = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT
_DEFAULT_ENDPOINT_TEMPLATE = {{ service.client_name }}._DEFAULT_ENDPOINT_TEMPLATE
_DEFAULT_UNIVERSE = {{ service.client_name }}._DEFAULT_UNIVERSE
Expand Down Expand Up @@ -118,40 +116,6 @@ class {{ service.async_client_name }}:

from_service_account_json = from_service_account_file

@classmethod
def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None):
"""Return the API endpoint and client cert source for mutual TLS.

The client cert source is determined in the following order:
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
client cert source is None.
(2) if `client_options.client_cert_source` is provided, use the provided one; if the
default client cert source exists, use the default one; otherwise the client cert
source is None.

The API endpoint is determined in the following order:
(1) if `client_options.api_endpoint` if provided, use the provided one.
(2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
default mTLS endpoint; if the environment variable is "never", use the default API
endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
use the default API endpoint.

More details can be found at https://google.aip.dev/auth/4114.

Args:
client_options (google.api_core.client_options.ClientOptions): Custom options for the
client. Only the `api_endpoint` and `client_cert_source` properties may be used
in this method.

Returns:
Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
client cert source to use.

Raises:
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
"""
return {{ service.client_name }}.get_mtls_endpoint_and_cert_source(client_options) # type: ignore

@property
def transport(self) -> {{ service.name }}Transport:
"""Returns the transport used by the client instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ except ImportError as e: # pragma: NO COVER
{% endif %}{# if rest_async_io_enabled #}
{% endif %}

from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif %}
{{- api.naming.versioned_module_name }}._compat import legacy_helpers

class {{ service.client_name }}Meta(type):
"""Metaclass for the {{ service.name }} client.
Expand Down Expand Up @@ -172,17 +174,14 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")

# Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead.
DEFAULT_ENDPOINT = {% if service.host %}"{{ service.host }}"{% else %}None{% endif %}

DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)

_DEFAULT_ENDPOINT_TEMPLATE = {% if service.host %}"{{ service.host.replace("googleapis.com", "{UNIVERSE_DOMAIN}") }}"{% else %}None{% endif %}

_DEFAULT_UNIVERSE = "googleapis.com"

DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
_DEFAULT_ENDPOINT_TEMPLATE.format(UNIVERSE_DOMAIN=_DEFAULT_UNIVERSE)
)

@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
Expand Down Expand Up @@ -260,91 +259,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

{% endfor %}{# common resources #}

@classmethod
def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None):
"""Deprecated. Return the API endpoint and client cert source for mutual TLS.

The client cert source is determined in the following order:
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
client cert source is None.
(2) if `client_options.client_cert_source` is provided, use the provided one; if the
default client cert source exists, use the default one; otherwise the client cert
source is None.

The API endpoint is determined in the following order:
(1) if `client_options.api_endpoint` if provided, use the provided one.
(2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
default mTLS endpoint; if the environment variable is "never", use the default API
endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
use the default API endpoint.

More details can be found at https://google.aip.dev/auth/4114.

Args:
client_options (google.api_core.client_options.ClientOptions): Custom options for the
client. Only the `api_endpoint` and `client_cert_source` properties may be used
in this method.

Returns:
Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
client cert source to use.

Raises:
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
"""

warnings.warn("get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.",
DeprecationWarning)
if client_options is None:
client_options = client_options_lib.ClientOptions()
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
if use_client_cert not in ("true", "false"):
raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`")
if use_mtls_endpoint not in ("auto", "never", "always"):
raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`")

# Figure out the client cert source to use.
client_cert_source = None
if use_client_cert == "true":
if client_options.client_cert_source:
client_cert_source = client_options.client_cert_source
elif mtls.has_default_client_cert_source():
client_cert_source = mtls.default_client_cert_source()

# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
api_endpoint = client_options.api_endpoint
elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source):
api_endpoint = cls.DEFAULT_MTLS_ENDPOINT
else:
api_endpoint = cls.DEFAULT_ENDPOINT

return api_endpoint, client_cert_source

@staticmethod
def _read_environment_variables():
"""Returns the environment variables used by the client.

Returns:
Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE,
GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables.

Raises:
ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not
any of ["true", "false"].
google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT
is not any of ["auto", "never", "always"].
"""
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false").lower()
use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower()
universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN")
if use_client_cert not in ("true", "false"):
raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`")
if use_mtls_endpoint not in ("auto", "never", "always"):
raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`")
return use_client_cert == "true", use_mtls_endpoint, universe_domain_env

@staticmethod
def _get_client_cert_source(provided_cert_source, use_cert_flag):
"""Return the client cert source to be used by the client.
Expand Down Expand Up @@ -538,7 +452,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

universe_domain_opt = getattr(self._client_options, 'universe_domain', None)

self._use_client_cert, self._use_mtls_endpoint, self._universe_domain_env = {{ service.client_name }}._read_environment_variables()
self._use_client_cert, self._use_mtls_endpoint, self._universe_domain_env = legacy_helpers._read_environment_variables()
self._client_cert_source = {{ service.client_name }}._get_client_cert_source(self._client_options.client_cert_source, self._use_client_cert)
self._universe_domain = {{ service.client_name }}._get_universe_domain(universe_domain_opt, self._universe_domain_env)
self._api_endpoint = None # updated below, depending on `transport`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,6 @@ def async_anonymous_credentials():
return ga_credentials_async.AnonymousCredentials()
return ga_credentials.AnonymousCredentials()

{#TODO(https://github.com/googleapis/gapic-generator-python/issues/1894): Remove this function as part of cleanup when DEFAULT_ENDPOINT is no longer used.#}
# If default endpoint is localhost, then default mtls endpoint will be the same.
# This method modifies the default endpoint so the client can produce a different
# mtls endpoint for endpoint testing purposes.
def modify_default_endpoint(client):
return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT

# If default endpoint template is localhost, then default mtls endpoint will be the same.
# This method modifies the default endpoint template so the client can produce a different
# mtls endpoint for endpoint testing purposes.
Expand All @@ -154,36 +147,6 @@ def test__get_default_mtls_endpoint():
assert {{ service.client_name }}._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint
assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi

def test__read_environment_variables():
assert {{ service.client_name }}._read_environment_variables() == (False, "auto", None)

with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
assert {{ service.client_name }}._read_environment_variables() == (True, "auto", None)

with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}):
assert {{ service.client_name }}._read_environment_variables() == (False, "auto", None)

with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}):
with pytest.raises(ValueError) as excinfo:
{{ service.client_name }}._read_environment_variables()
assert str(excinfo.value) == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"

with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
assert {{ service.client_name }}._read_environment_variables() == (False, "never", None)

with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
assert {{ service.client_name }}._read_environment_variables() == (False, "always", None)

with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
assert {{ service.client_name }}._read_environment_variables() == (False, "auto", None)

with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
with pytest.raises(MutualTLSChannelError) as excinfo:
{{ service.client_name }}._read_environment_variables()
assert str(excinfo.value) == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"

with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}):
assert {{ service.client_name }}._read_environment_variables() == (False, "auto", "foo.com")

def test__get_client_cert_source():
mock_provided_cert_source = mock.Mock()
Expand Down Expand Up @@ -632,79 +595,6 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
)


@pytest.mark.parametrize("client_class", [
{% if 'grpc' in opts.transport %}
{{ service.client_name }}, {{ service.async_client_name }}
{% elif 'rest' in opts.transport %}
{{ service.client_name }}
{% endif %}
])
@mock.patch.object({{ service.client_name }}, "DEFAULT_ENDPOINT", modify_default_endpoint({{ service.client_name }}))
{% if 'grpc' in opts.transport %}
@mock.patch.object({{ service.async_client_name }}, "DEFAULT_ENDPOINT", modify_default_endpoint({{ service.async_client_name }}))
{% endif %}
def test_{{ service.client_name|snake_case }}_get_mtls_endpoint_and_cert_source(client_class):
mock_client_cert_source = mock.Mock()

# Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true".
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
mock_api_endpoint = "foo"
options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint)
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options)
assert api_endpoint == mock_api_endpoint
assert cert_source == mock_client_cert_source

# Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false".
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}):
mock_client_cert_source = mock.Mock()
mock_api_endpoint = "foo"
options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint)
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options)
assert api_endpoint == mock_api_endpoint
assert cert_source is None

# Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never".
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source()
assert api_endpoint == client_class.DEFAULT_ENDPOINT
assert cert_source is None

# Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always".
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source()
assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
assert cert_source is None

# Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False):
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source()
assert api_endpoint == client_class.DEFAULT_ENDPOINT
assert cert_source is None

# Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True):
with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source):
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source()
assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
assert cert_source == mock_client_cert_source

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
# unsupported value.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
with pytest.raises(MutualTLSChannelError) as excinfo:
client_class.get_mtls_endpoint_and_cert_source()

assert str(excinfo.value) == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"

# Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}):
with pytest.raises(ValueError) as excinfo:
client_class.get_mtls_endpoint_and_cert_source()

assert str(excinfo.value) == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"

@pytest.mark.parametrize("client_class", [
{% if 'grpc' in opts.transport %}
{{ service.client_name }}, {{ service.async_client_name }}
Expand Down
Loading
Loading