Skip to content

Commit 28474dd

Browse files
authored
Adding PyTorchDistributedRuntime to support distributed (multi-node) training (#209)
2 parents d6b94b2 + a46bc76 commit 28474dd

File tree

20 files changed

+1969
-37
lines changed

20 files changed

+1969
-37
lines changed

ads/common/oci_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def _parse_kwargs(attribute_map: dict, **kwargs):
230230

231231
return parsed_kwargs
232232

233-
@classmethod
233+
@class_or_instance_method
234234
def deserialize(cls, data, to_cls):
235235
"""De-serialize data from dictionary to an OCI model"""
236236
if cls.type_mappings is None:
@@ -549,7 +549,7 @@ def from_dict(cls, data):
549549
"""
550550
return cls.create_instance(**data)
551551

552-
@classmethod
552+
@class_or_instance_method
553553
def deserialize(cls, data: dict, to_cls: str = None):
554554
"""Deserialize data
555555

ads/common/serializer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import yaml
1616

1717
from ads.common import logger
18+
from ads.common.auth import default_signer
1819

1920
try:
2021
from yaml import CSafeDumper as dumper
@@ -134,6 +135,14 @@ def _read_from_file(uri: str, **kwargs) -> str:
134135
-------
135136
string: Contents in file specified by URI
136137
"""
138+
# Add default signer if the uri is an object storage uri, and
139+
# the user does not specify config or signer.
140+
if (
141+
uri.startswith("oci://")
142+
and "config" not in kwargs
143+
and "signer" not in kwargs
144+
):
145+
kwargs.update(default_signer())
137146
with fsspec.open(uri, "r", **kwargs) as f:
138147
return f.read()
139148

ads/jobs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77

@@ -17,6 +17,7 @@
1717
DataFlowRuntime,
1818
DataFlowNotebookRuntime,
1919
)
20+
from ads.jobs.builders.runtimes.pytorch_runtime import PyTorchDistributedRuntime
2021
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
2122
from ads.jobs.ads_job import Job
2223
from ads.jobs.builders import infrastructure
@@ -44,6 +45,7 @@
4445
"NotebookRuntime",
4546
"ScriptRuntime",
4647
"ContainerRuntime",
48+
"PyTorchDistributedRuntime",
4749
"DataFlow",
4850
"DataFlowRun",
4951
"DataFlowRuntime",

ads/jobs/ads_job.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ads.jobs.builders.base import Builder
1212
from ads.jobs.builders.infrastructure.dataflow import DataFlow, DataFlowRun
1313
from ads.jobs.builders.infrastructure.dsc_job import DataScienceJob, DataScienceJobRun
14+
from ads.jobs.builders.runtimes.pytorch_runtime import PyTorchDistributedRuntime
1415
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
1516
from ads.jobs.builders.runtimes.python_runtime import (
1617
DataFlowRuntime,
@@ -131,6 +132,7 @@ class Job(Builder):
131132
ContainerRuntime,
132133
ScriptRuntime,
133134
NotebookRuntime,
135+
PyTorchDistributedRuntime,
134136
DataFlowRuntime,
135137
]
136138
}

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
3636
from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime
3737

38-
from ads.common.dsc_file_system import OCIFileStorage, DSCFileSystemManager, OCIObjectStorage
38+
from ads.common.dsc_file_system import (
39+
OCIFileStorage,
40+
DSCFileSystemManager,
41+
OCIObjectStorage,
42+
)
3943

4044
logger = logging.getLogger(__name__)
4145

@@ -1454,11 +1458,14 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
14541458
if value:
14551459
dsc_job.job_infrastructure_configuration_details[camel_attr] = value
14561460

1457-
if (
1458-
not dsc_job.job_infrastructure_configuration_details.get("shapeName", "").endswith("Flex")
1459-
and dsc_job.job_infrastructure_configuration_details.get("jobShapeConfigDetails")
1461+
if not dsc_job.job_infrastructure_configuration_details.get(
1462+
"shapeName", ""
1463+
).endswith("Flex") and dsc_job.job_infrastructure_configuration_details.get(
1464+
"jobShapeConfigDetails"
14601465
):
1461-
raise ValueError("Shape config is not required for non flex shape from user end.")
1466+
raise ValueError(
1467+
"Shape config is not required for non flex shape from user end."
1468+
)
14621469

14631470
if dsc_job.job_infrastructure_configuration_details.get("subnetId"):
14641471
dsc_job.job_infrastructure_configuration_details[
@@ -1495,7 +1502,10 @@ def init(self) -> DataScienceJob:
14951502
self.build()
14961503
.with_compartment_id(self.compartment_id or "{Provide a compartment OCID}")
14971504
.with_project_id(self.project_id or "{Provide a project OCID}")
1498-
.with_subnet_id(self.subnet_id or "{Provide a subnet OCID or remove this field if you use a default networking}")
1505+
.with_subnet_id(
1506+
self.subnet_id
1507+
or "{Provide a subnet OCID or remove this field if you use a default networking}"
1508+
)
14991509
)
15001510

15011511
def create(self, runtime, **kwargs) -> DataScienceJob:
@@ -1552,7 +1562,7 @@ def run(
15521562
freeform_tags=None,
15531563
defined_tags=None,
15541564
wait=False,
1555-
**kwargs
1565+
**kwargs,
15561566
) -> DataScienceJobRun:
15571567
"""Runs a job on OCI Data Science job
15581568
@@ -1603,15 +1613,21 @@ def run(
16031613
envs.update(env_var)
16041614
name = Template(name).safe_substitute(envs)
16051615

1606-
return self.dsc_job.run(
1616+
kwargs = dict(
16071617
display_name=name,
16081618
command_line_arguments=args,
16091619
environment_variables=env_var,
16101620
freeform_tags=freeform_tags,
16111621
defined_tags=defined_tags,
16121622
wait=wait,
1613-
**kwargs
1623+
**kwargs,
16141624
)
1625+
# A Runtime class may define customized run() method.
1626+
# Use the customized method if the run() method is defined by the runtime.
1627+
# Otherwise, use the default run() method defined in this class.
1628+
if hasattr(self.runtime, "run"):
1629+
return self.runtime.run(self.dsc_job, **kwargs)
1630+
return self.dsc_job.run(**kwargs)
16151631

16161632
def delete(self) -> None:
16171633
"""Deletes a job"""

ads/jobs/builders/infrastructure/dsc_job_runtime.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,19 @@
2929
GitPythonRuntime,
3030
)
3131
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
32+
from ads.jobs.builders.runtimes.pytorch_runtime import (
33+
PyTorchDistributedRuntime,
34+
PyTorchDistributedArtifact,
35+
)
3236
from ads.jobs.builders.runtimes.artifact import (
3337
ScriptArtifact,
3438
NotebookArtifact,
3539
PythonArtifact,
3640
GitPythonArtifact,
3741
)
42+
from ads.opctl.distributed.common import cluster_config_helper
3843
from ads.jobs.builders.infrastructure.utils import get_value
44+
from ads.jobs.templates import driver_utils
3945

4046

4147
class IncompatibleRuntime(Exception): # pragma: no cover
@@ -184,7 +190,7 @@ def _translate_config(self, runtime: Runtime) -> dict:
184190
if runtime.args:
185191
# shlex.join() is not available until python 3.8
186192
job_configuration_details["command_line_arguments"] = " ".join(
187-
shlex.quote(arg) for arg in runtime.get_spec(runtime.CONST_ARGS)
193+
shlex.quote(str(arg)) for arg in runtime.get_spec(runtime.CONST_ARGS)
188194
)
189195
return job_configuration_details
190196

@@ -653,7 +659,7 @@ def _translate_env(self, runtime: PythonRuntime) -> dict:
653659

654660
if runtime.entrypoint:
655661
envs[self.CONST_CODE_ENTRYPOINT] = runtime.entrypoint
656-
else:
662+
elif runtime.script_uri:
657663
envs[self.CONST_CODE_ENTRYPOINT] = os.path.basename(runtime.script_uri)
658664

659665
envs[self.CONST_JOB_ENTRYPOINT] = PythonArtifact.CONST_DRIVER_SCRIPT
@@ -674,9 +680,13 @@ def _extract_envs(self, dsc_job) -> dict:
674680
"""
675681
spec = super()._extract_envs(dsc_job)
676682
envs = spec.pop(PythonRuntime.CONST_ENV_VAR, {})
677-
if self.CONST_CODE_ENTRYPOINT not in envs:
683+
if (
684+
self.__class__ == PythonRuntimeHandler
685+
and self.CONST_CODE_ENTRYPOINT not in envs
686+
):
678687
raise IncompatibleRuntime()
679-
envs.pop(PythonRuntimeHandler.CONST_JOB_ENTRYPOINT)
688+
# PyTorchDistributedRuntime does not require entrypoint.
689+
envs.pop(PythonRuntimeHandler.CONST_JOB_ENTRYPOINT, None)
680690
spec.update(self._extract_specs(envs, self.SPEC_MAPPINGS))
681691
if PythonRuntime.CONST_PYTHON_PATH in spec:
682692
spec[PythonRuntime.CONST_PYTHON_PATH] = spec[
@@ -1035,6 +1045,98 @@ def _extract_envs(self, dsc_job):
10351045
return spec
10361046

10371047

1048+
class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
1049+
RUNTIME_CLASS = PyTorchDistributedRuntime
1050+
CONST_WORKER_COUNT = "OCI__WORKER_COUNT"
1051+
CONST_COMMAND = "OCI__LAUNCH_CMD"
1052+
CONST_DEEPSPEED = "OCI__DEEPSPEED"
1053+
1054+
GIT_SPEC_MAPPINGS = {
1055+
cluster_config_helper.OCI__RUNTIME_URI: GitPythonRuntime.CONST_GIT_URL,
1056+
cluster_config_helper.OCI__RUNTIME_GIT_BRANCH: GitPythonRuntime.CONST_BRANCH,
1057+
cluster_config_helper.OCI__RUNTIME_GIT_COMMIT: GitPythonRuntime.CONST_COMMIT,
1058+
cluster_config_helper.OCI__RUNTIME_GIT_SECRET_ID: GitPythonRuntime.CONST_GIT_SSH_SECRET_ID,
1059+
}
1060+
1061+
SPEC_MAPPINGS = PythonRuntimeHandler.SPEC_MAPPINGS
1062+
SPEC_MAPPINGS.update(
1063+
{
1064+
PyTorchDistributedRuntime.CONST_COMMAND: CONST_COMMAND,
1065+
}
1066+
)
1067+
1068+
def _translate_artifact(self, runtime: PyTorchDistributedRuntime):
1069+
return PyTorchDistributedArtifact(runtime.source_uri, runtime)
1070+
1071+
def _translate_env(self, runtime: PyTorchDistributedRuntime) -> dict:
1072+
envs = super()._translate_env(runtime)
1073+
replica = runtime.replica if runtime.replica else 1
1074+
# WORKER_COUNT = REPLICA - 1 so that it will be same as distributed training
1075+
envs[self.CONST_WORKER_COUNT] = str(replica - 1)
1076+
envs[self.CONST_JOB_ENTRYPOINT] = PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT
1077+
if runtime.inputs:
1078+
envs[driver_utils.CONST_ENV_INPUT_MAPPINGS] = json.dumps(runtime.inputs)
1079+
if runtime.git:
1080+
for env_key, spec_key in self.GIT_SPEC_MAPPINGS.items():
1081+
if not runtime.git.get(spec_key):
1082+
continue
1083+
envs[env_key] = runtime.git[spec_key]
1084+
if runtime.dependencies:
1085+
if PyTorchDistributedRuntime.CONST_PIP_PKG in runtime.dependencies:
1086+
envs[driver_utils.CONST_ENV_PIP_PKG] = runtime.dependencies[
1087+
PyTorchDistributedRuntime.CONST_PIP_PKG
1088+
]
1089+
if PyTorchDistributedRuntime.CONST_PIP_REQ in runtime.dependencies:
1090+
envs[driver_utils.CONST_ENV_PIP_REQ] = runtime.dependencies[
1091+
PyTorchDistributedRuntime.CONST_PIP_REQ
1092+
]
1093+
if runtime.use_deepspeed:
1094+
envs[self.CONST_DEEPSPEED] = "1"
1095+
return envs
1096+
1097+
def _extract_envs(self, dsc_job) -> dict:
1098+
spec = super()._extract_envs(dsc_job)
1099+
envs = spec.pop(PythonRuntime.CONST_ENV_VAR, {})
1100+
if self.CONST_WORKER_COUNT not in envs:
1101+
raise IncompatibleRuntime()
1102+
# Replicas
1103+
spec[PyTorchDistributedRuntime.CONST_REPLICA] = (
1104+
int(envs.pop(self.CONST_WORKER_COUNT)) + 1
1105+
)
1106+
# Git
1107+
if cluster_config_helper.OCI__RUNTIME_URI in envs:
1108+
git_spec = {}
1109+
for env_key, spec_key in self.GIT_SPEC_MAPPINGS.items():
1110+
if env_key in envs:
1111+
git_spec[spec_key] = envs.pop(env_key)
1112+
spec[PyTorchDistributedRuntime.CONST_GIT] = git_spec
1113+
# Inputs
1114+
input_mappings = envs.pop(driver_utils.CONST_ENV_INPUT_MAPPINGS, None)
1115+
if input_mappings:
1116+
try:
1117+
spec[PyTorchDistributedRuntime.CONST_INPUT] = json.loads(input_mappings)
1118+
except ValueError:
1119+
spec[PyTorchDistributedRuntime.CONST_INPUT] = input_mappings
1120+
# Dependencies
1121+
dep = {}
1122+
if driver_utils.CONST_ENV_PIP_PKG in envs:
1123+
dep[PyTorchDistributedRuntime.CONST_PIP_PKG] = envs.pop(
1124+
driver_utils.CONST_ENV_PIP_PKG
1125+
)
1126+
if driver_utils.CONST_ENV_PIP_REQ in envs:
1127+
dep[PyTorchDistributedRuntime.CONST_PIP_REQ] = envs.pop(
1128+
driver_utils.CONST_ENV_PIP_REQ
1129+
)
1130+
if dep:
1131+
spec[PyTorchDistributedRuntime.CONST_DEP] = dep
1132+
if envs.pop(self.CONST_DEEPSPEED, None):
1133+
spec[PyTorchDistributedRuntime.CONST_DEEPSPEED] = True
1134+
# Envs
1135+
if envs:
1136+
spec[PythonRuntime.CONST_ENV_VAR] = envs
1137+
return spec
1138+
1139+
10381140
class DataScienceJobRuntimeManager(RuntimeHandler):
10391141
"""This class is used by the DataScienceJob infrastructure to handle the runtime conversion.
10401142
The translate() method determines the actual runtime handler by matching the RUNTIME_CLASS.
@@ -1046,6 +1148,7 @@ class DataScienceJobRuntimeManager(RuntimeHandler):
10461148

10471149
runtime_handlers = [
10481150
ContainerRuntimeHandler,
1151+
PyTorchDistributedRuntimeHandler,
10491152
GitPythonRuntimeHandler,
10501153
NotebookRuntimeHandler,
10511154
PythonRuntimeHandler,

ads/jobs/builders/runtimes/artifact.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import os
99
import shutil
1010
import tempfile
11-
import zipfile
1211
from io import DEFAULT_BUFFER_SIZE
1312
from urllib import request
1413
from urllib.parse import urlparse
@@ -41,7 +40,7 @@ class Artifact:
4140

4241
def __init__(self, source, runtime=None) -> None:
4342
# Get the full path of source file if it is local file.
44-
if not urlparse(source).scheme:
43+
if source and not urlparse(source).scheme:
4544
self.source = os.path.abspath(os.path.expanduser(source))
4645
else:
4746
self.source = source
@@ -203,6 +202,7 @@ class PythonArtifact(Artifact):
203202
"""Represents a PythonRuntime job artifact"""
204203

205204
CONST_DRIVER_SCRIPT = "driver_python.py"
205+
DEFAULT_BASENAME = "artifact"
206206
# The directory to store user code
207207
# This directory must match the USER_CODE_DIR in driver_python.py
208208
USER_CODE_DIR = "code"
@@ -217,7 +217,12 @@ def _copy_artifacts(self, drivers=None):
217217
"""Copies the drivers and artifacts to the temp artifact dir."""
218218
# The basename of the job artifact,
219219
# this will be the name of the zip file uploading to OCI
220-
self.basename = os.path.basename(str(self.source).rstrip("/")).split(".", 1)[0]
220+
if self.source:
221+
self.basename = os.path.basename(str(self.source).rstrip("/")).split(
222+
".", 1
223+
)[0]
224+
else:
225+
self.basename = self.DEFAULT_BASENAME
221226
# The temp dir path for storing the artifacts, including drivers and user code
222227
self.artifact_dir = os.path.join(self.temp_dir.name, self.basename)
223228
# The temp dir path for storing the user code
@@ -236,8 +241,9 @@ def _copy_artifacts(self, drivers=None):
236241
shutil.copy(file_path, os.path.join(self.artifact_dir, filename))
237242

238243
# Copy user code
239-
os.makedirs(self.code_dir, exist_ok=True)
240-
Artifact.copy_from_uri(self.source, self.code_dir, unpack=True)
244+
if self.source:
245+
os.makedirs(self.code_dir, exist_ok=True)
246+
Artifact.copy_from_uri(self.source, self.code_dir, unpack=True)
241247

242248
def _zip_artifacts(self):
243249
"""Create a zip file from the temp artifact dir."""
@@ -312,7 +318,6 @@ def build(self):
312318

313319

314320
class GitPythonArtifact(Artifact):
315-
316321
CONST_DRIVER_SCRIPT = "driver_oci.py"
317322

318323
def __init__(self) -> None:

ads/jobs/builders/runtimes/python_runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class ScriptRuntime(CondaRuntime):
149149
# Environment variable
150150
.with_environment_variable(NAME="Welcome to OCI Data Science.")
151151
# Command line argument
152-
.with_argument("100 linux \"hi there\"")
152+
.with_argument("100 linux 'hi there'")
153153
# The entrypoint is applicable only to directory or zip file as source
154154
# The entrypoint should be a path relative to the working dir.
155155
# Here my_script.sh is a file in the code_dir/my_package directory

0 commit comments

Comments
 (0)