Skip to content

Commit 9968a51

Browse files
authored
Support Running Any Command in PyTorchDistributed Runtime (#513)
2 parents e0a028b + fc3af70 commit 9968a51

File tree

2 files changed

+74
-10
lines changed

2 files changed

+74
-10
lines changed

ads/jobs/templates/driver_pytorch.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
"""This module requires oracle-ads>=2.6.8
77
"""
@@ -341,13 +341,15 @@ def prepare_cmd(self, launch_args: list = None, prefix=""):
341341
launch_args = []
342342
# Append launch cmd args specified by the user.
343343
if self.launch_cmd:
344-
if not self.launch_cmd.startswith(self.LAUNCHER):
345-
raise ValueError(
346-
f"Command not supported: '{self.launch_cmd}'. "
347-
f"The command should start with '{self.LAUNCHER}'."
348-
)
344+
if self.LAUNCHER:
345+
if not self.launch_cmd.startswith(self.LAUNCHER):
346+
raise ValueError(
347+
f"Command not supported: '{self.launch_cmd}'. "
348+
)
349349

350-
launch_args.append(self.launch_cmd[len(self.LAUNCHER) + 1 :])
350+
launch_args.append(self.launch_cmd[len(self.LAUNCHER) + 1 :])
351+
else:
352+
launch_args.append(self.launch_cmd)
351353
else:
352354
launch_args.append(self.get_cmd_with_entrypoint_and_args())
353355

@@ -673,7 +675,51 @@ def run(self):
673675
self.run_deepspeed_worker()
674676

675677

678+
class GenericRunner(TorchRunner, DeepSpeedRunner):
679+
"""Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``."""
680+
681+
LAUNCHER = ""
682+
683+
def use_deepspeed(self) -> bool:
684+
"""Indicate if DeepSpeed is used."""
685+
if os.environ.get(CONST_ENV_DEEPSPEED):
686+
return True
687+
return False
688+
689+
def set_env_var(self):
690+
"""Set default environment variables."""
691+
defaults = {
692+
"WORLD_SIZE": self.node_count,
693+
"MASTER_ADDR": self.host_ip,
694+
"MASTER_PORT": self.RDZV_PORT,
695+
}
696+
for k, v in defaults.items():
697+
if k not in os.environ:
698+
os.environ[k] = str(v)
699+
700+
def run(self):
701+
"""Runs the user's command.
702+
Note that for TorchRunner or DeepSpeedRunner,
703+
we automatically add arguments for some settings,
704+
like the number of nodes and the host node address.
705+
706+
This generic runner does not modify the command specified by the user.
707+
User needs to make sure the command can work on all nodes.
708+
User may use the environment variables in the command.
709+
"""
710+
self.set_env_var()
711+
if self.use_deepspeed():
712+
if self.is_host:
713+
self.run_deepspeed_host()
714+
else:
715+
self.run_deepspeed_worker()
716+
else:
717+
self.time_cmd(cmd=self.prepare_cmd(prefix=self.env_ld_preload()))
718+
719+
676720
class AccelerateRunner(TorchRunner, DeepSpeedRunner):
721+
"""Runner for HuggingFace Accelerate."""
722+
677723
# accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
678724
# https://github.com/huggingface/accelerate/blob/70920895e80f78d96d8f91e0beeb3ebdb8e5e5d6/src/accelerate/utils/launch.py#L233
679725
DEFAULT_ARGS = [
@@ -704,11 +750,18 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
704750
self.main_process_ip = None
705751

706752
def use_deepspeed(self):
707-
return os.environ.get(CONST_ENV_DEEPSPEED) or self.launch_cmd_contains(
753+
"""Indicate if DeepSpeed is used."""
754+
# Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
755+
if os.environ.get(CONST_ENV_DEEPSPEED) or self.launch_cmd_contains(
708756
"use_deepspeed"
709-
)
757+
):
758+
return True
759+
return False
710760

711761
def accelerate_args(self):
762+
"""Gets the default arguments for the accelerate command.
763+
The value of the default arguments are assigned in ``__init__()``.
764+
"""
712765
args = []
713766
for arg in self.DEFAULT_ARGS:
714767
arg_val = getattr(self, arg, None)
@@ -720,6 +773,7 @@ def accelerate_args(self):
720773
return args
721774

722775
def run_with_torchrun(self):
776+
"""Runs the job with torchrun."""
723777
launch_args = self.accelerate_args()
724778
for arg in self.TORCHRUN_ARGS:
725779
if not self.launch_cmd_contains(arg):
@@ -728,6 +782,7 @@ def run_with_torchrun(self):
728782
self.time_cmd(cmd=cmd)
729783

730784
def run_with_deepspeed(self):
785+
"""Runs the job with DeepSpeed."""
731786
if self.is_host:
732787
launch_args = self.accelerate_args()
733788
if self.num_machines > 1:
@@ -758,6 +813,8 @@ def main():
758813
runner_class = DeepSpeedRunner
759814
elif launch_cmd.startswith("accelerate "):
760815
runner_class = AccelerateRunner
816+
else:
817+
runner_class = GenericRunner
761818

762819
runner = runner_class()
763820
runner: Runner

ads/jobs/templates/driver_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
import contextlib
77
import importlib
@@ -397,6 +397,7 @@ def run_command(
397397
shell=True,
398398
)
399399
# Stream the outputs
400+
logger.debug("Streaming command output from subprocess %s", process.pid)
400401
while True:
401402
output = process.stdout.readline()
402403
if process.poll() is not None and output == b"":
@@ -411,9 +412,15 @@ def run_command(
411412
# logging will add line break
412413
msg = msg.rstrip("\n")
413414
logger.log(level=level, msg=msg)
415+
if "pdsh@" in msg and "ssh exited with exit code 1" in msg:
416+
print("DeepSpeed Failed.")
417+
sys.exit(1)
414418
# Add a small delay so that
415419
# outputs from the subsequent code will have different timestamp for oci logging
416420
time.sleep(0.02)
421+
logger.debug(
422+
"subprocess %s returned exit code %s", process.pid, process.returncode
423+
)
417424
if check and process.returncode != 0:
418425
# If there is an error, exit the main process with the same return code.
419426
sys.exit(process.returncode)

0 commit comments

Comments
 (0)