Skip to content

Conversation

@jskswamy
Copy link

Problem

Custom DeepSpeed Docker images were losing the mpirun command and falling back to torchrun instead. The get_runtime_trainer function only used a hardcoded ALL_TRAINERS mapping, so any custom images not in this mapping would default to the PyTorch trainer configuration.

Solution

Enhanced trainer detection with regex-based pattern matching as a fallback mechanism:

  1. First priority: Check existing ALL_TRAINERS mapping for exact matches
  2. Second priority: Use regex patterns to detect framework from image names:
    • DeepSpeed: (?i)deepspeed (case-insensitive)
    • MLX: (?i)mlx
    • TorchTune: (?i)torchtune
    • PyTorch: (?i)(pytorch|torch) (but not torchtune)
  3. Fallback: Default to PyTorch trainer if no patterns match

Key Changes

  • Added _detect_trainer_from_image_patterns() function with case-insensitive regex matching
  • Modified _detect_trainer() to use pattern matching as fallback
  • Added copy.deepcopy() to prevent shared state issues between trainer configurations
  • Comprehensive test suite with 76 test cases covering various image name formats

Testing

  • ✅ Known images from ALL_TRAINERS mapping
  • ✅ Custom images with various case formats (lowercase, uppercase, mixed case)
  • ✅ Images with registry prefixes, ports, and complex paths
  • ✅ Edge cases and fallback scenarios
  • ✅ Accelerator count logic with ML policies
  • ✅ State isolation between test runs

This ensures custom DeepSpeed images like my-org/deepspeed-custom:latest correctly use mpirun instead of falling back to torchrun.

This fixes the issue #29

@google-oss-prow
Copy link

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign andreyvelich for approval. For more information see the Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

Copy link
Member

@eoinfennessy eoinfennessy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy thank you for this contribution! And thank you for also adding tests for get_container_devices.

It looks great to me -- I just added two small suggestions.

Comment on lines 75 to 93
# Edge cases - no match (should fall back to default)
("unknown-image:latest", types.TrainerFramework.TORCH),
("", types.TrainerFramework.TORCH),
("nginx:latest", types.TrainerFramework.TORCH),
("ubuntu:20.04", types.TrainerFramework.TORCH),
],
)
def test_trainer_detection_from_image_patterns(
self, image_name, expected_framework
):
"""Test trainer detection using image pattern matching with various case scenarios."""
trainer = utils._detect_trainer_from_image_patterns(image_name)
if expected_framework == types.TrainerFramework.TORCH and trainer is None:
# For unknown images, the _detect_trainer function should return default
# but _detect_trainer_from_image_patterns returns None
assert trainer is None
else:
assert trainer is not None
assert trainer.framework.value == expected_framework.value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small suggestion to replace expected_framework with None for no-match cases. I think this make the behavior of the function being tested clearer for readers.

Suggested change
# Edge cases - no match (should fall back to default)
("unknown-image:latest", types.TrainerFramework.TORCH),
("", types.TrainerFramework.TORCH),
("nginx:latest", types.TrainerFramework.TORCH),
("ubuntu:20.04", types.TrainerFramework.TORCH),
],
)
def test_trainer_detection_from_image_patterns(
self, image_name, expected_framework
):
"""Test trainer detection using image pattern matching with various case scenarios."""
trainer = utils._detect_trainer_from_image_patterns(image_name)
if expected_framework == types.TrainerFramework.TORCH and trainer is None:
# For unknown images, the _detect_trainer function should return default
# but _detect_trainer_from_image_patterns returns None
assert trainer is None
else:
assert trainer is not None
assert trainer.framework.value == expected_framework.value
# Edge cases - no match
("unknown-image:latest", None),
("", None),
("nginx:latest", None),
("ubuntu:20.04", None),
],
)
def test_trainer_detection_from_image_patterns(
self, image_name, expected_framework
):
"""Test trainer detection using image pattern matching with various case scenarios."""
trainer = utils._detect_trainer_from_image_patterns(image_name)
if expected_framework is None:
# For unknown images _detect_trainer_from_image_patterns returns None
assert trainer is None
else:
assert trainer is not None
assert trainer.framework.value == expected_framework.value

Comment on lines 24 to 30
# Trainer framework constants for easy reference
class TrainerFramework(Enum):
"""Trainer framework constants."""
TORCH = "torch"
DEEPSPEED = "deepspeed"
MLX = "mlx"
TORCHTUNE = "torchtune"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly the same as the Framework enum. Can we delete this and use Framework instead?

@jskswamy jskswamy force-pushed the fix/trainer-detection-custom-images branch from ed91f02 to e16b9c6 Compare June 23, 2025 09:15
@jskswamy
Copy link
Author

@eoinfennessy made all the suggested changes.

Copy link
Member

@eoinfennessy eoinfennessy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy LGTM! Thank you!

@google-oss-prow
Copy link

@eoinfennessy: changing LGTM is restricted to collaborators

In response to this:

@jskswamy LGTM! Thank you!

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this contribution @jskswamy 🎉

Comment on lines 201 to 200
if ml_policy.torch and ml_policy.torch.num_proc_per_node is not None:
num_proc = ml_policy.torch.num_proc_per_node.actual_instance
if isinstance(num_proc, int):
trainer.accelerator_count = num_proc
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add is not None here ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. Torch Policy Check (if trainer_container.accelerator_count is not None)

# Essential: Prevents AttributeError when accessing None.actual_instance
# Without this check: None.actual_instance would raise AttributeError
if trainer_container.accelerator_count is not None:
    if hasattr(trainer_container.accelerator_count, 'actual_instance'):
        trainer.accelerator_count = trainer_container.accelerator_count.actual_instance

2. MPI Policy Check (if trainer_container.mpi_policy is not None)

# Essential: Prevents setting accelerator_count to None when user explicitly sets it
# Without this check: trainer.accelerator_count would be overwritten to None
if trainer_container.mpi_policy is not None:
    trainer.accelerator_count = trainer_container.mpi_policy.num_procs

3. Semantic Correctness

These checks ensure that:

  • User-provided values are preserved and not overwritten
  • We don't attempt operations on None objects
  • The logic follows "only apply changes if the field is actually set"

Code Comments Added:

I've added explanatory comments to each check to make their necessity clear for future maintainers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy I just meant that those 2 lines are the same in Python, isn't ?

elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a subtle but important distinction in Python.

The is not None Check is Necessary

The current code is correct because 0 is a valid and meaningful value for num_proc_per_node:

# Current correct implementation
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:
    trainer.accelerator_count = ml_policy.mpi.num_proc_per_node

Why Truthiness Checking Would Break CPU-Only Training

If we used truthiness checking instead:

# This would be problematic
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
    trainer.accelerator_count = ml_policy.mpi.num_proc_per_node

Example Scenarios:

Scenario 1: CPU-Only Training (0 accelerators)

ml_policy.mpi.num_proc_per_node = 0  # Explicitly set to CPU-only

# With truthiness check:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node:  # 0 is falsy!
    trainer.accelerator_count = ml_policy.mpi.num_proc_per_node  # ❌ Never executes

# With is not None check:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:  # 0 is not None!
    trainer.accelerator_count = ml_policy.mpi.num_proc_per_node  # ✅ Executes correctly

Scenario 2: GPU Training (4 accelerators)

ml_policy.mpi.num_proc_per_node = 4  # Explicitly set to 4 GPUs

# Both approaches work correctly:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node:  # 4 is truthy ✅
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:  # 4 is not None ✅

Scenario 3: Not Set (defaults to UNKNOWN)

ml_policy.mpi.num_proc_per_node = None  # Not explicitly set

# Both approaches work correctly:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node:  # None is falsy ✅
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:  # None is None ✅

The Key Distinction

The is not None check properly distinguishes between:

  • "Not set" (None) → don't override accelerator count
  • "Explicitly set to 0" (0) → override with 0 (CPU-only training)
  • "Explicitly set to positive number" → override with that number

Copy link
Member

@andreyvelich andreyvelich Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why num_proc_per_node=0 is a valid value ?
We should not allow user to set such value or consider this as None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy Did you get a chance to check this comment ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late reply! I've addressed this, kindly check the changes now

return None


def _detect_trainer_from_image_patterns(image_name: str) -> Optional[types.Trainer]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tenzen-y @Electronic-Waste @astefanutti @jskswamy @eoinfennessy @franciscojavierarceo Do we see any concerns with regex approach ? It might be a good and simple method to start with, but I can imagine use cases where it wouldn't work. For example, users might have two DeepSpeed runtimes:

  • One uses torchrun
  • Another uses mpirun.

Perhaps in the future we can support such scenarios.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, there are use cases where image name patterns alone wouldn't be sufficient.

Current Regex Approach — Pragmatic Starting Point

The regex approach implemented serves as a practical starting point that:

  • Works immediately for the majority of common use cases
  • Supports all official Kubeflow trainer images out of the box
  • Provides sensible defaults without requiring users to specify trainer types manually
  • Maintains backward compatibility with existing workflows

Future API-Based Enhancement

For advanced scenarios like your DeepSpeed example (torchrun vs mpirun variants), we can introduce explicit API controls that override the regex detection:

# Option 1: Explicit trainer specification
trainer = Trainer(
    image="custom/deepspeed-runtime",
    trainer_type=TrainerType.DEEPSPEED_MPI,  # Override regex detection
    # ... other configs
)

# Option 2: Runtime configuration
trainer = Trainer(
    image="custom/deepspeed-runtime", 
    runtime_config=DeepSpeedConfig(launcher="mpirun"),  # vs "torchrun"
    # ... other configs
)

Approach

The regex approach handles ~90% of use cases elegantly, while keeping the door open for API-based precision when needed.

Question: Which approach would you prefer to proceed with?

Option A: Keep the current regex-based detection and enhance it incrementally with API overrides when needed

Option B: Move to a more explicit API-first approach where users specify trainer types directly

Option C: Hybrid approach where regex provides defaults, but API allows explicit overrides from day one

I'm happy to implement the changes that would be most valuable for users. What are your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the regex approach initially, we can have discussion how to cover more complex use-cases.
We just need to ensure we document that in our docs: https://www.kubeflow.org/docs/components/trainer/user-guides/

@jskswamy jskswamy force-pushed the fix/trainer-detection-custom-images branch from ec4a1c2 to e7f4425 Compare June 30, 2025 12:19
jskswamy added 10 commits June 30, 2025 18:17
This commit introduces optional dependencies specifically for testing
purposes. The `pytest` and `pytest-mock` packages are added to the
`pyproject.toml` file under the `optional-dependencies` section,
allowing developers to easily install testing tools when needed.

Additionally, a new `pytest.ini` configuration section is created to
standardize test settings, including options for verbosity and test
discovery patterns.

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
This commit introduces a new enumeration `TrainerFramework` to centralize
the definitions of various trainer frameworks used in the Kubeflow SDK.
The trainer configurations have been refactored into a dictionary
`TRAINER_CONFIGS`, which maps each framework to its respective
configuration, reducing duplication and improving maintainability.

Additionally, the trainer detection logic has been enhanced to utilize
image name patterns for identifying the appropriate trainer framework
based on the container image name. This improves the robustness of
trainer type detection and ensures backward compatibility with the
existing `ALL_TRAINERS` mapping.

- Added `TrainerFramework` enum for trainer framework constants.
- Refactored trainer configurations into `TRAINER_CONFIGS`.
- Enhanced trainer detection logic to support image name patterns.
- Added unit tests for the new detection logic and configurations.

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Updated the TrainerFramework Enum to a more generic
Framework Enum to improve code maintainability and clarity.
This change simplifies the trainer configurations and
associated functions by using the new Framework Enum,
ensuring consistent references throughout the codebase.

- Replaced TrainerFramework with Framework in types.py
- Updated references in utils.py to reflect the new Enum
- Adjusted test cases in test_utils.py to accommodate changes

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Refactor the test cases in `test_utils.py` to adjust the expected
output for edge cases where no matching framework is found. This
change ensures that the tests handle cases where the image does not
correspond to any known framework by returning `None` instead of a
default framework.

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Move test files from tests/ directory to be co-located with source files
and split types-related tests into a separate file:

- tests/test_utils.py → kubeflow/trainer/utils/utils_test.py
- Extract types tests → kubeflow/trainer/types/types_test.py
- Update pyproject.toml testpaths: ["tests"] → ["kubeflow"]
- Remove tests/ directory

This improves code organization by keeping tests next to the code
they validate, making it easier to maintain test coverage when
modifying source files.

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Remove underscore prefixes from detect_trainer_from_image_patterns()
and detect_trainer() to follow established codebase conventions.
Analysis shows no other utility functions in the codebase use
underscore prefixes.

Functions renamed:
- _detect_trainer_from_image_patterns → detect_trainer_from_image_patterns
- _detect_trainer → detect_trainer

Update all function calls and tests accordingly.

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Remove generic 'torch' pattern matching and require explicit 'pytorch'
in image names for better framework distinction. This prevents
ambiguity between PyTorch and other torch-related libraries.

- Remove regex pattern: r'^torch(?!tune)'
- Keep only: r'pytorch' for PyTorch detection
- Update test case: 'torch-custom:latest' → 'pytorch-torch-custom:latest'
- Add test case: 'torch-custom:latest' now returns None

This ensures clearer separation between PyTorch and TorchTune images.

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Add detailed comments explaining why 'is not None' checks are
necessary in ML policy processing:

1. For torch: prevents AttributeError when accessing None.actual_instance
2. For MPI: prevents setting accelerator_count to None
3. Semantically: only override when user explicitly provides values

These checks prevent runtime errors and ensure correct behavior
when ML policies have undefined num_proc_per_node values.

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Eliminate ALL_TRAINERS and rely solely on regex pattern matching
for trainer detection. This removes duplication between static mapping
and TRAINER_CONFIGS while maintaining full functionality.

- Remove ALL_TRAINERS from types.py
- Simplify detect_trainer(): regex patterns → DEFAULT_TRAINER fallback
- Update tests to verify official images work with regex patterns

All official Kubeflow images correctly detected by regex, ensuring
no breaking changes while reducing architectural complexity.
The regex patterns now serve as the single source of truth.

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
- Remove uv.lock file
- Remove test dependencies from pyproject.toml
- Remove pytest configuration from pyproject.toml
- Keep only core trainer detection improvements and tests

This ensures the PR focuses solely on trainer detection enhancements.

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
@jskswamy jskswamy force-pushed the fix/trainer-detection-custom-images branch from e7f4425 to b3aed48 Compare June 30, 2025 12:47
TRAINER_CONFIGS: Dict[Framework, Trainer] = {
Framework.TORCH: Trainer(
trainer_type=TrainerType.CUSTOM_TRAINER,
framework=Framework.TORCH,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to keep framework argument given that TRAINER_CONFIGS Dict has the Framework type in the Dict key.

Suggested change
framework=Framework.TORCH,

Copy link
Author

@jskswamy jskswamy Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the framework field in the Trainer class, I'd like to share my thoughts on why this field exists and why it serves a legitimate purpose:

The framework Field Has Critical Importance

After investigating the codebase, I discovered that the Trainer class and framework field were pre-existing before this PR. The field was intentionally designed to serve specific purposes:

Critical Importance for API Design

The framework field is essential for maintaining a clean, self-contained API:

  1. Object Identity: A Trainer object must "know" what framework it represents without external context
  2. API Completeness: When users receive a Trainer object, they can immediately determine its framework without reverse-engineering from other fields
  3. Serialization: The field is crucial for JSON serialization/deserialization of trainer objects
  4. Debugging & Logging: Essential for meaningful error messages and debugging information

Self-Contained Data Structure

The framework field makes Trainer objects self-contained and self-documenting:

# Example: A Trainer object "knows" what framework it represents
trainer = TRAINER_CONFIGS[Framework.DEEPSPEED]

# Self-documenting: The object tells us what it is
print(f"Using {trainer.framework} trainer with {trainer.trainer_type}")
# Output: "Using Framework.DEEPSPEED trainer with TrainerType.CUSTOM_TRAINER"

# Without the field, we'd need external context to know what framework this is
# We'd have to track which dictionary key was used to create this trainer

Breaking Changes Would Be Required

Removing the field would require:

  • Modifying any code that relies on the field for framework identification
  • Potentially breaking API consumers who expect this field
  • Adding complex lookup logic to determine framework from other properties

Architectural Integrity

The field maintains the principle of encapsulationTrainer object should contain all information about itself, including what framework it represents.

Why Dictionary Instead of Array?

The choice of using TRAINER_CONFIGS: Dict[Framework, Trainer] instead of an array of trainers was a performance and design optimization:

Performance Benefits

# Current efficient approach with dictionary
trainer = TRAINER_CONFIGS[Framework.DEEPSPEED]  # O(1) lookup
framework = trainer.framework  # Direct access

# Alternative inefficient approach with array
def find_trainer_by_framework(framework):
    for trainer in TRAINER_ARRAY:  # O(n) search
        if trainer.framework == framework:
            return trainer

Design Benefits

  1. Fast Lookup: O(1) constant time access instead of O(n) linear search
  2. Type Safety: Dictionary keys ensure we only access valid frameworks
  3. Explicit Mapping: Clear relationship between framework and trainer configuration
  4. Extensibility: Easy to add new frameworks without changing lookup logic

My Take

The framework field serves critical architectural purposes for API design and object encapsulation. The dictionary structure provides performance benefits, but the field itself is essential for maintaining clean, self-contained objects.

Removing the field would break the original design intent, make the API less clean and efficient, and potentially introduce breaking changes. The field was intentionally designed this way for good reasons, and I believe we should keep it to maintain the integrity of the API design.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should have dict to represent all Trainers where key is the Framework name and value is the Trainer object.
The question is should we also keep framework argument in the Trainer object. This is mostly used to just show users what framework this Trainer is using.

I am fine to keep it for now.

WDYT @szaher @astefanutti @Electronic-Waste ?

Comment on lines 162 to 167
trainer = detect_trainer_from_image_patterns(image_name)
if trainer:
return trainer

# 2. Fall back to DEFAULT_TRAINER
return copy.deepcopy(types.DEFAULT_TRAINER)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, this could be simplified if detect_trainer_from_image_patterns just return copy.deepcopy(types.DEFAULT_TRAINER) instead of None.

Can you just keep all of the required code to extract trainer in the get_trainer_from_image() function, which accepts image_name as input?

That will make our unit tests easier to maintain.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made necessary changes to simplify the detect_trainer_from_image_patterns function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy Sorry for the late reply, I meant can you just use this code snippet ?

def get_runtime_trainer(....):
....
image_name = trainer_container.image.split(":")[0]
trainer = get_trainer_from_image(image_name)


def get_trainer_from_image(image_name: str) -> types.Trainer:
    """
    Detect trainer type based on image name patterns using regex.
    This method uses pattern matching on the image name to determine
    the likely trainer type.
    Args:
        image_name: The container image name.
    Returns:
        Trainer: Trainer object if detected, otherwise the DEFAULT_TRAINER is returned.
    """
    # DeepSpeed patterns
    if re.search(r"deepspeed", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.DEEPSPEED])

    # MLX patterns
    if re.search(r"mlx", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.MLX])

    # TorchTune patterns (check before PyTorch to avoid conflicts)
    if re.search(r"torchtune", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCHTUNE])

    # PyTorch patterns - require explicit "pytorch" in image name for clarity
    if re.search(r"pytorch", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCH])

    return copy.deepcopy(types.DEFAULT_TRAINER)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've simplified the function as per your suggestion, kindly check

…sulation

- Add optional default parameter to detect_trainer_from_image_patterns()
- Handle copy.deepcopy() internally for better encapsulation
- Remove boilerplate code from detect_trainer() function
- Add comprehensive unit tests with proper separation of concerns
- Maintain backward compatibility with existing behavior

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
@jskswamy jskswamy force-pushed the fix/trainer-detection-custom-images branch from 7f6fa23 to d3c0043 Compare July 4, 2025 06:55
Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review, I think we almost ready to merge this.
@szaher @kramaranya @briangallagher @eoinfennessy Can you take a look as well please ?

Comment on lines 162 to 167
trainer = detect_trainer_from_image_patterns(image_name)
if trainer:
return trainer

# 2. Fall back to DEFAULT_TRAINER
return copy.deepcopy(types.DEFAULT_TRAINER)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy Sorry for the late reply, I meant can you just use this code snippet ?

def get_runtime_trainer(....):
....
image_name = trainer_container.image.split(":")[0]
trainer = get_trainer_from_image(image_name)


def get_trainer_from_image(image_name: str) -> types.Trainer:
    """
    Detect trainer type based on image name patterns using regex.
    This method uses pattern matching on the image name to determine
    the likely trainer type.
    Args:
        image_name: The container image name.
    Returns:
        Trainer: Trainer object if detected, otherwise the DEFAULT_TRAINER is returned.
    """
    # DeepSpeed patterns
    if re.search(r"deepspeed", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.DEEPSPEED])

    # MLX patterns
    if re.search(r"mlx", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.MLX])

    # TorchTune patterns (check before PyTorch to avoid conflicts)
    if re.search(r"torchtune", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCHTUNE])

    # PyTorch patterns - require explicit "pytorch" in image name for clarity
    if re.search(r"pytorch", image_name, re.IGNORECASE):
        return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCH])

    return copy.deepcopy(types.DEFAULT_TRAINER)

Comment on lines 201 to 200
if ml_policy.torch and ml_policy.torch.num_proc_per_node is not None:
num_proc = ml_policy.torch.num_proc_per_node.actual_instance
if isinstance(num_proc, int):
trainer.accelerator_count = num_proc
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jskswamy Did you get a chance to check this comment ?

jskswamy added 2 commits July 23, 2025 14:40
Simplify trainer detection API by removing optional default parameter
and always returning a Trainer object. The function now directly
returns DEFAULT_TRAINER when no regex patterns match, eliminating
the need for None handling in calling code.

Changes:
- Rename function to get_trainer_from_image for clarity
- Remove optional default parameter from function signature
- Always return types.Trainer instead of Optional[types.Trainer]
- Update all test cases to expect DEFAULT_TRAINER for unknown images
- Simplify detect_trainer() function logic

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Changes:
- For torch: check actual_instance value truthiness, not just object existence
- For MPI: already correctly validates the direct value
- Zero values (0) are now ignored (treated as None)
- Negative values are trusted as explicit user input
- Update test cases to reflect new behavior

Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Copy link
Contributor

@astefanutti astefanutti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may not have the full context, so please pardon my naive question, why those metadata, mostly the framework type, could not come from the training runtime itself, in the form of an annotation or a label?

For custom images, as a platform admin / user, I would understand I need to provide some hints.

@andreyvelich
Copy link
Member

o please pardon my naive question, why those metadata, mostly the framework type, could not come from the training runtime itself, in the form of an annotation or a label?

Previously we talked with @tenzen-y and @Electronic-Waste about introducing labels to the runtime that define framework type: https://cloud-native.slack.com/archives/C0742LDFZ4K/p1741266604716149?thread_ts=1741263570.091899&cid=C0742LDFZ4K

But we decided to not add more labels since this information can be retrieved from image and APIs.

@astefanutti
Copy link
Contributor

But we decided to not add more labels since this information can be retrieved from image and APIs.

The implicit contract based on the image name might prove fragile for users and downstream projects.

I understand the regex-based heuristics could provide a nice last-resort as it stands, but making that contract explicit seems simpler and more robust. I still fail to see why it couldn't be enforced in the training runtime API, or at least the SDK would only fallback to the regex-based heuristics if that API contract is made optional.

@andreyvelich
Copy link
Member

andreyvelich commented Jul 23, 2025

I understand the regex-based heuristics could provide a nice last-resort as it stands, but making that contract explicit seems simpler and more robust. I still fail to see why it couldn't be enforced in the training runtime API, or at least the SDK would only fallback to the regex-based heuristics if that API contract is made optional.

@astefanutti I think, we need to figure out why we even expose Runtime's trainer to the user.
Looking at the code it is only used to give user information about the Runtime:

Information such as:

  • whether this runtime can be used with CustomTrainer or with BuiltinTrainer. This is important data since users can incorrectly use CustomTrainer with builtin Runtimes cc @Electronic-Waste
  • What ML framework users should use with the runtime (I think, this might be removed if we add get_runtime_packages() API like I showed in the KubeCon previously: https://youtu.be/Fnb1a5Kaxgo?t=556
  • Entrypoint that is used while getting the TrainJob steps:
    trainjob_runtime.trainer.entrypoint
    and while creating the TrainJob using CustomTrainer from function:
    if runtime.trainer.entrypoint is None:
    raise Exception(f"Runtime trainer must have an entrypoint: {runtime.trainer}")

Do we think that we can refactor some of this and remove the TRAINER_CONFIGS list ?
Thoughts @astefanutti @szaher @kramaranya @eoinfennessy ?

@astefanutti
Copy link
Contributor

  • whether this runtime can be used with CustomTrainer or with BuiltinTrainer. This is important data since users can incorrectly use CustomTrainer with builtin Runtimes cc @Electronic-Waste

Right, that equally applies to the BuiltinTrainer.config. For now there is only one TorchTuneConfig, but nothing guarantees it's compatible with the training runtime when there'll be more.

Do we think that we can refactor some of this and remove the TRAINER_CONFIGS list ?

For built-in trainers, it seems there is a tight coupling between the trainer and the runtime, so maybe folding things into runtime as the "source-of-truth" would be better.

@andreyvelich
Copy link
Member

For built-in trainers, it seems there is a tight coupling between the trainer and the runtime, so maybe folding things into runtime as the "source-of-truth" would be better.

So do you mean that Runtime should tell users whether it is meant for CustomTrainer or BuiltinTrainer ?

@astefanutti
Copy link
Contributor

For built-in trainers, it seems there is a tight coupling between the trainer and the runtime, so maybe folding things into runtime as the "source-of-truth" would be better.

So do you mean that Runtime should tell users whether it is meant for CustomTrainer or BuiltinTrainer ?

Yes, one way or another. How a runtime is supposed to be used in the SDK is logically defined by the runtime, that includes the type of trainer (built-in, custom) and the framework (PyTorch, JAX, TorchTune, ...).

@andreyvelich
Copy link
Member

andreyvelich commented Jul 24, 2025

@astefanutti Do you think that framework information is still useful for SDK users if they can always run get_runtime_packages() API ?

@astefanutti
Copy link
Contributor

@astefanutti Do you think that framework information is still useful for SDK users if they can always run get_runtime_packages() API ?

No, though it'd be needed for checking the typed configuration passed by users for built-in trainers is compatible with the training runtime?

@andreyvelich
Copy link
Member

andreyvelich commented Jul 24, 2025

No, though it'd be needed for checking the typed configuration passed by users for built-in trainers is compatible with the training runtime?

This is correct, additionally we can't run the get_runtime_packages() API for BuiltinTrainer runtimes since by default it contains script for fine-tuning: https://github.com/kubeflow/trainer/blob/master/manifests/base/runtimes/torchtune/llama3_2/llama3_2_3B.yaml#L71.
We've done, so users can simple run this to fine-tune LLM:

client.train(
  runtime=Runtime(name="torchtune-llama3.2-3b")
)

Also, I don't think that users needs to know about installed packages in such runtimes, since they can only modify the config (e.g. fine-tuning parameters), but not the runtime packages.

Maybe for BuiltinTrainer runtime we should have two labels:

  • trainer.kubeflow.org/trainer-type: builtin
  • trainer.kubeflow.org/builtin-config: torchtune

If we don't want to introduce 2nd label, we can just tell users to rely on runtime name.

Thoughts @tenzen-y @astefanutti @Electronic-Waste @rudeigerc @szaher @kramaranya ?

@andreyvelich
Copy link
Member

I think, we should refactor our Runtime class: https://github.com/kubeflow/sdk/blob/main/python/kubeflow/trainer/types/types.py#L176-L179

@astefanutti
Copy link
Contributor

Maybe for BuiltinTrainer runtime we should have two labels:

  • trainer.kubeflow.org/trainer-type: builtin
  • trainer.kubeflow.org/builtin-config: torchtune

Yes, labels seem the most straightforward approach. There is already the trainer.kubeflow.org/accelerator label.
I wonder whether that'd make sense to go as far as to enforce those labels during training runtime admission?

@eoinfennessy
Copy link
Member

Agreed that it would be better to add APIs to TrainingRuntime and ClusterTrainingRuntime to specify the framework instead of relying on image names and regex checks. Adding trainer type would also be useful.

But why use labels instead of adding framework and other fields to the runtime spec? This would allow us to use schema-based validation to ensure a valid framework is provided.

One idea for this that would use cross-field validation to ensure one and only one of customTrainerConfig or builtinTrainerConfig is provided. Fields framework and type could use enum-based validation:

spec:
  customTrainerConfig:
    framework: "torch"
  # OR
  builtinTrainerConfig:
    type: "torchtune"

Probably best to consider the exact APIs alongside work on kubeflow/trainer#2752.

@astefanutti
Copy link
Contributor

But why use labels instead of adding framework and other fields to the runtime spec? This would allow us to use schema-based validation to ensure a valid framework is provided.

@eoinfennessy I agree with you it's a possible alternative. Labels are flexible and enable listing runtimes by label selectors, but we could conceptually consider these metadata as part of the spec.

@eoinfennessy
Copy link
Member

Labels are flexible and enable listing runtimes by label selectors

Ah, I hadn't considered that. Yes, that could help improve the UX of the list_runtimes method by filtering results. e.g:

client.list_runtimes(trainerType="custom", framework="torch")

@andreyvelich
Copy link
Member

andreyvelich commented Jul 25, 2025

There is already the trainer.kubeflow.org/accelerator label.

I am not sure if we should continue to maintain this label. IIRC, @tenzen-y has concerns introducing this label in the runtimes.

Labels are flexible and enable listing runtimes by label selectors,

We can also use field selector, if we introduce a new API in the runtime.

What are the pros and cons to add this property under labels or APIs ?

@astefanutti
Copy link
Contributor

astefanutti commented Jul 25, 2025

There is already the trainer.kubeflow.org/accelerator label.

I am not sure if we should continue to maintain this label. IIRC, @tenzen-y has concerns introducing this label in the runtimes.>

You're right, it may not be a good example "semantically".

Labels are flexible and enable listing runtimes by label selectors,

We can also use field selector, if we introduce a new API in the runtime.

I'm not sure custom fields from CRDs are indexed. It might be only few fields from core APIs.

What are the pros and cons to add this property under labels or APIs ?

I would say labels are more "free-form" and not as strictly part of the API contract compared to fields.

@andreyvelich
Copy link
Member

andreyvelich commented Jul 27, 2025

I'm not sure custom fields from CRDs are indexed. It might be only few fields from core APIs.

You are right @astefanutti, here is the list of supported fields: https://kubernetes.io/docs/concepts/overview/working-with-objects/field-selectors/#list-of-supported-fields

You're right, it may not be a good example "semantically".

Let me remove this label for now, unless we design better way to explain users the accelerator types in the Runtime.

@Electronic-Waste @astefanutti @tenzen-y Any concerns to introduce these three labels to our runtimes for now ?

trainer.kubeflow.org/trainer-type: custom
or
trainer.kubeflow.org/trainer-type: builtin
trainer.kubeflow.org/builtin-config: torchtune

Alternatively, we can introduce framework label to the runtimes, but not sure if that is really needed, since users only need to know builtin config type to use while creating TrainJob using BuiltinTrainer.

trainer.kubeflow.org/trainer-type: custom
trainer.kubeflow.org/framework: torch
---
trainer.kubeflow.org/trainer-type: custom
trainer.kubeflow.org/framework: deepspeed
---
trainer.kubeflow.org/trainer-type: builtin
trainer.kubeflow.org/framework: torchtune

@astefanutti
Copy link
Contributor

@Electronic-Waste @astefanutti @tenzen-y Any concerns to introduce these three labels to our runtimes for now ?

trainer.kubeflow.org/trainer-type: custom
or
trainer.kubeflow.org/trainer-type: builtin
trainer.kubeflow.org/builtin-config: torchtune

I think that's a good start. I only wonder whether those should be within the sdk.kubeflow.org prefix since it's metadata meant for the SDK.

Alternatively, we can introduce framework label to the runtimes, but not sure if that is really needed, since users only need to know builtin config type to use while creating TrainJob using BuiltinTrainer.

trainer.kubeflow.org/trainer-type: custom
trainer.kubeflow.org/framework: torch
---
trainer.kubeflow.org/trainer-type: custom
trainer.kubeflow.org/framework: deepspeed
---
trainer.kubeflow.org/trainer-type: builtin
trainer.kubeflow.org/framework: torchtune

Actually framework might be a different kind of metadata than those meant to be hints for the train API in the SDK.
That one would might be a simpler way to display the runtime framework rather than relying on the runtime name.

@andreyvelich
Copy link
Member

I think that's a good start. I only wonder whether those should be within the sdk.kubeflow.org prefix since it's metadata meant for the SDK.

I think, if we keep the sdk.kubeflow.org, we need to ensure that we don't require to use the same label for other Kubeflow projects that we want to integrate into SDK.

Actually framework might be a different kind of metadata than those meant to be hints for the train API in the SDK.

@astefanutti If we establish the contract that builtin configs contain framework name in the DataClass name, the trainer.kubeflow.org/framework: torchtune is sufficient.

@astefanutti
Copy link
Contributor

@astefanutti If we establish the contract that builtin configs contain framework name in the DataClass name, the trainer.kubeflow.org/framework: torchtune is sufficient.

That would be good yes. Having a mapping between framework name and DataClass name in the SDK would be perfectly acceptable I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants