1
1
#!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
4
- # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2025 Oracle and/or its affiliates.
5
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
from __future__ import annotations
7
6
21
20
import oci
22
21
import oci .data_science
23
22
import oci .util as oci_util
23
+ import yaml
24
+ from oci .data_science import models
24
25
from oci .data_science .models import JobInfrastructureConfigurationDetails
25
26
from oci .exceptions import ServiceError
26
- import yaml
27
+
27
28
from ads .common import utils
29
+ from ads .common .decorator .utils import class_or_instance_method
30
+ from ads .common .dsc_file_system import (
31
+ DSCFileSystemManager ,
32
+ OCIFileStorage ,
33
+ OCIObjectStorage ,
34
+ )
28
35
from ads .common .oci_datascience import DSCNotebookSession , OCIDataScienceMixin
29
36
from ads .common .oci_logging import OCILog
30
37
from ads .common .oci_resource import ResourceNotFoundError
31
38
from ads .jobs .builders .infrastructure .base import Infrastructure , RunInstance
32
39
from ads .jobs .builders .infrastructure .dsc_job_runtime import (
40
+ MULTI_NODE_JOB_SUPPORT ,
33
41
ContainerRuntimeHandler ,
34
42
DataScienceJobRuntimeManager ,
35
43
)
36
44
from ads .jobs .builders .infrastructure .utils import get_value
37
45
from ads .jobs .builders .runtimes .artifact import Artifact
46
+ from ads .jobs .builders .runtimes .base import MultiNodeRuntime
38
47
from ads .jobs .builders .runtimes .container_runtime import ContainerRuntime
39
48
from ads .jobs .builders .runtimes .python_runtime import GitPythonRuntime
40
49
41
- from ads .common .dsc_file_system import (
42
- OCIFileStorage ,
43
- DSCFileSystemManager ,
44
- OCIObjectStorage ,
45
- )
46
- from ads .common .decorator .utils import class_or_instance_method
47
-
48
50
logger = logging .getLogger (__name__ )
49
51
50
52
SLEEP_INTERVAL = 3
51
53
WAIT_SECONDS_AFTER_FINISHED = 90
52
54
MAXIMUM_MOUNT_COUNT = 5
53
55
FILE_STORAGE_TYPE = "FILE_STORAGE"
54
56
OBJECT_STORAGE_TYPE = "OBJECT_STORAGE"
57
+ DEFAULT_NODE_GROUP_NAME = "node-group"
55
58
56
59
57
60
class DSCJob (OCIDataScienceMixin , oci .data_science .models .Job ):
@@ -284,11 +287,15 @@ def load_properties_from_env(self) -> None:
284
287
285
288
def load_defaults (self ) -> DSCJob :
286
289
self .load_properties_from_env ()
290
+ if getattr (self , "job_node_configuration_details" , None ):
291
+ return self
292
+ # Following are for single node job run only
287
293
if not self .job_infrastructure_configuration_details :
288
294
self .job_infrastructure_configuration_details = {}
295
+
289
296
# Convert the dict to JobInfrastructureConfigurationDetails object
290
297
if isinstance (self .job_infrastructure_configuration_details , dict ):
291
- # Default networking
298
+
292
299
if not self .job_infrastructure_configuration_details .get (
293
300
"jobInfrastructureType"
294
301
):
@@ -352,6 +359,7 @@ def create(self) -> DSCJob:
352
359
raise ValueError ("Specify compartment ID for data science job." )
353
360
if not self .project_id :
354
361
raise ValueError ("Specify project ID for data science job." )
362
+
355
363
self ._create_with_oci_api ()
356
364
return self
357
365
@@ -498,7 +506,9 @@ def run(self, **kwargs) -> DataScienceJobRun:
498
506
keys = list (kwargs .keys ())
499
507
for key in keys :
500
508
if key in config_swagger_types :
501
- config_kwargs [key ] = kwargs .pop (key )
509
+ val = kwargs .pop (key )
510
+ if val is not None :
511
+ config_kwargs [key ] = val
502
512
elif key in env_config_swagger_types :
503
513
value = kwargs .pop (key )
504
514
if key in [
@@ -545,6 +555,25 @@ def run(self, **kwargs) -> DataScienceJobRun:
545
555
env_config_override
546
556
)
547
557
558
+ if getattr (self , "job_node_configuration_details" , None ):
559
+ job_config_override = kwargs .pop ("job_configuration_override_details" , None )
560
+ env_config_override = kwargs .pop (
561
+ "job_environment_configuration_override_details" , None
562
+ )
563
+ if job_config_override or env_config_override :
564
+ node_config = {
565
+ "jobNodeType" : "MULTI_NODE" ,
566
+ "jobNodeGroupConfigurationDetailsList" : [
567
+ {
568
+ # Node group name must match the node group name in the job.
569
+ "name" : DEFAULT_NODE_GROUP_NAME ,
570
+ "JobConfigurationDetails" : job_config_override ,
571
+ "JobEnvironmentConfigurationDetails" : env_config_override ,
572
+ }
573
+ ],
574
+ }
575
+ kwargs ["job_node_configuration_override_details" ] = node_config
576
+
548
577
wait = kwargs .pop ("wait" , False )
549
578
run = DataScienceJobRun (** kwargs , ** self .auth ).create ()
550
579
if wait :
@@ -756,13 +785,11 @@ def stop_condition():
756
785
return True
757
786
# Stop only if time_finished is over 2 minute ago.
758
787
# This is for the time delay between job run stopped and the logs appear in oci logging.
759
- if (
788
+ return (
760
789
datetime .datetime .now (self .time_finished .tzinfo )
761
790
- datetime .timedelta (seconds = wait )
762
791
> self .time_finished
763
- ):
764
- return True
765
- return False
792
+ )
766
793
767
794
if not self .log_id and not self .log_group_id :
768
795
print (
@@ -1471,6 +1498,23 @@ def _update_from_dsc_model(
1471
1498
}
1472
1499
self .dsc_job = dsc_job
1473
1500
1501
+ # Process multi-node infrastructure config
1502
+ node_groups = get_value (
1503
+ dsc_job ,
1504
+ "job_node_configuration_details.job_node_group_configuration_details_list" ,
1505
+ )
1506
+ if node_groups and len (node_groups ) == 1 :
1507
+ node_group = node_groups [0 ]
1508
+ dsc_job .job_infrastructure_configuration_details = (
1509
+ node_group .job_infrastructure_configuration_details
1510
+ )
1511
+ subnet_id = get_value (
1512
+ dsc_job ,
1513
+ "job_node_configuration_details.job_network_configuration.subnet_id" ,
1514
+ )
1515
+ if subnet_id :
1516
+ self .set_spec (self .CONST_SUBNET_ID , subnet_id )
1517
+
1474
1518
for infra_attr , dsc_attr in self .payload_attribute_map .items ():
1475
1519
value = get_value (dsc_job , dsc_attr )
1476
1520
if not value :
@@ -1557,10 +1601,13 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
1557
1601
if value :
1558
1602
dsc_job .job_infrastructure_configuration_details [camel_attr ] = value
1559
1603
1560
- if not dsc_job .job_infrastructure_configuration_details .get (
1561
- "shapeName" , ""
1562
- ).endswith ("Flex" ) and dsc_job .job_infrastructure_configuration_details .get (
1563
- "jobShapeConfigDetails"
1604
+ shape = dsc_job .job_infrastructure_configuration_details .get ("shapeName" , "" )
1605
+ if (
1606
+ shape
1607
+ and not str (shape ).endswith ("Flex" )
1608
+ and dsc_job .job_infrastructure_configuration_details .get (
1609
+ "jobShapeConfigDetails"
1610
+ )
1564
1611
):
1565
1612
raise ValueError (
1566
1613
"Shape config is not required for non flex shape from user end."
@@ -1583,7 +1630,6 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
1583
1630
return self
1584
1631
1585
1632
def build (self ) -> DataScienceJob :
1586
- self .dsc_job .load_defaults ()
1587
1633
1588
1634
try :
1589
1635
self .dsc_job .load_defaults ()
@@ -1611,6 +1657,48 @@ def init(self, **kwargs) -> DataScienceJob:
1611
1657
)
1612
1658
)
1613
1659
1660
+ def _config_multi_node (self , runtime : MultiNodeRuntime ):
1661
+ """Configure the payload for multi-node job run."""
1662
+ infra_config : dict = self .dsc_job .job_infrastructure_configuration_details
1663
+ job_config : models .DefaultJobConfigurationDetails = (
1664
+ self .dsc_job .job_configuration_details
1665
+ )
1666
+ env_config = self .dsc_job .job_environment_configuration_details
1667
+ # For multi-node jobs,
1668
+ # the job_infrastructure_configuration_details and job_configuration_details
1669
+ # should be the special EMPTY class.
1670
+ # The job_environment_configuration_details should be None.
1671
+ # The configs will be specified in each node group.
1672
+ self .dsc_job .job_infrastructure_configuration_details = None
1673
+ self .dsc_job .job_configuration_details = None
1674
+ self .dsc_job .job_environment_configuration_details = None
1675
+
1676
+ subnet_id = infra_config .pop ("subnetId" , None )
1677
+ infra_config ["jobInfrastructureType" ] = (
1678
+ models .MultiNodeJobInfrastructureConfigurationDetails .JOB_INFRASTRUCTURE_TYPE_MULTI_NODE
1679
+ )
1680
+
1681
+ if subnet_id :
1682
+ network_config = models .JobCustomNetworkConfiguration (subnet_id = subnet_id )
1683
+ else :
1684
+ network_config = models .JobDefaultNetworkConfiguration ()
1685
+
1686
+ node_group_config : dict = {
1687
+ "name" : DEFAULT_NODE_GROUP_NAME ,
1688
+ "replicas" : runtime .replica ,
1689
+ "minimumSuccessReplicas" : runtime .replica ,
1690
+ "jobInfrastructureConfigurationDetails" : infra_config ,
1691
+ "jobConfigurationDetails" : job_config ,
1692
+ "jobEnvironmentConfigurationDetails" : env_config ,
1693
+ }
1694
+
1695
+ self .dsc_job .job_node_configuration_details = {
1696
+ "jobNodeType" : "MULTI_NODE" ,
1697
+ "startupOrder" : "IN_PARALLEL" ,
1698
+ "jobNetworkConfiguration" : network_config ,
1699
+ "jobNodeGroupConfigurationDetailsList" : [node_group_config ],
1700
+ }
1701
+
1614
1702
def create (self , runtime , ** kwargs ) -> DataScienceJob :
1615
1703
"""Creates a job with runtime.
1616
1704
@@ -1635,9 +1723,7 @@ def create(self, runtime, **kwargs) -> DataScienceJob:
1635
1723
1636
1724
if self .name :
1637
1725
display_name = Template (self .name ).safe_substitute (runtime .envs )
1638
- elif isinstance (runtime , GitPythonRuntime ) or isinstance (
1639
- runtime , ContainerRuntime
1640
- ):
1726
+ elif isinstance (runtime , (GitPythonRuntime , ContainerRuntime )):
1641
1727
display_name = utils .get_random_name_for_resource ()
1642
1728
else :
1643
1729
display_name = None
@@ -1652,11 +1738,22 @@ def create(self, runtime, **kwargs) -> DataScienceJob:
1652
1738
self .dsc_job = DSCJob (** payload , ** self .auth )
1653
1739
# Set Job infra to user values after DSCJob initialized the defaults
1654
1740
self ._update_job_infra (self .dsc_job )
1741
+ if self .is_multi_node_job (runtime ):
1742
+ self ._config_multi_node (runtime = runtime )
1655
1743
self .dsc_job .create ()
1656
1744
# Update the model from infra after job creation.
1657
1745
self ._update_from_dsc_model (self .dsc_job )
1658
1746
return self
1659
1747
1748
+ @staticmethod
1749
+ def is_multi_node_job (runtime ):
1750
+ """Check if the job is multi-node job."""
1751
+ return (
1752
+ MULTI_NODE_JOB_SUPPORT
1753
+ and isinstance (runtime , MultiNodeRuntime )
1754
+ and runtime .replica > 1
1755
+ )
1756
+
1660
1757
def run (
1661
1758
self ,
1662
1759
name = None ,
0 commit comments