1
1
#!/usr/bin/env python
2
2
# -*- coding: utf-8; -*-
3
3
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
5
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
"""This module requires oracle-ads>=2.6.8
7
7
"""
@@ -341,13 +341,15 @@ def prepare_cmd(self, launch_args: list = None, prefix=""):
341
341
launch_args = []
342
342
# Append launch cmd args specified by the user.
343
343
if self .launch_cmd :
344
- if not self .launch_cmd . startswith ( self . LAUNCHER ) :
345
- raise ValueError (
346
- f"Command not supported: ' { self . launch_cmd } '. "
347
- f"The command should start with '{ self .LAUNCHER } '."
348
- )
344
+ if self .LAUNCHER :
345
+ if not self . launch_cmd . startswith ( self . LAUNCHER ):
346
+ raise ValueError (
347
+ f"Command not supported: '{ self .launch_cmd } '. "
348
+ )
349
349
350
- launch_args .append (self .launch_cmd [len (self .LAUNCHER ) + 1 :])
350
+ launch_args .append (self .launch_cmd [len (self .LAUNCHER ) + 1 :])
351
+ else :
352
+ launch_args .append (self .launch_cmd )
351
353
else :
352
354
launch_args .append (self .get_cmd_with_entrypoint_and_args ())
353
355
@@ -673,7 +675,51 @@ def run(self):
673
675
self .run_deepspeed_worker ()
674
676
675
677
678
+ class GenericRunner (TorchRunner , DeepSpeedRunner ):
679
+ """Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``."""
680
+
681
+ LAUNCHER = ""
682
+
683
+ def use_deepspeed (self ) -> bool :
684
+ """Indicate if DeepSpeed is used."""
685
+ if os .environ .get (CONST_ENV_DEEPSPEED ):
686
+ return True
687
+ return False
688
+
689
+ def set_env_var (self ):
690
+ """Set default environment variables."""
691
+ defaults = {
692
+ "WORLD_SIZE" : self .node_count ,
693
+ "MASTER_ADDR" : self .host_ip ,
694
+ "MASTER_PORT" : self .RDZV_PORT ,
695
+ }
696
+ for k , v in defaults .items ():
697
+ if k not in os .environ :
698
+ os .environ [k ] = str (v )
699
+
700
+ def run (self ):
701
+ """Runs the user's command.
702
+ Note that for TorchRunner or DeepSpeedRunner,
703
+ we automatically add arguments for some settings,
704
+ like the number of nodes and the host node address.
705
+
706
+ This generic runner does not modify the command specified by the user.
707
+ User needs to make sure the command can work on all nodes.
708
+ User may use the environment variables in the command.
709
+ """
710
+ self .set_env_var ()
711
+ if self .use_deepspeed ():
712
+ if self .is_host :
713
+ self .run_deepspeed_host ()
714
+ else :
715
+ self .run_deepspeed_worker ()
716
+ else :
717
+ self .time_cmd (cmd = self .prepare_cmd (prefix = self .env_ld_preload ()))
718
+
719
+
676
720
class AccelerateRunner (TorchRunner , DeepSpeedRunner ):
721
+ """Runner for HuggingFace Accelerate."""
722
+
677
723
# accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
678
724
# https://github.com/huggingface/accelerate/blob/70920895e80f78d96d8f91e0beeb3ebdb8e5e5d6/src/accelerate/utils/launch.py#L233
679
725
DEFAULT_ARGS = [
@@ -704,11 +750,18 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
704
750
self .main_process_ip = None
705
751
706
752
def use_deepspeed (self ):
707
- return os .environ .get (CONST_ENV_DEEPSPEED ) or self .launch_cmd_contains (
753
+ """Indicate if DeepSpeed is used."""
754
+ # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
755
+ if os .environ .get (CONST_ENV_DEEPSPEED ) or self .launch_cmd_contains (
708
756
"use_deepspeed"
709
- )
757
+ ):
758
+ return True
759
+ return False
710
760
711
761
def accelerate_args (self ):
762
+ """Gets the default arguments for the accelerate command.
763
+ The value of the default arguments are assigned in ``__init__()``.
764
+ """
712
765
args = []
713
766
for arg in self .DEFAULT_ARGS :
714
767
arg_val = getattr (self , arg , None )
@@ -720,6 +773,7 @@ def accelerate_args(self):
720
773
return args
721
774
722
775
def run_with_torchrun (self ):
776
+ """Runs the job with torchrun."""
723
777
launch_args = self .accelerate_args ()
724
778
for arg in self .TORCHRUN_ARGS :
725
779
if not self .launch_cmd_contains (arg ):
@@ -728,6 +782,7 @@ def run_with_torchrun(self):
728
782
self .time_cmd (cmd = cmd )
729
783
730
784
def run_with_deepspeed (self ):
785
+ """Runs the job with DeepSpeed."""
731
786
if self .is_host :
732
787
launch_args = self .accelerate_args ()
733
788
if self .num_machines > 1 :
@@ -758,6 +813,8 @@ def main():
758
813
runner_class = DeepSpeedRunner
759
814
elif launch_cmd .startswith ("accelerate " ):
760
815
runner_class = AccelerateRunner
816
+ else :
817
+ runner_class = GenericRunner
761
818
762
819
runner = runner_class ()
763
820
runner : Runner
0 commit comments