diff --git a/noxfile.py b/noxfile.py index b4f8b203de..3bedc1ddb2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -322,6 +322,7 @@ def showcase_library( f"google/showcase/v1beta1/echo.proto", f"google/showcase/v1beta1/identity.proto", f"google/showcase/v1beta1/messaging.proto", + f"google/showcase/v1beta1/sequence.proto", ) session.run( *cmd_tup, diff --git a/tests/system/conftest.py b/tests/system/conftest.py index a7967d4f5d..287852b9a7 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -33,6 +33,7 @@ import google.auth from google.auth import credentials as ga_credentials from google.showcase import EchoClient +from google.showcase import SequenceServiceClient from google.showcase import IdentityClient from google.showcase import MessagingClient @@ -41,6 +42,7 @@ import asyncio from google.showcase import EchoAsyncClient from google.showcase import IdentityAsyncClient + from google.showcase import SequenceServiceAsyncClient try: from google.showcase_v1beta1.services.echo.transports import ( @@ -59,6 +61,14 @@ HAS_ASYNC_REST_IDENTITY_TRANSPORT = True except: HAS_ASYNC_REST_IDENTITY_TRANSPORT = False + try: + from google.showcase_v1beta1.services.seuence.transports import ( + AsyncSequenceServiceRestTransport, + ) + + HAS_ASYNC_REST_SEQUENCE_TRANSPORT = True + except: + HAS_ASYNC_REST_SEQUENCE_TRANSPORT = False # TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. # See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. @@ -109,6 +119,21 @@ def async_identity(use_mtls, request, event_loop): credentials=async_anonymous_credentials(), ) + @pytest.fixture(params=["grpc_asyncio", "rest_asyncio"]) + def async_sequence(use_mtls, request, event_loop): + transport = request.param + if transport == "rest_asyncio" and not HAS_ASYNC_REST_SEQUENCE_TRANSPORT: + pytest.skip("Skipping test with async rest.") + return construct_client( + SequenceServiceAsyncClient, + use_mtls, + transport_name=transport, + channel_creator=( + aio.insecure_channel if request.param == "grpc_asyncio" else None + ), + credentials=async_anonymous_credentials(), + ) + dir = os.path.dirname(__file__) with open(os.path.join(dir, "../cert/mtls.crt"), "rb") as fh: @@ -247,6 +272,13 @@ def identity(use_mtls, request): return construct_client(IdentityClient, use_mtls, transport_name=request.param) +@pytest.fixture(params=["grpc", "rest"]) +def sequence(use_mtls, request): + return construct_client( + SequenceServiceClient, use_mtls, transport_name=request.param + ) + + @pytest.fixture(params=["grpc", "rest"]) def messaging(use_mtls, request): return construct_client(MessagingClient, use_mtls, transport_name=request.param) diff --git a/tests/system/test_retry_streaming.py b/tests/system/test_retry_streaming.py new file mode 100644 index 0000000000..c0327861e8 --- /dev/null +++ b/tests/system/test_retry_streaming.py @@ -0,0 +1,314 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest import mock +from google.rpc.status_pb2 import Status +from datetime import timedelta +from google.api_core import retry as retries +from google.api_core import exceptions as core_exceptions +from google.api_core.version import __version__ as api_core_version + +if [int(n) for n in api_core_version.split(".")] < [2, 16, 0]: + pytest.skip("streaming retries requires api_core v2.16.0+", allow_module_level=True) + + +def _code_from_exc(exc): + """ + return the grpc code from an exception + """ + return exc.grpc_status_code.value[0] + + +def test_streaming_retry_success(sequence): + """ + Test a stream with a sigle success response + """ + retry = retries.StreamingRetry(predicate=retries.if_exception_type()) + content = ["hello", "world"] + seq = sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + # single response with entire stream content + "responses": [{"status": Status(code=0), "response_index": len(content)}], + } + ) + it = sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + results = [pb.content for pb in it] + assert results == content + # verify streaming report + report = sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 1 + assert report.attempts[0].status == Status(code=0) + + +def test_streaming_non_retryable_error(sequence): + """ + Test a retryable stream failing with non-retryable error + """ + retry = retries.StreamingRetry(predicate=retries.if_exception_type()) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="expected error", + ) + seq = sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": [{"status": error, "response_index": 0}], + } + ) + with pytest.raises(core_exceptions.ServiceUnavailable): + it = sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + next(it) + # verify streaming report + report = sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 1 + assert report.attempts[0].status == error + + +def test_streaming_transient_retryable(sequence): + """ + Server returns a retryable error a number of times before success. + Retryable errors should not be presented to the end user. + """ + retry = retries.StreamingRetry( + predicate=retries.if_exception_type(core_exceptions.ServiceUnavailable), + initial=0, + maximum=0, + timeout=1, + ) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="transient error", + ) + responses = [{"status": error, "response_index": 0} for _ in range(3)] + [ + {"status": Status(code=0), "response_index": len(content)} + ] + seq = sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + it = sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + results = [pb.content for pb in it] + assert results == content + # verify streaming report + report = sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 4 + assert report.attempts[0].status == error + assert report.attempts[1].status == error + assert report.attempts[2].status == error + assert report.attempts[3].status == Status(code=0) + + +def test_streaming_transient_retryable_partial_data(sequence): + """ + Server stream yields some data before failing with a retryable error a number of times before success. + Wrapped stream should contain data from all attempts + + TODO: + Test is currently skipped for rest clients due to known issue: + https://github.com/googleapis/gapic-generator-python/issues/2375 + """ + from google.protobuf.duration_pb2 import Duration + + if isinstance(sequence.transport, type(sequence).get_transport_class("rest")): + pytest.skip("Skipping due to known streaming issue in rest client") + + retry = retries.StreamingRetry( + predicate=retries.if_exception_type(core_exceptions.ServiceUnavailable), + initial=0, + maximum=0, + ) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="transient error", + ) + transient_error_list = [ + {"status": error, "response_index": 1, "delay": Duration(seconds=30)} + ] * 3 + + responses = transient_error_list + [ + {"status": Status(code=0), "response_index": len(content)} + ] + seq = sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + it = sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + results = [pb.content for pb in it] + assert results == ["hello", "hello", "hello", "hello", "world"] + # verify streaming report + report = sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 4 + assert report.attempts[0].status == error + assert report.attempts[1].status == error + assert report.attempts[2].status == error + assert report.attempts[3].status == Status(code=0) + + +def test_streaming_retryable_eventual_timeout(sequence): + """ + Server returns a retryable error a number of times before reaching timeout. + Should raise a retry error. + """ + retry = retries.StreamingRetry( + predicate=retries.if_exception_type(core_exceptions.ServiceUnavailable), + initial=0, + maximum=0, + timeout=0.35, + ) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="transient error", + ) + transient_error_list = [ + {"status": error, "response_index": 1, "delay": timedelta(seconds=0.15)} + ] * 10 + responses = transient_error_list + [ + {"status": Status(code=0), "response_index": len(content)} + ] + seq = sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + with pytest.raises(core_exceptions.RetryError) as exc_info: + it = sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + [pb.content for pb in it] + cause = exc_info.value.__cause__ + assert isinstance(cause, core_exceptions.ServiceUnavailable) + # verify streaming report + report = sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 3 + assert report.attempts[0].status == error + assert report.attempts[1].status == error + assert report.attempts[2].status == error + + +def test_streaming_retry_on_error(sequence): + """ + on_error should be called for all retryable errors as they are encountered + """ + encountered_excs = [] + + def on_error(exc): + encountered_excs.append(exc) + + retry = retries.StreamingRetry( + predicate=retries.if_exception_type( + core_exceptions.ServiceUnavailable, core_exceptions.GatewayTimeout + ), + initial=0, + maximum=0, + on_error=on_error, + ) + content = ["hello", "world"] + errors = [ + core_exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.NotFound, + ] + responses = [{"status": Status(code=_code_from_exc(exc))} for exc in errors] + seq = sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + with pytest.raises(core_exceptions.NotFound): + it = sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + [pb.content for pb in it] + # on_error should have been called on the first two errors, but not the terminal one + assert len(encountered_excs) == 2 + assert isinstance(encountered_excs[0], core_exceptions.ServiceUnavailable) + # rest raises superclass GatewayTimeout in place of DeadlineExceeded + assert isinstance( + encountered_excs[1], + (core_exceptions.DeadlineExceeded, core_exceptions.GatewayTimeout), + ) + + +@pytest.mark.parametrize( + "initial,multiplier,maximum,expected", + [ + (0.1, 1.0, 0.5, [0.1, 0.1, 0.1]), + (0, 2.0, 0.5, [0, 0]), + (0.1, 2.0, 0.5, [0.1, 0.2, 0.4, 0.5, 0.5]), + (1, 1.5, 5, [1, 1.5, 2.25, 3.375, 5, 5]), + ], +) +def test_streaming_retry_sleep_generator( + sequence, initial, multiplier, maximum, expected +): + """ + should be able to pass in sleep generator to control backoff + """ + retry = retries.StreamingRetry( + predicate=retries.if_exception_type(core_exceptions.ServiceUnavailable), + initial=initial, + maximum=maximum, + multiplier=multiplier, + ) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="transient error", + ) + transient_error_list = [{"status": error}] * len(expected) + responses = transient_error_list + [ + {"status": Status(code=0), "response_index": len(content)} + ] + seq = sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + with mock.patch("random.uniform") as mock_uniform: + # make sleep generator deterministic + mock_uniform.side_effect = lambda a, b: b + with mock.patch("time.sleep") as mock_sleep: + it = sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + [pb.content for pb in it] + assert mock_sleep.call_count == len(expected) + # ensure that sleep times match expected + assert mock_sleep.call_args_list == [ + mock.call(sleep_time) for sleep_time in expected + ] diff --git a/tests/system/test_retry_streaming_async.py b/tests/system/test_retry_streaming_async.py new file mode 100644 index 0000000000..39e07d3377 --- /dev/null +++ b/tests/system/test_retry_streaming_async.py @@ -0,0 +1,308 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import sys +from unittest import mock +from google.rpc.status_pb2 import Status +from datetime import timedelta +from google.api_core import retry as retries +from google.api_core import exceptions as core_exceptions +from google.api_core.version import __version__ as api_core_version + +from test_retry_streaming import _code_from_exc + +if [int(n) for n in api_core_version.split(".")] < [2, 16, 0]: + pytest.skip("streaming retries requires api_core v2.16.0+", allow_module_level=True) + + +@pytest.mark.asyncio +async def test_async_streaming_retry_success(async_sequence): + """ + Test a stream with a sigle success response + """ + retry = retries.AsyncStreamingRetry(predicate=retries.if_exception_type()) + content = ["hello", "world"] + seq = await async_sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + # single response with entire stream content + "responses": [{"status": Status(code=0), "response_index": len(content)}], + } + ) + it = await async_sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + results = [pb.content async for pb in it] + assert results == content + # verify streaming report + report = await async_sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 1 + assert report.attempts[0].status == Status(code=0) + + +@pytest.mark.asyncio +async def test_async_streaming_non_retryable_error(async_sequence): + """ + Test a retryable stream failing with non-retryable error + """ + retry = retries.AsyncStreamingRetry(predicate=retries.if_exception_type()) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="expected error", + ) + seq = await async_sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": [{"status": error, "response_index": 0}], + } + ) + with pytest.raises(core_exceptions.ServiceUnavailable): + it = await async_sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + await it.__anext__() + # verify streaming report + report = await async_sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 1 + assert report.attempts[0].status == error + + +@pytest.mark.asyncio +async def test_async_streaming_transient_retryable(async_sequence): + """ + Server returns a retryable error a number of times before success. + Retryable errors should not be presented to the end user. + """ + retry = retries.AsyncStreamingRetry( + predicate=retries.if_exception_type(core_exceptions.ServiceUnavailable), + initial=0, + maximum=0, + timeout=1, + ) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="transient error", + ) + responses = [{"status": error, "response_index": 0} for _ in range(3)] + [ + {"status": Status(code=0), "response_index": len(content)} + ] + seq = await async_sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + it = await async_sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + results = [pb.content async for pb in it] + assert results == content + # verify streaming report + report = await async_sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 4 + assert report.attempts[0].status == error + assert report.attempts[1].status == error + assert report.attempts[2].status == error + assert report.attempts[3].status == Status(code=0) + + +@pytest.mark.asyncio +async def test_async_streaming_transient_retryable_partial_data(async_sequence): + """ + Server stream yields some data before failing with a retryable error a number of times before success. + Wrapped stream should contain data from all attempts + """ + retry = retries.AsyncStreamingRetry( + predicate=retries.if_exception_type(core_exceptions.ServiceUnavailable), + initial=0, + maximum=0, + ) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="transient error", + ) + transient_error_list = [{"status": error, "response_index": 1}] * 3 + responses = transient_error_list + [ + {"status": Status(code=0), "response_index": len(content)} + ] + seq = await async_sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + it = await async_sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + results = [pb.content async for pb in it] + assert results == ["hello"] * len(transient_error_list) + ["hello", "world"] + # verify streaming report + report = await async_sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 4 + assert report.attempts[0].status == error + assert report.attempts[1].status == error + assert report.attempts[2].status == error + assert report.attempts[3].status == Status(code=0) + + +@pytest.mark.asyncio +async def test_async_streaming_retryable_eventual_timeout(async_sequence): + """ + Server returns a retryable error a number of times before reaching timeout. + Should raise a retry error. + """ + retry = retries.AsyncStreamingRetry( + predicate=retries.if_exception_type(core_exceptions.ServiceUnavailable), + initial=0, + maximum=0, + timeout=0.35, + ) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="transient error", + ) + transient_error_list = [ + {"status": error, "response_index": 1, "delay": timedelta(seconds=0.15)} + ] * 10 + responses = transient_error_list + [ + {"status": Status(code=0), "response_index": len(content)} + ] + seq = await async_sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + with pytest.raises(core_exceptions.RetryError) as exc_info: + it = await async_sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + [pb.content async for pb in it] + cause = exc_info.value.__cause__ + assert isinstance(cause, core_exceptions.ServiceUnavailable) + # verify streaming report + report = await async_sequence.get_streaming_sequence_report( + name=f"{seq.name}/streamingSequenceReport" + ) + assert len(report.attempts) == 3 + assert report.attempts[0].status == error + assert report.attempts[1].status == error + assert report.attempts[2].status == error + + +@pytest.mark.asyncio +async def test_async_streaming_retry_on_error(async_sequence): + """ + on_error should be called for all retryable errors as they are encountered + """ + encountered_excs = [] + + def on_error(exc): + encountered_excs.append(exc) + + retry = retries.AsyncStreamingRetry( + predicate=retries.if_exception_type( + core_exceptions.ServiceUnavailable, core_exceptions.GatewayTimeout + ), + initial=0, + maximum=0, + on_error=on_error, + ) + content = ["hello", "world"] + errors = [ + core_exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.NotFound, + ] + responses = [{"status": Status(code=_code_from_exc(exc))} for exc in errors] + seq = await async_sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + with pytest.raises(core_exceptions.NotFound): + it = await async_sequence.attempt_streaming_sequence(name=seq.name, retry=retry) + [pb.content async for pb in it] + # on_error should have been called on the first two errors, but not the terminal one + assert len(encountered_excs) == 2 + assert isinstance(encountered_excs[0], core_exceptions.ServiceUnavailable) + # rest raises superclass GatewayTimeout in place of DeadlineExceeded + assert isinstance( + encountered_excs[1], + (core_exceptions.DeadlineExceeded, core_exceptions.GatewayTimeout), + ) + + +@pytest.mark.parametrize( + "initial,multiplier,maximum,expected", + [ + (0.1, 1.0, 0.5, [0.1, 0.1, 0.1]), + (0, 2.0, 0.5, [0, 0]), + (0.1, 2.0, 0.5, [0.1, 0.2, 0.4, 0.5, 0.5]), + (1, 1.5, 5, [1, 1.5, 2.25, 3.375, 5, 5]), + ], +) +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires 3.8") +async def test_async_streaming_retry_sleep_generator( + async_sequence, initial, multiplier, maximum, expected +): + """ + should be able to pass in sleep generator to control backoff + """ + retry = retries.AsyncStreamingRetry( + predicate=retries.if_exception_type(core_exceptions.ServiceUnavailable), + initial=initial, + maximum=maximum, + multiplier=multiplier, + ) + content = ["hello", "world"] + error = Status( + code=_code_from_exc(core_exceptions.ServiceUnavailable), + message="transient error", + ) + transient_error_list = [{"status": error}] * len(expected) + responses = transient_error_list + [ + {"status": Status(code=0), "response_index": len(content)} + ] + seq = await async_sequence.create_streaming_sequence( + streaming_sequence={ + "name": __name__, + "content": " ".join(content), + "responses": responses, + } + ) + with mock.patch("random.uniform") as mock_uniform: + # make sleep generator deterministic + mock_uniform.side_effect = lambda a, b: b + with mock.patch("asyncio.sleep", mock.AsyncMock()) as mock_sleep: + it = await async_sequence.attempt_streaming_sequence( + name=seq.name, retry=retry + ) + [pb.content async for pb in it] + assert mock_sleep.call_count == len(expected) + # ensure that sleep times match expected + assert mock_sleep.call_args_list == [ + mock.call(sleep_time) for sleep_time in expected + ]