Skip to content

Commit 2cfbeff

Browse files
committed
Updated pr.
1 parent 2dd4293 commit 2cfbeff

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ads.common.oci_resource import ResourceNotFoundError
3131
from ads.jobs.builders.infrastructure.base import Infrastructure, RunInstance
3232
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
33+
ContainerRuntimeHandler,
3334
DataScienceJobRuntimeManager,
3435
)
3536
from ads.jobs.builders.infrastructure.utils import get_value
@@ -458,14 +459,19 @@ def run(self, **kwargs) -> DataScienceJobRun:
458459
----------
459460
**kwargs :
460461
Keyword arguments for initializing a Data Science Job Run.
461-
The keys can be any keys in supported by OCI JobConfigurationDetails and JobRun, including:
462+
The keys can be any keys in supported by OCI JobConfigurationDetails, OcirContainerJobEnvironmentConfigurationDetails and JobRun, including:
462463
* hyperparameter_values: dict(str, str)
463464
* environment_variables: dict(str, str)
464465
* command_line_arguments: str
465466
* maximum_runtime_in_minutes: int
466467
* display_name: str
467468
* freeform_tags: dict(str, str)
468469
* defined_tags: dict(str, dict(str, object))
470+
* image: str
471+
* cmd: list[str]
472+
* entrypoint: list[str]
473+
* image_digest: str
474+
* image_signature_id: str
469475
470476
If display_name is not specified, it will be generated as "<JOB_NAME>-run-<TIMESTAMP>".
471477
@@ -478,14 +484,28 @@ def run(self, **kwargs) -> DataScienceJobRun:
478484
if not self.id:
479485
self.create()
480486

481-
swagger_types = (
487+
config_swagger_types = (
482488
oci.data_science.models.DefaultJobConfigurationDetails().swagger_types.keys()
483489
)
490+
env_config_swagger_types = {}
491+
if hasattr(oci.data_science.models, "OcirContainerJobEnvironmentConfigurationDetails"):
492+
env_config_swagger_types = (
493+
oci.data_science.models.OcirContainerJobEnvironmentConfigurationDetails().swagger_types.keys()
494+
)
484495
config_kwargs = {}
496+
env_config_kwargs = {}
485497
keys = list(kwargs.keys())
486498
for key in keys:
487-
if key in swagger_types:
499+
if key in config_swagger_types:
488500
config_kwargs[key] = kwargs.pop(key)
501+
elif key in env_config_swagger_types:
502+
value = kwargs.pop(key)
503+
if key in [
504+
ContainerRuntime.CONST_CMD,
505+
ContainerRuntime.CONST_ENTRYPOINT
506+
] and isinstance(value, str):
507+
value = ContainerRuntimeHandler.split_args(value)
508+
env_config_kwargs[key] = value
489509

490510
# remove timestamp from the job name (added in default names, when display_name not specified by user)
491511
if self.display_name:
@@ -514,6 +534,12 @@ def run(self, **kwargs) -> DataScienceJobRun:
514534
config_override.update(config_kwargs)
515535
kwargs["job_configuration_override_details"] = config_override
516536

537+
if env_config_kwargs:
538+
env_config_kwargs["jobEnvironmentType"] = "OCIR_CONTAINER"
539+
env_config_override = kwargs.get("job_environment_configuration_override_details", {})
540+
env_config_override.update(env_config_kwargs)
541+
kwargs["job_environment_configuration_override_details"] = env_config_override
542+
517543
wait = kwargs.pop("wait", False)
518544
run = DataScienceJobRun(**kwargs, **self.auth).create()
519545
if wait:

0 commit comments

Comments
 (0)