From 40e995a242de140706b04a99b7dbd6d56ab25bb7 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 2 May 2025 22:08:27 +0530 Subject: [PATCH 1/6] chore: Add unit test coverage. --- packages/toolbox-core/tests/test_client.py | 132 +++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 54b925ee..6b6033de 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -699,6 +699,138 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client): bound_async_callable.assert_awaited_once() + @pytest.mark.asyncio + async def test_bind_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_param("argA", 5) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_callable_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_param("argA", lambda: 5) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_param_fail(self, tool_name, client): + """Tests 'bind_param' with a bound parameter that doesn't exist.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + with pytest.raises(Exception) as e: + tool.bind_param("argC", lambda: 5) + assert "unable to bind parameters: no parameter named argC" in str(e.value) + + @pytest.mark.asyncio + async def test_rebind_param_fail(self, tool_name, client): + """ + Tests that 'bind_param' fails when attempting to re-bind a + parameter that has already been bound. + """ + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool_with_bound_param = tool.bind_param("argA", lambda: 10) + + assert len(tool_with_bound_param.__signature__.parameters) == 1 + assert "argA" not in tool_with_bound_param.__signature__.parameters + + with pytest.raises(ValueError) as e: + tool_with_bound_param.bind_param("argA", lambda: 20) + + assert "cannot re-bind parameter: parameter 'argA' is already bound" in str( + e.value + ) + + @pytest.mark.asyncio + async def test_bind_param_static_value_success(self, tool_name, client): + """ + Tests bind_param method with a static value. + """ + + bound_value = "Test value" + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_param("argB", bound_value) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value} + + @pytest.mark.asyncio + async def test_bind_param_sync_callable_value_success(self, tool_name, client): + """ + Tests bind_param method with a sync callable value. + """ + + bound_value_result = True + bound_sync_callable = Mock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_param("argB", bound_sync_callable) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_sync_callable.assert_called_once() + + @pytest.mark.asyncio + async def test_bind_param_async_callable_value_success(self, tool_name, client): + """ + Tests bind_param method with an async callable value. + """ + + bound_value_result = True + bound_async_callable = AsyncMock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_param("argB", bound_async_callable) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_async_callable.assert_awaited_once() + + class TestUnusedParameterValidation: """ Tests for validation errors related to unused auth tokens or bound From a48389ab7275ef085d435b7831584fe2b711d10d Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 5 May 2025 13:35:59 +0530 Subject: [PATCH 2/6] chore: Consolidate auth header creation logic Post adding the feature of adding client-level auth headers (#178), we have the logic for creating an auth header, from the given auth token getter name, in 3 different places. This PR unifies all of that logic into a single helper to improve maintenance, and make it easier to change the way we add suffix/prefix, and reduces WET code. --- packages/toolbox-core/tests/test_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 6b6033de..15176544 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -698,7 +698,6 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client): assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} bound_async_callable.assert_awaited_once() - @pytest.mark.asyncio async def test_bind_param_success(self, tool_name, client): """Tests 'bind_param' with a bound parameter specified.""" From bf10677eda0ebc2918df528a17fa73f004713c55 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 12 May 2025 10:35:28 +0530 Subject: [PATCH 3/6] docs: Add names to return values for better readability --- packages/toolbox-core/src/toolbox_core/utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index a30dc66b..ddbcfffd 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -60,11 +60,15 @@ def identify_required_authn_params( Returns: A tuple containing: - - A new dictionary representing the subset of required - authentication parameters that are not covered by the provided - `auth_service_names`. - - A list of authentication service names from `auth_service_names` - that were found to satisfy at least one parameter's requirements. + - required_authn_params: A new dictionary representing the subset of + required authentication parameters that are not covered by the + provided `auth_service_names`. + - required_authz_tokens: A list of required authorization tokens if + no service name in `auth_service_names` matches any token in + `req_authz_tokens`. If any match is found, this list is empty. + - used_services: A set of service names from `auth_service_names` + that were found to satisfy at least one authentication parameter's + requirements or matched one of the `req_authz_tokens`. """ required_params: dict[str, list[str]] = {} used_services: set[str] = set() From 1e8d265f069b1236689b6f598c86dda7d1135c24 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 12 May 2025 11:07:16 +0530 Subject: [PATCH 4/6] chore: Improve readability --- packages/toolbox-core/tests/test_utils.py | 209 ++++++++++++++++++++-- 1 file changed, 194 insertions(+), 15 deletions(-) diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py index 52cdb38f..03a1c915 100644 --- a/packages/toolbox-core/tests/test_utils.py +++ b/packages/toolbox-core/tests/test_utils.py @@ -88,8 +88,12 @@ def test_identify_required_authn_params_none_required(): auth_service_names = ["service_a", "service_b"] expected = {} expected_used = set() - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) @@ -101,10 +105,15 @@ def test_identify_required_authn_params_all_covered(): "token_b": ["service_b", "service_c"], } auth_service_names = ["service_a", "service_b"] - expected = {} - expected_used = set(auth_service_names) - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + expected_params = {} + expected_authz: list[str] = [] + expected_used = {"service_a", "service_b"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) @@ -122,9 +131,15 @@ def test_identify_required_authn_params_some_covered(): "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } - expected_used = set(auth_service_names) - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + expected_authz: list[str] = [] + expected_used = {"service_a", "service_b"} + + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) @@ -141,8 +156,12 @@ def test_identify_required_authn_params_none_covered(): "token_e": ["service_e", "service_f"], } expected_used = set() - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) @@ -159,8 +178,12 @@ def test_identify_required_authn_params_no_available_services(): "token_b": ["service_b", "service_c"], } expected_used = set() - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) @@ -175,8 +198,164 @@ def test_identify_required_authn_params_empty_services_for_param(): "token_x": [], } expected_used = set() - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, + expected_used, + ) + + +def test_identify_auth_params_only_authz_empty(): + """Test with empty req_authz_tokens and no authn params.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens: list[str] = [] + auth_service_names = ["s1"] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, [], set()) + + +def test_identify_auth_params_authz_all_covered(): + """Test when all req_authz_tokens are covered by auth_service_names.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens = ["s1", "s2"] + auth_service_names = ["s1", "s2", "s3"] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, [], {"s1", "s2"}) + + +def test_identify_auth_params_authz_partially_covered_by_available(): + """Test when some req_authz_tokens are covered.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens = ["s1", "s2"] + auth_service_names = ["s1", "s3"] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, [], {"s1"}) + + +def test_identify_auth_params_authz_none_covered(): + """Test when none of req_authz_tokens are covered by auth_service_names.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens = ["s1", "s2"] + auth_service_names = ["s3", "s4"] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, ["s1", "s2"], set()) + + +def test_identify_auth_params_authz_none_covered_empty_available(): + """Test with req_authz_tokens but no available services.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens = ["s1", "s2"] + auth_service_names: list[str] = [] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, ["s1", "s2"], set()) + + +def test_identify_auth_params_authn_covered_authz_uncovered(): + """Test authn params covered, but authz tokens are not.""" + req_authn_params = {"param1": ["s_authn1"]} + req_authz_tokens = ["s_authz_needed1", "s_authz_needed2"] + auth_service_names = ["s_authn1", "s_other"] + expected_params = {} + expected_authz: list[str] = ["s_authz_needed1", "s_authz_needed2"] + expected_used = {"s_authn1"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_authn_uncovered_authz_covered(): + """Test authn params not covered, but authz tokens are covered.""" + req_authn_params = {"param1": ["s_authn_needed"]} + req_authz_tokens = ["s_authz1"] + auth_service_names = ["s_authz1", "s_other"] + expected_params = {"param1": ["s_authn_needed"]} + expected_authz: list[str] = [] + expected_used = {"s_authz1"} + + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_authn_and_authz_covered_no_overlap(): + """Test both authn and authz are covered by different services.""" + req_authn_params = {"param1": ["s_authn1"]} + req_authz_tokens = ["s_authz1"] + auth_service_names = ["s_authn1", "s_authz1"] + expected_params = {} + expected_authz: list[str] = [] + expected_used = {"s_authn1", "s_authz1"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_authn_and_authz_covered_with_overlap(): + """Test both authn and authz are covered, with some services overlapping.""" + req_authn_params = {"param1": ["s_common"], "param2": ["s_authn_specific_avail"]} + req_authz_tokens = ["s_common", "s_authz_specific_avail"] + auth_service_names = [ + "s_common", + "s_authz_specific_avail", + "s_authn_specific_avail", + ] + expected_params = {} + expected_authz: list[str] = [] + expected_used = {"s_common", "s_authz_specific_avail", "s_authn_specific_avail"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_authn_and_authz_covered_with_overlap_same_param(): + """Test both authn and authz are covered, with some services overlapping within same param.""" + req_authn_params = {"param1": ["s_common", "s_authn_specific_avail"]} + req_authz_tokens = ["s_common", "s_authz_specific_avail"] + auth_service_names = [ + "s_common", + "s_authz_specific_avail", + "s_authn_specific_avail", + ] + expected_params = {} + expected_authz: list[str] = [] + expected_used = {"s_common", "s_authz_specific_avail", "s_authn_specific_avail"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_complex_scenario(): + """Test a more complex scenario with partial coverage for both authn and authz.""" + req_authn_params = {"p1": ["s1", "s2"], "p2": ["s3"]} + req_authz_tokens = ["s4", "s6"] + auth_service_names = ["s1", "s4", "s5"] + expected_params = {"p2": ["s3"]} + expected_authz: list[str] = [] + expected_used = {"s1", "s4"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) From 2795b7aeeb3241ed01798855374a3c010e280f9b Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 12 May 2025 19:21:29 +0530 Subject: [PATCH 5/6] chore: Remove redundant tests --- packages/toolbox-core/tests/test_client.py | 131 --------------------- 1 file changed, 131 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 15176544..57f8d56c 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -599,137 +599,6 @@ async def test_bind_callable_param_success(self, tool_name, client): res = await tool(True) assert "argA" in res - @pytest.mark.asyncio - async def test_bind_param_fail(self, tool_name, client): - """Tests 'bind_param' with a bound parameter that doesn't exist.""" - tool = await client.load_tool(tool_name) - - assert len(tool.__signature__.parameters) == 2 - assert "argA" in tool.__signature__.parameters - - with pytest.raises(ValueError) as e: - tool.bind_param("argC", lambda: 5) - assert "unable to bind parameters: no parameter named argC" in str(e.value) - - @pytest.mark.asyncio - async def test_rebind_param_fail(self, tool_name, client): - """ - Tests that 'bind_param' fails when attempting to re-bind a - parameter that has already been bound. - """ - tool = await client.load_tool(tool_name) - - assert len(tool.__signature__.parameters) == 2 - assert "argA" in tool.__signature__.parameters - - tool_with_bound_param = tool.bind_param("argA", lambda: 10) - - assert len(tool_with_bound_param.__signature__.parameters) == 1 - assert "argA" not in tool_with_bound_param.__signature__.parameters - - with pytest.raises(ValueError) as e: - tool_with_bound_param.bind_param("argA", lambda: 20) - - assert "cannot re-bind parameter: parameter 'argA' is already bound" in str( - e.value - ) - - @pytest.mark.asyncio - async def test_bind_param_static_value_success(self, tool_name, client): - """ - Tests bind_param method with a static value. - """ - - bound_value = "Test value" - - tool = await client.load_tool(tool_name) - bound_tool = tool.bind_param("argB", bound_value) - - assert bound_tool is not tool - assert "argB" not in bound_tool.__signature__.parameters - assert "argA" in bound_tool.__signature__.parameters - - passed_value_a = 42 - res_payload = await bound_tool(argA=passed_value_a) - - assert res_payload == {"argA": passed_value_a, "argB": bound_value} - - @pytest.mark.asyncio - async def test_bind_param_sync_callable_value_success(self, tool_name, client): - """ - Tests bind_param method with a sync callable value. - """ - - bound_value_result = True - bound_sync_callable = Mock(return_value=bound_value_result) - - tool = await client.load_tool(tool_name) - bound_tool = tool.bind_param("argB", bound_sync_callable) - - assert bound_tool is not tool - assert "argB" not in bound_tool.__signature__.parameters - assert "argA" in bound_tool.__signature__.parameters - - passed_value_a = 42 - res_payload = await bound_tool(argA=passed_value_a) - - assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} - bound_sync_callable.assert_called_once() - - @pytest.mark.asyncio - async def test_bind_param_async_callable_value_success(self, tool_name, client): - """ - Tests bind_param method with an async callable value. - """ - - bound_value_result = True - bound_async_callable = AsyncMock(return_value=bound_value_result) - - tool = await client.load_tool(tool_name) - bound_tool = tool.bind_param("argB", bound_async_callable) - - assert bound_tool is not tool - assert "argB" not in bound_tool.__signature__.parameters - assert "argA" in bound_tool.__signature__.parameters - - passed_value_a = 42 - res_payload = await bound_tool(argA=passed_value_a) - - assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} - bound_async_callable.assert_awaited_once() - - @pytest.mark.asyncio - async def test_bind_param_success(self, tool_name, client): - """Tests 'bind_param' with a bound parameter specified.""" - tool = await client.load_tool(tool_name) - - assert len(tool.__signature__.parameters) == 2 - assert "argA" in tool.__signature__.parameters - - tool = tool.bind_param("argA", 5) - - assert len(tool.__signature__.parameters) == 1 - assert "argA" not in tool.__signature__.parameters - - res = await tool(True) - assert "argA" in res - - @pytest.mark.asyncio - async def test_bind_callable_param_success(self, tool_name, client): - """Tests 'bind_param' with a bound parameter specified.""" - tool = await client.load_tool(tool_name) - - assert len(tool.__signature__.parameters) == 2 - assert "argA" in tool.__signature__.parameters - - tool = tool.bind_param("argA", lambda: 5) - - assert len(tool.__signature__.parameters) == 1 - assert "argA" not in tool.__signature__.parameters - - res = await tool(True) - assert "argA" in res - @pytest.mark.asyncio async def test_bind_param_fail(self, tool_name, client): """Tests 'bind_param' with a bound parameter that doesn't exist.""" From 8149616e2ec4601bfad5aec2739b771e9e85d605 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Tue, 6 May 2025 13:18:40 +0530 Subject: [PATCH 6/6] feat: Introduce identifying used authz token getters This PR adds the feature in `identify_required_authn_params` helper to determine which of the provided auth token getters are actively used to satisfy a tool's authorization requirements (as defined by the `authRequired` key in its manifest). This is a foundational step towards future validation in `ToolboxTool.add_auth_token_getters`, which will ensure no configured auth token getters remain unused, thereby preventing potential misconfigurations. --- .../toolbox-core/src/toolbox_core/client.py | 7 ++- .../toolbox-core/src/toolbox_core/tool.py | 5 +- .../toolbox-core/src/toolbox_core/utils.py | 51 +++++++++++++------ packages/toolbox-core/tests/test_utils.py | 36 ++++++++----- 4 files changed, 67 insertions(+), 32 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 799f7032..f2ff014c 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -79,8 +79,11 @@ def __parse_tool( else: # regular parameter params.append(p) - authn_params, used_auth_keys = identify_required_authn_params( - authn_params, auth_token_getters.keys() + authn_params, _, used_auth_keys = identify_required_authn_params( + # TODO: Add schema.authRequired as second arg + authn_params, + [], + auth_token_getters.keys(), ) tool = ToolboxTool( diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 634e479a..0b9c5ce5 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -299,7 +299,10 @@ def add_auth_token_getters( # create a read-only updated for params that are still required new_req_authn_params = MappingProxyType( identify_required_authn_params( - self.__required_authn_params, auth_token_getters.keys() + # TODO: Add authRequired + self.__required_authn_params, + [], + auth_token_getters.keys(), )[0] ) diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index ddbcfffd..b2954d37 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -45,18 +45,23 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) - def identify_required_authn_params( - req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] -) -> tuple[dict[str, list[str]], set[str]]: + req_authn_params: Mapping[str, list[str]], + req_authz_tokens: list[str], + auth_service_names: Iterable[str], +) -> tuple[dict[str, list[str]], list[str], set[str]]: """ - Identifies authentication parameters that are still required; because they - are not covered by the provided `auth_service_names`, and also returns a - set of all authentication services that were found to be matching. + Identifies authentication parameters and authorization tokens that are still + required because they are not covered by the provided `auth_service_names`. + Also returns a set of all authentication/authorization services from + `auth_service_names` that were found to be matching. - Args: - req_authn_params: A mapping of parameter names to lists of required - authentication services. - auth_service_names: An iterable of authentication service names for which - token getters are available. + Args: + req_authn_params: A mapping of parameter names to lists of required + authentication services for those parameters. + req_authz_tokens: A list of strings representing all authorization + tokens that are required to invoke the current tool. + auth_service_names: An iterable of authentication/authorization service + names for which token getters are available. Returns: A tuple containing: @@ -70,20 +75,34 @@ def identify_required_authn_params( that were found to satisfy at least one authentication parameter's requirements or matched one of the `req_authz_tokens`. """ - required_params: dict[str, list[str]] = {} + required_authn_params: dict[str, list[str]] = {} used_services: set[str] = set() + # find which of the required authn params are covered by available services. for param, services in req_authn_params.items(): + # if we don't have a token_getter for any of the services required by the param, # the param is still required - matched_services = [s for s in services if s in auth_service_names] + matched_authn_services = [s for s in services if s in auth_service_names] - if matched_services: - used_services.update(matched_services) + if matched_authn_services: + used_services.update(matched_authn_services) else: - required_params[param] = services + required_authn_params[param] = services + + # find which of the required authz tokens are covered by available services. + matched_authz_services = [s for s in auth_service_names if s in req_authz_tokens] + required_authz_tokens: list[str] = [] + + # If a match is found, authorization is met (no remaining required tokens). + # Otherwise, all `req_authz_tokens` are still required. (Handles empty + # `req_authz_tokens` correctly, resulting in no required tokens). + if matched_authz_services: + used_services.update(matched_authz_services) + else: + required_authz_tokens = req_authz_tokens - return required_params, used_services + return required_authn_params, required_authz_tokens, used_services def params_to_pydantic_model( diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py index 03a1c915..8c41e2e8 100644 --- a/packages/toolbox-core/tests/test_utils.py +++ b/packages/toolbox-core/tests/test_utils.py @@ -83,10 +83,12 @@ def test_create_func_docstring_empty_description(): def test_identify_required_authn_params_none_required(): - """Test when no authentication parameters are required initially.""" - req_authn_params = {} + """Test when no authentication parameters or authorization tokens are required initially.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens: list[str] = [] auth_service_names = ["service_a", "service_b"] - expected = {} + expected_params = {} + expected_authz: list[str] = [] expected_used = set() result = identify_required_authn_params( req_authn_params, req_authz_tokens, auth_service_names @@ -99,11 +101,12 @@ def test_identify_required_authn_params_none_required(): def test_identify_required_authn_params_all_covered(): - """Test when all required parameters are covered by available services.""" + """Test when all required authn parameters are covered, no authz tokens.""" req_authn_params = { "token_a": ["service_a"], "token_b": ["service_b", "service_c"], } + req_authz_tokens: list[str] = [] auth_service_names = ["service_a", "service_b"] expected_params = {} expected_authz: list[str] = [] @@ -119,15 +122,16 @@ def test_identify_required_authn_params_all_covered(): def test_identify_required_authn_params_some_covered(): - """Test when some parameters are covered, and some are not.""" + """Test when some authn parameters are covered, and some are not, no authz tokens.""" req_authn_params = { "token_a": ["service_a"], "token_b": ["service_b", "service_c"], "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } + req_authz_tokens: list[str] = [] auth_service_names = ["service_a", "service_b"] - expected = { + expected_params = { "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } @@ -145,16 +149,18 @@ def test_identify_required_authn_params_some_covered(): def test_identify_required_authn_params_none_covered(): - """Test when none of the required parameters are covered.""" + """Test when none of the required authn parameters are covered, no authz tokens.""" req_authn_params = { "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } + req_authz_tokens: list[str] = [] auth_service_names = ["service_a", "service_b"] - expected = { + expected_params = { "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } + expected_authz: list[str] = [] expected_used = set() result = identify_required_authn_params( req_authn_params, req_authz_tokens, auth_service_names @@ -167,16 +173,18 @@ def test_identify_required_authn_params_none_covered(): def test_identify_required_authn_params_no_available_services(): - """Test when no authentication services are available.""" + """Test when no authn services are available, no authz tokens.""" req_authn_params = { "token_a": ["service_a"], "token_b": ["service_b", "service_c"], } - auth_service_names = [] - expected = { + req_authz_tokens: list[str] = [] + auth_service_names: list[str] = [] + expected_params = { "token_a": ["service_a"], "token_b": ["service_b", "service_c"], } + expected_authz: list[str] = [] expected_used = set() result = identify_required_authn_params( req_authn_params, req_authz_tokens, auth_service_names @@ -189,14 +197,16 @@ def test_identify_required_authn_params_no_available_services(): def test_identify_required_authn_params_empty_services_for_param(): - """Test edge case where a param requires an empty list of services.""" + """Test edge case: authn param requires an empty list of services, no authz tokens.""" req_authn_params = { "token_x": [], } + req_authz_tokens: list[str] = [] auth_service_names = ["service_a"] - expected = { + expected_params = { "token_x": [], } + expected_authz: list[str] = [] expected_used = set() result = identify_required_authn_params( req_authn_params, req_authz_tokens, auth_service_names