diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py index 63471670448..5dfe494148c 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py @@ -214,7 +214,7 @@ def tune( ] = None, retain_trials: bool = False, packages_to_install: List[str] = None, - pip_index_url: str = "https://pypi.org/simple", + pip_index_urls : Optional[List[str]] = ["https://pypi.org/simple"], metrics_collector_config: Dict[str, Any] = {"kind": "StdOut"}, ): """ @@ -351,7 +351,7 @@ class name in this argument. packages_to_install: List of Python packages to install in addition to the base image packages. These packages are installed before executing the objective function. - pip_index_url: The PyPI url from which to install Python packages. + pip_index_urls: List of PyPI urls from which to install Python packages. metrics_collector_config: Specify the config of metrics collector, for example, `metrics_collector_config = {"kind": "Push"}`. Currently, we only support `StdOut` and `Push` metrics collector. @@ -462,7 +462,7 @@ class name in this argument. entrypoint, input_params, packages_to_install, - pip_index_url, + pip_index_urls, ) # Generate container spec for PyTorchJob or Job. diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py index 55a9bc08e6a..08406b4a3a5 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py @@ -541,6 +541,20 @@ def create_experiment( }, TEST_RESULT_SUCCESS, ), + ( + "valid flow with pip_index_urls", + { + "name": "tune_test", + "objective": lambda x: print(f"a={x}"), + "parameters": {"a": katib.search.int(min=10, max=100)}, + "packages_to_install": ["pandas", "numpy"], + "pip_index_urls": [ + "https://pypi.org/simple", + "https://private-repo.com/simple", + ], + }, + TEST_RESULT_SUCCESS, + ), ] @@ -703,6 +717,18 @@ def test_tune(katib_client, test_name, kwargs, expected_output): additional_metric_names=[], ) + elif test_name == "valid flow with pip_index_urls": + # Verify pip install command in container args. + args_content = "".join( + experiment.spec.trial_template.trial_spec.spec.template.spec.containers[ + 0 + ].args + ) + assert ( + "--index-url https://pypi.org/simple --extra-index-url https://private-repo.com/simple pandas numpy" + in args_content + ) + except Exception as e: assert type(e) is expected_output print("test execution complete") diff --git a/sdk/python/v1beta1/kubeflow/katib/utils/utils.py b/sdk/python/v1beta1/kubeflow/katib/utils/utils.py index 1c0784fa292..067fc608d0c 100644 --- a/sdk/python/v1beta1/kubeflow/katib/utils/utils.py +++ b/sdk/python/v1beta1/kubeflow/katib/utils/utils.py @@ -114,8 +114,14 @@ def validate_objective_function(objective: Callable): f"Current Objective arguments: {objective_signature}" ) +def format_pip_index_urls(pip_index_urls: List[str] = ["https://pypi.org/simple"]) -> str: + index_url = f'--index-url {pip_index_urls[0]}' + for url in pip_index_urls[1:]: + index_url += f' --extra-index-url {url}' + return index_url -def get_script_for_python_packages(packages_to_install, pip_index_url): + +def get_script_for_python_packages(packages_to_install, pip_index_urls=["https://pypi.org/simple"]): packages_str = " ".join([str(package) for package in packages_to_install]) script_for_python_packages = textwrap.dedent( @@ -125,7 +131,7 @@ def get_script_for_python_packages(packages_to_install, pip_index_url): fi PIP_DISABLE_PIP_VERSION_CHECK=1 python3 -m pip install --prefer-binary --quiet \ - --no-warn-script-location --index-url {pip_index_url} {packages_str} + --no-warn-script-location {format_pip_index_urls(pip_index_urls)} {packages_str} """ ) @@ -228,7 +234,7 @@ def get_exec_script_from_objective( entrypoint: str, input_params: Dict[str, Any], packages_to_install: Optional[List[str]] = None, - pip_index_url: str = "https://pypi.org/simple", + pip_index_urls: Optional[List[str]] = ["https://pypi.org/simple"], ) -> str: """ Get executable script for container args from the given objective function and parameters. @@ -272,7 +278,7 @@ def get_exec_script_from_objective( # Install Python packages if that is required. if packages_to_install is not None: exec_script = ( - get_script_for_python_packages(packages_to_install, pip_index_url) + get_script_for_python_packages(packages_to_install, pip_index_urls) + exec_script ) @@ -350,4 +356,4 @@ def get_trial_template_with_pytorchjob( trial_parameters=trial_parameters, trial_spec=pytorchjob, ) - return trial_template + return trial_template \ No newline at end of file diff --git a/sdk/python/v1beta1/kubeflow/katib/utils/utils_test.py b/sdk/python/v1beta1/kubeflow/katib/utils/utils_test.py new file mode 100644 index 00000000000..bfee6ee8f89 --- /dev/null +++ b/sdk/python/v1beta1/kubeflow/katib/utils/utils_test.py @@ -0,0 +1,18 @@ +import pytest +from kubeflow.katib.utils import utils + +@pytest.mark.parametrize( + "pip_index_urls, expected", + [ + (["https://pypi.org/simple"], + "--index-url https://pypi.org/simple"), + (["https://pypi.org/simple", "https://private-repo.com/simple"], + "--index-url https://pypi.org/simple --extra-index-url https://private-repo.com/simple"), + (["https://pypi.org/simple", "https://private-repo.com/simple", "https://another-repo.com/simple"], + "--index-url https://pypi.org/simple --extra-index-url https://private-repo.com/simple --extra-index-url https://another-repo.com/simple"), + (None, + "--index-url https://pypi.org/simple"), + ] +) +def test_format_pip_index_urls(pip_index_urls, expected): + assert utils.format_pip_index_urls(pip_index_urls) == expected \ No newline at end of file