Skip to content

Commit 3eb3240

Browse files
authored
ODSC-50353: Allow to override the config values in the operator specification (#456)
2 parents 01b783a + 5531a79 commit 3eb3240

File tree

8 files changed

+343
-7
lines changed

8 files changed

+343
-7
lines changed

ads/opctl/config/merger.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ def _fill_config_with_defaults(self, ads_config_path: str) -> None:
124124
exec_config.get("auth") or AuthType.API_KEY
125125
)
126126
# determine profile
127-
if self.config["execution"]["auth"] != AuthType.API_KEY:
127+
if self.config["execution"]["auth"] in (
128+
AuthType.RESOURCE_PRINCIPAL,
129+
AuthType.INSTANCE_PRINCIPAL,
130+
):
128131
profile = self.config["execution"]["auth"].upper()
129132
exec_config.pop("oci_profile", None)
130133
self.config["execution"]["oci_profile"] = None

ads/opctl/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
# OPERATOR
3333
OPERATOR_MODULE_PATH = "ads.opctl.operator.lowcode"
3434
OPERATOR_IMAGE_WORK_DIR = "/etc/operator"
35+
OVERRIDE_KWARGS = "override_kwargs"
3536

3637

3738
class RUNTIME_TYPE(ExtendedEnum):

ads/opctl/decorator/common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
from functools import wraps
88
from typing import Callable, Dict, List
99

10+
import click
11+
1012
from ads.common.auth import AuthContext
1113
from ads.opctl import logger
1214
from ads.opctl.config.base import ConfigProcessor
1315
from ads.opctl.config.merger import ConfigMerger
16+
from ads.opctl.constants import OVERRIDE_KWARGS
1417

1518
RUN_ID_FIELD = "run_id"
1619

@@ -101,3 +104,26 @@ def wrapper(*args, **kwargs) -> Dict:
101104
return func(*args, **kwargs)
102105

103106
return wrapper
107+
108+
109+
def with_click_unknown_args(func: Callable) -> Callable:
110+
"""The decorator to parse the click unknown arguments and put them into kwargs."""
111+
112+
@wraps(func)
113+
def wrapper(*args, **kwargs) -> Dict:
114+
kwargs[OVERRIDE_KWARGS] = {}
115+
try:
116+
click_context = next(
117+
item for item in args if isinstance(item, click.core.Context)
118+
)
119+
kwargs[OVERRIDE_KWARGS] = {
120+
key[2:]: value
121+
for key, value in zip(click_context.args[::2], click_context.args[1::2])
122+
}
123+
except Exception as ex:
124+
logger.debug(ex)
125+
pass
126+
127+
return func(*args, **kwargs)
128+
129+
return wrapper

ads/opctl/operator/cli.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ads.common.auth import AuthType
1414
from ads.common.object_storage_details import ObjectStorageDetails
1515
from ads.opctl.constants import BACKEND_NAME, RUNTIME_TYPE
16-
from ads.opctl.decorator.common import click_options, with_auth
16+
from ads.opctl.decorator.common import click_options, with_auth, with_click_unknown_args
1717
from ads.opctl.utils import suppress_traceback
1818

1919
from .__init__ import __operators__
@@ -271,7 +271,12 @@ def publish_conda(debug: bool, **kwargs: Dict[str, Any]) -> None:
271271
suppress_traceback(debug)(cmd_publish_conda)(**kwargs)
272272

273273

274-
@commands.command()
274+
@commands.command(
275+
context_settings=dict(
276+
ignore_unknown_options=True,
277+
allow_extra_args=True,
278+
)
279+
)
275280
@click_options(DEBUG_OPTION + ADS_CONFIG_OPTION + AUTH_TYPE_OPTION)
276281
@click.option(
277282
"--file",
@@ -303,8 +308,10 @@ def publish_conda(debug: bool, **kwargs: Dict[str, Any]) -> None:
303308
is_flag=True,
304309
help="During dry run, the actual operation is not performed, only the steps are enumerated.",
305310
)
311+
@click.pass_context
312+
@with_click_unknown_args
306313
@with_auth
307-
def run(debug: bool, **kwargs: Dict[str, Any]) -> None:
314+
def run(ctx: click.core.Context, debug: bool, **kwargs: Dict[str, Any]) -> None:
308315
"""
309316
Runs the operator with the given specification on the targeted backend.
310317
"""

ads/opctl/operator/cmd.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,9 @@ def create(
565565
raise NotImplementedError()
566566

567567

568-
def run(config: Dict, backend: Union[Dict, str] = None, **kwargs) -> None:
568+
def run(
569+
config: Dict, backend: Union[Dict, str] = None, **kwargs: Dict[str, Any]
570+
) -> None:
569571
"""
570572
Runs the operator with the given specification on the targeted backend.
571573
@@ -575,7 +577,7 @@ def run(config: Dict, backend: Union[Dict, str] = None, **kwargs) -> None:
575577
The operator's config.
576578
backend: (Union[Dict, str], optional)
577579
The backend config or backend name to run the operator.
578-
kwargs: (Dict, optional)
580+
kwargs: (Dict[str, Any], optional)
579581
Optional key value arguments to run the operator.
580582
"""
581583
BackendFactory.backend(

ads/opctl/operator/common/backend_factory.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
from ads.opctl.constants import (
2525
BACKEND_NAME,
2626
DEFAULT_ADS_CONFIG_FOLDER,
27+
OVERRIDE_KWARGS,
2728
RESOURCE_TYPE,
2829
RUNTIME_TYPE,
2930
)
3031
from ads.opctl.operator.common.const import PACK_TYPE
32+
from ads.opctl.operator.common.dictionary_merger import DictionaryMerger
3133
from ads.opctl.operator.common.operator_loader import OperatorInfo, OperatorLoader
3234

3335

@@ -202,7 +204,10 @@ def backend(
202204
{**backend, **{"execution": {"backend": backend_kind}}}
203205
).step(ConfigMerger, **kwargs)
204206

205-
config.config["runtime"] = backend
207+
# merge backend with the override parameters
208+
config.config["runtime"] = DictionaryMerger(
209+
updates=kwargs.get(OVERRIDE_KWARGS)
210+
).merge(backend)
206211
config.config["infrastructure"] = p_backend.config["infrastructure"]
207212
config.config["execution"] = p_backend.config["execution"]
208213

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import copy
8+
from typing import Any, Dict, List
9+
10+
11+
class DictionaryMerger:
12+
"""
13+
A class to update dictionary values for specified keys and
14+
then merge these updates back into the original dictionary.
15+
16+
Example
17+
-------
18+
>>> updates = {
19+
... "infrastructure.blockStorageSize": "20",
20+
... "infrastructure.projectId": "my_new_project_id"
21+
... "runtime.conda": "my_conda"
22+
... }
23+
>>> updater = DictionaryMerger(updates)
24+
>>> source_data = {
25+
... "infrastructure": {
26+
... "blockStorageSize": "10",
27+
... "projectId": "old_project_id",
28+
... },
29+
... "runtime": {
30+
... "conda": "conda",
31+
... },
32+
... }
33+
>>> result = updater.dispatch(source_data)
34+
... {
35+
... "infrastructure": {
36+
... "blockStorageSize": "20",
37+
... "projectId": "my_new_project_id",
38+
... },
39+
... "runtime": {
40+
... "conda": "my_conda",
41+
... },
42+
... }
43+
44+
Attributes
45+
----------
46+
updates: Dict[str, Any]
47+
A dictionary containing the keys with their new values for the update.
48+
"""
49+
50+
_SYSTEM_KEYS = set(("kind", "type", "spec", "infrastructure", "runtime"))
51+
52+
def __init__(self, updates: Dict[str, Any], system_keys: List[str] = None):
53+
"""
54+
Initializes the DictionaryMerger with a dictionary of updates.
55+
56+
Parameters
57+
----------
58+
updates Dict[str, Any]
59+
A dictionary with keys that need to be updated and their new values.
60+
system_keys: List[str]
61+
The list of keys that cannot be replaced in the source dictionary.
62+
"""
63+
self.updates = updates
64+
self.system_keys = set(system_keys or []).union(self._SYSTEM_KEYS)
65+
66+
def _update_keys(
67+
self, dict_to_update: Dict[str, Any], parent_key: str = ""
68+
) -> None:
69+
"""
70+
Recursively updates the values of given keys in a dictionary.
71+
72+
Parameters
73+
----------
74+
dict_to_update: Dict[str, Any]
75+
The dictionary whose values are to be updated.
76+
parent_key: (str, optional)
77+
The current path in the dictionary being processed, used for nested dictionaries.
78+
79+
Returns
80+
-------
81+
None
82+
The method updates the dict_to_update in place.
83+
"""
84+
for key, value in dict_to_update.items():
85+
new_key = f"{parent_key}.{key}" if parent_key else key
86+
if isinstance(value, dict):
87+
self._update_keys(value, new_key)
88+
elif new_key in self.updates and key not in self.system_keys:
89+
dict_to_update[key] = self.updates[new_key]
90+
91+
def _merge_updates(
92+
self,
93+
original_dict: Dict[str, Any],
94+
updated_dict: Dict[str, Any],
95+
parent_key: str = "",
96+
) -> None:
97+
"""
98+
Merges updated values from the updated_dict into the original_dict based on the provided keys.
99+
100+
Parameters
101+
----------
102+
original_dict: Dict[str, Any]
103+
The original dictionary to merge updates into.
104+
updated_dict: Dict[str, Any]
105+
The updated dictionary with new values.
106+
parent_key: str
107+
The base key path for recursive merging.
108+
109+
Returns
110+
-------
111+
None
112+
The method updates the original_dict in place.
113+
"""
114+
for key, value in updated_dict.items():
115+
new_key = f"{parent_key}.{key}" if parent_key else key
116+
if isinstance(value, dict) and key in original_dict:
117+
self._merge_updates(original_dict[key], value, new_key)
118+
elif new_key in self.updates:
119+
original_dict[key] = value
120+
121+
def merge(self, src_dict: Dict[str, Any]) -> Dict[str, Any]:
122+
"""
123+
Updates the dictionary with new values for specified keys and merges
124+
these changes back into the original dictionary.
125+
126+
Parameters
127+
----------
128+
src_dict: Dict[str, Any]
129+
The dictionary to be updated and merged.
130+
131+
Returns
132+
-------
133+
Dict[str, Any]
134+
The updated and merged dictionary.
135+
"""
136+
if not self.updates:
137+
return src_dict
138+
139+
original_dict = copy.deepcopy(src_dict)
140+
updated_dict = copy.deepcopy(src_dict)
141+
142+
# Update the dictionary with the new values
143+
self._update_keys(updated_dict)
144+
145+
# Merge the updates back into the original dictionary
146+
self._merge_updates(original_dict, updated_dict)
147+
148+
return original_dict

0 commit comments

Comments
 (0)