Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions python/kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,16 @@ class Initializer:
model: Optional[HuggingFaceModelInitializer] = None


# The dict where key is the container image and value its representation.
# Each Trainer representation defines trainer parameters (e.g. type, framework, entrypoint).
# TODO (andreyvelich): We should allow user to overrides the default image names.
ALL_TRAINERS: Dict[str, Trainer] = {
# Custom Trainers.
"pytorch/pytorch": Trainer(
# Centralized trainer configurations to eliminate duplication
TRAINER_CONFIGS: Dict[Framework, Trainer] = {
Framework.TORCH: Trainer(
trainer_type=TrainerType.CUSTOM_TRAINER,
framework=Framework.TORCH,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to keep framework argument given that TRAINER_CONFIGS Dict has the Framework type in the Dict key.

Suggested change
framework=Framework.TORCH,

Copy link
Author

@jskswamy jskswamy Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the framework field in the Trainer class, I'd like to share my thoughts on why this field exists and why it serves a legitimate purpose:

The framework Field Has Critical Importance

After investigating the codebase, I discovered that the Trainer class and framework field were pre-existing before this PR. The field was intentionally designed to serve specific purposes:

Critical Importance for API Design

The framework field is essential for maintaining a clean, self-contained API:

  1. Object Identity: A Trainer object must "know" what framework it represents without external context
  2. API Completeness: When users receive a Trainer object, they can immediately determine its framework without reverse-engineering from other fields
  3. Serialization: The field is crucial for JSON serialization/deserialization of trainer objects
  4. Debugging & Logging: Essential for meaningful error messages and debugging information

Self-Contained Data Structure

The framework field makes Trainer objects self-contained and self-documenting:

# Example: A Trainer object "knows" what framework it represents
trainer = TRAINER_CONFIGS[Framework.DEEPSPEED]

# Self-documenting: The object tells us what it is
print(f"Using {trainer.framework} trainer with {trainer.trainer_type}")
# Output: "Using Framework.DEEPSPEED trainer with TrainerType.CUSTOM_TRAINER"

# Without the field, we'd need external context to know what framework this is
# We'd have to track which dictionary key was used to create this trainer

Breaking Changes Would Be Required

Removing the field would require:

  • Modifying any code that relies on the field for framework identification
  • Potentially breaking API consumers who expect this field
  • Adding complex lookup logic to determine framework from other properties

Architectural Integrity

The field maintains the principle of encapsulationTrainer object should contain all information about itself, including what framework it represents.

Why Dictionary Instead of Array?

The choice of using TRAINER_CONFIGS: Dict[Framework, Trainer] instead of an array of trainers was a performance and design optimization:

Performance Benefits

# Current efficient approach with dictionary
trainer = TRAINER_CONFIGS[Framework.DEEPSPEED]  # O(1) lookup
framework = trainer.framework  # Direct access

# Alternative inefficient approach with array
def find_trainer_by_framework(framework):
    for trainer in TRAINER_ARRAY:  # O(n) search
        if trainer.framework == framework:
            return trainer

Design Benefits

  1. Fast Lookup: O(1) constant time access instead of O(n) linear search
  2. Type Safety: Dictionary keys ensure we only access valid frameworks
  3. Explicit Mapping: Clear relationship between framework and trainer configuration
  4. Extensibility: Easy to add new frameworks without changing lookup logic

My Take

The framework field serves critical architectural purposes for API design and object encapsulation. The dictionary structure provides performance benefits, but the field itself is essential for maintaining clean, self-contained objects.

Removing the field would break the original design intent, make the API less clean and efficient, and potentially introduce breaking changes. The field was intentionally designed this way for good reasons, and I believe we should keep it to maintain the integrity of the API design.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should have dict to represent all Trainers where key is the Framework name and value is the Trainer object.
The question is should we also keep framework argument in the Trainer object. This is mostly used to just show users what framework this Trainer is using.

I am fine to keep it for now.

WDYT @szaher @astefanutti @Electronic-Waste ?

entrypoint=[constants.TORCH_ENTRYPOINT],
),
"ghcr.io/kubeflow/trainer/mlx-runtime": Trainer(
Framework.DEEPSPEED: Trainer(
trainer_type=TrainerType.CUSTOM_TRAINER,
framework=Framework.MLX,
framework=Framework.DEEPSPEED,
entrypoint=[
constants.MPI_ENTRYPOINT,
"--hostfile",
Expand All @@ -251,9 +248,9 @@ class Initializer:
"-c",
],
),
"ghcr.io/kubeflow/trainer/deepspeed-runtime": Trainer(
Framework.MLX: Trainer(
trainer_type=TrainerType.CUSTOM_TRAINER,
framework=Framework.DEEPSPEED,
framework=Framework.MLX,
entrypoint=[
constants.MPI_ENTRYPOINT,
"--hostfile",
Expand All @@ -262,20 +259,15 @@ class Initializer:
"-c",
],
),
# Builtin Trainers.
"ghcr.io/kubeflow/trainer/torchtune-trainer": Trainer(
Framework.TORCHTUNE: Trainer(
trainer_type=TrainerType.BUILTIN_TRAINER,
framework=Framework.TORCHTUNE,
entrypoint=constants.DEFAULT_TORCHTUNE_COMMAND,
),
}

# The default trainer configuration when runtime detection fails
DEFAULT_TRAINER = Trainer(
trainer_type=TrainerType.CUSTOM_TRAINER,
framework=Framework.TORCH,
entrypoint=[constants.TORCH_ENTRYPOINT],
)
DEFAULT_TRAINER = TRAINER_CONFIGS[Framework.TORCH]

# The default runtime configuration for the train() API
DEFAULT_RUNTIME = Runtime(
Expand Down
18 changes: 18 additions & 0 deletions python/kubeflow/trainer/types/types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from kubeflow.trainer.types import types


class TestTrainerConfigurations:
"""Test cases for trainer configurations and types."""

def test_centralized_trainer_configs(self):
"""Test that centralized trainer configurations are properly defined."""
# Verify all trainer frameworks have configurations
for framework in types.Framework:
assert framework in types.TRAINER_CONFIGS
trainer = types.TRAINER_CONFIGS[framework]
assert trainer.framework == framework

def test_default_trainer_uses_centralized_config(self):
"""Test that DEFAULT_TRAINER uses centralized configuration."""
assert types.DEFAULT_TRAINER == types.TRAINER_CONFIGS[types.Framework.TORCH]
assert types.DEFAULT_TRAINER.framework == types.Framework.TORCH
75 changes: 69 additions & 6 deletions python/kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import os
import queue
Expand Down Expand Up @@ -107,6 +108,65 @@ def get_runtime_trainer_container(
return None


def detect_trainer_from_image_patterns(image_name: str) -> Optional[types.Trainer]:
"""
Detect trainer type based on image name patterns using regex.

This method uses pattern matching on the image name to determine
the likely trainer type.

Args:
image_name: The container image name

Returns:
Trainer object if detected, None otherwise
"""
# DeepSpeed patterns
if re.search(r"deepspeed", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.DEEPSPEED])

# MLX patterns
if re.search(r"mlx", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.MLX])

# TorchTune patterns (check before PyTorch to avoid conflicts)
if re.search(r"torchtune", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCHTUNE])

# PyTorch patterns - require explicit "pytorch" in image name for clarity
if re.search(r"pytorch", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCH])

return None


def detect_trainer(
trainer_container: models.IoK8sApiCoreV1Container,
) -> types.Trainer:
"""
Detect trainer type using pattern matching with fallback.

This method implements the detection logic:
1. Use image pattern matching to detect framework
2. Fall back to DEFAULT_TRAINER if no patterns match

Args:
trainer_container: The trainer container object

Returns:
Trainer object
"""
image_name = trainer_container.image.split(":")[0]

# 1. Use image pattern matching
trainer = detect_trainer_from_image_patterns(image_name)
if trainer:
return trainer

# 2. Fall back to DEFAULT_TRAINER
return copy.deepcopy(types.DEFAULT_TRAINER)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, this could be simplified if detect_trainer_from_image_patterns just return copy.deepcopy(types.DEFAULT_TRAINER) instead of None.

Can you just keep all of the required code to extract trainer in the get_trainer_from_image() function, which accepts image_name as input?

That will make our unit tests easier to maintain.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made necessary changes to simplify the detect_trainer_from_image_patterns function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy Sorry for the late reply, I meant can you just use this code snippet ?

def get_runtime_trainer(....):
....
image_name = trainer_container.image.split(":")[0]
trainer = get_trainer_from_image(image_name)


def get_trainer_from_image(image_name: str) -> types.Trainer:
    """
    Detect trainer type based on image name patterns using regex.
    This method uses pattern matching on the image name to determine
    the likely trainer type.
    Args:
        image_name: The container image name.
    Returns:
        Trainer: Trainer object if detected, otherwise the DEFAULT_TRAINER is returned.
    """
    # DeepSpeed patterns
    if re.search(r"deepspeed", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.DEEPSPEED])

    # MLX patterns
    if re.search(r"mlx", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.MLX])

    # TorchTune patterns (check before PyTorch to avoid conflicts)
    if re.search(r"torchtune", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCHTUNE])

    # PyTorch patterns - require explicit "pytorch" in image name for clarity
    if re.search(r"pytorch", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCH])

    return copy.deepcopy(types.DEFAULT_TRAINER)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've simplified the function as per your suggestion, kindly check



def get_runtime_trainer(
replicated_jobs: List[models.JobsetV1alpha2ReplicatedJob],
ml_policy: models.TrainerV1alpha1MLPolicy,
Expand All @@ -121,20 +181,23 @@ def get_runtime_trainer(
if not (trainer_container and trainer_container.image):
raise Exception(f"Runtime doesn't have trainer container {replicated_jobs}")

# Extract image name from the container image to get appropriate Trainer.
image_name = trainer_container.image.split(":")[0]
trainer = types.ALL_TRAINERS.get(image_name, types.DEFAULT_TRAINER)
# Use the new detection logic with fallback
trainer = detect_trainer(trainer_container)

# Get the container devices.
if devices := get_container_devices(trainer_container.resources):
_, trainer.accelerator_count = devices

# Torch and MPI plugins override accelerator count.
if ml_policy.torch and ml_policy.torch.num_proc_per_node:
# NOTE: The 'is not None' checks are essential because:
# 1. For torch: prevents AttributeError when accessing None.actual_instance
# 2. For MPI: prevents setting accelerator_count to None
# 3. Semantically: only override when user explicitly provides num_proc_per_node
if ml_policy.torch and ml_policy.torch.num_proc_per_node is not None:
num_proc = ml_policy.torch.num_proc_per_node.actual_instance
if isinstance(num_proc, int):
trainer.accelerator_count = num_proc
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add is not None here ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. Torch Policy Check (if trainer_container.accelerator_count is not None)

# Essential: Prevents AttributeError when accessing None.actual_instance
# Without this check: None.actual_instance would raise AttributeError
if trainer_container.accelerator_count is not None:
    if hasattr(trainer_container.accelerator_count, 'actual_instance'):
        trainer.accelerator_count = trainer_container.accelerator_count.actual_instance

2. MPI Policy Check (if trainer_container.mpi_policy is not None)

# Essential: Prevents setting accelerator_count to None when user explicitly sets it
# Without this check: trainer.accelerator_count would be overwritten to None
if trainer_container.mpi_policy is not None:
    trainer.accelerator_count = trainer_container.mpi_policy.num_procs

3. Semantic Correctness

These checks ensure that:

  • User-provided values are preserved and not overwritten
  • We don't attempt operations on None objects
  • The logic follows "only apply changes if the field is actually set"

Code Comments Added:

I've added explanatory comments to each check to make their necessity clear for future maintainers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy I just meant that those 2 lines are the same in Python, isn't ?

elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a subtle but important distinction in Python.

The is not None Check is Necessary

The current code is correct because 0 is a valid and meaningful value for num_proc_per_node:

# Current correct implementation
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:
    trainer.accelerator_count = ml_policy.mpi.num_proc_per_node

Why Truthiness Checking Would Break CPU-Only Training

If we used truthiness checking instead:

# This would be problematic
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
    trainer.accelerator_count = ml_policy.mpi.num_proc_per_node

Example Scenarios:

Scenario 1: CPU-Only Training (0 accelerators)

ml_policy.mpi.num_proc_per_node = 0  # Explicitly set to CPU-only

# With truthiness check:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node:  # 0 is falsy!
    trainer.accelerator_count = ml_policy.mpi.num_proc_per_node  # ❌ Never executes

# With is not None check:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:  # 0 is not None!
    trainer.accelerator_count = ml_policy.mpi.num_proc_per_node  # ✅ Executes correctly

Scenario 2: GPU Training (4 accelerators)

ml_policy.mpi.num_proc_per_node = 4  # Explicitly set to 4 GPUs

# Both approaches work correctly:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node:  # 4 is truthy ✅
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:  # 4 is not None ✅

Scenario 3: Not Set (defaults to UNKNOWN)

ml_policy.mpi.num_proc_per_node = None  # Not explicitly set

# Both approaches work correctly:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node:  # None is falsy ✅
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:  # None is None ✅

The Key Distinction

The is not None check properly distinguishes between:

  • "Not set" (None) → don't override accelerator count
  • "Explicitly set to 0" (0) → override with 0 (CPU-only training)
  • "Explicitly set to positive number" → override with that number

Copy link
Member

@andreyvelich andreyvelich Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why num_proc_per_node=0 is a valid value ?
We should not allow user to set such value or consider this as None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy Did you get a chance to check this comment ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late reply! I've addressed this, kindly check the changes now

trainer.accelerator_count = ml_policy.mpi.num_proc_per_node

# Multiply accelerator_count by the number of nodes.
Expand Down Expand Up @@ -212,7 +275,7 @@ def get_trainjob_node_step(
# TODO (andreyvelich): We should also override the device_count
# based on OMPI_MCA_orte_set_default_slots value. Right now, it is hard to do
# since we inject this env only to the Launcher Pod.
step.name = f"{constants.NODE}-{job_index+1}"
step.name = f"{constants.NODE}-{job_index + 1}"

if container.env:
for env in container.env:
Expand Down
Loading