12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import collections
15
+
16
16
import grpc
17
17
from unittest import mock
18
18
import os
19
19
import pytest
20
20
21
+ from typing import Sequence , Tuple
22
+
21
23
from google .api_core .client_options import ClientOptions # type: ignore
24
+ from google .showcase_v1beta1 .services .echo .transports import EchoRestInterceptor
22
25
23
26
try :
24
27
from google .auth .aio import credentials as ga_credentials_async
42
45
try :
43
46
from google .showcase_v1beta1 .services .echo .transports import (
44
47
AsyncEchoRestTransport ,
48
+ AsyncEchoRestInterceptor ,
45
49
)
46
50
47
51
HAS_ASYNC_REST_ECHO_TRANSPORT = True
@@ -248,7 +252,51 @@ def messaging(use_mtls, request):
248
252
return construct_client (MessagingClient , use_mtls , transport_name = request .param )
249
253
250
254
251
- class MetadataClientInterceptor (
255
+ class EchoMetadataClientRestInterceptor (EchoRestInterceptor ):
256
+ request_metadata : Sequence [Tuple [str , str ]] = []
257
+ response_metadata : Sequence [Tuple [str , str ]] = []
258
+
259
+ def pre_echo (self , request , metadata ):
260
+ self .request_metadata = metadata
261
+ return request , metadata
262
+
263
+ def post_echo_with_metadata (self , request , metadata ):
264
+ self .response_metadata = metadata
265
+ return request , metadata
266
+
267
+ def pre_expand (self , request , metadata ):
268
+ self .request_metadata = metadata
269
+ return request , metadata
270
+
271
+ def post_expand_with_metadata (self , request , metadata ):
272
+ self .response_metadata = metadata
273
+ return request , metadata
274
+
275
+
276
+ if HAS_ASYNC_REST_ECHO_TRANSPORT :
277
+
278
+ class EchoMetadataClientRestAsyncInterceptor (AsyncEchoRestInterceptor ):
279
+ request_metadata : Sequence [Tuple [str , str ]] = []
280
+ response_metadata : Sequence [Tuple [str , str ]] = []
281
+
282
+ async def pre_echo (self , request , metadata ):
283
+ self .request_metadata = metadata
284
+ return request , metadata
285
+
286
+ async def post_echo_with_metadata (self , request , metadata ):
287
+ self .response_metadata = metadata
288
+ return request , metadata
289
+
290
+ async def pre_expand (self , request , metadata ):
291
+ self .request_metadata = metadata
292
+ return request , metadata
293
+
294
+ async def post_expand_with_metadata (self , request , metadata ):
295
+ self .response_metadata = metadata
296
+ return request , metadata
297
+
298
+
299
+ class EchoMetadataClientGrpcInterceptor (
252
300
grpc .UnaryUnaryClientInterceptor ,
253
301
grpc .UnaryStreamClientInterceptor ,
254
302
grpc .StreamUnaryClientInterceptor ,
@@ -257,42 +305,94 @@ class MetadataClientInterceptor(
257
305
def __init__ (self , key , value ):
258
306
self ._key = key
259
307
self ._value = value
308
+ self .request_metadata = []
309
+ self .response_metadata = []
260
310
261
- def _add_metadata (self , client_call_details ):
311
+ def _add_request_metadata (self , client_call_details ):
262
312
if client_call_details .metadata is not None :
263
313
client_call_details .metadata .append ((self ._key , self ._value ))
314
+ self .request_metadata = client_call_details .metadata
264
315
265
316
def intercept_unary_unary (self , continuation , client_call_details , request ):
266
- self ._add_metadata (client_call_details )
317
+ self ._add_request_metadata (client_call_details )
267
318
response = continuation (client_call_details , request )
319
+ metadata = [(k , str (v )) for k , v in response .trailing_metadata ()]
320
+ self .response_metadata = metadata
268
321
return response
269
322
270
323
def intercept_unary_stream (self , continuation , client_call_details , request ):
271
- self ._add_metadata (client_call_details )
324
+ self ._add_request_metadata (client_call_details )
272
325
response_it = continuation (client_call_details , request )
273
326
return response_it
274
327
275
328
def intercept_stream_unary (
276
329
self , continuation , client_call_details , request_iterator
277
330
):
278
- self ._add_metadata (client_call_details )
331
+ self ._add_request_metadata (client_call_details )
279
332
response = continuation (client_call_details , request_iterator )
280
333
return response
281
334
282
335
def intercept_stream_stream (
283
336
self , continuation , client_call_details , request_iterator
284
337
):
285
- self ._add_metadata (client_call_details )
338
+ self ._add_request_metadata (client_call_details )
339
+ response_it = continuation (client_call_details , request_iterator )
340
+ return response_it
341
+
342
+
343
+ class EchoMetadataClientGrpcAsyncInterceptor (
344
+ grpc .aio .UnaryUnaryClientInterceptor ,
345
+ grpc .aio .UnaryStreamClientInterceptor ,
346
+ grpc .aio .StreamUnaryClientInterceptor ,
347
+ grpc .aio .StreamStreamClientInterceptor ,
348
+ ):
349
+ def __init__ (self , key , value ):
350
+ self ._key = key
351
+ self ._value = value
352
+ self .request_metadata = []
353
+ self .response_metadata = []
354
+
355
+ async def _add_request_metadata (self , client_call_details ):
356
+ if client_call_details .metadata is not None :
357
+ client_call_details .metadata .append ((self ._key , self ._value ))
358
+ self .request_metadata = client_call_details .metadata
359
+
360
+ async def intercept_unary_unary (self , continuation , client_call_details , request ):
361
+ await self ._add_request_metadata (client_call_details )
362
+ response = await continuation (client_call_details , request )
363
+ metadata = [(k , str (v )) for k , v in await response .trailing_metadata ()]
364
+ self .response_metadata = metadata
365
+ return response
366
+
367
+ async def intercept_unary_stream (self , continuation , client_call_details , request ):
368
+ self ._add_request_metadata (client_call_details )
369
+ response_it = continuation (client_call_details , request )
370
+ return response_it
371
+
372
+ async def intercept_stream_unary (
373
+ self , continuation , client_call_details , request_iterator
374
+ ):
375
+ self ._add_request_metadata (client_call_details )
376
+ response = continuation (client_call_details , request_iterator )
377
+ return response
378
+
379
+ async def intercept_stream_stream (
380
+ self , continuation , client_call_details , request_iterator
381
+ ):
382
+ self ._add_request_metadata (client_call_details )
286
383
response_it = continuation (client_call_details , request_iterator )
287
384
return response_it
288
385
289
386
290
387
@pytest .fixture
291
- def intercepted_echo (use_mtls ):
388
+ def intercepted_echo_grpc (use_mtls ):
292
389
# The interceptor adds 'showcase-trailer' client metadata. Showcase server
293
- # echos any metadata with key 'showcase-trailer', so the same metadata
390
+ # echoes any metadata with key 'showcase-trailer', so the same metadata
294
391
# should appear as trailing metadata in the response.
295
- interceptor = MetadataClientInterceptor ("showcase-trailer" , "intercepted" )
392
+ interceptor = EchoMetadataClientGrpcInterceptor (
393
+ "showcase-trailer" ,
394
+ "intercepted" ,
395
+ )
296
396
host = "localhost:7469"
297
397
channel = (
298
398
grpc .secure_channel (host , ssl_credentials )
@@ -304,4 +404,58 @@ def intercepted_echo(use_mtls):
304
404
credentials = ga_credentials .AnonymousCredentials (),
305
405
channel = intercept_channel ,
306
406
)
307
- return EchoClient (transport = transport )
407
+ return EchoClient (transport = transport ), interceptor
408
+
409
+
410
+ @pytest .fixture
411
+ def intercepted_echo_grpc_async ():
412
+ # The interceptor adds 'showcase-trailer' client metadata. Showcase server
413
+ # echoes any metadata with key 'showcase-trailer', so the same metadata
414
+ # should appear as trailing metadata in the response.
415
+ interceptor = EchoMetadataClientGrpcAsyncInterceptor (
416
+ "showcase-trailer" ,
417
+ "intercepted" ,
418
+ )
419
+ host = "localhost:7469"
420
+ channel = grpc .aio .insecure_channel (host , interceptors = [interceptor ])
421
+ # intercept_channel = grpc.aio.intercept_channel(channel, interceptor)
422
+ transport = EchoAsyncClient .get_transport_class ("grpc_asyncio" )(
423
+ credentials = ga_credentials .AnonymousCredentials (),
424
+ channel = channel ,
425
+ )
426
+ return EchoAsyncClient (transport = transport ), interceptor
427
+
428
+
429
+ @pytest .fixture
430
+ def intercepted_echo_rest ():
431
+ transport_name = "rest"
432
+ transport_cls = EchoClient .get_transport_class (transport_name )
433
+ interceptor = EchoMetadataClientRestInterceptor ()
434
+
435
+ # The custom host explicitly bypasses https.
436
+ transport = transport_cls (
437
+ credentials = ga_credentials .AnonymousCredentials (),
438
+ host = "localhost:7469" ,
439
+ url_scheme = "http" ,
440
+ interceptor = interceptor ,
441
+ )
442
+ return EchoClient (transport = transport ), interceptor
443
+
444
+
445
+ @pytest .fixture
446
+ def intercepted_echo_rest_async ():
447
+ if not HAS_ASYNC_REST_ECHO_TRANSPORT :
448
+ pytest .skip ("Skipping test with async rest." )
449
+
450
+ transport_name = "rest_asyncio"
451
+ transport_cls = EchoAsyncClient .get_transport_class (transport_name )
452
+ interceptor = EchoMetadataClientRestAsyncInterceptor ()
453
+
454
+ # The custom host explicitly bypasses https.
455
+ transport = transport_cls (
456
+ credentials = async_anonymous_credentials (),
457
+ host = "localhost:7469" ,
458
+ url_scheme = "http" ,
459
+ interceptor = interceptor ,
460
+ )
461
+ return EchoAsyncClient (transport = transport ), interceptor
0 commit comments