@@ -8,6 +8,7 @@ import mock
8
8
import grpc
9
9
from grpc.experimental import aio
10
10
{% if "rest" in opts .transport %}
11
+ from collections.abc import Iterable
11
12
import json
12
13
{% endif %}
13
14
import math
@@ -861,8 +862,8 @@ def test_{{ method_name }}_raw_page_lro():
861
862
{% endfor %} {# method in methods for grpc #}
862
863
863
864
{% for method in service .methods .values () if 'rest' in opts .transport %}{% with method_name = method .name |snake_case + "_unary" if method .operation_service else method .name |snake_case %}{% if method .http_options %}
864
- {# TODO(kbandes): remove this if condition when streaming are supported. #}
865
- {% if not ( method .server_streaming or method . client_streaming ) %}
865
+ {# TODO(kbandes): remove this if condition when client streaming are supported. #}
866
+ {% if not method .client_streaming %}
866
867
@pytest.mark.parametrize("request_type", [
867
868
{{ method.input.ident }},
868
869
dict,
@@ -884,8 +885,6 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
884
885
return_value = None
885
886
{% elif method .lro %}
886
887
return_value = operations_pb2.Operation(name='operations/spam')
887
- {% elif method .server_streaming %}
888
- return_value = iter([{{ method.output.ident }}()])
889
888
{% else %}
890
889
return_value = {{ method.output.ident }}(
891
890
{% for field in method .output .fields .values () | rejectattr ('message' )%}
@@ -905,6 +904,8 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
905
904
req.return_value.request = PreparedRequest()
906
905
{% if method .void %}
907
906
json_return_value = ''
907
+ {% elif method .server_streaming %}
908
+ json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
908
909
{% else %}
909
910
json_return_value = {{ method.output.ident }}.to_json(return_value)
910
911
{% endif %}
@@ -914,6 +915,10 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
914
915
# the request over the wire, so an empty request is fine.
915
916
{% if method .client_streaming %}
916
917
client.{{ method_name }}(iter([requests]))
918
+ {% elif method .server_streaming %}
919
+ with mock.patch.object(response_value, 'iter_content') as iter_content:
920
+ iter_content.return_value = iter(json_return_value)
921
+ response = client.{{ method_name }}(request)
917
922
{% else %}
918
923
client.{{ method_name }}(request)
919
924
{% endif %}
@@ -950,8 +955,6 @@ def test_{{ method.name|snake_case }}_rest(request_type):
950
955
return_value = None
951
956
{% elif method .lro %}
952
957
return_value = operations_pb2.Operation(name='operations/spam')
953
- {% elif method .server_streaming %}
954
- return_value = iter([{{ method.output.ident }}()])
955
958
{% else %}
956
959
return_value = {{ method.output.ident }}(
957
960
{% for field in method .output .fields .values () | rejectattr ('message' )%}
@@ -974,13 +977,19 @@ def test_{{ method.name|snake_case }}_rest(request_type):
974
977
json_return_value = ''
975
978
{% elif method .lro %}
976
979
json_return_value = json_format.MessageToJson(return_value)
980
+ {% elif method .server_streaming %}
981
+ json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
977
982
{% else %}
978
983
json_return_value = {{ method.output.ident }}.to_json(return_value)
979
984
{% endif %}
980
985
response_value._content = json_return_value.encode('UTF-8')
981
986
req.return_value = response_value
982
987
{% if method .client_streaming %}
983
988
response = client.{{ method.name|snake_case }}(iter(requests))
989
+ {% elif method .server_streaming %}
990
+ with mock.patch.object(response_value, 'iter_content') as iter_content:
991
+ iter_content.return_value = iter(json_return_value)
992
+ response = client.{{ method_name }}(request)
984
993
{% else %}
985
994
response = client.{{ method_name }}(request)
986
995
{% endif %}
@@ -991,6 +1000,11 @@ def test_{{ method.name|snake_case }}_rest(request_type):
991
1000
992
1001
{% endif %}
993
1002
1003
+ {% if method .server_streaming %}
1004
+ assert isinstance(response, Iterable)
1005
+ response = next(response)
1006
+ {% endif %}
1007
+
994
1008
# Establish that the response is the type that we expect.
995
1009
{% if method .void %}
996
1010
assert response is None
@@ -1085,8 +1099,6 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
1085
1099
return_value = None
1086
1100
{% elif method .lro %}
1087
1101
return_value = operations_pb2.Operation(name='operations/spam')
1088
- {% elif method .server_streaming %}
1089
- return_value = iter([{{ method.output.ident }}()])
1090
1102
{% else %}
1091
1103
return_value = {{ method.output.ident }}()
1092
1104
{% endif %}
@@ -1114,6 +1126,8 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
1114
1126
json_return_value = ''
1115
1127
{% elif method .lro %}
1116
1128
json_return_value = json_format.MessageToJson(return_value)
1129
+ {% elif method .server_streaming %}
1130
+ json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
1117
1131
{% else %}
1118
1132
json_return_value = {{ method.output.ident }}.to_json(return_value)
1119
1133
{% endif %}
@@ -1122,6 +1136,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
1122
1136
1123
1137
{% if method .client_streaming %}
1124
1138
response = client.{{ method.name|snake_case }}(iter(requests))
1139
+ {% elif method .server_streaming %}
1140
+ with mock.patch.object(response_value, 'iter_content') as iter_content:
1141
+ iter_content.return_value = iter(json_return_value)
1142
+ response = client.{{ method_name }}(request)
1125
1143
{% else %}
1126
1144
response = client.{{ method_name }}(request)
1127
1145
{% endif %}
@@ -1248,8 +1266,6 @@ def test_{{ method.name|snake_case }}_rest_flattened():
1248
1266
return_value = None
1249
1267
{% elif method .lro %}
1250
1268
return_value = operations_pb2.Operation(name='operations/spam')
1251
- {% elif method .server_streaming %}
1252
- return_value = iter([{{ method.output.ident }}()])
1253
1269
{% else %}
1254
1270
return_value = {{ method.output.ident }}()
1255
1271
{% endif %}
@@ -1261,6 +1277,8 @@ def test_{{ method.name|snake_case }}_rest_flattened():
1261
1277
json_return_value = ''
1262
1278
{% elif method .lro %}
1263
1279
json_return_value = json_format.MessageToJson(return_value)
1280
+ {% elif method .server_streaming %}
1281
+ json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
1264
1282
{% else %}
1265
1283
json_return_value = {{ method.output.ident }}.to_json(return_value)
1266
1284
{% endif %}
@@ -1281,7 +1299,14 @@ def test_{{ method.name|snake_case }}_rest_flattened():
1281
1299
{% endfor %}
1282
1300
)
1283
1301
mock_args.update(sample_request)
1302
+
1303
+ {% if method .server_streaming %}
1304
+ with mock.patch.object(response_value, 'iter_content') as iter_content:
1305
+ iter_content.return_value = iter(json_return_value)
1306
+ client.{{ method_name }}(**mock_args)
1307
+ {% else %}
1284
1308
client.{{ method_name }}(**mock_args)
1309
+ {% endif %}
1285
1310
1286
1311
# Establish that the underlying call was made with the expected
1287
1312
# request object values.
@@ -1385,6 +1410,9 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
1385
1410
response = tuple({{ method.output.ident }}.to_json(x) for x in response)
1386
1411
return_values = tuple(Response() for i in response)
1387
1412
for return_val, response_val in zip(return_values, response):
1413
+ {% if method .server_streaming %}
1414
+ response_val = "[{}]".format({{ method.output.ident }}.to_json(response_val))
1415
+ {% endif %}
1388
1416
return_val._content = response_val.encode('UTF-8')
1389
1417
return_val.status_code = 200
1390
1418
req.side_effect = return_values
0 commit comments