Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions python/kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,16 @@ class Initializer:
model: Optional[HuggingFaceModelInitializer] = None


# The dict where key is the container image and value its representation.
# Each Trainer representation defines trainer parameters (e.g. type, framework, entrypoint).
# TODO (andreyvelich): We should allow user to overrides the default image names.
ALL_TRAINERS: Dict[str, Trainer] = {
# Custom Trainers.
"pytorch/pytorch": Trainer(
# Centralized trainer configurations to eliminate duplication
TRAINER_CONFIGS: Dict[Framework, Trainer] = {
Framework.TORCH: Trainer(
trainer_type=TrainerType.CUSTOM_TRAINER,
framework=Framework.TORCH,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to keep framework argument given that TRAINER_CONFIGS Dict has the Framework type in the Dict key.

Suggested change
framework=Framework.TORCH,

Copy link
Author

@jskswamy jskswamy Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the framework field in the Trainer class, I'd like to share my thoughts on why this field exists and why it serves a legitimate purpose:

The framework Field Has Critical Importance

After investigating the codebase, I discovered that the Trainer class and framework field were pre-existing before this PR. The field was intentionally designed to serve specific purposes:

Critical Importance for API Design

The framework field is essential for maintaining a clean, self-contained API:

  1. Object Identity: A Trainer object must "know" what framework it represents without external context
  2. API Completeness: When users receive a Trainer object, they can immediately determine its framework without reverse-engineering from other fields
  3. Serialization: The field is crucial for JSON serialization/deserialization of trainer objects
  4. Debugging & Logging: Essential for meaningful error messages and debugging information

Self-Contained Data Structure

The framework field makes Trainer objects self-contained and self-documenting:

# Example: A Trainer object "knows" what framework it represents
trainer = TRAINER_CONFIGS[Framework.DEEPSPEED]

# Self-documenting: The object tells us what it is
print(f"Using {trainer.framework} trainer with {trainer.trainer_type}")
# Output: "Using Framework.DEEPSPEED trainer with TrainerType.CUSTOM_TRAINER"

# Without the field, we'd need external context to know what framework this is
# We'd have to track which dictionary key was used to create this trainer

Breaking Changes Would Be Required

Removing the field would require:

  • Modifying any code that relies on the field for framework identification
  • Potentially breaking API consumers who expect this field
  • Adding complex lookup logic to determine framework from other properties

Architectural Integrity

The field maintains the principle of encapsulationTrainer object should contain all information about itself, including what framework it represents.

Why Dictionary Instead of Array?

The choice of using TRAINER_CONFIGS: Dict[Framework, Trainer] instead of an array of trainers was a performance and design optimization:

Performance Benefits

# Current efficient approach with dictionary
trainer = TRAINER_CONFIGS[Framework.DEEPSPEED]  # O(1) lookup
framework = trainer.framework  # Direct access

# Alternative inefficient approach with array
def find_trainer_by_framework(framework):
    for trainer in TRAINER_ARRAY:  # O(n) search
        if trainer.framework == framework:
            return trainer

Design Benefits

  1. Fast Lookup: O(1) constant time access instead of O(n) linear search
  2. Type Safety: Dictionary keys ensure we only access valid frameworks
  3. Explicit Mapping: Clear relationship between framework and trainer configuration
  4. Extensibility: Easy to add new frameworks without changing lookup logic

My Take

The framework field serves critical architectural purposes for API design and object encapsulation. The dictionary structure provides performance benefits, but the field itself is essential for maintaining clean, self-contained objects.

Removing the field would break the original design intent, make the API less clean and efficient, and potentially introduce breaking changes. The field was intentionally designed this way for good reasons, and I believe we should keep it to maintain the integrity of the API design.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should have dict to represent all Trainers where key is the Framework name and value is the Trainer object.
The question is should we also keep framework argument in the Trainer object. This is mostly used to just show users what framework this Trainer is using.

I am fine to keep it for now.

WDYT @szaher @astefanutti @Electronic-Waste ?

entrypoint=[constants.TORCH_ENTRYPOINT],
),
"ghcr.io/kubeflow/trainer/mlx-runtime": Trainer(
Framework.DEEPSPEED: Trainer(
trainer_type=TrainerType.CUSTOM_TRAINER,
framework=Framework.MLX,
framework=Framework.DEEPSPEED,
entrypoint=[
constants.MPI_ENTRYPOINT,
"--hostfile",
Expand All @@ -251,9 +248,9 @@ class Initializer:
"-c",
],
),
"ghcr.io/kubeflow/trainer/deepspeed-runtime": Trainer(
Framework.MLX: Trainer(
trainer_type=TrainerType.CUSTOM_TRAINER,
framework=Framework.DEEPSPEED,
framework=Framework.MLX,
entrypoint=[
constants.MPI_ENTRYPOINT,
"--hostfile",
Expand All @@ -262,20 +259,15 @@ class Initializer:
"-c",
],
),
# Builtin Trainers.
"ghcr.io/kubeflow/trainer/torchtune-trainer": Trainer(
Framework.TORCHTUNE: Trainer(
trainer_type=TrainerType.BUILTIN_TRAINER,
framework=Framework.TORCHTUNE,
entrypoint=constants.DEFAULT_TORCHTUNE_COMMAND,
),
}

# The default trainer configuration when runtime detection fails
DEFAULT_TRAINER = Trainer(
trainer_type=TrainerType.CUSTOM_TRAINER,
framework=Framework.TORCH,
entrypoint=[constants.TORCH_ENTRYPOINT],
)
DEFAULT_TRAINER = TRAINER_CONFIGS[Framework.TORCH]

# The default runtime configuration for the train() API
DEFAULT_RUNTIME = Runtime(
Expand Down
18 changes: 18 additions & 0 deletions python/kubeflow/trainer/types/types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from kubeflow.trainer.types import types


class TestTrainerConfigurations:
"""Test cases for trainer configurations and types."""

def test_centralized_trainer_configs(self):
"""Test that centralized trainer configurations are properly defined."""
# Verify all trainer frameworks have configurations
for framework in types.Framework:
assert framework in types.TRAINER_CONFIGS
trainer = types.TRAINER_CONFIGS[framework]
assert trainer.framework == framework

def test_default_trainer_uses_centralized_config(self):
"""Test that DEFAULT_TRAINER uses centralized configuration."""
assert types.DEFAULT_TRAINER == types.TRAINER_CONFIGS[types.Framework.TORCH]
assert types.DEFAULT_TRAINER.framework == types.Framework.TORCH
64 changes: 58 additions & 6 deletions python/kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import os
import queue
Expand Down Expand Up @@ -107,6 +108,55 @@ def get_runtime_trainer_container(
return None


def get_trainer_from_image(image_name: str) -> types.Trainer:
"""
Detect trainer type based on image name patterns using regex.
This method uses pattern matching on the image name to determine
the likely trainer type.
Args:
image_name: The container image name.
Returns:
Trainer: Trainer object if detected, otherwise the DEFAULT_TRAINER is returned.
"""
# DeepSpeed patterns
if re.search(r"deepspeed", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.DEEPSPEED])

# MLX patterns
if re.search(r"mlx", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.MLX])

# TorchTune patterns (check before PyTorch to avoid conflicts)
if re.search(r"torchtune", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCHTUNE])

# PyTorch patterns - require explicit "pytorch" in image name for clarity
if re.search(r"pytorch", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCH])

return copy.deepcopy(types.DEFAULT_TRAINER)


def detect_trainer(
trainer_container: models.IoK8sApiCoreV1Container,
) -> types.Trainer:
"""
Detect trainer type using pattern matching with fallback.

This method implements the detection logic:
1. Use image pattern matching to detect framework
2. Fall back to DEFAULT_TRAINER if no patterns match

Args:
trainer_container: The trainer container object

Returns:
Trainer object
"""
image_name = trainer_container.image.split(":")[0]
return get_trainer_from_image(image_name)


def get_runtime_trainer(
replicated_jobs: List[models.JobsetV1alpha2ReplicatedJob],
ml_policy: models.TrainerV1alpha1MLPolicy,
Expand All @@ -121,20 +171,22 @@ def get_runtime_trainer(
if not (trainer_container and trainer_container.image):
raise Exception(f"Runtime doesn't have trainer container {replicated_jobs}")

# Extract image name from the container image to get appropriate Trainer.
image_name = trainer_container.image.split(":")[0]
trainer = types.ALL_TRAINERS.get(image_name, types.DEFAULT_TRAINER)
# Use the new detection logic with fallback
trainer = detect_trainer(trainer_container)

# Get the container devices.
if devices := get_container_devices(trainer_container.resources):
_, trainer.accelerator_count = devices

# Torch and MPI plugins override accelerator count.
if ml_policy.torch and ml_policy.torch.num_proc_per_node:
# NOTE: Using truthiness check handles None/0 automatically
if (ml_policy.torch and
ml_policy.torch.num_proc_per_node and
ml_policy.torch.num_proc_per_node.actual_instance):
num_proc = ml_policy.torch.num_proc_per_node.actual_instance
if isinstance(num_proc, int):
trainer.accelerator_count = num_proc
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
elif (ml_policy.mpi and ml_policy.mpi.num_proc_per_node):
trainer.accelerator_count = ml_policy.mpi.num_proc_per_node

# Multiply accelerator_count by the number of nodes.
Expand Down Expand Up @@ -212,7 +264,7 @@ def get_trainjob_node_step(
# TODO (andreyvelich): We should also override the device_count
# based on OMPI_MCA_orte_set_default_slots value. Right now, it is hard to do
# since we inject this env only to the Launcher Pod.
step.name = f"{constants.NODE}-{job_index+1}"
step.name = f"{constants.NODE}-{job_index + 1}"

if container.env:
for env in container.env:
Expand Down
Loading
Loading