Skip to content

Commit fa59003

Browse files
committed
Add GenericRunner to allow running arbitrary command with PyTorchDistributed runtime.
1 parent 7257dfb commit fa59003

File tree

1 file changed

+55
-2
lines changed

1 file changed

+55
-2
lines changed

ads/jobs/templates/driver_pytorch.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,49 @@ def run(self):
673673
self.run_deepspeed_worker()
674674

675675

676+
class GenericRunner(TorchRunner, DeepSpeedRunner):
677+
"""Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``."""
678+
679+
def use_deepspeed(self) -> bool:
680+
"""Indicate if DeepSpeed is used."""
681+
if os.environ.get(CONST_ENV_DEEPSPEED):
682+
return True
683+
return False
684+
685+
def set_env_var(self):
686+
"""Set default environment variables."""
687+
defaults = {
688+
"WORLD_SIZE": self.node_count,
689+
"MASTER_ADDR": self.host_ip,
690+
"MASTER_PORT": self.RDZV_PORT,
691+
}
692+
for k, v in defaults.items():
693+
if k not in os.environ:
694+
os.environ[k] = v
695+
696+
def run(self):
697+
"""Runs the user's command.
698+
Note that for TorchRunner or DeepSpeedRunner,
699+
we automatically add arguments for some settings,
700+
like the number of nodes and the host node address.
701+
702+
This generic runner does not modify the command specified by the user.
703+
User needs to make sure the command can work on all nodes.
704+
User may use the environment variables in the command.
705+
"""
706+
self.set_env_var()
707+
if self.use_deepspeed():
708+
if self.is_host:
709+
self.run_deepspeed_host()
710+
else:
711+
self.run_deepspeed_worker()
712+
else:
713+
self.time_cmd(cmd=self.prepare_cmd(prefix=self.env_ld_preload()))
714+
715+
676716
class AccelerateRunner(TorchRunner, DeepSpeedRunner):
717+
"""Runner for HuggingFace Accelerate."""
718+
677719
# accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
678720
# https://github.com/huggingface/accelerate/blob/70920895e80f78d96d8f91e0beeb3ebdb8e5e5d6/src/accelerate/utils/launch.py#L233
679721
DEFAULT_ARGS = [
@@ -704,11 +746,18 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
704746
self.main_process_ip = None
705747

706748
def use_deepspeed(self):
707-
return os.environ.get(CONST_ENV_DEEPSPEED) or self.launch_cmd_contains(
749+
"""Indicate if DeepSpeed is used."""
750+
# Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
751+
if os.environ.get(CONST_ENV_DEEPSPEED) or self.launch_cmd_contains(
708752
"use_deepspeed"
709-
)
753+
):
754+
return True
755+
return False
710756

711757
def accelerate_args(self):
758+
"""Gets the default arguments for the accelerate command.
759+
The value of the default arguments are assigned in ``__init__()``.
760+
"""
712761
args = []
713762
for arg in self.DEFAULT_ARGS:
714763
arg_val = getattr(self, arg, None)
@@ -720,6 +769,7 @@ def accelerate_args(self):
720769
return args
721770

722771
def run_with_torchrun(self):
772+
"""Runs the job with torchrun."""
723773
launch_args = self.accelerate_args()
724774
for arg in self.TORCHRUN_ARGS:
725775
if not self.launch_cmd_contains(arg):
@@ -728,6 +778,7 @@ def run_with_torchrun(self):
728778
self.time_cmd(cmd=cmd)
729779

730780
def run_with_deepspeed(self):
781+
"""Runs the job with DeepSpeed."""
731782
if self.is_host:
732783
launch_args = self.accelerate_args()
733784
if self.num_machines > 1:
@@ -758,6 +809,8 @@ def main():
758809
runner_class = DeepSpeedRunner
759810
elif launch_cmd.startswith("accelerate "):
760811
runner_class = AccelerateRunner
812+
else:
813+
runner_class = GenericRunner
761814

762815
runner = runner_class()
763816
runner: Runner

0 commit comments

Comments
 (0)