4
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
6
6
from copy import deepcopy
7
- from typing import Any , Dict , List , Optional , Union
7
+ from typing import Any , Dict , List , Optional
8
8
9
9
from pydantic import Field
10
10
19
19
INFERENCE_DELAY = 0
20
20
21
21
22
- class ModelParamItem (Serializable ):
23
- """Represents min, max, and default values for a model parameter."""
24
-
25
- min : Optional [Union [int , float ]] = None
26
- max : Optional [Union [int , float ]] = None
27
- default : Optional [Union [int , float ]] = None
28
-
29
- class Config :
30
- extra = "ignore"
31
-
32
-
33
22
class ModelParamsOverrides (Serializable ):
34
23
"""Defines overrides for model parameters, including exclusions and additional inclusions."""
35
24
@@ -51,28 +40,11 @@ class Config:
51
40
extra = "ignore"
52
41
53
42
54
- class ModelDefaultParams (Serializable ):
55
- """Defines default parameters for a model within a specific framework."""
56
-
57
- model : Optional [str ] = None
58
- max_tokens : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
59
- temperature : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
60
- top_p : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
61
- top_k : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
62
- presence_penalty : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
63
- frequency_penalty : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
64
- stop : List [str ] = Field (default_factory = list )
43
+ class ModelParamsContainer (Serializable ):
44
+ """Represents a container's model configuration, including tasks, defaults, and versions."""
65
45
66
- class Config :
67
- extra = "allow"
68
-
69
-
70
- class ModelFramework (Serializable ):
71
- """Represents a framework's model configuration, including tasks, defaults, and versions."""
72
-
73
- framework : Optional [str ] = None
74
- task : Optional [List [str ]] = Field (default_factory = list )
75
- default : Optional [ModelDefaultParams ] = Field (default_factory = ModelDefaultParams )
46
+ name : Optional [str ] = None
47
+ default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
76
48
versions : Optional [Dict [str , ModelParamsVersion ]] = Field (default_factory = dict )
77
49
78
50
class Config :
@@ -93,10 +65,10 @@ class Config:
93
65
extra = "allow"
94
66
95
67
96
- class InferenceFramework (Serializable ):
97
- """Represents the inference parameters specific to a framework ."""
68
+ class InferenceContainer (Serializable ):
69
+ """Represents the inference parameters specific to a container ."""
98
70
99
- framework : Optional [str ] = None
71
+ name : Optional [str ] = None
100
72
params : Optional [Dict [str , Any ]] = Field (default_factory = dict )
101
73
102
74
class Config :
@@ -113,70 +85,66 @@ class Config:
113
85
114
86
115
87
class InferenceParamsConfig (Serializable ):
116
- """Combines default inference parameters with framework -specific configurations."""
88
+ """Combines default inference parameters with container -specific configurations."""
117
89
118
90
default : Optional [InferenceParams ] = Field (default_factory = InferenceParams )
119
- frameworks : Optional [List [InferenceFramework ]] = Field (default_factory = list )
91
+ containers : Optional [List [InferenceContainer ]] = Field (default_factory = list )
120
92
121
- def get_merged_params (self , framework_name : str ) -> InferenceParams :
93
+ def get_merged_params (self , container_name : str ) -> InferenceParams :
122
94
"""
123
- Merges default inference params with those specific to the given framework .
95
+ Merges default inference params with those specific to the given container .
124
96
125
97
Parameters
126
98
----------
127
- framework_name (str): The name of the framework .
99
+ container_name (str): The name of the container .
128
100
129
101
Returns
130
102
-------
131
103
InferenceParams: The merged inference parameters.
132
104
"""
133
105
merged_params = self .default .to_dict ()
134
- for framework in self .frameworks :
135
- if framework . framework .lower () == framework_name .lower ():
136
- merged_params .update (framework .params or {})
106
+ for containers in self .containers :
107
+ if containers . name .lower () == container_name .lower ():
108
+ merged_params .update (containers .params or {})
137
109
break
138
110
return InferenceParams (** merged_params )
139
111
140
112
class Config :
141
113
extra = "ignore"
142
114
143
115
144
- class ModelParamsConfig (Serializable ):
145
- """Encapsulates the model parameters for different frameworks ."""
116
+ class InferenceModelParamsConfig (Serializable ):
117
+ """Encapsulates the model parameters for different containers ."""
146
118
147
119
default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
148
- frameworks : Optional [List [ModelFramework ]] = Field (default_factory = list )
120
+ containers : Optional [List [ModelParamsContainer ]] = Field (default_factory = list )
149
121
150
- def get_model_params (
122
+ def get_merged_model_params (
151
123
self ,
152
- framework_name : str ,
124
+ container_name : str ,
153
125
version : Optional [str ] = None ,
154
- task : Optional [str ] = None ,
155
126
) -> Dict [str , Any ]:
156
127
"""
157
- Gets the model parameters for a given framework , version, and tasks ,
128
+ Gets the model parameters for a given container , version,
158
129
merged with the defaults.
159
130
160
131
Parameters
161
132
----------
162
- framework_name (str): The name of the framework.
163
- version (Optional[str]): The specific version of the framework.
164
- task (Optional[str]): The specific task.
133
+ container_name (str): The name of the container.
134
+ version (Optional[str]): The specific version of the container.
165
135
166
136
Returns
167
137
-------
168
138
Dict[str, Any]: The merged model parameters.
169
139
"""
170
140
params = deepcopy (self .default )
171
141
172
- for framework in self .frameworks :
173
- if framework .framework .lower () == framework_name .lower () and (
174
- not task or task .lower () in framework .task
175
- ):
176
- params .update (framework .default .to_dict ())
142
+ for container in self .containers :
143
+ if container .name .lower () == container_name .lower ():
144
+ params .update (container .default )
177
145
178
- if version and version in framework .versions :
179
- version_overrides = framework .versions [version ].overrides
146
+ if version and version in container .versions :
147
+ version_overrides = container .versions [version ].overrides
180
148
if version_overrides :
181
149
if version_overrides .include :
182
150
params .update (version_overrides .include )
@@ -228,59 +196,17 @@ class Config:
228
196
extra = "ignore"
229
197
230
198
231
- class EvaluationServiceConfig (Serializable ):
232
- """
233
- Root configuration class for evaluation setup including model,
234
- inference, and shape configurations.
235
- """
199
+ class ModelParamsConfig (Serializable ):
200
+ """Encapsulates the default model parameters."""
236
201
237
- version : Optional [str ] = "1.0"
238
- kind : Optional [str ] = "evaluation"
239
- report_params : Optional [ReportParams ] = Field (default_factory = ReportParams )
240
- inference_params : Optional [InferenceParamsConfig ] = Field (
241
- default_factory = InferenceParamsConfig
242
- )
202
+ default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
203
+
204
+
205
+ class UIConfig (Serializable ):
243
206
model_params : Optional [ModelParamsConfig ] = Field (default_factory = ModelParamsConfig )
244
207
shapes : List [ShapeConfig ] = Field (default_factory = list )
245
208
metrics : List [MetricConfig ] = Field (default_factory = list )
246
209
247
- def get_merged_inference_params (self , framework_name : str ) -> InferenceParams :
248
- """
249
- Merges default inference params with those specific to the given framework.
250
-
251
- Params
252
- ------
253
- framework_name (str): The name of the framework.
254
-
255
- Returns
256
- -------
257
- InferenceParams: The merged inference parameters.
258
- """
259
- return self .inference_params .get_merged_params (framework_name = framework_name )
260
-
261
- def get_merged_model_params (
262
- self ,
263
- framework_name : str ,
264
- version : Optional [str ] = None ,
265
- task : Optional [str ] = None ,
266
- ) -> Dict [str , Any ]:
267
- """
268
- Gets the model parameters for a given framework, version, and task, merged with the defaults.
269
-
270
- Parameters
271
- ----------
272
- framework_name (str): The name of the framework.
273
- version (Optional[str]): The specific version of the framework.
274
- task (Optional[str]): The task.
275
-
276
- Returns
277
- -------
278
- Dict[str, Any]: The merged model parameters.
279
- """
280
- return self .model_params .get_model_params (
281
- framework_name = framework_name , version = version , task = task
282
- )
283
-
284
210
def search_shapes (
285
211
self ,
286
212
evaluation_container : Optional [str ] = None ,
@@ -315,3 +241,59 @@ def search_shapes(
315
241
316
242
class Config :
317
243
extra = "ignore"
244
+
245
+
246
+ class EvaluationServiceConfig (Serializable ):
247
+ """
248
+ Root configuration class for evaluation setup including model,
249
+ inference, and shape configurations.
250
+ """
251
+
252
+ version : Optional [str ] = "1.0"
253
+ kind : Optional [str ] = "evaluation"
254
+ report_params : Optional [ReportParams ] = Field (default_factory = ReportParams )
255
+ inference_params : Optional [InferenceParamsConfig ] = Field (
256
+ default_factory = InferenceParamsConfig
257
+ )
258
+ inference_model_params : Optional [InferenceModelParamsConfig ] = Field (
259
+ default_factory = InferenceModelParamsConfig
260
+ )
261
+ ui_config : Optional [UIConfig ] = Field (default_factory = UIConfig )
262
+
263
+ def get_merged_inference_params (self , container_name : str ) -> InferenceParams :
264
+ """
265
+ Merges default inference params with those specific to the given container.
266
+
267
+ Params
268
+ ------
269
+ container_name (str): The name of the container.
270
+
271
+ Returns
272
+ -------
273
+ InferenceParams: The merged inference parameters.
274
+ """
275
+ return self .inference_params .get_merged_params (container_name = container_name )
276
+
277
+ def get_merged_inference_model_params (
278
+ self ,
279
+ container_name : str ,
280
+ version : Optional [str ] = None ,
281
+ ) -> Dict [str , Any ]:
282
+ """
283
+ Gets the model parameters for a given container, version, and task, merged with the defaults.
284
+
285
+ Parameters
286
+ ----------
287
+ container_name (str): The name of the container.
288
+ version (Optional[str]): The specific version of the container.
289
+
290
+ Returns
291
+ -------
292
+ Dict[str, Any]: The merged model parameters.
293
+ """
294
+ return self .inference_model_params .get_merged_model_params (
295
+ container_name = container_name , version = version
296
+ )
297
+
298
+ class Config :
299
+ extra = "ignore"
0 commit comments