diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 799f7032..f2ff014c 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -79,8 +79,11 @@ def __parse_tool( else: # regular parameter params.append(p) - authn_params, used_auth_keys = identify_required_authn_params( - authn_params, auth_token_getters.keys() + authn_params, _, used_auth_keys = identify_required_authn_params( + # TODO: Add schema.authRequired as second arg + authn_params, + [], + auth_token_getters.keys(), ) tool = ToolboxTool( diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 634e479a..0b9c5ce5 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -299,7 +299,10 @@ def add_auth_token_getters( # create a read-only updated for params that are still required new_req_authn_params = MappingProxyType( identify_required_authn_params( - self.__required_authn_params, auth_token_getters.keys() + # TODO: Add authRequired + self.__required_authn_params, + [], + auth_token_getters.keys(), )[0] ) diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index a30dc66b..b2954d37 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -45,41 +45,64 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) - def identify_required_authn_params( - req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] -) -> tuple[dict[str, list[str]], set[str]]: + req_authn_params: Mapping[str, list[str]], + req_authz_tokens: list[str], + auth_service_names: Iterable[str], +) -> tuple[dict[str, list[str]], list[str], set[str]]: """ - Identifies authentication parameters that are still required; because they - are not covered by the provided `auth_service_names`, and also returns a - set of all authentication services that were found to be matching. + Identifies authentication parameters and authorization tokens that are still + required because they are not covered by the provided `auth_service_names`. + Also returns a set of all authentication/authorization services from + `auth_service_names` that were found to be matching. - Args: - req_authn_params: A mapping of parameter names to lists of required - authentication services. - auth_service_names: An iterable of authentication service names for which - token getters are available. + Args: + req_authn_params: A mapping of parameter names to lists of required + authentication services for those parameters. + req_authz_tokens: A list of strings representing all authorization + tokens that are required to invoke the current tool. + auth_service_names: An iterable of authentication/authorization service + names for which token getters are available. Returns: A tuple containing: - - A new dictionary representing the subset of required - authentication parameters that are not covered by the provided - `auth_service_names`. - - A list of authentication service names from `auth_service_names` - that were found to satisfy at least one parameter's requirements. + - required_authn_params: A new dictionary representing the subset of + required authentication parameters that are not covered by the + provided `auth_service_names`. + - required_authz_tokens: A list of required authorization tokens if + no service name in `auth_service_names` matches any token in + `req_authz_tokens`. If any match is found, this list is empty. + - used_services: A set of service names from `auth_service_names` + that were found to satisfy at least one authentication parameter's + requirements or matched one of the `req_authz_tokens`. """ - required_params: dict[str, list[str]] = {} + required_authn_params: dict[str, list[str]] = {} used_services: set[str] = set() + # find which of the required authn params are covered by available services. for param, services in req_authn_params.items(): + # if we don't have a token_getter for any of the services required by the param, # the param is still required - matched_services = [s for s in services if s in auth_service_names] + matched_authn_services = [s for s in services if s in auth_service_names] - if matched_services: - used_services.update(matched_services) + if matched_authn_services: + used_services.update(matched_authn_services) else: - required_params[param] = services + required_authn_params[param] = services + + # find which of the required authz tokens are covered by available services. + matched_authz_services = [s for s in auth_service_names if s in req_authz_tokens] + required_authz_tokens: list[str] = [] + + # If a match is found, authorization is met (no remaining required tokens). + # Otherwise, all `req_authz_tokens` are still required. (Handles empty + # `req_authz_tokens` correctly, resulting in no required tokens). + if matched_authz_services: + used_services.update(matched_authz_services) + else: + required_authz_tokens = req_authz_tokens - return required_params, used_services + return required_authn_params, required_authz_tokens, used_services def params_to_pydantic_model( diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 54b925ee..57f8d56c 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -607,7 +607,7 @@ async def test_bind_param_fail(self, tool_name, client): assert len(tool.__signature__.parameters) == 2 assert "argA" in tool.__signature__.parameters - with pytest.raises(ValueError) as e: + with pytest.raises(Exception) as e: tool.bind_param("argC", lambda: 5) assert "unable to bind parameters: no parameter named argC" in str(e.value) diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py index 52cdb38f..8c41e2e8 100644 --- a/packages/toolbox-core/tests/test_utils.py +++ b/packages/toolbox-core/tests/test_utils.py @@ -83,100 +83,289 @@ def test_create_func_docstring_empty_description(): def test_identify_required_authn_params_none_required(): - """Test when no authentication parameters are required initially.""" - req_authn_params = {} + """Test when no authentication parameters or authorization tokens are required initially.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens: list[str] = [] auth_service_names = ["service_a", "service_b"] - expected = {} + expected_params = {} + expected_authz: list[str] = [] expected_used = set() - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) def test_identify_required_authn_params_all_covered(): - """Test when all required parameters are covered by available services.""" + """Test when all required authn parameters are covered, no authz tokens.""" req_authn_params = { "token_a": ["service_a"], "token_b": ["service_b", "service_c"], } + req_authz_tokens: list[str] = [] auth_service_names = ["service_a", "service_b"] - expected = {} - expected_used = set(auth_service_names) - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + expected_params = {} + expected_authz: list[str] = [] + expected_used = {"service_a", "service_b"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) def test_identify_required_authn_params_some_covered(): - """Test when some parameters are covered, and some are not.""" + """Test when some authn parameters are covered, and some are not, no authz tokens.""" req_authn_params = { "token_a": ["service_a"], "token_b": ["service_b", "service_c"], "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } + req_authz_tokens: list[str] = [] auth_service_names = ["service_a", "service_b"] - expected = { + expected_params = { "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } - expected_used = set(auth_service_names) - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + expected_authz: list[str] = [] + expected_used = {"service_a", "service_b"} + + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) def test_identify_required_authn_params_none_covered(): - """Test when none of the required parameters are covered.""" + """Test when none of the required authn parameters are covered, no authz tokens.""" req_authn_params = { "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } + req_authz_tokens: list[str] = [] auth_service_names = ["service_a", "service_b"] - expected = { + expected_params = { "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } + expected_authz: list[str] = [] expected_used = set() - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) def test_identify_required_authn_params_no_available_services(): - """Test when no authentication services are available.""" + """Test when no authn services are available, no authz tokens.""" req_authn_params = { "token_a": ["service_a"], "token_b": ["service_b", "service_c"], } - auth_service_names = [] - expected = { + req_authz_tokens: list[str] = [] + auth_service_names: list[str] = [] + expected_params = { "token_a": ["service_a"], "token_b": ["service_b", "service_c"], } + expected_authz: list[str] = [] expected_used = set() - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, ) def test_identify_required_authn_params_empty_services_for_param(): - """Test edge case where a param requires an empty list of services.""" + """Test edge case: authn param requires an empty list of services, no authz tokens.""" req_authn_params = { "token_x": [], } + req_authz_tokens: list[str] = [] auth_service_names = ["service_a"] - expected = { + expected_params = { "token_x": [], } + expected_authz: list[str] = [] expected_used = set() - assert identify_required_authn_params(req_authn_params, auth_service_names) == ( - expected, + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, + expected_used, + ) + + +def test_identify_auth_params_only_authz_empty(): + """Test with empty req_authz_tokens and no authn params.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens: list[str] = [] + auth_service_names = ["s1"] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, [], set()) + + +def test_identify_auth_params_authz_all_covered(): + """Test when all req_authz_tokens are covered by auth_service_names.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens = ["s1", "s2"] + auth_service_names = ["s1", "s2", "s3"] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, [], {"s1", "s2"}) + + +def test_identify_auth_params_authz_partially_covered_by_available(): + """Test when some req_authz_tokens are covered.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens = ["s1", "s2"] + auth_service_names = ["s1", "s3"] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, [], {"s1"}) + + +def test_identify_auth_params_authz_none_covered(): + """Test when none of req_authz_tokens are covered by auth_service_names.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens = ["s1", "s2"] + auth_service_names = ["s3", "s4"] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, ["s1", "s2"], set()) + + +def test_identify_auth_params_authz_none_covered_empty_available(): + """Test with req_authz_tokens but no available services.""" + req_authn_params: dict[str, list[str]] = {} + req_authz_tokens = ["s1", "s2"] + auth_service_names: list[str] = [] + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ({}, ["s1", "s2"], set()) + + +def test_identify_auth_params_authn_covered_authz_uncovered(): + """Test authn params covered, but authz tokens are not.""" + req_authn_params = {"param1": ["s_authn1"]} + req_authz_tokens = ["s_authz_needed1", "s_authz_needed2"] + auth_service_names = ["s_authn1", "s_other"] + expected_params = {} + expected_authz: list[str] = ["s_authz_needed1", "s_authz_needed2"] + expected_used = {"s_authn1"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_authn_uncovered_authz_covered(): + """Test authn params not covered, but authz tokens are covered.""" + req_authn_params = {"param1": ["s_authn_needed"]} + req_authz_tokens = ["s_authz1"] + auth_service_names = ["s_authz1", "s_other"] + expected_params = {"param1": ["s_authn_needed"]} + expected_authz: list[str] = [] + expected_used = {"s_authz1"} + + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_authn_and_authz_covered_no_overlap(): + """Test both authn and authz are covered by different services.""" + req_authn_params = {"param1": ["s_authn1"]} + req_authz_tokens = ["s_authz1"] + auth_service_names = ["s_authn1", "s_authz1"] + expected_params = {} + expected_authz: list[str] = [] + expected_used = {"s_authn1", "s_authz1"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_authn_and_authz_covered_with_overlap(): + """Test both authn and authz are covered, with some services overlapping.""" + req_authn_params = {"param1": ["s_common"], "param2": ["s_authn_specific_avail"]} + req_authz_tokens = ["s_common", "s_authz_specific_avail"] + auth_service_names = [ + "s_common", + "s_authz_specific_avail", + "s_authn_specific_avail", + ] + expected_params = {} + expected_authz: list[str] = [] + expected_used = {"s_common", "s_authz_specific_avail", "s_authn_specific_avail"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_authn_and_authz_covered_with_overlap_same_param(): + """Test both authn and authz are covered, with some services overlapping within same param.""" + req_authn_params = {"param1": ["s_common", "s_authn_specific_avail"]} + req_authz_tokens = ["s_common", "s_authz_specific_avail"] + auth_service_names = [ + "s_common", + "s_authz_specific_avail", + "s_authn_specific_avail", + ] + expected_params = {} + expected_authz: list[str] = [] + expected_used = {"s_common", "s_authz_specific_avail", "s_authn_specific_avail"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == (expected_params, expected_authz, expected_used) + + +def test_identify_auth_params_complex_scenario(): + """Test a more complex scenario with partial coverage for both authn and authz.""" + req_authn_params = {"p1": ["s1", "s2"], "p2": ["s3"]} + req_authz_tokens = ["s4", "s6"] + auth_service_names = ["s1", "s4", "s5"] + expected_params = {"p2": ["s3"]} + expected_authz: list[str] = [] + expected_used = {"s1", "s4"} + result = identify_required_authn_params( + req_authn_params, req_authz_tokens, auth_service_names + ) + assert result == ( + expected_params, + expected_authz, expected_used, )