2
2
import time
3
3
import grpc
4
4
import threading
5
+ import tempfile
6
+ import os
7
+ import json
5
8
6
9
import tests .auth .test_credentials
7
10
import tests .oauth2_token_exchange
11
14
import ydb .oauth2_token_exchange .token_source
12
15
13
16
14
- class TestServiceAccountCredentials (ydb .aio .iam .ServiceAccountCredentials ):
17
+ class ServiceAccountCredentialsForTest (ydb .aio .iam .ServiceAccountCredentials ):
15
18
def _channel_factory (self ):
16
19
return grpc .aio .insecure_channel (self ._iam_endpoint )
17
20
18
21
def get_expire_time (self ):
19
22
return self ._expires_in - time .time ()
20
23
21
24
22
- class TestOauth2TokenExchangeCredentials (ydb .aio .oauth2_token_exchange .Oauth2TokenExchangeCredentials ):
25
+ class Oauth2TokenExchangeCredentialsForTest (ydb .aio .oauth2_token_exchange .Oauth2TokenExchangeCredentials ):
23
26
def get_expire_time (self ):
24
27
return self ._expires_in - time .time ()
25
28
26
29
27
30
@pytest .mark .asyncio
28
31
async def test_yandex_service_account_credentials ():
29
32
server = tests .auth .test_credentials .IamTokenServiceTestServer ()
30
- credentials = TestServiceAccountCredentials (
33
+ credentials = ServiceAccountCredentialsForTest (
31
34
tests .auth .test_credentials .SERVICE_ACCOUNT_ID ,
32
35
tests .auth .test_credentials .ACCESS_KEY_ID ,
33
36
tests .auth .test_credentials .PRIVATE_KEY ,
@@ -49,7 +52,7 @@ def serve(s):
49
52
serve_thread = threading .Thread (target = serve , args = (server ,))
50
53
serve_thread .start ()
51
54
52
- credentials = TestOauth2TokenExchangeCredentials (
55
+ credentials = Oauth2TokenExchangeCredentialsForTest (
53
56
server .endpoint (),
54
57
ydb .oauth2_token_exchange .token_source .FixedTokenSource ("test_src_token" , "test_token_type" ),
55
58
audience = ["a1" , "a2" ],
@@ -60,3 +63,51 @@ def serve(s):
60
63
assert credentials .get_expire_time () <= 42
61
64
62
65
serve_thread .join ()
66
+
67
+
68
+ @pytest .mark .asyncio
69
+ async def test_oauth2_token_exchange_credentials_file ():
70
+ server = tests .oauth2_token_exchange .test_token_exchange .Oauth2TokenExchangeServiceForTest (40124 )
71
+
72
+ def serve (s ):
73
+ s .handle_request ()
74
+
75
+ serve_thread = threading .Thread (target = serve , args = (server ,))
76
+ serve_thread .start ()
77
+
78
+ cfg = {
79
+ "subject-credentials" : {
80
+ "type" : "FIXED" ,
81
+ "token" : "test_src_token" ,
82
+ "token-type" : "test_token_type" ,
83
+ },
84
+ "aud" : [
85
+ "a1" ,
86
+ "a2" ,
87
+ ],
88
+ "scope" : [
89
+ "s1" ,
90
+ "s2" ,
91
+ ],
92
+ }
93
+
94
+ temp_cfg_file = tempfile .NamedTemporaryFile (delete = False )
95
+ cfg_file_name = temp_cfg_file .name
96
+
97
+ try :
98
+ temp_cfg_file .write (json .dumps (cfg , indent = 4 ).encode ("utf-8" ))
99
+ temp_cfg_file .close ()
100
+
101
+ credentials = Oauth2TokenExchangeCredentialsForTest .from_file (
102
+ cfg_file = cfg_file_name , iam_endpoint = server .endpoint ()
103
+ )
104
+
105
+ t = (await credentials .auth_metadata ())[0 ][1 ]
106
+ assert t == "Bearer test_dst_token"
107
+ assert credentials .get_expire_time () <= 42
108
+
109
+ serve_thread .join ()
110
+ os .remove (cfg_file_name )
111
+ except Exception :
112
+ os .remove (cfg_file_name )
113
+ raise
0 commit comments