Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions kubeflow/trainer/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def test_get_runtime_packages(trainer_client, test_case):
func=lambda: print("Hello World"),
func_args={"learning_rate": 0.001, "batch_size": 32},
packages_to_install=["torch", "numpy"],
pip_index_url=constants.DEFAULT_PIP_INDEX_URL,
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
num_nodes=2,
)
},
Expand All @@ -741,7 +741,7 @@ def test_get_runtime_packages(trainer_client, test_case):
func=lambda: print("Hello World"),
func_args={"learning_rate": 0.001, "batch_size": 32},
packages_to_install=["torch", "numpy"],
pip_index_url=constants.DEFAULT_PIP_INDEX_URL,
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
num_nodes=2,
env={
"TEST_ENV": "test_value",
Expand Down
10 changes: 9 additions & 1 deletion kubeflow/trainer/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,15 @@
)

# The default PIP index URL to download Python packages.
DEFAULT_PIP_INDEX_URL = os.getenv("DEFAULT_PIP_INDEX_URL", "https://pypi.org/simple")
DEFAULT_PYPI_URL = "https://pypi.org/simple"
# Handle environment variable for multiple URLs (comma-separated)
DEFAULT_PIP_INDEX_URLS = (
os.getenv("DEFAULT_PIP_INDEX_URLS", DEFAULT_PYPI_URL).split(",")
if os.getenv("DEFAULT_PIP_INDEX_URLS")
else [DEFAULT_PYPI_URL]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could simplify this to not call getenv twice

# Keep backward compatibility
DEFAULT_PIP_INDEX_URL = DEFAULT_PYPI_URL

# The exec script to embed training function into container command.
# __ENTRYPOINT__ depends on the MLPolicy, func_code and func_file is substituted in the `train` API.
Expand Down
4 changes: 2 additions & 2 deletions kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CustomTrainer:
func_args (`Optional[Dict]`): The arguments to pass to the function.
packages_to_install (`Optional[List[str]]`):
A list of Python packages to install before running the function.
pip_index_url (`Optional[str]`): The PyPI URL from which to install Python packages.
pip_index_urls (`Optional[list[str]]`): The PyPI URLs from which to install Python packages.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the comment that the first URL will be used as the default index, and the rest will be used as additional indexes.

num_nodes (`Optional[int]`): The number of nodes to use for training.
resources_per_node (`Optional[Dict]`): The computing resources to allocate per node.
env (`Optional[Dict[str, str]]`): The environment variables to set in the training nodes.
Expand All @@ -41,7 +41,7 @@ class CustomTrainer:
func: Callable
func_args: Optional[Dict] = None
packages_to_install: Optional[list[str]] = None
pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL
pip_index_urls: list[str] = field(default_factory=lambda: constants.DEFAULT_PIP_INDEX_URLS)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to use field here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use simply the constants.DEFAULT_PIP_INDEX_URLS, in any instance that we create if someone modifies constants.DEFAULT_PIP_INDEX_ULRLS it will affect all created instance. Sothe field with default factory will assure that we have a fresh copy of it each time. This is optional, I can remove it as well !

Copy link
Member

@andreyvelich andreyvelich Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we probably need to do this to return new copy all the time object is created ?

pip_index_urls: list[str] = field(default_factory=lambda: list(constants.DEFAULT_PIP_INDEX_URLS))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly

num_nodes: Optional[int] = None
resources_per_node: Optional[Dict] = None
env: Optional[Dict[str, str]] = None
Expand Down
31 changes: 20 additions & 11 deletions kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,28 +255,37 @@ def get_resources_per_node(

def get_script_for_python_packages(
packages_to_install: list[str],
pip_index_url: str,
is_mpi: bool,
pip_index_urls: list[str] = constants.DEFAULT_PIP_INDEX_URLS,
is_mpi: bool = False,
include_pypi: bool = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to introduce this flag ?
I would let user explicitly sets the PyPI index in the list if they need it:

pip_index_urls=["pypi.custom.com/simple", "https://pypi.org/simple"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I Just was thinking if the user forgot to set the default pypi. For instance he chose pip_index_urls = ["repoA", "repoB"], but the package wasnt there, so we have a fallback solution which is the default pypi. This is also optional and I can remove it :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have it in the future, if we get any complains from users.
I would just define single constant that can be overridden by env:

DEFAULT_PIP_INDEX_URLS = os.getenv(
    "DEFAULT_PIP_INDEX_URLS", "https://pypi.org/simple"
).split(",")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @andreyvelich

) -> str:
"""
Get init script to install Python packages from the given pip index URL.
Get init script to install Python packages from the given pip index URLs.
"""
# packages_str = " ".join([str(package) for package in packages_to_install])
packages_str = " ".join(packages_to_install)

if include_pypi and constants.DEFAULT_PYPI_URL not in pip_index_urls:
pip_index_urls.append(constants.DEFAULT_PYPI_URL)

options = [f"--index-url {pip_index_urls[0]}"]
options.extend(f"--extra-index-url {extra_index_url}" for extra_index_url in pip_index_urls[1:])
# For the OpenMPI, the packages must be installed for the mpiuser.
if is_mpi:
options.append("--user")

options_str = " ".join(options)

script_for_python_packages = textwrap.dedent(
"""
if ! [ -x "$(command -v pip)" ]; then
python -m ensurepip || python -m ensurepip --user || apt-get install python-pip
fi

PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \
--no-warn-script-location --index-url {} {} {}
--no-warn-script-location {} {}
""".format(
pip_index_url,
options_str,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can directly set it here.

Suggested change
options_str,
" ".join(options),

packages_str,
# For the OpenMPI, the packages must be installed for the mpiuser.
"--user" if is_mpi else "",
)
)

Expand All @@ -287,7 +296,7 @@ def get_command_using_train_func(
runtime: types.Runtime,
train_func: Callable,
train_func_parameters: Optional[Dict[str, Any]],
pip_index_url: str,
pip_index_urls: list[str] = constants.DEFAULT_PIP_INDEX_URLS,
packages_to_install: Optional[list[str]] = None,
) -> list[str]:
"""
Expand Down Expand Up @@ -333,7 +342,7 @@ def get_command_using_train_func(
if packages_to_install:
install_packages = get_script_for_python_packages(
packages_to_install,
pip_index_url,
pip_index_urls,
is_mpi,
)

Expand Down Expand Up @@ -374,7 +383,7 @@ def get_trainer_crd_from_custom_trainer(
runtime,
trainer.func,
trainer.func_args,
trainer.pip_index_url,
trainer.pip_index_urls,
trainer.packages_to_install,
)

Expand Down
Loading