Skip to content

chore: refactor wrap method helper into a macro #2111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,53 @@ def _get_response(
raise core_exceptions.from_http_response(response)

{% endmacro %}


{% macro prep_wrapped_messages_async_method(service) %}
def _prep_wrapped_messages(self, client_info):
""" Precompute the wrapped methods, overriding the base class method to use async wrappers."""
self._wrapped_methods = {
{% for method in service.methods.values() %}
self.{{ method.transport_safe_name|snake_case }}: self._wrap_method(
self.{{ method.transport_safe_name|snake_case }},
{% if method.retry %}
default_retry=retries.AsyncRetry(
{% if method.retry.initial_backoff %}
initial={{ method.retry.initial_backoff }},
{% endif %}
{% if method.retry.max_backoff %}
maximum={{ method.retry.max_backoff }},
{% endif %}
{% if method.retry.backoff_multiplier %}
multiplier={{ method.retry.backoff_multiplier }},
{% endif %}
predicate=retries.if_exception_type(
{% for ex in method.retry.retryable_exceptions|sort(attribute='__name__') %}
core_exceptions.{{ ex.__name__ }},
{% endfor %}
),
deadline={{ method.timeout }},
),
{% endif %}
default_timeout={{ method.timeout }},
client_info=client_info,
),
{% endfor %}{# service.methods.values() #}
}
{% endmacro %}

{# TODO: This helper logic to check whether `kind` needs to be configured in wrap_method
can be removed once we require the correct version of the google-api-core dependency to
avoid having a gRPC code path in an async REST call.
See related issue: https://github.com/googleapis/python-api-core/issues/661.
In the meantime, if an older version of the dependency is installed (which has a wrap_method with
no kind parameter), then an async gRPC call will work correctly and async REST transport
will not be available as a transport.
See related issue: https://github.com/googleapis/gapic-generator-python/issues/2119. #}
{% macro wrap_async_method_macro() %}
def _wrap_method(self, func, *args, **kwargs):
{# TODO: Remove `pragma: NO COVER` once https://github.com/googleapis/python-api-core/pull/688 is merged. #}
if self._wrap_with_kind: # pragma: NO COVER
kwargs["kind"] = self.kind
return gapic_v1.method_async.wrap_method(func, *args, **kwargs)
{% endmacro %}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{% extends '_base.py.j2' %}

{% block content %}
{% import "%namespace/%name_%version/%sub/services/%service/_shared_macros.j2" as shared_macros %}

import inspect
import warnings
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -241,6 +243,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
)

# Wrap messages. This must be done after self._grpc_channel exists
self._wrap_with_kind = "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters
self._prep_wrapped_messages(client_info)

@property
Expand Down Expand Up @@ -385,39 +388,16 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
return self._stubs["test_iam_permissions"]
{% endif %}

def _prep_wrapped_messages(self, client_info):
""" Precompute the wrapped methods, overriding the base class method to use async wrappers."""
self._wrapped_methods = {
{% for method in service.methods.values() %}
self.{{ method.transport_safe_name|snake_case }}: gapic_v1.method_async.wrap_method(
self.{{ method.transport_safe_name|snake_case }},
{% if method.retry %}
default_retry=retries.AsyncRetry(
{% if method.retry.initial_backoff %}
initial={{ method.retry.initial_backoff }},
{% endif %}
{% if method.retry.max_backoff %}
maximum={{ method.retry.max_backoff }},
{% endif %}
{% if method.retry.backoff_multiplier %}
multiplier={{ method.retry.backoff_multiplier }},
{% endif %}
predicate=retries.if_exception_type(
{% for ex in method.retry.retryable_exceptions|sort(attribute='__name__') %}
core_exceptions.{{ ex.__name__ }},
{% endfor %}
),
deadline={{ method.timeout }},
),
{% endif %}
default_timeout={{ method.timeout }},
client_info=client_info,
),
{% endfor %} {# service.methods.values() #}
}
{{ shared_macros.prep_wrapped_messages_async_method(service)|indent(4) }}

{{ shared_macros.wrap_async_method_macro()|indent(4) }}

def close(self):
return self.grpc_channel.close()

@property
def kind(self) -> str:
return "grpc_asyncio"

{% include '%namespace/%name_%version/%sub/services/%service/transports/_mixins.py.j2' %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1041,19 +1041,9 @@ def test_transport_adc(transport_class):
transport_class()
adc.assert_called_once()

@pytest.mark.parametrize("transport_name", [
{% if "grpc" in opts.transport %}
"grpc",
{% endif %}
{% if "rest" in opts.transport %}
"rest",
{% endif %}
])
def test_transport_kind(transport_name):
transport = {{ service.client_name }}.get_transport_class(transport_name)(
credentials=ga_credentials.AnonymousCredentials(),
)
assert transport.kind == transport_name
{{ test_macros.transport_kind_test(service, opts) }}

{{ test_macros.transport_kind_test(service, opts, is_async=True) }}

{% if 'grpc' in opts.transport %}
def test_transport_grpc_default():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1876,3 +1876,34 @@ def test_{{ method_name }}_empty_call():
{% endwith %}{# method_settings #}
assert args[0] == {{ method.input.ident }}()
{% endmacro %}


{% macro transport_kind_test(service, opts, is_async=False) %}
@pytest.mark.parametrize("transport_name", [
{% if is_async %}
{% if "grpc" in opts.transport %}
"grpc_asyncio",
{% endif %}
{% else %}{# if not is_async #}
{% if "grpc" in opts.transport%}
"grpc",
{% endif %}
{% if "rest" in opts.transport %}
"rest",
{% endif %}
{% endif %}{# is_async #}
])
{% if is_async %}
@pytest.mark.asyncio
async def test_transport_kind_async(transport_name):
transport = {{ service.async_client_name }}.get_transport_class(transport_name)(
credentials=async_anonymous_credentials(),
)
{% else %}
def test_transport_kind(transport_name):
transport = {{ service.client_name }}.get_transport_class(transport_name)(
credentials=ga_credentials.AnonymousCredentials(),
)
{% endif %}
assert transport.kind == transport_name
{% endmacro %}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import warnings
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -227,6 +228,7 @@ def __init__(self, *,
)

# Wrap messages. This must be done after self._grpc_channel exists
self._wrap_with_kind = "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters
self._prep_wrapped_messages(client_info)

@property
Expand Down Expand Up @@ -941,17 +943,17 @@ def analyze_org_policy_governed_assets(self) -> Callable[
def _prep_wrapped_messages(self, client_info):
""" Precompute the wrapped methods, overriding the base class method to use async wrappers."""
self._wrapped_methods = {
self.export_assets: gapic_v1.method_async.wrap_method(
self.export_assets: self._wrap_method(
self.export_assets,
default_timeout=60.0,
client_info=client_info,
),
self.list_assets: gapic_v1.method_async.wrap_method(
self.list_assets: self._wrap_method(
self.list_assets,
default_timeout=None,
client_info=client_info,
),
self.batch_get_assets_history: gapic_v1.method_async.wrap_method(
self.batch_get_assets_history: self._wrap_method(
self.batch_get_assets_history,
default_retry=retries.AsyncRetry(
initial=0.1,
Expand All @@ -966,12 +968,12 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=60.0,
client_info=client_info,
),
self.create_feed: gapic_v1.method_async.wrap_method(
self.create_feed: self._wrap_method(
self.create_feed,
default_timeout=60.0,
client_info=client_info,
),
self.get_feed: gapic_v1.method_async.wrap_method(
self.get_feed: self._wrap_method(
self.get_feed,
default_retry=retries.AsyncRetry(
initial=0.1,
Expand All @@ -986,7 +988,7 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=60.0,
client_info=client_info,
),
self.list_feeds: gapic_v1.method_async.wrap_method(
self.list_feeds: self._wrap_method(
self.list_feeds,
default_retry=retries.AsyncRetry(
initial=0.1,
Expand All @@ -1001,12 +1003,12 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=60.0,
client_info=client_info,
),
self.update_feed: gapic_v1.method_async.wrap_method(
self.update_feed: self._wrap_method(
self.update_feed,
default_timeout=60.0,
client_info=client_info,
),
self.delete_feed: gapic_v1.method_async.wrap_method(
self.delete_feed: self._wrap_method(
self.delete_feed,
default_retry=retries.AsyncRetry(
initial=0.1,
Expand All @@ -1021,7 +1023,7 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=60.0,
client_info=client_info,
),
self.search_all_resources: gapic_v1.method_async.wrap_method(
self.search_all_resources: self._wrap_method(
self.search_all_resources,
default_retry=retries.AsyncRetry(
initial=0.1,
Expand All @@ -1036,7 +1038,7 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=15.0,
client_info=client_info,
),
self.search_all_iam_policies: gapic_v1.method_async.wrap_method(
self.search_all_iam_policies: self._wrap_method(
self.search_all_iam_policies,
default_retry=retries.AsyncRetry(
initial=0.1,
Expand All @@ -1051,7 +1053,7 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=15.0,
client_info=client_info,
),
self.analyze_iam_policy: gapic_v1.method_async.wrap_method(
self.analyze_iam_policy: self._wrap_method(
self.analyze_iam_policy,
default_retry=retries.AsyncRetry(
initial=0.1,
Expand All @@ -1065,71 +1067,80 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=300.0,
client_info=client_info,
),
self.analyze_iam_policy_longrunning: gapic_v1.method_async.wrap_method(
self.analyze_iam_policy_longrunning: self._wrap_method(
self.analyze_iam_policy_longrunning,
default_timeout=60.0,
client_info=client_info,
),
self.analyze_move: gapic_v1.method_async.wrap_method(
self.analyze_move: self._wrap_method(
self.analyze_move,
default_timeout=None,
client_info=client_info,
),
self.query_assets: gapic_v1.method_async.wrap_method(
self.query_assets: self._wrap_method(
self.query_assets,
default_timeout=None,
client_info=client_info,
),
self.create_saved_query: gapic_v1.method_async.wrap_method(
self.create_saved_query: self._wrap_method(
self.create_saved_query,
default_timeout=None,
client_info=client_info,
),
self.get_saved_query: gapic_v1.method_async.wrap_method(
self.get_saved_query: self._wrap_method(
self.get_saved_query,
default_timeout=None,
client_info=client_info,
),
self.list_saved_queries: gapic_v1.method_async.wrap_method(
self.list_saved_queries: self._wrap_method(
self.list_saved_queries,
default_timeout=None,
client_info=client_info,
),
self.update_saved_query: gapic_v1.method_async.wrap_method(
self.update_saved_query: self._wrap_method(
self.update_saved_query,
default_timeout=None,
client_info=client_info,
),
self.delete_saved_query: gapic_v1.method_async.wrap_method(
self.delete_saved_query: self._wrap_method(
self.delete_saved_query,
default_timeout=None,
client_info=client_info,
),
self.batch_get_effective_iam_policies: gapic_v1.method_async.wrap_method(
self.batch_get_effective_iam_policies: self._wrap_method(
self.batch_get_effective_iam_policies,
default_timeout=None,
client_info=client_info,
),
self.analyze_org_policies: gapic_v1.method_async.wrap_method(
self.analyze_org_policies: self._wrap_method(
self.analyze_org_policies,
default_timeout=None,
client_info=client_info,
),
self.analyze_org_policy_governed_containers: gapic_v1.method_async.wrap_method(
self.analyze_org_policy_governed_containers: self._wrap_method(
self.analyze_org_policy_governed_containers,
default_timeout=None,
client_info=client_info,
),
self.analyze_org_policy_governed_assets: gapic_v1.method_async.wrap_method(
self.analyze_org_policy_governed_assets: self._wrap_method(
self.analyze_org_policy_governed_assets,
default_timeout=None,
client_info=client_info,
),
}
}

def _wrap_method(self, func, *args, **kwargs):
if self._wrap_with_kind: # pragma: NO COVER
kwargs["kind"] = self.kind
return gapic_v1.method_async.wrap_method(func, *args, **kwargs)

def close(self):
return self.grpc_channel.close()

@property
def kind(self) -> str:
return "grpc_asyncio"

@property
def get_operation(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16415,6 +16415,18 @@ def test_transport_kind(transport_name):
)
assert transport.kind == transport_name


@pytest.mark.parametrize("transport_name", [
"grpc_asyncio",
])
@pytest.mark.asyncio
async def test_transport_kind_async(transport_name):
transport = AssetServiceAsyncClient.get_transport_class(transport_name)(
credentials=async_anonymous_credentials(),
)
assert transport.kind == transport_name


def test_transport_grpc_default():
# A client should use the gRPC transport by default.
client = AssetServiceClient(
Expand Down
Loading