Skip to content

Commit de655f7

Browse files
authored
Odsc 44980: fix unittests (#259)
2 parents 90f2e30 + f6d6fe9 commit de655f7

29 files changed

+2122
-143
lines changed

.gitleaks.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ useDefault = true
99
# Paths listed in allowlist will not be scanned.
1010
[allowlist]
1111
description = "Global allow list"
12-
stopwords = ["test_password", "sample_key"]
1312
regexes = [
1413
'''example-password''',
1514
'''this-is-not-the-secret''',
16-
'''<redacted>'''
15+
'''<redacted>''',
16+
# NVIDIA_GPGKEY_SUM from public documentation:
17+
# https://gitlab.com/nvidia/container-images/cuda/-/blob/master/dist/10.1/centos7/base/Dockerfile
18+
'''d0664fbbdb8c32356d45de36c5984617217b2d0bef41b93ccecd326ba3b80c87'''
1719
]
1820
paths = [
1921
'''tests/integration/tests_configs.yaml'''

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ repos:
4040
- id: gitleaks
4141
# Oracle copyright checker
4242
- repo: https://github.com/oracle-samples/oci-data-science-ai-samples/
43-
rev: cbe0136
43+
rev: cbe0136f7aaffe463b31ddf3f34b0e16b4b124ff
4444
hooks:
4545
- id: check-copyright
4646
name: check-copyright

ads/common/oci_client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from oci.resource_search import ResourceSearchClient
1818
from oci.secrets import SecretsClient
1919
from oci.vault import VaultsClient
20-
from oci.feature_store import FeatureStoreClient
21-
2220
logger = logging.getLogger(__name__)
2321

2422

@@ -65,10 +63,15 @@ def _client_impl(self, client):
6563
"ai_language": AIServiceLanguageClient,
6664
"data_labeling_dp": DataLabelingClient,
6765
"data_labeling_cp": DataLabelingManagementClient,
68-
"feature_store": FeatureStoreClient,
6966
"resource_search": ResourceSearchClient,
7067
"data_catalog": DataCatalogClient
7168
}
69+
try:
70+
from oci.feature_store import FeatureStoreClient
71+
client_map["feature_store"] = FeatureStoreClient
72+
except ImportError:
73+
logger.warning("OCI SDK with feature store support is not installed")
74+
pass
7275

7376
assert (
7477
client in client_map

ads/common/oci_mixin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def _parse_kwargs(attribute_map: dict, **kwargs):
230230

231231
return parsed_kwargs
232232

233-
@classmethod
233+
@class_or_instance_method
234234
def deserialize(cls, data, to_cls):
235235
"""De-serialize data from dictionary to an OCI model"""
236236
if cls.type_mappings is None:
@@ -549,7 +549,7 @@ def from_dict(cls, data):
549549
"""
550550
return cls.create_instance(**data)
551551

552-
@classmethod
552+
@class_or_instance_method
553553
def deserialize(cls, data: dict, to_cls: str = None):
554554
"""Deserialize data
555555
@@ -726,7 +726,7 @@ def update_from_oci_model(
726726
for attr in self.swagger_types.keys():
727727
if (
728728
hasattr(oci_model_instance, attr)
729-
and getattr(oci_model_instance, attr) is not None
729+
and getattr(oci_model_instance, attr)
730730
and (
731731
not hasattr(self, attr)
732732
or not getattr(self, attr)

ads/common/serializer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import yaml
1616

1717
from ads.common import logger
18+
from ads.common.auth import default_signer
1819

1920
try:
2021
from yaml import CSafeDumper as dumper
@@ -134,6 +135,14 @@ def _read_from_file(uri: str, **kwargs) -> str:
134135
-------
135136
string: Contents in file specified by URI
136137
"""
138+
# Add default signer if the uri is an object storage uri, and
139+
# the user does not specify config or signer.
140+
if (
141+
uri.startswith("oci://")
142+
and "config" not in kwargs
143+
and "signer" not in kwargs
144+
):
145+
kwargs.update(default_signer())
137146
with fsspec.open(uri, "r", **kwargs) as f:
138147
return f.read()
139148

ads/jobs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77

@@ -17,6 +17,7 @@
1717
DataFlowRuntime,
1818
DataFlowNotebookRuntime,
1919
)
20+
from ads.jobs.builders.runtimes.pytorch_runtime import PyTorchDistributedRuntime
2021
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
2122
from ads.jobs.ads_job import Job
2223
from ads.jobs.builders import infrastructure
@@ -44,6 +45,7 @@
4445
"NotebookRuntime",
4546
"ScriptRuntime",
4647
"ContainerRuntime",
48+
"PyTorchDistributedRuntime",
4749
"DataFlow",
4850
"DataFlowRun",
4951
"DataFlowRuntime",

ads/jobs/ads_job.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ads.jobs.builders.base import Builder
1212
from ads.jobs.builders.infrastructure.dataflow import DataFlow, DataFlowRun
1313
from ads.jobs.builders.infrastructure.dsc_job import DataScienceJob, DataScienceJobRun
14+
from ads.jobs.builders.runtimes.pytorch_runtime import PyTorchDistributedRuntime
1415
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
1516
from ads.jobs.builders.runtimes.python_runtime import (
1617
DataFlowRuntime,
@@ -131,6 +132,7 @@ class Job(Builder):
131132
ContainerRuntime,
132133
ScriptRuntime,
133134
NotebookRuntime,
135+
PyTorchDistributedRuntime,
134136
DataFlowRuntime,
135137
]
136138
}

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
3636
from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime
3737

38-
from ads.common.dsc_file_system import OCIFileStorage, DSCFileSystemManager, OCIObjectStorage
38+
from ads.common.dsc_file_system import (
39+
OCIFileStorage,
40+
DSCFileSystemManager,
41+
OCIObjectStorage,
42+
)
3943

4044
logger = logging.getLogger(__name__)
4145

@@ -1454,11 +1458,14 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
14541458
if value:
14551459
dsc_job.job_infrastructure_configuration_details[camel_attr] = value
14561460

1457-
if (
1458-
not dsc_job.job_infrastructure_configuration_details.get("shapeName", "").endswith("Flex")
1459-
and dsc_job.job_infrastructure_configuration_details.get("jobShapeConfigDetails")
1461+
if not dsc_job.job_infrastructure_configuration_details.get(
1462+
"shapeName", ""
1463+
).endswith("Flex") and dsc_job.job_infrastructure_configuration_details.get(
1464+
"jobShapeConfigDetails"
14601465
):
1461-
raise ValueError("Shape config is not required for non flex shape from user end.")
1466+
raise ValueError(
1467+
"Shape config is not required for non flex shape from user end."
1468+
)
14621469

14631470
if dsc_job.job_infrastructure_configuration_details.get("subnetId"):
14641471
dsc_job.job_infrastructure_configuration_details[
@@ -1495,7 +1502,10 @@ def init(self) -> DataScienceJob:
14951502
self.build()
14961503
.with_compartment_id(self.compartment_id or "{Provide a compartment OCID}")
14971504
.with_project_id(self.project_id or "{Provide a project OCID}")
1498-
.with_subnet_id(self.subnet_id or "{Provide a subnet OCID or remove this field if you use a default networking}")
1505+
.with_subnet_id(
1506+
self.subnet_id
1507+
or "{Provide a subnet OCID or remove this field if you use a default networking}"
1508+
)
14991509
)
15001510

15011511
def create(self, runtime, **kwargs) -> DataScienceJob:
@@ -1552,7 +1562,7 @@ def run(
15521562
freeform_tags=None,
15531563
defined_tags=None,
15541564
wait=False,
1555-
**kwargs
1565+
**kwargs,
15561566
) -> DataScienceJobRun:
15571567
"""Runs a job on OCI Data Science job
15581568
@@ -1603,15 +1613,21 @@ def run(
16031613
envs.update(env_var)
16041614
name = Template(name).safe_substitute(envs)
16051615

1606-
return self.dsc_job.run(
1616+
kwargs = dict(
16071617
display_name=name,
16081618
command_line_arguments=args,
16091619
environment_variables=env_var,
16101620
freeform_tags=freeform_tags,
16111621
defined_tags=defined_tags,
16121622
wait=wait,
1613-
**kwargs
1623+
**kwargs,
16141624
)
1625+
# A Runtime class may define customized run() method.
1626+
# Use the customized method if the run() method is defined by the runtime.
1627+
# Otherwise, use the default run() method defined in this class.
1628+
if hasattr(self.runtime, "run"):
1629+
return self.runtime.run(self.dsc_job, **kwargs)
1630+
return self.dsc_job.run(**kwargs)
16151631

16161632
def delete(self) -> None:
16171633
"""Deletes a job"""

ads/jobs/builders/infrastructure/dsc_job_runtime.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,19 @@
2929
GitPythonRuntime,
3030
)
3131
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
32+
from ads.jobs.builders.runtimes.pytorch_runtime import (
33+
PyTorchDistributedRuntime,
34+
PyTorchDistributedArtifact,
35+
)
3236
from ads.jobs.builders.runtimes.artifact import (
3337
ScriptArtifact,
3438
NotebookArtifact,
3539
PythonArtifact,
3640
GitPythonArtifact,
3741
)
42+
from ads.opctl.distributed.common import cluster_config_helper
3843
from ads.jobs.builders.infrastructure.utils import get_value
44+
from ads.jobs.templates import driver_utils
3945

4046

4147
class IncompatibleRuntime(Exception): # pragma: no cover
@@ -184,7 +190,7 @@ def _translate_config(self, runtime: Runtime) -> dict:
184190
if runtime.args:
185191
# shlex.join() is not available until python 3.8
186192
job_configuration_details["command_line_arguments"] = " ".join(
187-
shlex.quote(arg) for arg in runtime.get_spec(runtime.CONST_ARGS)
193+
shlex.quote(str(arg)) for arg in runtime.get_spec(runtime.CONST_ARGS)
188194
)
189195
return job_configuration_details
190196

@@ -653,7 +659,7 @@ def _translate_env(self, runtime: PythonRuntime) -> dict:
653659

654660
if runtime.entrypoint:
655661
envs[self.CONST_CODE_ENTRYPOINT] = runtime.entrypoint
656-
else:
662+
elif runtime.script_uri:
657663
envs[self.CONST_CODE_ENTRYPOINT] = os.path.basename(runtime.script_uri)
658664

659665
envs[self.CONST_JOB_ENTRYPOINT] = PythonArtifact.CONST_DRIVER_SCRIPT
@@ -674,9 +680,13 @@ def _extract_envs(self, dsc_job) -> dict:
674680
"""
675681
spec = super()._extract_envs(dsc_job)
676682
envs = spec.pop(PythonRuntime.CONST_ENV_VAR, {})
677-
if self.CONST_CODE_ENTRYPOINT not in envs:
683+
if (
684+
self.__class__ == PythonRuntimeHandler
685+
and self.CONST_CODE_ENTRYPOINT not in envs
686+
):
678687
raise IncompatibleRuntime()
679-
envs.pop(PythonRuntimeHandler.CONST_JOB_ENTRYPOINT)
688+
# PyTorchDistributedRuntime does not require entrypoint.
689+
envs.pop(PythonRuntimeHandler.CONST_JOB_ENTRYPOINT, None)
680690
spec.update(self._extract_specs(envs, self.SPEC_MAPPINGS))
681691
if PythonRuntime.CONST_PYTHON_PATH in spec:
682692
spec[PythonRuntime.CONST_PYTHON_PATH] = spec[
@@ -1035,6 +1045,98 @@ def _extract_envs(self, dsc_job):
10351045
return spec
10361046

10371047

1048+
class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
1049+
RUNTIME_CLASS = PyTorchDistributedRuntime
1050+
CONST_WORKER_COUNT = "OCI__WORKER_COUNT"
1051+
CONST_COMMAND = "OCI__LAUNCH_CMD"
1052+
CONST_DEEPSPEED = "OCI__DEEPSPEED"
1053+
1054+
GIT_SPEC_MAPPINGS = {
1055+
cluster_config_helper.OCI__RUNTIME_URI: GitPythonRuntime.CONST_GIT_URL,
1056+
cluster_config_helper.OCI__RUNTIME_GIT_BRANCH: GitPythonRuntime.CONST_BRANCH,
1057+
cluster_config_helper.OCI__RUNTIME_GIT_COMMIT: GitPythonRuntime.CONST_COMMIT,
1058+
cluster_config_helper.OCI__RUNTIME_GIT_SECRET_ID: GitPythonRuntime.CONST_GIT_SSH_SECRET_ID,
1059+
}
1060+
1061+
SPEC_MAPPINGS = PythonRuntimeHandler.SPEC_MAPPINGS
1062+
SPEC_MAPPINGS.update(
1063+
{
1064+
PyTorchDistributedRuntime.CONST_COMMAND: CONST_COMMAND,
1065+
}
1066+
)
1067+
1068+
def _translate_artifact(self, runtime: PyTorchDistributedRuntime):
1069+
return PyTorchDistributedArtifact(runtime.source_uri, runtime)
1070+
1071+
def _translate_env(self, runtime: PyTorchDistributedRuntime) -> dict:
1072+
envs = super()._translate_env(runtime)
1073+
replica = runtime.replica if runtime.replica else 1
1074+
# WORKER_COUNT = REPLICA - 1 so that it will be same as distributed training
1075+
envs[self.CONST_WORKER_COUNT] = str(replica - 1)
1076+
envs[self.CONST_JOB_ENTRYPOINT] = PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT
1077+
if runtime.inputs:
1078+
envs[driver_utils.CONST_ENV_INPUT_MAPPINGS] = json.dumps(runtime.inputs)
1079+
if runtime.git:
1080+
for env_key, spec_key in self.GIT_SPEC_MAPPINGS.items():
1081+
if not runtime.git.get(spec_key):
1082+
continue
1083+
envs[env_key] = runtime.git[spec_key]
1084+
if runtime.dependencies:
1085+
if PyTorchDistributedRuntime.CONST_PIP_PKG in runtime.dependencies:
1086+
envs[driver_utils.CONST_ENV_PIP_PKG] = runtime.dependencies[
1087+
PyTorchDistributedRuntime.CONST_PIP_PKG
1088+
]
1089+
if PyTorchDistributedRuntime.CONST_PIP_REQ in runtime.dependencies:
1090+
envs[driver_utils.CONST_ENV_PIP_REQ] = runtime.dependencies[
1091+
PyTorchDistributedRuntime.CONST_PIP_REQ
1092+
]
1093+
if runtime.use_deepspeed:
1094+
envs[self.CONST_DEEPSPEED] = "1"
1095+
return envs
1096+
1097+
def _extract_envs(self, dsc_job) -> dict:
1098+
spec = super()._extract_envs(dsc_job)
1099+
envs = spec.pop(PythonRuntime.CONST_ENV_VAR, {})
1100+
if self.CONST_WORKER_COUNT not in envs:
1101+
raise IncompatibleRuntime()
1102+
# Replicas
1103+
spec[PyTorchDistributedRuntime.CONST_REPLICA] = (
1104+
int(envs.pop(self.CONST_WORKER_COUNT)) + 1
1105+
)
1106+
# Git
1107+
if cluster_config_helper.OCI__RUNTIME_URI in envs:
1108+
git_spec = {}
1109+
for env_key, spec_key in self.GIT_SPEC_MAPPINGS.items():
1110+
if env_key in envs:
1111+
git_spec[spec_key] = envs.pop(env_key)
1112+
spec[PyTorchDistributedRuntime.CONST_GIT] = git_spec
1113+
# Inputs
1114+
input_mappings = envs.pop(driver_utils.CONST_ENV_INPUT_MAPPINGS, None)
1115+
if input_mappings:
1116+
try:
1117+
spec[PyTorchDistributedRuntime.CONST_INPUT] = json.loads(input_mappings)
1118+
except ValueError:
1119+
spec[PyTorchDistributedRuntime.CONST_INPUT] = input_mappings
1120+
# Dependencies
1121+
dep = {}
1122+
if driver_utils.CONST_ENV_PIP_PKG in envs:
1123+
dep[PyTorchDistributedRuntime.CONST_PIP_PKG] = envs.pop(
1124+
driver_utils.CONST_ENV_PIP_PKG
1125+
)
1126+
if driver_utils.CONST_ENV_PIP_REQ in envs:
1127+
dep[PyTorchDistributedRuntime.CONST_PIP_REQ] = envs.pop(
1128+
driver_utils.CONST_ENV_PIP_REQ
1129+
)
1130+
if dep:
1131+
spec[PyTorchDistributedRuntime.CONST_DEP] = dep
1132+
if envs.pop(self.CONST_DEEPSPEED, None):
1133+
spec[PyTorchDistributedRuntime.CONST_DEEPSPEED] = True
1134+
# Envs
1135+
if envs:
1136+
spec[PythonRuntime.CONST_ENV_VAR] = envs
1137+
return spec
1138+
1139+
10381140
class DataScienceJobRuntimeManager(RuntimeHandler):
10391141
"""This class is used by the DataScienceJob infrastructure to handle the runtime conversion.
10401142
The translate() method determines the actual runtime handler by matching the RUNTIME_CLASS.
@@ -1046,6 +1148,7 @@ class DataScienceJobRuntimeManager(RuntimeHandler):
10461148

10471149
runtime_handlers = [
10481150
ContainerRuntimeHandler,
1151+
PyTorchDistributedRuntimeHandler,
10491152
GitPythonRuntimeHandler,
10501153
NotebookRuntimeHandler,
10511154
PythonRuntimeHandler,

0 commit comments

Comments
 (0)