1
1
#!/usr/bin/env python
2
2
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
6
- from copy import deepcopy
7
6
from typing import Any , Dict , List , Optional
8
7
9
8
from pydantic import Field
10
9
11
10
from ads .aqua .config .utils .serializer import Serializable
12
11
13
12
14
- class ModelParamsOverrides (Serializable ):
15
- """Defines overrides for model parameters, including exclusions and additional inclusions."""
16
-
17
- exclude : Optional [List [str ]] = Field (default_factory = list )
18
- include : Optional [Dict [str , Any ]] = Field (default_factory = dict )
19
-
20
- class Config :
21
- extra = "ignore"
22
-
23
-
24
- class ModelParamsVersion (Serializable ):
25
- """Handles version-specific model parameter overrides."""
26
-
27
- overrides : Optional [ModelParamsOverrides ] = Field (
28
- default_factory = ModelParamsOverrides
29
- )
30
-
31
- class Config :
32
- extra = "ignore"
33
-
34
-
35
- class ModelParamsContainer (Serializable ):
36
- """Represents a container's model configuration, including tasks, defaults, and versions."""
37
-
38
- name : Optional [str ] = None
39
- default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
40
- versions : Optional [Dict [str , ModelParamsVersion ]] = Field (default_factory = dict )
41
-
42
- class Config :
43
- extra = "ignore"
44
-
45
-
46
- class InferenceParams (Serializable ):
47
- """Contains inference-related parameters with defaults."""
48
-
49
- class Config :
50
- extra = "allow"
51
-
52
-
53
- class InferenceContainer (Serializable ):
54
- """Represents the inference parameters specific to a container."""
55
-
56
- name : Optional [str ] = None
57
- params : Optional [Dict [str , Any ]] = Field (default_factory = dict )
58
-
59
- class Config :
60
- extra = "ignore"
61
-
62
-
63
- class ReportParams (Serializable ):
64
- """Handles the report-related parameters."""
65
-
66
- default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
67
-
68
- class Config :
69
- extra = "ignore"
70
-
71
-
72
- class InferenceParamsConfig (Serializable ):
73
- """Combines default inference parameters with container-specific configurations."""
74
-
75
- default : Optional [InferenceParams ] = Field (default_factory = InferenceParams )
76
- containers : Optional [List [InferenceContainer ]] = Field (default_factory = list )
77
-
78
- def get_merged_params (self , container_name : str ) -> InferenceParams :
79
- """
80
- Merges default inference params with those specific to the given container.
81
-
82
- Parameters
83
- ----------
84
- container_name (str): The name of the container.
85
-
86
- Returns
87
- -------
88
- InferenceParams: The merged inference parameters.
89
- """
90
- merged_params = self .default .to_dict ()
91
- for containers in self .containers :
92
- if containers .name .lower () == container_name .lower ():
93
- merged_params .update (containers .params or {})
94
- break
95
- return InferenceParams (** merged_params )
96
-
97
- class Config :
98
- extra = "ignore"
99
-
100
-
101
- class InferenceModelParamsConfig (Serializable ):
102
- """Encapsulates the model parameters for different containers."""
103
-
104
- default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
105
- containers : Optional [List [ModelParamsContainer ]] = Field (default_factory = list )
106
-
107
- def get_merged_model_params (
108
- self ,
109
- container_name : str ,
110
- version : Optional [str ] = None ,
111
- ) -> Dict [str , Any ]:
112
- """
113
- Gets the model parameters for a given container, version,
114
- merged with the defaults.
115
-
116
- Parameters
117
- ----------
118
- container_name (str): The name of the container.
119
- version (Optional[str]): The specific version of the container.
120
-
121
- Returns
122
- -------
123
- Dict[str, Any]: The merged model parameters.
124
- """
125
- params = deepcopy (self .default )
126
-
127
- for container in self .containers :
128
- if container .name .lower () == container_name .lower ():
129
- params .update (container .default )
130
-
131
- if version and version in container .versions :
132
- version_overrides = container .versions [version ].overrides
133
- if version_overrides :
134
- if version_overrides .include :
135
- params .update (version_overrides .include )
136
- if version_overrides .exclude :
137
- for key in version_overrides .exclude :
138
- params .pop (key , None )
139
- break
140
-
141
- return params
142
-
143
- class Config :
144
- extra = "ignore"
145
-
146
-
147
13
class ShapeFilterConfig (Serializable ):
148
14
"""Represents the filtering options for a specific shape."""
149
15
150
16
evaluation_container : Optional [List [str ]] = Field (default_factory = list )
151
17
evaluation_target : Optional [List [str ]] = Field (default_factory = list )
152
18
153
19
class Config :
154
- extra = "ignore "
20
+ extra = "allow "
155
21
156
22
157
23
class ShapeConfig (Serializable ):
@@ -178,7 +44,7 @@ class MetricConfig(Serializable):
178
44
tags : Optional [List [str ]] = Field (default_factory = list )
179
45
180
46
class Config :
181
- extra = "ignore "
47
+ extra = "allow "
182
48
183
49
184
50
class ModelParamsConfig (Serializable ):
@@ -223,7 +89,7 @@ def search_shapes(
223
89
]
224
90
225
91
class Config :
226
- extra = "ignore "
92
+ extra = "allow "
227
93
protected_namespaces = ()
228
94
229
95
@@ -235,49 +101,7 @@ class EvaluationServiceConfig(Serializable):
235
101
236
102
version : Optional [str ] = "1.0"
237
103
kind : Optional [str ] = "evaluation_service_config"
238
- report_params : Optional [ReportParams ] = Field (default_factory = ReportParams )
239
- inference_params : Optional [InferenceParamsConfig ] = Field (
240
- default_factory = InferenceParamsConfig
241
- )
242
- inference_model_params : Optional [InferenceModelParamsConfig ] = Field (
243
- default_factory = InferenceModelParamsConfig
244
- )
245
104
ui_config : Optional [UIConfig ] = Field (default_factory = UIConfig )
246
105
247
- def get_merged_inference_params (self , container_name : str ) -> InferenceParams :
248
- """
249
- Merges default inference params with those specific to the given container.
250
-
251
- Params
252
- ------
253
- container_name (str): The name of the container.
254
-
255
- Returns
256
- -------
257
- InferenceParams: The merged inference parameters.
258
- """
259
- return self .inference_params .get_merged_params (container_name = container_name )
260
-
261
- def get_merged_inference_model_params (
262
- self ,
263
- container_name : str ,
264
- version : Optional [str ] = None ,
265
- ) -> Dict [str , Any ]:
266
- """
267
- Gets the model parameters for a given container, version, and task, merged with the defaults.
268
-
269
- Parameters
270
- ----------
271
- container_name (str): The name of the container.
272
- version (Optional[str]): The specific version of the container.
273
-
274
- Returns
275
- -------
276
- Dict[str, Any]: The merged model parameters.
277
- """
278
- return self .inference_model_params .get_merged_model_params (
279
- container_name = container_name , version = version
280
- )
281
-
282
106
class Config :
283
- extra = "ignore "
107
+ extra = "allow "
0 commit comments