30
30
from ads .common .oci_resource import ResourceNotFoundError
31
31
from ads .jobs .builders .infrastructure .base import Infrastructure , RunInstance
32
32
from ads .jobs .builders .infrastructure .dsc_job_runtime import (
33
+ ContainerRuntimeHandler ,
33
34
DataScienceJobRuntimeManager ,
34
35
)
35
36
from ads .jobs .builders .infrastructure .utils import get_value
@@ -458,14 +459,19 @@ def run(self, **kwargs) -> DataScienceJobRun:
458
459
----------
459
460
**kwargs :
460
461
Keyword arguments for initializing a Data Science Job Run.
461
- The keys can be any keys in supported by OCI JobConfigurationDetails and JobRun, including:
462
+ The keys can be any keys in supported by OCI JobConfigurationDetails, OcirContainerJobEnvironmentConfigurationDetails and JobRun, including:
462
463
* hyperparameter_values: dict(str, str)
463
464
* environment_variables: dict(str, str)
464
465
* command_line_arguments: str
465
466
* maximum_runtime_in_minutes: int
466
467
* display_name: str
467
468
* freeform_tags: dict(str, str)
468
469
* defined_tags: dict(str, dict(str, object))
470
+ * image: str
471
+ * cmd: list[str]
472
+ * entrypoint: list[str]
473
+ * image_digest: str
474
+ * image_signature_id: str
469
475
470
476
If display_name is not specified, it will be generated as "<JOB_NAME>-run-<TIMESTAMP>".
471
477
@@ -478,14 +484,28 @@ def run(self, **kwargs) -> DataScienceJobRun:
478
484
if not self .id :
479
485
self .create ()
480
486
481
- swagger_types = (
487
+ config_swagger_types = (
482
488
oci .data_science .models .DefaultJobConfigurationDetails ().swagger_types .keys ()
483
489
)
490
+ env_config_swagger_types = {}
491
+ if hasattr (oci .data_science .models , "OcirContainerJobEnvironmentConfigurationDetails" ):
492
+ env_config_swagger_types = (
493
+ oci .data_science .models .OcirContainerJobEnvironmentConfigurationDetails ().swagger_types .keys ()
494
+ )
484
495
config_kwargs = {}
496
+ env_config_kwargs = {}
485
497
keys = list (kwargs .keys ())
486
498
for key in keys :
487
- if key in swagger_types :
499
+ if key in config_swagger_types :
488
500
config_kwargs [key ] = kwargs .pop (key )
501
+ elif key in env_config_swagger_types :
502
+ value = kwargs .pop (key )
503
+ if key in [
504
+ ContainerRuntime .CONST_CMD ,
505
+ ContainerRuntime .CONST_ENTRYPOINT
506
+ ] and isinstance (value , str ):
507
+ value = ContainerRuntimeHandler .split_args (value )
508
+ env_config_kwargs [key ] = value
489
509
490
510
# remove timestamp from the job name (added in default names, when display_name not specified by user)
491
511
if self .display_name :
@@ -514,6 +534,12 @@ def run(self, **kwargs) -> DataScienceJobRun:
514
534
config_override .update (config_kwargs )
515
535
kwargs ["job_configuration_override_details" ] = config_override
516
536
537
+ if env_config_kwargs :
538
+ env_config_kwargs ["jobEnvironmentType" ] = "OCIR_CONTAINER"
539
+ env_config_override = kwargs .get ("job_environment_configuration_override_details" , {})
540
+ env_config_override .update (env_config_kwargs )
541
+ kwargs ["job_environment_configuration_override_details" ] = env_config_override
542
+
517
543
wait = kwargs .pop ("wait" , False )
518
544
run = DataScienceJobRun (** kwargs , ** self .auth ).create ()
519
545
if wait :
0 commit comments