Skip to content

Commit 99946dc

Browse files
authored
Fix Jobs API independent authentication issue for different objects/threads. (#742)
2 parents c5b657a + 02d78f1 commit 99946dc

File tree

4 files changed

+107
-21
lines changed

4 files changed

+107
-21
lines changed

ads/jobs/ads_job.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import fsspec
1212
import oci
13+
import yaml
1314
from ads.common.auth import default_signer
1415
from ads.common.decorator.utils import class_or_instance_method
1516
from ads.jobs.builders.base import Builder
@@ -263,6 +264,9 @@ def __init__(
263264
Job runtime, by default None.
264265
265266
"""
267+
# Saves a copy of the auth object from the class to the instance.
268+
# Future changes to the class level Job.auth will not affect the auth of existing instances.
269+
self.auth = self.auth.copy()
266270
for key in ["config", "signer", "client_kwargs"]:
267271
if kwargs.get(key):
268272
self.auth[key] = kwargs.pop(key)
@@ -545,6 +549,26 @@ def to_dict(self, **kwargs: Dict) -> Dict:
545549
"spec": spec,
546550
}
547551

552+
@class_or_instance_method
553+
def from_yaml(
554+
cls,
555+
yaml_string: str = None,
556+
uri: str = None,
557+
loader: callable = yaml.SafeLoader,
558+
**kwargs,
559+
):
560+
if inspect.isclass(cls):
561+
job = cls(**cls.auth)
562+
else:
563+
job = cls.__class__(**cls.auth)
564+
565+
if yaml_string:
566+
return job.from_dict(yaml.load(yaml_string, Loader=loader))
567+
if uri:
568+
yaml_dict = yaml.load(cls._read_from_file(uri=uri, **kwargs), Loader=loader)
569+
return job.from_dict(yaml_dict)
570+
raise ValueError("Must provide either YAML string or URI location")
571+
548572
@class_or_instance_method
549573
def from_dict(cls, config: dict) -> "Job":
550574
"""Initializes a job from a dictionary containing the configurations.
@@ -573,9 +597,9 @@ def from_dict(cls, config: dict) -> "Job":
573597
"runtime": cls._RUNTIME_MAPPING,
574598
}
575599
if inspect.isclass(cls):
576-
job = cls()
600+
job = cls(**cls.auth)
577601
else:
578-
job = cls.__class__()
602+
job = cls.__class__(**cls.auth)
579603

580604
for key, value in spec.items():
581605
if key in mappings:

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from __future__ import annotations
77

88
import datetime
9+
import inspect
910
import logging
10-
import oci
1111
import os
1212
import time
1313
import traceback
@@ -17,11 +17,12 @@
1717
from typing import Any, Dict, List, Optional, Union
1818

1919
import fsspec
20+
import oci
2021
import oci.data_science
2122
import oci.util as oci_util
22-
import yaml
2323
from oci.data_science.models import JobInfrastructureConfigurationDetails
2424
from oci.exceptions import ServiceError
25+
import yaml
2526
from ads.common import utils
2627
from ads.common.oci_datascience import DSCNotebookSession, OCIDataScienceMixin
2728
from ads.common.oci_logging import OCILog
@@ -782,7 +783,7 @@ def to_yaml(self) -> str:
782783
# Update runtime from job run
783784
from ads.jobs import Job
784785

785-
job = Job.from_dict(job_dict)
786+
job = Job(**self.auth).from_dict(job_dict)
786787
envs = job.runtime.envs
787788
run_config_override = run_dict.get("jobConfigurationOverrideDetails", {})
788789
envs.update(run_config_override.get("environmentVariables", {}))
@@ -811,7 +812,7 @@ def job(self):
811812
"""
812813
from ads.jobs import Job
813814

814-
return Job.from_datascience_job(self.job_id)
815+
return Job(**self.auth).from_datascience_job(self.job_id)
815816

816817
def download(self, to_dir):
817818
"""Downloads files from job run output URI to local.
@@ -953,9 +954,9 @@ def standardize_spec(spec):
953954
if key not in attribute_map and key.lower() in snake_to_camel_map:
954955
value = spec.pop(key)
955956
if isinstance(value, dict):
956-
spec[
957-
snake_to_camel_map[key.lower()]
958-
] = DataScienceJob.standardize_spec(value)
957+
spec[snake_to_camel_map[key.lower()]] = (
958+
DataScienceJob.standardize_spec(value)
959+
)
959960
else:
960961
spec[snake_to_camel_map[key.lower()]] = value
961962
return spec
@@ -971,6 +972,9 @@ def __init__(self, spec: Dict = None, **kwargs) -> None:
971972
Specification as keyword arguments.
972973
If spec contains the same key as the one in kwargs, the value from kwargs will be used.
973974
"""
975+
# Saves a copy of the auth object from the class to the instance.
976+
# Future changes to the class level Job.auth will not affect the auth of existing instances.
977+
self.auth = self.auth.copy()
974978
for key in ["config", "signer", "client_kwargs"]:
975979
if kwargs.get(key):
976980
self.auth[key] = kwargs.pop(key)
@@ -1710,6 +1714,15 @@ def from_id(cls, job_id: str) -> DataScienceJob:
17101714
"""
17111715
return cls.from_dsc_job(DSCJob(**cls.auth).from_ocid(job_id))
17121716

1717+
@class_or_instance_method
1718+
def from_dict(cls, obj_dict: dict):
1719+
"""Initialize the object from a Python dictionary"""
1720+
if inspect.isclass(cls):
1721+
job_cls = cls
1722+
else:
1723+
job_cls = cls.__class__
1724+
return job_cls(spec=obj_dict.get("spec"), **cls.auth)
1725+
17131726
@class_or_instance_method
17141727
def list_jobs(cls, compartment_id: str = None, **kwargs) -> List[DataScienceJob]:
17151728
"""Lists all jobs in a compartment.

ads/pipeline/ads_pipeline_step.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
import copy
77
from typing import List
88

99
from ads.jobs import Job
1010
from ads.jobs.builders.infrastructure.dsc_job import DataScienceJob
1111
from ads.jobs.builders.runtimes.base import Runtime
12+
from ads.common.utils import get_random_name_for_resource
1213

1314
PIPELINE_STEP_KIND_TO_OCI_MAP = {
1415
"dataScienceJob": "ML_JOB",
@@ -43,7 +44,7 @@ class PipelineStep(Job):
4344

4445
def __init__(
4546
self,
46-
name: str,
47+
name: str = None,
4748
job_id: str = None,
4849
infrastructure=None,
4950
runtime=None,
@@ -174,7 +175,7 @@ def __init__(
174175

175176
super().__init__()
176177
if not name:
177-
raise ValueError("PipelineStep name must be specified.")
178+
name = get_random_name_for_resource()
178179
elif any(char in PIPELINE_STEP_RESTRICTED_CHAR_SET for char in name):
179180
raise ValueError(
180181
"PipelineStep name can not include any of the "
@@ -521,17 +522,15 @@ def to_dict(self) -> dict:
521522
dict_details["spec"][self.CONST_DESCRIPTION] = self.description
522523
if self.kind == "ML_JOB":
523524
if self.environment_variable:
524-
dict_details["spec"][self.CONST_ENVIRONMENT_VARIABLES] = (
525-
self.environment_variable
526-
)
525+
dict_details["spec"][
526+
self.CONST_ENVIRONMENT_VARIABLES
527+
] = self.environment_variable
527528
if self.argument:
528-
dict_details["spec"][self.CONST_COMMAND_LINE_ARGUMENTS] = (
529-
self.argument
530-
)
529+
dict_details["spec"][self.CONST_COMMAND_LINE_ARGUMENTS] = self.argument
531530
if self.maximum_runtime_in_minutes:
532-
dict_details["spec"][self.CONST_MAXIMUM_RUNTIME_IN_MINUTES] = (
533-
self.maximum_runtime_in_minutes
534-
)
531+
dict_details["spec"][
532+
self.CONST_MAXIMUM_RUNTIME_IN_MINUTES
533+
] = self.maximum_runtime_in_minutes
535534

536535
dict_details["spec"].pop(self.CONST_DEPENDS_ON, None)
537536

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Contains tests for Jobs API authentication."""
2+
3+
from unittest import TestCase
4+
from ads.jobs import Job
5+
6+
JOB_YAML = """
7+
kind: job
8+
apiVersion: v1.0
9+
spec:
10+
name: llama2
11+
infrastructure:
12+
kind: infrastructure
13+
spec:
14+
blockStorageSize: 256
15+
compartmentId: "ocid1.compartment.oc1..aaa"
16+
logGroupId: "ocid1.loggroup.oc1.iad.aaa"
17+
logId: "ocid1.log.oc1.iad.aaa"
18+
projectId: "ocid1.datascienceproject.oc1.iad.aaa"
19+
subnetId: "ocid1.subnet.oc1.iad.aaa"
20+
shapeName: VM.GPU.A10.2
21+
type: dataScienceJob
22+
runtime:
23+
kind: runtime
24+
type: pyTorchDistributed
25+
spec:
26+
replicas: 2
27+
conda:
28+
type: service
29+
slug: pytorch20_p39_gpu_v2
30+
command: >-
31+
torchrun examples/finetuning.py
32+
"""
33+
34+
35+
class JobsAuthTest(TestCase):
36+
"""Contains tests for Jobs API authentication."""
37+
38+
def test_auth_from_yaml(self):
39+
"""Test using different endpoints for different jobs."""
40+
auth1 = {"client_kwargs": {"endpoint": "endpoint1.com"}}
41+
auth2 = {"client_kwargs": {"endpoint": "endpoint2.com"}}
42+
job1 = Job(**auth1).from_yaml(JOB_YAML)
43+
job2 = Job(**auth2).from_yaml(JOB_YAML)
44+
job3 = Job.from_yaml(JOB_YAML)
45+
self.assertEqual(job1.auth, auth1)
46+
self.assertEqual(job1.infrastructure.auth, auth1)
47+
self.assertEqual(job2.auth, auth2)
48+
self.assertEqual(job2.infrastructure.auth, auth2)
49+
self.assertEqual(job3.auth, {})
50+
self.assertEqual(job3.infrastructure.auth, {})

0 commit comments

Comments
 (0)