Skip to content

Commit fc3af70

Browse files
committed
Update prepare_cmd for generic runner.
1 parent e246154 commit fc3af70

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

ads/jobs/templates/driver_pytorch.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,15 @@ def prepare_cmd(self, launch_args: list = None, prefix=""):
341341
launch_args = []
342342
# Append launch cmd args specified by the user.
343343
if self.launch_cmd:
344-
if not self.launch_cmd.startswith(self.LAUNCHER):
345-
raise ValueError(
346-
f"Command not supported: '{self.launch_cmd}'. "
347-
f"The command should start with '{self.LAUNCHER}'."
348-
)
344+
if self.LAUNCHER:
345+
if not self.launch_cmd.startswith(self.LAUNCHER):
346+
raise ValueError(
347+
f"Command not supported: '{self.launch_cmd}'. "
348+
)
349349

350-
launch_args.append(self.launch_cmd[len(self.LAUNCHER) + 1 :])
350+
launch_args.append(self.launch_cmd[len(self.LAUNCHER) + 1 :])
351+
else:
352+
launch_args.append(self.launch_cmd)
351353
else:
352354
launch_args.append(self.get_cmd_with_entrypoint_and_args())
353355

@@ -676,6 +678,8 @@ def run(self):
676678
class GenericRunner(TorchRunner, DeepSpeedRunner):
677679
"""Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``."""
678680

681+
LAUNCHER = ""
682+
679683
def use_deepspeed(self) -> bool:
680684
"""Indicate if DeepSpeed is used."""
681685
if os.environ.get(CONST_ENV_DEEPSPEED):

0 commit comments

Comments
 (0)