From 4aa4f1cc8c63ca7667ad07a8eef28b5d2bb18cf0 Mon Sep 17 00:00:00 2001 From: Vasily Gerasimov Date: Thu, 11 Jul 2024 14:20:36 +0000 Subject: [PATCH] Load OAuth 2.0 token exchange credentials provider from config file --- tests/aio/test_credentials.py | 59 +- tests/auth/test_credentials.py | 6 +- .../test_token_exchange.py | 661 +++++++++++++++++- ydb/driver.py | 7 + ydb/oauth2_token_exchange/token_exchange.py | 172 ++++- 5 files changed, 870 insertions(+), 35 deletions(-) diff --git a/tests/aio/test_credentials.py b/tests/aio/test_credentials.py index 6e1fb316..a6f1d170 100644 --- a/tests/aio/test_credentials.py +++ b/tests/aio/test_credentials.py @@ -2,6 +2,9 @@ import time import grpc import threading +import tempfile +import os +import json import tests.auth.test_credentials import tests.oauth2_token_exchange @@ -11,7 +14,7 @@ import ydb.oauth2_token_exchange.token_source -class TestServiceAccountCredentials(ydb.aio.iam.ServiceAccountCredentials): +class ServiceAccountCredentialsForTest(ydb.aio.iam.ServiceAccountCredentials): def _channel_factory(self): return grpc.aio.insecure_channel(self._iam_endpoint) @@ -19,7 +22,7 @@ def get_expire_time(self): return self._expires_in - time.time() -class TestOauth2TokenExchangeCredentials(ydb.aio.oauth2_token_exchange.Oauth2TokenExchangeCredentials): +class Oauth2TokenExchangeCredentialsForTest(ydb.aio.oauth2_token_exchange.Oauth2TokenExchangeCredentials): def get_expire_time(self): return self._expires_in - time.time() @@ -27,7 +30,7 @@ def get_expire_time(self): @pytest.mark.asyncio async def test_yandex_service_account_credentials(): server = tests.auth.test_credentials.IamTokenServiceTestServer() - credentials = TestServiceAccountCredentials( + credentials = ServiceAccountCredentialsForTest( tests.auth.test_credentials.SERVICE_ACCOUNT_ID, tests.auth.test_credentials.ACCESS_KEY_ID, tests.auth.test_credentials.PRIVATE_KEY, @@ -49,7 +52,7 @@ def serve(s): serve_thread = threading.Thread(target=serve, args=(server,)) serve_thread.start() - credentials = TestOauth2TokenExchangeCredentials( + credentials = Oauth2TokenExchangeCredentialsForTest( server.endpoint(), ydb.oauth2_token_exchange.token_source.FixedTokenSource("test_src_token", "test_token_type"), audience=["a1", "a2"], @@ -60,3 +63,51 @@ def serve(s): assert credentials.get_expire_time() <= 42 serve_thread.join() + + +@pytest.mark.asyncio +async def test_oauth2_token_exchange_credentials_file(): + server = tests.oauth2_token_exchange.test_token_exchange.Oauth2TokenExchangeServiceForTest(40124) + + def serve(s): + s.handle_request() + + serve_thread = threading.Thread(target=serve, args=(server,)) + serve_thread.start() + + cfg = { + "subject-credentials": { + "type": "FIXED", + "token": "test_src_token", + "token-type": "test_token_type", + }, + "aud": [ + "a1", + "a2", + ], + "scope": [ + "s1", + "s2", + ], + } + + temp_cfg_file = tempfile.NamedTemporaryFile(delete=False) + cfg_file_name = temp_cfg_file.name + + try: + temp_cfg_file.write(json.dumps(cfg, indent=4).encode("utf-8")) + temp_cfg_file.close() + + credentials = Oauth2TokenExchangeCredentialsForTest.from_file( + cfg_file=cfg_file_name, iam_endpoint=server.endpoint() + ) + + t = (await credentials.auth_metadata())[0][1] + assert t == "Bearer test_dst_token" + assert credentials.get_expire_time() <= 42 + + serve_thread.join() + os.remove(cfg_file_name) + except Exception: + os.remove(cfg_file_name) + raise diff --git a/tests/auth/test_credentials.py b/tests/auth/test_credentials.py index 8b1d3e68..a78040ce 100644 --- a/tests/auth/test_credentials.py +++ b/tests/auth/test_credentials.py @@ -57,7 +57,7 @@ def get_endpoint(self): return "localhost:54321" -class TestServiceAccountCredentials(ydb.iam.ServiceAccountCredentials): +class ServiceAccountCredentialsForTest(ydb.iam.ServiceAccountCredentials): def _channel_factory(self): return grpc.insecure_channel(self._iam_endpoint) @@ -67,7 +67,9 @@ def get_expire_time(self): def test_yandex_service_account_credentials(): server = IamTokenServiceTestServer() - credentials = TestServiceAccountCredentials(SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.get_endpoint()) + credentials = ServiceAccountCredentialsForTest( + SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.get_endpoint() + ) t = credentials.get_auth_token() assert t == "test_token" assert credentials.get_expire_time() <= 42 diff --git a/tests/oauth2_token_exchange/test_token_exchange.py b/tests/oauth2_token_exchange/test_token_exchange.py index a4bb1bb5..010a5d42 100644 --- a/tests/oauth2_token_exchange/test_token_exchange.py +++ b/tests/oauth2_token_exchange/test_token_exchange.py @@ -4,12 +4,23 @@ import urllib import json import threading +import tempfile +import os +import jwt +import base64 from ydb.oauth2_token_exchange import Oauth2TokenExchangeCredentials, FixedTokenSource +from ydb.driver import credentials_from_env_variables -class TestOauth2TokenExchangeCredentials(Oauth2TokenExchangeCredentials): - def get_expire_time(self): - return self._expires_in - time.time() +TEST_RSA_PRIVATE_KEY_CONTENT = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC75/JS3rMcLJxv\nFgpOzF5+2gH+Yig3RE2MTl9uwC0BZKAv6foYr7xywQyWIK+W1cBhz8R4LfFmZo2j\nM0aCvdRmNBdW0EDSTnHLxCsFhoQWLVq+bI5f5jzkcoiioUtaEpADPqwgVULVtN/n\nnPJiZ6/dU30C3jmR6+LUgEntUtWt3eq3xQIn5lG3zC1klBY/HxtfH5Hu8xBvwRQT\nJnh3UpPLj8XwSmriDgdrhR7o6umWyVuGrMKlLHmeivlfzjYtfzO1MOIMG8t2/zxG\nR+xb4Vwks73sH1KruH/0/JMXU97npwpe+Um+uXhpldPygGErEia7abyZB2gMpXqr\nWYKMo02NAgMBAAECggEAO0BpC5OYw/4XN/optu4/r91bupTGHKNHlsIR2rDzoBhU\nYLd1evpTQJY6O07EP5pYZx9mUwUdtU4KRJeDGO/1/WJYp7HUdtxwirHpZP0lQn77\nuccuX/QQaHLrPekBgz4ONk+5ZBqukAfQgM7fKYOLk41jgpeDbM2Ggb6QUSsJISEp\nzrwpI/nNT/wn+Hvx4DxrzWU6wF+P8kl77UwPYlTA7GsT+T7eKGVH8xsxmK8pt6lg\nsvlBA5XosWBWUCGLgcBkAY5e4ZWbkdd183o+oMo78id6C+PQPE66PLDtHWfpRRmN\nm6XC03x6NVhnfvfozoWnmS4+e4qj4F/emCHvn0GMywKBgQDLXlj7YPFVXxZpUvg/\nrheVcCTGbNmQJ+4cZXx87huqwqKgkmtOyeWsRc7zYInYgraDrtCuDBCfP//ZzOh0\nLxepYLTPk5eNn/GT+VVrqsy35Ccr60g7Lp/bzb1WxyhcLbo0KX7/6jl0lP+VKtdv\nmto+4mbSBXSM1Y5BVVoVgJ3T/wKBgQDsiSvPRzVi5TTj13x67PFymTMx3HCe2WzH\nJUyepCmVhTm482zW95pv6raDr5CTO6OYpHtc5sTTRhVYEZoEYFTM9Vw8faBtluWG\nBjkRh4cIpoIARMn74YZKj0C/0vdX7SHdyBOU3bgRPHg08Hwu3xReqT1kEPSI/B2V\n4pe5fVrucwKBgQCNFgUxUA3dJjyMES18MDDYUZaRug4tfiYouRdmLGIxUxozv6CG\nZnbZzwxFt+GpvPUV4f+P33rgoCvFU+yoPctyjE6j+0aW0DFucPmb2kBwCu5J/856\nkFwCx3blbwFHAco+SdN7g2kcwgmV2MTg/lMOcU7XwUUcN0Obe7UlWbckzQKBgQDQ\nnXaXHL24GGFaZe4y2JFmujmNy1dEsoye44W9ERpf9h1fwsoGmmCKPp90az5+rIXw\nFXl8CUgk8lXW08db/r4r+ma8Lyx0GzcZyplAnaB5/6j+pazjSxfO4KOBy4Y89Tb+\nTP0AOcCi6ws13bgY+sUTa/5qKA4UVw+c5zlb7nRpgwKBgGXAXhenFw1666482iiN\ncHSgwc4ZHa1oL6aNJR1XWH+aboBSwR+feKHUPeT4jHgzRGo/aCNHD2FE5I8eBv33\nof1kWYjAO0YdzeKrW0rTwfvt9gGg+CS397aWu4cy+mTI+MNfBgeDAIVBeJOJXLlX\nhL8bFAuNNVrCOp79TNnNIsh7\n-----END PRIVATE KEY-----\n" # noqa: E501 +TEST_RSA_PUBLIC_KEY_CONTENT = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu+fyUt6zHCycbxYKTsxe\nftoB/mIoN0RNjE5fbsAtAWSgL+n6GK+8csEMliCvltXAYc/EeC3xZmaNozNGgr3U\nZjQXVtBA0k5xy8QrBYaEFi1avmyOX+Y85HKIoqFLWhKQAz6sIFVC1bTf55zyYmev\n3VN9At45kevi1IBJ7VLVrd3qt8UCJ+ZRt8wtZJQWPx8bXx+R7vMQb8EUEyZ4d1KT\ny4/F8Epq4g4Ha4Ue6OrplslbhqzCpSx5nor5X842LX8ztTDiDBvLdv88RkfsW+Fc\nJLO97B9Sq7h/9PyTF1Pe56cKXvlJvrl4aZXT8oBhKxImu2m8mQdoDKV6q1mCjKNN\njQIDAQAB\n-----END PUBLIC KEY-----\n" # noqa: E501 +TEST_EC_PRIVATE_KEY_CONTENT = "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIB6fv25gf7P/7fkjW/2kcKICUhHeOygkFeUJ/ylyU3hloAoGCCqGSM49\nAwEHoUQDQgAEvkKy92hpLiT0GEpzFkYBEWWnkAGTTA6141H0oInA9X30eS0RObAa\nmVY8yD39NI7Nj03hBxEa4Z0tOhrq9cW8eg==\n-----END EC PRIVATE KEY-----\n" # noqa: E501 +TEST_EC_PUBLIC_KEY_CONTENT = "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEvkKy92hpLiT0GEpzFkYBEWWnkAGT\nTA6141H0oInA9X30eS0RObAamVY8yD39NI7Nj03hBxEa4Z0tOhrq9cW8eg==\n-----END PUBLIC KEY-----\n" # noqa: E501 +TEST_HMAC_SECRET_KEY_BASE64_CONTENT = "VGhlIHdvcmxkIGhhcyBjaGFuZ2VkLgpJIHNlZSBpdCBpbiB0aGUgd2F0ZXIuCkkgZmVlbCBpdCBpbiB0aGUgRWFydGguCkkgc21lbGwgaXQgaW4gdGhlIGFpci4KTXVjaCB0aGF0IG9uY2Ugd2FzIGlzIGxvc3QsCkZvciBub25lIG5vdyBsaXZlIHdobyByZW1lbWJlciBpdC4K" # noqa: E501 + + +def get_expire_time(creds): + return creds._expires_in - time.time() TOKEN_EXCHANGE_RESPONSES = { @@ -63,45 +74,197 @@ def get_expire_time(self): } -class OauthExchangeServiceHandler(http.server.BaseHTTPRequestHandler): - def do_POST(self): - assert self.headers["Content-Type"] == "application/x-www-form-urlencoded" - assert self.path == "/token/exchange" - content_length = int(self.headers["Content-Length"]) - post_data = self.rfile.read(content_length).decode("utf8") - print("OauthExchangeServiceHandler.POST data: {}".format(post_data)) - parsed_request = urllib.parse.parse_qs(str(post_data)) +class FixedTokenSourceChecker: + def __init__( + self, + token, + token_type, + ): + self.token = token + self.token_type = token_type + + def check(self, token, token_type): + assert token == self.token + assert token_type == self.token_type + + +class AnyTokenSourceChecker: + def check(self, token, token_type): + assert token != "" + assert token_type != "" + + +class JwtTokenSourceChecker: + def __init__( + self, + alg, + public_key, + key_id=None, + issuer=None, + subject=None, + aud=None, + id=None, + ttl_seconds=3600, + ): + self.alg = alg + self.public_key = public_key + self.key_id = key_id + self.issuer = issuer + self.subject = subject + self.aud = aud + self.id = id + self.ttl_seconds = ttl_seconds + + def check(self, token, token_type): + assert token_type == "urn:ietf:params:oauth:token-type:jwt" + decoded = jwt.decode( + token, + key=self.public_key, + algorithms=[self.alg], + options={ + "require": ["iat", "exp"], + "verify_signature": True, + "verify_aud": False, + "verify_iss": False, + }, + ) + header = jwt.get_unverified_header(token) + assert header.get("kid") == self.key_id + assert header.get("alg") == self.alg + assert decoded.get("iss") == self.issuer + assert decoded.get("sub") == self.subject + assert decoded.get("aud") == self.aud + assert decoded.get("jti") == self.id + assert abs(decoded["iat"] - time.time()) <= 60 + assert abs(decoded["exp"] - decoded["iat"]) == self.ttl_seconds + + +class Oauth2ExchangeServiceChecker: + def __init__( + self, + subject_token_source=None, + actor_token_source=None, + audience=None, + scope=None, + resource=None, + requested_token_type="urn:ietf:params:oauth:token-type:access_token", + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + ): + self.subject_token_source = subject_token_source + self.actor_token_source = actor_token_source + self.audience = audience + self.scope = scope + self.resource = resource + self.requested_token_type = requested_token_type + self.grant_type = grant_type + + def check(self, handler, parsed_request) -> None: + assert handler.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert handler.path == "/token/exchange" + assert len(parsed_request["grant_type"]) == 1 - assert parsed_request["grant_type"][0] == "urn:ietf:params:oauth:grant-type:token-exchange" + assert parsed_request["grant_type"][0] == self.grant_type assert len(parsed_request["requested_token_type"]) == 1 - assert parsed_request["requested_token_type"][0] == "urn:ietf:params:oauth:token-type:access_token" + assert parsed_request["requested_token_type"][0] == self.requested_token_type + + if self.audience is None or len(self.audience) == 0: + assert len(parsed_request.get("audience", [])) == 0 + else: + assert len(parsed_request["audience"]) == len(self.audience) + for i in range(len(self.audience)): + assert parsed_request["audience"][i] == self.audience[i] - assert len(parsed_request["subject_token_type"]) == 1 - assert parsed_request["subject_token_type"][0] == "test_token_type" + if self.scope is None or len(self.scope) == 0: + assert len(parsed_request.get("scope", [])) == 0 + else: + assert len(parsed_request.get("scope", [])) == 1 + assert parsed_request["scope"][0] == " ".join(self.scope) + if self.resource is None or self.resource == "": + assert len(parsed_request.get("resource", [])) == 0 + else: + assert len(parsed_request.get("resource", [])) == 1 + assert parsed_request["resource"][0] == self.resource + + if self.subject_token_source is None: + assert len(parsed_request.get("subject_token", [])) == 0 + assert len(parsed_request.get("subject_token_type", [])) == 0 + else: + assert len(parsed_request.get("subject_token", [])) == 1 + assert len(parsed_request.get("subject_token_type", [])) == 1 + self.subject_token_source.check(parsed_request["subject_token"][0], parsed_request["subject_token_type"][0]) + + if self.actor_token_source is None: + assert len(parsed_request.get("actor_token", [])) == 0 + assert len(parsed_request.get("actor_token_type", [])) == 0 + else: + assert len(parsed_request.get("actor_token", [])) == 1 + assert len(parsed_request.get("actor_token_type", [])) == 1 + self.actor_token_source.check(parsed_request["actor_token"][0], parsed_request["actor_token_type"][0]) + + +class TokenExchangeResponseBySubjectToken: + def __init__(self, responses=TOKEN_EXCHANGE_RESPONSES): + self.responses = responses + + def get_response(self, parsed_request): assert len(parsed_request["subject_token"]) == 1 - responses = TOKEN_EXCHANGE_RESPONSES.get(parsed_request["subject_token"][0]) + responses = self.responses.get(parsed_request["subject_token"][0]) assert responses is not None response_code = responses[0] response = responses[1] + return response_code, response + - assert len(parsed_request["audience"]) == 2 - assert parsed_request["audience"][0] == "a1" - assert parsed_request["audience"][1] == "a2" +class Oauth2TokenExchangeResponse: + def __init__( + self, + response_code, + response, + ): + self.response_code = response_code + self.response = response - assert len(parsed_request["scope"]) == 1 - assert parsed_request["scope"][0] == "s1 s2" + def get_response(self, parsed_request): + return self.response_code, self.response - self.send_response(response_code) - self.send_header("Content-type", "application/json") - self.end_headers() - self.wfile.write(json.dumps(response).encode("utf8")) + +class OauthExchangeServiceHandler(http.server.BaseHTTPRequestHandler): + def do_POST(self): + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length).decode("utf8") + print("OauthExchangeServiceHandler.POST data: {}".format(post_data)) + parsed_request = urllib.parse.parse_qs(str(post_data)) + try: + self.server.checker.check(self, parsed_request) + self.server.checker.check_successful = True + response_code, response = self.server.response.get_response(parsed_request) + + self.send_response(response_code) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode("utf8")) + except Exception as ex: + self.send_response(500) + self.send_header("Content-type", "text/plain") + self.end_headers() + self.wfile.write("Exception during text check: {}".format(ex).encode("utf8")) class Oauth2TokenExchangeServiceForTest(http.server.HTTPServer): - def __init__(self, port): + def __init__(self, port, checker=None, response=None): http.server.HTTPServer.__init__(self, ("localhost", port), OauthExchangeServiceHandler) + self.checker = checker + if self.checker is None: + self.checker = Oauth2ExchangeServiceChecker( + subject_token_source=AnyTokenSourceChecker(), + audience=["a1", "a2"], + scope=["s1", "s2"], + ) + self.response = response + if self.response is None: + self.response = TokenExchangeResponseBySubjectToken() self.port = port def endpoint(self): @@ -116,7 +279,7 @@ def __init__(self, src_token, dst_token, error_text): def run_check(self, server): try: - credentials = TestOauth2TokenExchangeCredentials( + credentials = Oauth2TokenExchangeCredentials( server.endpoint(), subject_token_source=FixedTokenSource(self.src_token, "test_token_type"), audience=["a1", "a2"], @@ -125,7 +288,7 @@ def run_check(self, server): t = credentials.get_auth_token() assert not self.error_text, "Exception is expected. Test: {}".format(self.src_token) assert t == self.dst_token - assert credentials.get_expire_time() <= 42 + assert get_expire_time(credentials) <= 42 except AssertionError: raise except Exception as ex: @@ -181,3 +344,445 @@ def serve(s): raise serve_thread.join() + + +class DataForConfigTest: + def __init__( + self, + cfg=None, # cfg or cfg text + cfg_file=None, + checker=None, + response=None, + init_error_text_part=None, + get_token_error_text_part=None, + dst_token=None, + dst_expire_time=42, + http_request_is_expected=None, + init_from_env=False, + ): + self.cfg = cfg + self.cfg_file = cfg_file + self.checker = checker + self.response = response + self.init_error_text_part = init_error_text_part + self.get_token_error_text_part = get_token_error_text_part + self.dst_token = dst_token + self.dst_expire_time = dst_expire_time + self.http_request_is_expected = http_request_is_expected + self.init_from_env = init_from_env + + def get_cfg(self): + if isinstance(self.cfg, str): + return self.cfg + else: + return json.dumps(self.cfg, indent=4) + + def expect_http_request(self): + if self.http_request_is_expected is not None: + return self.http_request_is_expected + + if self.init_error_text_part is not None: + return False + return True + + def run_check(self, server): + server.checker = self.checker + if server.checker is not None: + server.checker.check_successful = False + server.response = self.response + + if self.cfg_file: + cfg_file = self.cfg_file + + def rm_file(): + pass + + else: + temp_cfg_file = tempfile.NamedTemporaryFile(delete=False) + cfg_file = temp_cfg_file.name + temp_cfg_file.write(self.get_cfg().encode("utf-8")) + temp_cfg_file.close() + + def rm_file(): + os.remove(cfg_file) + if self.init_from_env: + del os.environ["YDB_OAUTH2_KEY_FILE"] + + try: + if self.init_from_env: + os.environ["YDB_OAUTH2_KEY_FILE"] = cfg_file + credentials = credentials_from_env_variables() + else: + credentials = Oauth2TokenExchangeCredentials.from_file( + cfg_file, + iam_endpoint=server.endpoint(), + ) + assert self.init_error_text_part is None, "Init exception is expected. Test:\n{}".format(self.get_cfg()) + + t = credentials.get_auth_token() + assert not self.get_token_error_text_part, "Exception is expected. Test:\n{}".format(self.get_cfg()) + if self.expect_http_request() and server.checker is not None: + assert server.checker.check_successful + + assert t == self.dst_token + assert get_expire_time(credentials) <= self.dst_expire_time + rm_file() + except AssertionError: + rm_file() + raise + except Exception as ex: + rm_file() + err_text = self.init_error_text_part if self.init_error_text_part else self.get_token_error_text_part + if err_text: + assert err_text in str(ex) + else: + assert False, "Exception is not expected. Test:\n{}. Exception text: {}".format(self.get_cfg(), ex) + + +def test_oauth2_token_exchange_credentials_file(): + server = Oauth2TokenExchangeServiceForTest(40124) + + tests = [ + DataForConfigTest(cfg="not json config", init_error_text_part="Failed to parse json config"), + DataForConfigTest( + init_from_env=True, + cfg={ + "res": "tEst", + "grant-type": "grant", + "requested-token-type": "access_token", + "subject-credentials": { + "type": "fixed", + "token": "test-token", + "token-type": "test-token-type", + }, + }, + init_error_text_part="no token endpoint specified", + ), + DataForConfigTest( + init_from_env=True, + cfg={ + "token-endpoint": server.endpoint(), + "res": "tEst", + "grant-type": "grant", + "requested-token-type": "access_token", + "subject-credentials": { + "type": "fixed", + "token": "test-token", + "token-type": "test-token-type", + }, + }, + checker=Oauth2ExchangeServiceChecker( + subject_token_source=FixedTokenSourceChecker( + token="test-token", + token_type="test-token-type", + ), + grant_type="grant", + requested_token_type="access_token", + resource="tEst", + ), + response=Oauth2TokenExchangeResponse( + 200, + { + "access_token": "test_dst_token", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + "expires_in": 42, + }, + ), + dst_token="Bearer test_dst_token", + ), + DataForConfigTest( + cfg={ + "aud": "test-aud", + "scope": [ + "s1", + "s2", + ], + "unknown-field": [123], + "actor-credentials": { + "type": "fixed", + "token": "test-token", + "token-type": "test-token-type", + }, + }, + checker=Oauth2ExchangeServiceChecker( + actor_token_source=FixedTokenSourceChecker( + token="test-token", + token_type="test-token-type", + ), + audience=["test-aud"], + scope=["s1", "s2"], + ), + response=Oauth2TokenExchangeResponse( + 200, + { + "access_token": "test_dst_token", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + "expires_in": 42, + }, + ), + dst_token="Bearer test_dst_token", + ), + DataForConfigTest( + cfg={ + "requested-token-type": "access_token", + "subject-credentials": { + "type": "JWT", + "alg": "ps256", + "private-key": TEST_RSA_PRIVATE_KEY_CONTENT, + "aud": ["a1", "a2"], + "jti": "123", + "sub": "test_subject", + "iss": "test_issuer", + "kid": "test_key_id", + "ttl": "24h", + "unknown_field": "hello!", + }, + }, + checker=Oauth2ExchangeServiceChecker( + subject_token_source=JwtTokenSourceChecker( + alg="PS256", + public_key=TEST_RSA_PUBLIC_KEY_CONTENT, + aud=["a1", "a2"], + id="123", + subject="test_subject", + issuer="test_issuer", + key_id="test_key_id", + ttl_seconds=24 * 3600, + ), + requested_token_type="access_token", + ), + response=Oauth2TokenExchangeResponse( + 200, + { + "access_token": "test_dst_token", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "bearer", + "expires_in": 42, + }, + ), + dst_token="Bearer test_dst_token", + ), + DataForConfigTest( + cfg={ + "actor-credentials": { + "type": "JWT", + "alg": "es256", + "private-key": TEST_EC_PRIVATE_KEY_CONTENT, + "ttl": "3m", + "unknown_field": "hello!", + }, + }, + checker=Oauth2ExchangeServiceChecker( + actor_token_source=JwtTokenSourceChecker( + alg="ES256", + public_key=TEST_EC_PUBLIC_KEY_CONTENT, + ttl_seconds=180, + ), + ), + response=Oauth2TokenExchangeResponse( + 200, + { + "access_token": "test_dst_token", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "bearer", + "expires_in": 42, + }, + ), + dst_token="Bearer test_dst_token", + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "alg": "hs512", + "private-key": TEST_HMAC_SECRET_KEY_BASE64_CONTENT, + }, + }, + checker=Oauth2ExchangeServiceChecker( + subject_token_source=JwtTokenSourceChecker( + alg="HS512", + public_key=base64.b64decode(TEST_HMAC_SECRET_KEY_BASE64_CONTENT), + ), + ), + response=Oauth2TokenExchangeResponse( + 200, + { + "access_token": "test_dst_token", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "bearer", + "expires_in": 42, + }, + ), + dst_token="Bearer test_dst_token", + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "alg": "rs512", + "private-key": TEST_HMAC_SECRET_KEY_BASE64_CONTENT, + }, + }, + http_request_is_expected=False, + get_token_error_text_part="Could not deserialize key data.", + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "alg": "es512", + "private-key": TEST_HMAC_SECRET_KEY_BASE64_CONTENT, + }, + }, + http_request_is_expected=False, + get_token_error_text_part="Could not deserialize key data.", + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "alg": "es512", + "private-key": TEST_RSA_PRIVATE_KEY_CONTENT, + }, + }, + http_request_is_expected=False, + get_token_error_text_part="sign() missing 1 required positional argument", + ), + DataForConfigTest( + cfg_file="~/unknown-file.cfg", + init_error_text_part="No such file or directory", + ), + DataForConfigTest( + cfg={ + "actor-credentials": "", + }, + init_error_text_part='Key "actor-credentials" is expected to be a json map', + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "alg": "RS256", + "private-key": "123", + "ttl": 123, + }, + }, + init_error_text_part='Key "ttl" is expected to be a string', + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "alg": "RS256", + "private-key": "123", + "ttl": "-3h", + }, + }, + init_error_text_part="-3: negative duration is not allowed", + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "private-key": "123", + }, + }, + init_error_text_part='Key "alg" is expected to be a nonempty string', + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "alg": "HS384", + }, + }, + init_error_text_part='Key "private-key" is expected to be a nonempty string', + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "alg": "unknown", + "private-key": "123", + }, + }, + http_request_is_expected=False, + get_token_error_text_part="Algorithm not supported.", + ), + DataForConfigTest( + cfg={ + "aud": { + "value": "wrong_format of aud: not string and not list", + }, + "subject-credentials": { + "type": "FIXED", + "token": "test-token", + "token-type": "test-token-type", + }, + }, + init_error_text_part='Key "aud" is expected to be a single string or list of nonempty strings', + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "JWT", + "alg": "RS256", + "private-key": "123", + "aud": { + "value": "wrong_format of aud: not string and not list", + }, + }, + }, + init_error_text_part='Key "aud" is expected to be a single string or list of nonempty strings', + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "unknown", + }, + }, + init_error_text_part='"subject-credentials": unknown token source type: "unknown"', + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "token": "test", + }, + }, + init_error_text_part='Key "type" is expected to be a nonempty string', + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "FIXED", + "token": "test", + }, + }, + init_error_text_part='Key "token-type" is expected to be a nonempty string', + ), + DataForConfigTest( + cfg={ + "subject-credentials": { + "type": "FIXED", + "token-type": "test", + }, + }, + init_error_text_part='Key "token" is expected to be a nonempty string', + ), + ] + + def serve(s): + for t in tests: + # one request per test + if t.expect_http_request(): + s.handle_request() + + serve_thread = threading.Thread(target=serve, args=(server,)) + serve_thread.start() + + for t in tests: + t.run_check(server) + + serve_thread.join() diff --git a/ydb/driver.py b/ydb/driver.py index 89109b9b..ecd3319e 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -54,6 +54,13 @@ def credentials_from_env_variables(tracer=None): ctx.trace({"credentials.access_token": True}) return credentials_impl.AuthTokenCredentials(access_token) + oauth2_key_file = os.getenv("YDB_OAUTH2_KEY_FILE") + if oauth2_key_file: + ctx.trace({"credentials.oauth2_key_file": True}) + import ydb.oauth2_token_exchange + + return ydb.oauth2_token_exchange.Oauth2TokenExchangeCredentials.from_file(oauth2_key_file) + ctx.trace( { "credentials.env_default": True, diff --git a/ydb/oauth2_token_exchange/token_exchange.py b/ydb/oauth2_token_exchange/token_exchange.py index e4d1db94..8f16619d 100644 --- a/ydb/oauth2_token_exchange/token_exchange.py +++ b/ydb/oauth2_token_exchange/token_exchange.py @@ -2,6 +2,8 @@ import typing import json import abc +import os +import base64 try: import requests @@ -9,7 +11,24 @@ requests = None from ydb import credentials, tracing, issues -from .token_source import TokenSource +from .token_source import TokenSource, FixedTokenSource, JwtTokenSource + + +# method -> is HMAC +_supported_uppercase_jwt_algs = { + "HS256": True, + "HS384": True, + "HS512": True, + "RS256": False, + "RS384": False, + "RS512": False, + "PS256": False, + "PS384": False, + "PS512": False, + "ES256": False, + "ES384": False, + "ES512": False, +} class Oauth2TokenExchangeCredentialsBase(abc.ABC): @@ -94,6 +113,157 @@ def _make_token_request_params(self): return params + @classmethod + def _jwt_token_source_from_config(cls, cfg_json): + signing_method = cls._required_string_from_config(cfg_json, "alg") + is_hmac = _supported_uppercase_jwt_algs.get(signing_method.upper(), None) + if is_hmac is not None: # we know this method => do uppercase + signing_method = signing_method.upper() + private_key = cls._required_string_from_config(cfg_json, "private-key") + if is_hmac: # decode from base64 + private_key = base64.b64decode(private_key + "===") # to allow unpadded strings + return JwtTokenSource( + signing_method=signing_method, + private_key=private_key, + key_id=cls._string_with_default_from_config(cfg_json, "kid", None), + issuer=cls._string_with_default_from_config(cfg_json, "iss", None), + subject=cls._string_with_default_from_config(cfg_json, "sub", None), + audience=cls._list_of_strings_or_single_from_config(cfg_json, "aud"), + id=cls._string_with_default_from_config(cfg_json, "jti", None), + token_ttl_seconds=cls._duration_seconds_from_config(cfg_json, "ttl", 3600), + ) + + @classmethod + def _fixed_token_source_from_config(cls, cfg_json): + return FixedTokenSource( + cls._required_string_from_config(cfg_json, "token"), + cls._required_string_from_config(cfg_json, "token-type"), + ) + + @classmethod + def _token_source_from_config(cls, cfg_json, key_name): + value = cfg_json.get(key_name, None) + if value is None: + return None + if not isinstance(value, dict): + raise Exception('Key "{}" is expected to be a json map'.format(key_name)) + + source_type = cls._required_string_from_config(value, "type") + if source_type.upper() == "FIXED": + return cls._fixed_token_source_from_config(value) + if source_type.upper() == "JWT": + return cls._jwt_token_source_from_config(value) + raise Exception('"{}": unknown token source type: "{}"'.format(key_name, source_type)) + + @classmethod + def _list_of_strings_or_single_from_config(cls, cfg_json, key_name): + value = cfg_json.get(key_name, None) + if value is None: + return None + if isinstance(value, list): + for val in value: + if not isinstance(val, str) or not val: + raise Exception( + 'Key "{}" is expected to be a single string or list of nonempty strings'.format(key_name) + ) + return value + else: + if isinstance(value, str): + return value + raise Exception('Key "{}" is expected to be a single string or list of nonempty strings'.format(key_name)) + + @classmethod + def _required_string_from_config(cls, cfg_json, key_name): + value = cfg_json.get(key_name, None) + if value is None or not isinstance(value, str) or not value: + raise Exception('Key "{}" is expected to be a nonempty string'.format(key_name)) + return value + + @classmethod + def _string_with_default_from_config(cls, cfg_json, key_name, default_value): + value = cfg_json.get(key_name, None) + if value is None: + return default_value + if not isinstance(value, str): + raise Exception('Key "{}" is expected to be a string'.format(key_name)) + return value + + @classmethod + def _duration_seconds_from_config(cls, cfg_json, key_name, default_value): + value = cfg_json.get(key_name, None) + if value is None: + return default_value + if not isinstance(value, str): + raise Exception('Key "{}" is expected to be a string'.format(key_name)) + multiplier = 1 + if value.endswith("s"): + multiplier = 1 + value = value[:-1] + elif value.endswith("m"): + multiplier = 60 + value = value[:-1] + elif value.endswith("h"): + multiplier = 3600 + value = value[:-1] + elif value.endswith("d"): + multiplier = 3600 * 24 + value = value[:-1] + elif value.endswith("ms"): + multiplier = 1.0 / 1000 + value = value[:-2] + elif value.endswith("us"): + multiplier = 1.0 / 1000000 + value = value[:-2] + elif value.endswith("ns"): + multiplier = 1.0 / 1000000000 + value = value[:-2] + f = float(value) + if f < 0.0: + raise Exception("{}: negative duration is not allowed".format(value)) + return int(f * multiplier) + + @classmethod + def from_file(cls, cfg_file, iam_endpoint=None): + with open(os.path.expanduser(cfg_file), "r") as r: + cfg = r.read() + + return cls.from_content(cfg, iam_endpoint=iam_endpoint) + + @classmethod + def from_content(cls, cfg, iam_endpoint=None): + try: + cfg_json = json.loads(cfg) + except Exception as ex: + raise Exception("Failed to parse json config: {}".format(ex)) + + if iam_endpoint is not None: + token_endpoint = iam_endpoint + else: + token_endpoint = cfg_json.get("token-endpoint", "") + + subject_token_source = cls._token_source_from_config(cfg_json, "subject-credentials") + actor_token_source = cls._token_source_from_config(cfg_json, "actor-credentials") + audience = cls._list_of_strings_or_single_from_config(cfg_json, "aud") + scope = cls._list_of_strings_or_single_from_config(cfg_json, "scope") + resource = cls._string_with_default_from_config(cfg_json, "res", None) + grant_type = cls._string_with_default_from_config( + cfg_json, "grant-type", "urn:ietf:params:oauth:grant-type:token-exchange" + ) + requested_token_type = cls._string_with_default_from_config( + cfg_json, "requested-token-type", "urn:ietf:params:oauth:token-type:access_token" + ) + + return cls( + token_endpoint=token_endpoint, + subject_token_source=subject_token_source, + actor_token_source=actor_token_source, + audience=audience, + scope=scope, + resource=resource, + grant_type=grant_type, + requested_token_type=requested_token_type, + ) + class Oauth2TokenExchangeCredentials(credentials.AbstractExpiringTokenCredentials, Oauth2TokenExchangeCredentialsBase): def __init__(