Skip to content

Commit 8df7d77

Browse files
authored
Add support to use multi-node job run (DTv2) API for distributed training (#1165)
2 parents 2e5cd96 + fecea3e commit 8df7d77

File tree

9 files changed

+941
-254
lines changed

9 files changed

+941
-254
lines changed

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 121 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

4-
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65
from __future__ import annotations
76

@@ -21,37 +20,41 @@
2120
import oci
2221
import oci.data_science
2322
import oci.util as oci_util
23+
import yaml
24+
from oci.data_science import models
2425
from oci.data_science.models import JobInfrastructureConfigurationDetails
2526
from oci.exceptions import ServiceError
26-
import yaml
27+
2728
from ads.common import utils
29+
from ads.common.decorator.utils import class_or_instance_method
30+
from ads.common.dsc_file_system import (
31+
DSCFileSystemManager,
32+
OCIFileStorage,
33+
OCIObjectStorage,
34+
)
2835
from ads.common.oci_datascience import DSCNotebookSession, OCIDataScienceMixin
2936
from ads.common.oci_logging import OCILog
3037
from ads.common.oci_resource import ResourceNotFoundError
3138
from ads.jobs.builders.infrastructure.base import Infrastructure, RunInstance
3239
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
40+
MULTI_NODE_JOB_SUPPORT,
3341
ContainerRuntimeHandler,
3442
DataScienceJobRuntimeManager,
3543
)
3644
from ads.jobs.builders.infrastructure.utils import get_value
3745
from ads.jobs.builders.runtimes.artifact import Artifact
46+
from ads.jobs.builders.runtimes.base import MultiNodeRuntime
3847
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
3948
from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime
4049

41-
from ads.common.dsc_file_system import (
42-
OCIFileStorage,
43-
DSCFileSystemManager,
44-
OCIObjectStorage,
45-
)
46-
from ads.common.decorator.utils import class_or_instance_method
47-
4850
logger = logging.getLogger(__name__)
4951

5052
SLEEP_INTERVAL = 3
5153
WAIT_SECONDS_AFTER_FINISHED = 90
5254
MAXIMUM_MOUNT_COUNT = 5
5355
FILE_STORAGE_TYPE = "FILE_STORAGE"
5456
OBJECT_STORAGE_TYPE = "OBJECT_STORAGE"
57+
DEFAULT_NODE_GROUP_NAME = "node-group"
5558

5659

5760
class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
@@ -284,11 +287,15 @@ def load_properties_from_env(self) -> None:
284287

285288
def load_defaults(self) -> DSCJob:
286289
self.load_properties_from_env()
290+
if getattr(self, "job_node_configuration_details", None):
291+
return self
292+
# Following are for single node job run only
287293
if not self.job_infrastructure_configuration_details:
288294
self.job_infrastructure_configuration_details = {}
295+
289296
# Convert the dict to JobInfrastructureConfigurationDetails object
290297
if isinstance(self.job_infrastructure_configuration_details, dict):
291-
# Default networking
298+
292299
if not self.job_infrastructure_configuration_details.get(
293300
"jobInfrastructureType"
294301
):
@@ -352,6 +359,7 @@ def create(self) -> DSCJob:
352359
raise ValueError("Specify compartment ID for data science job.")
353360
if not self.project_id:
354361
raise ValueError("Specify project ID for data science job.")
362+
355363
self._create_with_oci_api()
356364
return self
357365

@@ -498,7 +506,9 @@ def run(self, **kwargs) -> DataScienceJobRun:
498506
keys = list(kwargs.keys())
499507
for key in keys:
500508
if key in config_swagger_types:
501-
config_kwargs[key] = kwargs.pop(key)
509+
val = kwargs.pop(key)
510+
if val is not None:
511+
config_kwargs[key] = val
502512
elif key in env_config_swagger_types:
503513
value = kwargs.pop(key)
504514
if key in [
@@ -545,6 +555,25 @@ def run(self, **kwargs) -> DataScienceJobRun:
545555
env_config_override
546556
)
547557

558+
if getattr(self, "job_node_configuration_details", None):
559+
job_config_override = kwargs.pop("job_configuration_override_details", None)
560+
env_config_override = kwargs.pop(
561+
"job_environment_configuration_override_details", None
562+
)
563+
if job_config_override or env_config_override:
564+
node_config = {
565+
"jobNodeType": "MULTI_NODE",
566+
"jobNodeGroupConfigurationDetailsList": [
567+
{
568+
# Node group name must match the node group name in the job.
569+
"name": DEFAULT_NODE_GROUP_NAME,
570+
"JobConfigurationDetails": job_config_override,
571+
"JobEnvironmentConfigurationDetails": env_config_override,
572+
}
573+
],
574+
}
575+
kwargs["job_node_configuration_override_details"] = node_config
576+
548577
wait = kwargs.pop("wait", False)
549578
run = DataScienceJobRun(**kwargs, **self.auth).create()
550579
if wait:
@@ -756,13 +785,11 @@ def stop_condition():
756785
return True
757786
# Stop only if time_finished is over 2 minute ago.
758787
# This is for the time delay between job run stopped and the logs appear in oci logging.
759-
if (
788+
return (
760789
datetime.datetime.now(self.time_finished.tzinfo)
761790
- datetime.timedelta(seconds=wait)
762791
> self.time_finished
763-
):
764-
return True
765-
return False
792+
)
766793

767794
if not self.log_id and not self.log_group_id:
768795
print(
@@ -1471,6 +1498,23 @@ def _update_from_dsc_model(
14711498
}
14721499
self.dsc_job = dsc_job
14731500

1501+
# Process multi-node infrastructure config
1502+
node_groups = get_value(
1503+
dsc_job,
1504+
"job_node_configuration_details.job_node_group_configuration_details_list",
1505+
)
1506+
if node_groups and len(node_groups) == 1:
1507+
node_group = node_groups[0]
1508+
dsc_job.job_infrastructure_configuration_details = (
1509+
node_group.job_infrastructure_configuration_details
1510+
)
1511+
subnet_id = get_value(
1512+
dsc_job,
1513+
"job_node_configuration_details.job_network_configuration.subnet_id",
1514+
)
1515+
if subnet_id:
1516+
self.set_spec(self.CONST_SUBNET_ID, subnet_id)
1517+
14741518
for infra_attr, dsc_attr in self.payload_attribute_map.items():
14751519
value = get_value(dsc_job, dsc_attr)
14761520
if not value:
@@ -1557,10 +1601,13 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
15571601
if value:
15581602
dsc_job.job_infrastructure_configuration_details[camel_attr] = value
15591603

1560-
if not dsc_job.job_infrastructure_configuration_details.get(
1561-
"shapeName", ""
1562-
).endswith("Flex") and dsc_job.job_infrastructure_configuration_details.get(
1563-
"jobShapeConfigDetails"
1604+
shape = dsc_job.job_infrastructure_configuration_details.get("shapeName", "")
1605+
if (
1606+
shape
1607+
and not str(shape).endswith("Flex")
1608+
and dsc_job.job_infrastructure_configuration_details.get(
1609+
"jobShapeConfigDetails"
1610+
)
15641611
):
15651612
raise ValueError(
15661613
"Shape config is not required for non flex shape from user end."
@@ -1583,7 +1630,6 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
15831630
return self
15841631

15851632
def build(self) -> DataScienceJob:
1586-
self.dsc_job.load_defaults()
15871633

15881634
try:
15891635
self.dsc_job.load_defaults()
@@ -1611,6 +1657,48 @@ def init(self, **kwargs) -> DataScienceJob:
16111657
)
16121658
)
16131659

1660+
def _config_multi_node(self, runtime: MultiNodeRuntime):
1661+
"""Configure the payload for multi-node job run."""
1662+
infra_config: dict = self.dsc_job.job_infrastructure_configuration_details
1663+
job_config: models.DefaultJobConfigurationDetails = (
1664+
self.dsc_job.job_configuration_details
1665+
)
1666+
env_config = self.dsc_job.job_environment_configuration_details
1667+
# For multi-node jobs,
1668+
# the job_infrastructure_configuration_details and job_configuration_details
1669+
# should be the special EMPTY class.
1670+
# The job_environment_configuration_details should be None.
1671+
# The configs will be specified in each node group.
1672+
self.dsc_job.job_infrastructure_configuration_details = None
1673+
self.dsc_job.job_configuration_details = None
1674+
self.dsc_job.job_environment_configuration_details = None
1675+
1676+
subnet_id = infra_config.pop("subnetId", None)
1677+
infra_config["jobInfrastructureType"] = (
1678+
models.MultiNodeJobInfrastructureConfigurationDetails.JOB_INFRASTRUCTURE_TYPE_MULTI_NODE
1679+
)
1680+
1681+
if subnet_id:
1682+
network_config = models.JobCustomNetworkConfiguration(subnet_id=subnet_id)
1683+
else:
1684+
network_config = models.JobDefaultNetworkConfiguration()
1685+
1686+
node_group_config: dict = {
1687+
"name": DEFAULT_NODE_GROUP_NAME,
1688+
"replicas": runtime.replica,
1689+
"minimumSuccessReplicas": runtime.replica,
1690+
"jobInfrastructureConfigurationDetails": infra_config,
1691+
"jobConfigurationDetails": job_config,
1692+
"jobEnvironmentConfigurationDetails": env_config,
1693+
}
1694+
1695+
self.dsc_job.job_node_configuration_details = {
1696+
"jobNodeType": "MULTI_NODE",
1697+
"startupOrder": "IN_PARALLEL",
1698+
"jobNetworkConfiguration": network_config,
1699+
"jobNodeGroupConfigurationDetailsList": [node_group_config],
1700+
}
1701+
16141702
def create(self, runtime, **kwargs) -> DataScienceJob:
16151703
"""Creates a job with runtime.
16161704
@@ -1635,9 +1723,7 @@ def create(self, runtime, **kwargs) -> DataScienceJob:
16351723

16361724
if self.name:
16371725
display_name = Template(self.name).safe_substitute(runtime.envs)
1638-
elif isinstance(runtime, GitPythonRuntime) or isinstance(
1639-
runtime, ContainerRuntime
1640-
):
1726+
elif isinstance(runtime, (GitPythonRuntime, ContainerRuntime)):
16411727
display_name = utils.get_random_name_for_resource()
16421728
else:
16431729
display_name = None
@@ -1652,11 +1738,22 @@ def create(self, runtime, **kwargs) -> DataScienceJob:
16521738
self.dsc_job = DSCJob(**payload, **self.auth)
16531739
# Set Job infra to user values after DSCJob initialized the defaults
16541740
self._update_job_infra(self.dsc_job)
1741+
if self.is_multi_node_job(runtime):
1742+
self._config_multi_node(runtime=runtime)
16551743
self.dsc_job.create()
16561744
# Update the model from infra after job creation.
16571745
self._update_from_dsc_model(self.dsc_job)
16581746
return self
16591747

1748+
@staticmethod
1749+
def is_multi_node_job(runtime):
1750+
"""Check if the job is multi-node job."""
1751+
return (
1752+
MULTI_NODE_JOB_SUPPORT
1753+
and isinstance(runtime, MultiNodeRuntime)
1754+
and runtime.replica > 1
1755+
)
1756+
16601757
def run(
16611758
self,
16621759
name=None,

0 commit comments

Comments
 (0)