@@ -673,7 +673,49 @@ def run(self):
673
673
self .run_deepspeed_worker ()
674
674
675
675
676
+ class GenericRunner (TorchRunner , DeepSpeedRunner ):
677
+ """Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``."""
678
+
679
+ def use_deepspeed (self ) -> bool :
680
+ """Indicate if DeepSpeed is used."""
681
+ if os .environ .get (CONST_ENV_DEEPSPEED ):
682
+ return True
683
+ return False
684
+
685
+ def set_env_var (self ):
686
+ """Set default environment variables."""
687
+ defaults = {
688
+ "WORLD_SIZE" : self .node_count ,
689
+ "MASTER_ADDR" : self .host_ip ,
690
+ "MASTER_PORT" : self .RDZV_PORT ,
691
+ }
692
+ for k , v in defaults .items ():
693
+ if k not in os .environ :
694
+ os .environ [k ] = v
695
+
696
+ def run (self ):
697
+ """Runs the user's command.
698
+ Note that for TorchRunner or DeepSpeedRunner,
699
+ we automatically add arguments for some settings,
700
+ like the number of nodes and the host node address.
701
+
702
+ This generic runner does not modify the command specified by the user.
703
+ User needs to make sure the command can work on all nodes.
704
+ User may use the environment variables in the command.
705
+ """
706
+ self .set_env_var ()
707
+ if self .use_deepspeed ():
708
+ if self .is_host :
709
+ self .run_deepspeed_host ()
710
+ else :
711
+ self .run_deepspeed_worker ()
712
+ else :
713
+ self .time_cmd (cmd = self .prepare_cmd (prefix = self .env_ld_preload ()))
714
+
715
+
676
716
class AccelerateRunner (TorchRunner , DeepSpeedRunner ):
717
+ """Runner for HuggingFace Accelerate."""
718
+
677
719
# accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
678
720
# https://github.com/huggingface/accelerate/blob/70920895e80f78d96d8f91e0beeb3ebdb8e5e5d6/src/accelerate/utils/launch.py#L233
679
721
DEFAULT_ARGS = [
@@ -704,11 +746,18 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
704
746
self .main_process_ip = None
705
747
706
748
def use_deepspeed (self ):
707
- return os .environ .get (CONST_ENV_DEEPSPEED ) or self .launch_cmd_contains (
749
+ """Indicate if DeepSpeed is used."""
750
+ # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
751
+ if os .environ .get (CONST_ENV_DEEPSPEED ) or self .launch_cmd_contains (
708
752
"use_deepspeed"
709
- )
753
+ ):
754
+ return True
755
+ return False
710
756
711
757
def accelerate_args (self ):
758
+ """Gets the default arguments for the accelerate command.
759
+ The value of the default arguments are assigned in ``__init__()``.
760
+ """
712
761
args = []
713
762
for arg in self .DEFAULT_ARGS :
714
763
arg_val = getattr (self , arg , None )
@@ -720,6 +769,7 @@ def accelerate_args(self):
720
769
return args
721
770
722
771
def run_with_torchrun (self ):
772
+ """Runs the job with torchrun."""
723
773
launch_args = self .accelerate_args ()
724
774
for arg in self .TORCHRUN_ARGS :
725
775
if not self .launch_cmd_contains (arg ):
@@ -728,6 +778,7 @@ def run_with_torchrun(self):
728
778
self .time_cmd (cmd = cmd )
729
779
730
780
def run_with_deepspeed (self ):
781
+ """Runs the job with DeepSpeed."""
731
782
if self .is_host :
732
783
launch_args = self .accelerate_args ()
733
784
if self .num_machines > 1 :
@@ -758,6 +809,8 @@ def main():
758
809
runner_class = DeepSpeedRunner
759
810
elif launch_cmd .startswith ("accelerate " ):
760
811
runner_class = AccelerateRunner
812
+ else :
813
+ runner_class = GenericRunner
761
814
762
815
runner = runner_class ()
763
816
runner : Runner
0 commit comments