Skip to content

Commit 812cf3e

Browse files
authored
feat: adds REST server-streaming support. (googleapis#1120)
1 parent 8078961 commit 812cf3e

File tree

8 files changed

+170
-61
lines changed

8 files changed

+170
-61
lines changed

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ from google.auth import credentials as ga_credentials # type: ignore
1111
from google.api_core import exceptions as core_exceptions
1212
from google.api_core import retry as retries
1313
from google.api_core import rest_helpers
14+
from google.api_core import rest_streaming
1415
from google.api_core import path_template
1516
from google.api_core import gapic_v1
17+
1618
{% if service.has_lro %}
1719
from google.api_core import operations_v1
1820
from google.protobuf import json_format
@@ -66,7 +68,7 @@ class {{ service.name }}RestInterceptor:
6668

6769
.. code-block:
6870
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
69-
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
71+
{% for _, method in service.methods|dictsort if not method.client_streaming %}
7072
def pre_{{ method.name|snake_case }}(request, metadata):
7173
logging.log(f"Received request: {request}")
7274
return request, metadata
@@ -82,7 +84,7 @@ class {{ service.name }}RestInterceptor:
8284

8385

8486
"""
85-
{% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %}
87+
{% for method in service.methods.values()|sort(attribute="name") if not method.client_streaming %}
8688
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
8789
"""Pre-rpc interceptor for {{ method.name|snake_case }}
8890

@@ -92,7 +94,11 @@ class {{ service.name }}RestInterceptor:
9294
return request, metadata
9395

9496
{% if not method.void %}
97+
{% if not method.server_streaming %}
9598
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
99+
{% else %}
100+
def post_{{ method.name|snake_case }}(self, response: rest_streaming.ResponseIterator) -> rest_streaming.ResponseIterator:
101+
{% endif %}
96102
"""Post-rpc interceptor for {{ method.name|snake_case }}
97103

98104
Override in a subclass to manipulate the response
@@ -248,8 +254,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
248254
def __hash__(self):
249255
return hash("{{method.name}}")
250256

251-
252-
{% if not (method.server_streaming or method.client_streaming) %}
257+
{% if not method.client_streaming %}
253258
{% if method.input.required_fields %}
254259
__REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {
255260
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
@@ -262,15 +267,15 @@ class {{service.name}}RestTransport({{service.name}}Transport):
262267
def _get_unset_required_fields(cls, message_dict):
263268
return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict}
264269
{% endif %}{# required fields #}
265-
{% endif %}{# not (method.server_streaming or method.client_streaming) #}
270+
{% endif %}{# not method.client_streaming #}
266271

267272
def __call__(self,
268273
request: {{method.input.ident}}, *,
269274
retry: OptionalRetry=gapic_v1.method.DEFAULT,
270275
timeout: float=None,
271276
metadata: Sequence[Tuple[str, str]]=(),
272-
){% if not method.void %} -> {{method.output.ident}}{% endif %}:
273-
{% if method.http_options and not (method.server_streaming or method.client_streaming) %}
277+
){% if not method.void %} -> {% if not method.server_streaming %}{{method.output.ident}}{% else %}rest_streaming.ResponseIterator{% endif %}{% endif %}:
278+
{% if method.http_options and not method.client_streaming %}
274279
r"""Call the {{- ' ' -}}
275280
{{ (method.name|snake_case).replace('_',' ')|wrap(
276281
width=70, offset=45, indent=8) }}
@@ -360,6 +365,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
360365
{% if method.lro %}
361366
resp = operations_pb2.Operation()
362367
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
368+
{% elif method.server_streaming %}
369+
resp = rest_streaming.ResponseIterator(response, {{method.output.ident}})
363370
{% else %}
364371
resp = {{method.output.ident}}.from_json(
365372
response.content,
@@ -370,14 +377,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):
370377
return resp
371378

372379
{% endif %}{# method.void #}
373-
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
380+
{% else %}{# method.http_options and not method.client_streaming #}
374381
{% if not method.http_options %}
375382
raise RuntimeError(
376383
"Cannot define a method without a valid 'google.api.http' annotation.")
377384

378-
{% elif method.server_streaming or method.client_streaming %}
385+
{% elif method.client_streaming %}
379386
raise NotImplementedError(
380-
"Streaming over REST is not yet defined for python client")
387+
"Client streaming over REST is not yet defined for python client")
381388

382389
{% else %}
383390
raise NotImplementedError()

gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import mock
88
import grpc
99
from grpc.experimental import aio
1010
{% if "rest" in opts.transport %}
11+
from collections.abc import Iterable
1112
import json
1213
{% endif %}
1314
import math
@@ -861,8 +862,8 @@ def test_{{ method_name }}_raw_page_lro():
861862
{% endfor %} {# method in methods for grpc #}
862863

863864
{% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}{% if method.http_options %}
864-
{# TODO(kbandes): remove this if condition when streaming are supported. #}
865-
{% if not (method.server_streaming or method.client_streaming) %}
865+
{# TODO(kbandes): remove this if condition when client streaming are supported. #}
866+
{% if not method.client_streaming %}
866867
@pytest.mark.parametrize("request_type", [
867868
{{ method.input.ident }},
868869
dict,
@@ -884,8 +885,6 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
884885
return_value = None
885886
{% elif method.lro %}
886887
return_value = operations_pb2.Operation(name='operations/spam')
887-
{% elif method.server_streaming %}
888-
return_value = iter([{{ method.output.ident }}()])
889888
{% else %}
890889
return_value = {{ method.output.ident }}(
891890
{% for field in method.output.fields.values() | rejectattr('message')%}
@@ -905,6 +904,8 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
905904
req.return_value.request = PreparedRequest()
906905
{% if method.void %}
907906
json_return_value = ''
907+
{% elif method.server_streaming %}
908+
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
908909
{% else %}
909910
json_return_value = {{ method.output.ident }}.to_json(return_value)
910911
{% endif %}
@@ -914,6 +915,10 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
914915
# the request over the wire, so an empty request is fine.
915916
{% if method.client_streaming %}
916917
client.{{ method_name }}(iter([requests]))
918+
{% elif method.server_streaming %}
919+
with mock.patch.object(response_value, 'iter_content') as iter_content:
920+
iter_content.return_value = iter(json_return_value)
921+
response = client.{{ method_name }}(request)
917922
{% else %}
918923
client.{{ method_name }}(request)
919924
{% endif %}
@@ -950,8 +955,6 @@ def test_{{ method.name|snake_case }}_rest(request_type):
950955
return_value = None
951956
{% elif method.lro %}
952957
return_value = operations_pb2.Operation(name='operations/spam')
953-
{% elif method.server_streaming %}
954-
return_value = iter([{{ method.output.ident }}()])
955958
{% else %}
956959
return_value = {{ method.output.ident }}(
957960
{% for field in method.output.fields.values() | rejectattr('message')%}
@@ -974,13 +977,19 @@ def test_{{ method.name|snake_case }}_rest(request_type):
974977
json_return_value = ''
975978
{% elif method.lro %}
976979
json_return_value = json_format.MessageToJson(return_value)
980+
{% elif method.server_streaming %}
981+
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
977982
{% else %}
978983
json_return_value = {{ method.output.ident }}.to_json(return_value)
979984
{% endif %}
980985
response_value._content = json_return_value.encode('UTF-8')
981986
req.return_value = response_value
982987
{% if method.client_streaming %}
983988
response = client.{{ method.name|snake_case }}(iter(requests))
989+
{% elif method.server_streaming %}
990+
with mock.patch.object(response_value, 'iter_content') as iter_content:
991+
iter_content.return_value = iter(json_return_value)
992+
response = client.{{ method_name }}(request)
984993
{% else %}
985994
response = client.{{ method_name }}(request)
986995
{% endif %}
@@ -991,6 +1000,11 @@ def test_{{ method.name|snake_case }}_rest(request_type):
9911000

9921001
{% endif %}
9931002

1003+
{% if method.server_streaming %}
1004+
assert isinstance(response, Iterable)
1005+
response = next(response)
1006+
{% endif %}
1007+
9941008
# Establish that the response is the type that we expect.
9951009
{% if method.void %}
9961010
assert response is None
@@ -1085,8 +1099,6 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
10851099
return_value = None
10861100
{% elif method.lro %}
10871101
return_value = operations_pb2.Operation(name='operations/spam')
1088-
{% elif method.server_streaming %}
1089-
return_value = iter([{{ method.output.ident }}()])
10901102
{% else %}
10911103
return_value = {{ method.output.ident }}()
10921104
{% endif %}
@@ -1114,6 +1126,8 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
11141126
json_return_value = ''
11151127
{% elif method.lro %}
11161128
json_return_value = json_format.MessageToJson(return_value)
1129+
{% elif method.server_streaming %}
1130+
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
11171131
{% else %}
11181132
json_return_value = {{ method.output.ident }}.to_json(return_value)
11191133
{% endif %}
@@ -1122,6 +1136,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
11221136

11231137
{% if method.client_streaming %}
11241138
response = client.{{ method.name|snake_case }}(iter(requests))
1139+
{% elif method.server_streaming %}
1140+
with mock.patch.object(response_value, 'iter_content') as iter_content:
1141+
iter_content.return_value = iter(json_return_value)
1142+
response = client.{{ method_name }}(request)
11251143
{% else %}
11261144
response = client.{{ method_name }}(request)
11271145
{% endif %}
@@ -1248,8 +1266,6 @@ def test_{{ method.name|snake_case }}_rest_flattened():
12481266
return_value = None
12491267
{% elif method.lro %}
12501268
return_value = operations_pb2.Operation(name='operations/spam')
1251-
{% elif method.server_streaming %}
1252-
return_value = iter([{{ method.output.ident }}()])
12531269
{% else %}
12541270
return_value = {{ method.output.ident }}()
12551271
{% endif %}
@@ -1261,6 +1277,8 @@ def test_{{ method.name|snake_case }}_rest_flattened():
12611277
json_return_value = ''
12621278
{% elif method.lro %}
12631279
json_return_value = json_format.MessageToJson(return_value)
1280+
{% elif method.server_streaming %}
1281+
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
12641282
{% else %}
12651283
json_return_value = {{ method.output.ident }}.to_json(return_value)
12661284
{% endif %}
@@ -1281,7 +1299,14 @@ def test_{{ method.name|snake_case }}_rest_flattened():
12811299
{% endfor %}
12821300
)
12831301
mock_args.update(sample_request)
1302+
1303+
{% if method.server_streaming %}
1304+
with mock.patch.object(response_value, 'iter_content') as iter_content:
1305+
iter_content.return_value = iter(json_return_value)
1306+
client.{{ method_name }}(**mock_args)
1307+
{% else %}
12841308
client.{{ method_name }}(**mock_args)
1309+
{% endif %}
12851310

12861311
# Establish that the underlying call was made with the expected
12871312
# request object values.
@@ -1385,6 +1410,9 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
13851410
response = tuple({{ method.output.ident }}.to_json(x) for x in response)
13861411
return_values = tuple(Response() for i in response)
13871412
for return_val, response_val in zip(return_values, response):
1413+
{% if method.server_streaming %}
1414+
response_val = "[{}]".format({{ method.output.ident }}.to_json(response_val))
1415+
{% endif %}
13881416
return_val._content = response_val.encode('UTF-8')
13891417
return_val.status_code = 200
13901418
req.side_effect = return_values

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ from google.auth import credentials as ga_credentials # type: ignore
1111
from google.api_core import exceptions as core_exceptions
1212
from google.api_core import retry as retries
1313
from google.api_core import rest_helpers
14+
from google.api_core import rest_streaming
1415
from google.api_core import path_template
1516
from google.api_core import gapic_v1
17+
1618
{% if service.has_lro %}
1719
from google.api_core import operations_v1
1820
from google.protobuf import json_format
@@ -66,7 +68,7 @@ class {{ service.name }}RestInterceptor:
6668

6769
.. code-block:
6870
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
69-
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
71+
{% for _, method in service.methods|dictsort if not method.client_streaming %}
7072
def pre_{{ method.name|snake_case }}(request, metadata):
7173
logging.log(f"Received request: {request}")
7274
return request, metadata
@@ -82,7 +84,7 @@ class {{ service.name }}RestInterceptor:
8284

8385

8486
"""
85-
{% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %}
87+
{% for method in service.methods.values()|sort(attribute="name") if not method.client_streaming %}
8688
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
8789
"""Pre-rpc interceptor for {{ method.name|snake_case }}
8890

@@ -92,7 +94,11 @@ class {{ service.name }}RestInterceptor:
9294
return request, metadata
9395

9496
{% if not method.void %}
97+
{% if not method.server_streaming %}
9598
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
99+
{% else %}
100+
def post_{{ method.name|snake_case }}(self, response: rest_streaming.ResponseIterator) -> rest_streaming.ResponseIterator:
101+
{% endif %}
96102
"""Post-rpc interceptor for {{ method.name|snake_case }}
97103

98104
Override in a subclass to manipulate the response
@@ -248,8 +254,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
248254
def __hash__(self):
249255
return hash("{{method.name}}")
250256

251-
252-
{% if not (method.server_streaming or method.client_streaming) %}
257+
{% if not method.client_streaming %}
253258
{% if method.input.required_fields %}
254259
__REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {
255260
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
@@ -262,15 +267,15 @@ class {{service.name}}RestTransport({{service.name}}Transport):
262267
def _get_unset_required_fields(cls, message_dict):
263268
return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict}
264269
{% endif %}{# required fields #}
265-
{% endif %}{# not (method.server_streaming or method.client_streaming) #}
270+
{% endif %}{# not method.client_streaming #}
266271

267272
def __call__(self,
268273
request: {{method.input.ident}}, *,
269274
retry: OptionalRetry=gapic_v1.method.DEFAULT,
270275
timeout: float=None,
271276
metadata: Sequence[Tuple[str, str]]=(),
272-
){% if not method.void %} -> {{method.output.ident}}{% endif %}:
273-
{% if method.http_options and not (method.server_streaming or method.client_streaming) %}
277+
){% if not method.void %} -> {% if not method.server_streaming %}{{method.output.ident}}{% else %}rest_streaming.ResponseIterator{% endif %}{% endif %}:
278+
{% if method.http_options and not method.client_streaming %}
274279
r"""Call the {{- ' ' -}}
275280
{{ (method.name|snake_case).replace('_',' ')|wrap(
276281
width=70, offset=45, indent=8) }}
@@ -360,6 +365,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
360365
{% if method.lro %}
361366
resp = operations_pb2.Operation()
362367
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
368+
{% elif method.server_streaming %}
369+
resp = rest_streaming.ResponseIterator(response, {{method.output.ident}})
363370
{% else %}
364371
resp = {{method.output.ident}}.from_json(
365372
response.content,
@@ -370,14 +377,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):
370377
return resp
371378

372379
{% endif %}{# method.void #}
373-
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
380+
{% else %}{# method.http_options and not method.client_streaming #}
374381
{% if not method.http_options %}
375382
raise RuntimeError(
376383
"Cannot define a method without a valid 'google.api.http' annotation.")
377384

378-
{% elif method.server_streaming or method.client_streaming %}
385+
{% elif method.client_streaming %}
379386
raise NotImplementedError(
380-
"Streaming over REST is not yet defined for python client")
387+
"Client streaming over REST is not yet defined for python client")
381388

382389
{% else %}
383390
raise NotImplementedError()

gapic/templates/setup.py.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ setuptools.setup(
2929
install_requires=(
3030
{# TODO(dovs): remove when 1.x deprecation is complete #}
3131
{% if 'rest' in opts.transport %}
32-
'google-api-core[grpc] >= 2.3.0, < 3.0.0dev',
32+
'google-api-core[grpc] >= 2.4.0, < 3.0.0dev',
3333
{% else %}
3434
'google-api-core[grpc] >= 1.28.0, < 3.0.0dev',
3535
{% endif %}

0 commit comments

Comments
 (0)