Skip to content

Commit cdfab98

Browse files
committed
Merge commit 'bbd905400d4b32329aa54596b7993678c0663b56' from jobs/update
2 parents fe82ef1 + bbd9054 commit cdfab98

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

ads/jobs/builders/runtimes/pytorch_runtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def run(self, dsc_job, **kwargs):
205205
if not envs:
206206
envs = {}
207207
# Huggingface accelerate requires machine rank
208-
envs["RANK"] = str(i)
208+
# Here we use NODE_RANK to store the machine rank
209+
envs["NODE_RANK"] = str(i)
209210
envs["WORLD_SIZE"] = str(replicas)
210211
if main_run:
211212
envs["MAIN_JOB_RUN_OCID"] = main_run.id

ads/jobs/templates/driver_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
694694
# --multi_gpu will be set automatically if there is more than 1 GPU
695695
# self.multi_gpu = bool(self.node_count > 1 or self.gpu_count > 1)
696696
self.num_machines = self.node_count
697-
self.machine_rank = os.environ["RANK"]
697+
self.machine_rank = os.environ["NODE_RANK"]
698698
# Total number of processes across all nodes
699699
# Here we assume all nodes are having the same shape
700700
self.num_processes = (self.gpu_count if self.gpu_count else 1) * self.node_count

ads/jobs/templates/driver_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def copy_inputs(mappings: dict = None):
276276
return
277277

278278
for src, dest in mappings.items():
279-
logger.debug("Copying %s to %s", src, dest)
279+
logger.debug("Copying %s to %s", src, os.path.abspath(dest))
280280
# Create the dest dir if one does not exist.
281281
if str(dest).endswith("/"):
282282
dest_dir = dest
@@ -439,6 +439,10 @@ def install_pip_packages(self, packages: str = None):
439439
packages = os.environ.get(CONST_ENV_PIP_PKG)
440440
if not packages:
441441
return self
442+
# The package requirement may contain special character like '>'.
443+
# Here we wrap each package requirement with single quote to make sure they can be installed correctly
444+
package_list = shlex.split(packages)
445+
packages = " ".join([f"'{package}'" for package in package_list])
442446
self.run_command(
443447
f"pip install {packages}", conda_prefix=self.conda_prefix, check=True
444448
)

0 commit comments

Comments
 (0)