Skip to content

Commit 4aa4f1c

Browse files
committed
Load OAuth 2.0 token exchange credentials provider from config file
1 parent 36ef09b commit 4aa4f1c

File tree

5 files changed

+870
-35
lines changed

5 files changed

+870
-35
lines changed

tests/aio/test_credentials.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import time
33
import grpc
44
import threading
5+
import tempfile
6+
import os
7+
import json
58

69
import tests.auth.test_credentials
710
import tests.oauth2_token_exchange
@@ -11,23 +14,23 @@
1114
import ydb.oauth2_token_exchange.token_source
1215

1316

14-
class TestServiceAccountCredentials(ydb.aio.iam.ServiceAccountCredentials):
17+
class ServiceAccountCredentialsForTest(ydb.aio.iam.ServiceAccountCredentials):
1518
def _channel_factory(self):
1619
return grpc.aio.insecure_channel(self._iam_endpoint)
1720

1821
def get_expire_time(self):
1922
return self._expires_in - time.time()
2023

2124

22-
class TestOauth2TokenExchangeCredentials(ydb.aio.oauth2_token_exchange.Oauth2TokenExchangeCredentials):
25+
class Oauth2TokenExchangeCredentialsForTest(ydb.aio.oauth2_token_exchange.Oauth2TokenExchangeCredentials):
2326
def get_expire_time(self):
2427
return self._expires_in - time.time()
2528

2629

2730
@pytest.mark.asyncio
2831
async def test_yandex_service_account_credentials():
2932
server = tests.auth.test_credentials.IamTokenServiceTestServer()
30-
credentials = TestServiceAccountCredentials(
33+
credentials = ServiceAccountCredentialsForTest(
3134
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
3235
tests.auth.test_credentials.ACCESS_KEY_ID,
3336
tests.auth.test_credentials.PRIVATE_KEY,
@@ -49,7 +52,7 @@ def serve(s):
4952
serve_thread = threading.Thread(target=serve, args=(server,))
5053
serve_thread.start()
5154

52-
credentials = TestOauth2TokenExchangeCredentials(
55+
credentials = Oauth2TokenExchangeCredentialsForTest(
5356
server.endpoint(),
5457
ydb.oauth2_token_exchange.token_source.FixedTokenSource("test_src_token", "test_token_type"),
5558
audience=["a1", "a2"],
@@ -60,3 +63,51 @@ def serve(s):
6063
assert credentials.get_expire_time() <= 42
6164

6265
serve_thread.join()
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_oauth2_token_exchange_credentials_file():
70+
server = tests.oauth2_token_exchange.test_token_exchange.Oauth2TokenExchangeServiceForTest(40124)
71+
72+
def serve(s):
73+
s.handle_request()
74+
75+
serve_thread = threading.Thread(target=serve, args=(server,))
76+
serve_thread.start()
77+
78+
cfg = {
79+
"subject-credentials": {
80+
"type": "FIXED",
81+
"token": "test_src_token",
82+
"token-type": "test_token_type",
83+
},
84+
"aud": [
85+
"a1",
86+
"a2",
87+
],
88+
"scope": [
89+
"s1",
90+
"s2",
91+
],
92+
}
93+
94+
temp_cfg_file = tempfile.NamedTemporaryFile(delete=False)
95+
cfg_file_name = temp_cfg_file.name
96+
97+
try:
98+
temp_cfg_file.write(json.dumps(cfg, indent=4).encode("utf-8"))
99+
temp_cfg_file.close()
100+
101+
credentials = Oauth2TokenExchangeCredentialsForTest.from_file(
102+
cfg_file=cfg_file_name, iam_endpoint=server.endpoint()
103+
)
104+
105+
t = (await credentials.auth_metadata())[0][1]
106+
assert t == "Bearer test_dst_token"
107+
assert credentials.get_expire_time() <= 42
108+
109+
serve_thread.join()
110+
os.remove(cfg_file_name)
111+
except Exception:
112+
os.remove(cfg_file_name)
113+
raise

tests/auth/test_credentials.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_endpoint(self):
5757
return "localhost:54321"
5858

5959

60-
class TestServiceAccountCredentials(ydb.iam.ServiceAccountCredentials):
60+
class ServiceAccountCredentialsForTest(ydb.iam.ServiceAccountCredentials):
6161
def _channel_factory(self):
6262
return grpc.insecure_channel(self._iam_endpoint)
6363

@@ -67,7 +67,9 @@ def get_expire_time(self):
6767

6868
def test_yandex_service_account_credentials():
6969
server = IamTokenServiceTestServer()
70-
credentials = TestServiceAccountCredentials(SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.get_endpoint())
70+
credentials = ServiceAccountCredentialsForTest(
71+
SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.get_endpoint()
72+
)
7173
t = credentials.get_auth_token()
7274
assert t == "test_token"
7375
assert credentials.get_expire_time() <= 42

0 commit comments

Comments
 (0)