-
Notifications
You must be signed in to change notification settings - Fork 41
Fix trainer detection for custom Docker images with regex pattern matching #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
2b08be7
c5cf2bc
e57b003
f3f3bcf
9ef621e
82ecdf3
76eba6d
c0a7d93
baf87ba
b3aed48
d3c0043
48aa98b
fa5778b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| from kubeflow.trainer.types import types | ||
|
|
||
|
|
||
| class TestTrainerConfigurations: | ||
| """Test cases for trainer configurations and types.""" | ||
|
|
||
| def test_centralized_trainer_configs(self): | ||
| """Test that centralized trainer configurations are properly defined.""" | ||
| # Verify all trainer frameworks have configurations | ||
| for framework in types.Framework: | ||
| assert framework in types.TRAINER_CONFIGS | ||
| trainer = types.TRAINER_CONFIGS[framework] | ||
| assert trainer.framework == framework | ||
|
|
||
| def test_default_trainer_uses_centralized_config(self): | ||
| """Test that DEFAULT_TRAINER uses centralized configuration.""" | ||
| assert types.DEFAULT_TRAINER == types.TRAINER_CONFIGS[types.Framework.TORCH] | ||
| assert types.DEFAULT_TRAINER.framework == types.Framework.TORCH |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import copy | ||
| import inspect | ||
| import os | ||
| import queue | ||
|
|
@@ -107,6 +108,66 @@ def get_runtime_trainer_container( | |
| return None | ||
|
|
||
|
|
||
| def detect_trainer_from_image_patterns( | ||
| image_name: str, default: Optional[types.Trainer] = None | ||
| ) -> Optional[types.Trainer]: | ||
| """ | ||
| Detect trainer type based on image name patterns using regex. | ||
|
|
||
| This method uses pattern matching on the image name to determine | ||
| the likely trainer type. | ||
|
|
||
| Args: | ||
| image_name: The container image name | ||
| default: Optional default trainer to return if no patterns match | ||
|
|
||
| Returns: | ||
| Trainer object if detected, default if provided, None otherwise | ||
| """ | ||
| # DeepSpeed patterns | ||
| if re.search(r"deepspeed", image_name, re.IGNORECASE): | ||
| return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.DEEPSPEED]) | ||
|
|
||
| # MLX patterns | ||
| if re.search(r"mlx", image_name, re.IGNORECASE): | ||
| return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.MLX]) | ||
|
|
||
| # TorchTune patterns (check before PyTorch to avoid conflicts) | ||
| if re.search(r"torchtune", image_name, re.IGNORECASE): | ||
| return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCHTUNE]) | ||
|
|
||
| # PyTorch patterns - require explicit "pytorch" in image name for clarity | ||
| if re.search(r"pytorch", image_name, re.IGNORECASE): | ||
| return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCH]) | ||
|
|
||
| # Handle deep copy internally | ||
| if default is not None: | ||
| return copy.deepcopy(default) | ||
| return None | ||
|
|
||
|
|
||
| def detect_trainer( | ||
| trainer_container: models.IoK8sApiCoreV1Container, | ||
| ) -> types.Trainer: | ||
| """ | ||
| Detect trainer type using pattern matching with fallback. | ||
|
|
||
| This method implements the detection logic: | ||
| 1. Use image pattern matching to detect framework | ||
| 2. Fall back to DEFAULT_TRAINER if no patterns match | ||
|
|
||
| Args: | ||
| trainer_container: The trainer container object | ||
|
|
||
| Returns: | ||
| Trainer object | ||
| """ | ||
| image_name = trainer_container.image.split(":")[0] | ||
|
|
||
| # Use image pattern matching with default fallback | ||
| return detect_trainer_from_image_patterns(image_name, types.DEFAULT_TRAINER) | ||
|
|
||
|
|
||
| def get_runtime_trainer( | ||
| replicated_jobs: List[models.JobsetV1alpha2ReplicatedJob], | ||
| ml_policy: models.TrainerV1alpha1MLPolicy, | ||
|
|
@@ -121,20 +182,23 @@ def get_runtime_trainer( | |
| if not (trainer_container and trainer_container.image): | ||
| raise Exception(f"Runtime doesn't have trainer container {replicated_jobs}") | ||
|
|
||
| # Extract image name from the container image to get appropriate Trainer. | ||
| image_name = trainer_container.image.split(":")[0] | ||
| trainer = types.ALL_TRAINERS.get(image_name, types.DEFAULT_TRAINER) | ||
| # Use the new detection logic with fallback | ||
| trainer = detect_trainer(trainer_container) | ||
|
|
||
| # Get the container devices. | ||
| if devices := get_container_devices(trainer_container.resources): | ||
| _, trainer.accelerator_count = devices | ||
|
|
||
| # Torch and MPI plugins override accelerator count. | ||
| if ml_policy.torch and ml_policy.torch.num_proc_per_node: | ||
| # NOTE: The 'is not None' checks are essential because: | ||
| # 1. For torch: prevents AttributeError when accessing None.actual_instance | ||
| # 2. For MPI: prevents setting accelerator_count to None | ||
| # 3. Semantically: only override when user explicitly provides num_proc_per_node | ||
| if ml_policy.torch and ml_policy.torch.num_proc_per_node is not None: | ||
| num_proc = ml_policy.torch.num_proc_per_node.actual_instance | ||
| if isinstance(num_proc, int): | ||
| trainer.accelerator_count = num_proc | ||
| elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node: | ||
| elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None: | ||
|
||
| trainer.accelerator_count = ml_policy.mpi.num_proc_per_node | ||
|
|
||
| # Multiply accelerator_count by the number of nodes. | ||
|
|
@@ -212,7 +276,7 @@ def get_trainjob_node_step( | |
| # TODO (andreyvelich): We should also override the device_count | ||
| # based on OMPI_MCA_orte_set_default_slots value. Right now, it is hard to do | ||
| # since we inject this env only to the Launcher Pod. | ||
| step.name = f"{constants.NODE}-{job_index+1}" | ||
| step.name = f"{constants.NODE}-{job_index + 1}" | ||
|
|
||
| if container.env: | ||
| for env in container.env: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to keep framework argument given that
TRAINER_CONFIGSDict has the Framework type in the Dict key.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the
frameworkfield in theTrainerclass, I'd like to share my thoughts on why this field exists and why it serves a legitimate purpose:The
frameworkField Has Critical ImportanceAfter investigating the codebase, I discovered that the
Trainerclass andframeworkfield were pre-existing before this PR. The field was intentionally designed to serve specific purposes:Critical Importance for API Design
The
frameworkfield is essential for maintaining a clean, self-contained API:Trainerobject must "know" what framework it represents without external contextTrainerobject, they can immediately determine its framework without reverse-engineering from other fieldsSelf-Contained Data Structure
The
frameworkfield makesTrainerobjects self-contained and self-documenting:Breaking Changes Would Be Required
Removing the field would require:
Architectural Integrity
The field maintains the principle of encapsulation —
Trainerobject should contain all information about itself, including what framework it represents.Why Dictionary Instead of Array?
The choice of using
TRAINER_CONFIGS: Dict[Framework, Trainer]instead of an array of trainers was a performance and design optimization:Performance Benefits
Design Benefits
My Take
The
frameworkfield serves critical architectural purposes for API design and object encapsulation. The dictionary structure provides performance benefits, but the field itself is essential for maintaining clean, self-contained objects.Removing the field would break the original design intent, make the API less clean and efficient, and potentially introduce breaking changes. The field was intentionally designed this way for good reasons, and I believe we should keep it to maintain the integrity of the API design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we should have dict to represent all Trainers where key is the Framework name and value is the Trainer object.
The question is should we also keep
frameworkargument in the Trainer object. This is mostly used to just show users what framework this Trainer is using.I am fine to keep it for now.
WDYT @szaher @astefanutti @Electronic-Waste ?