Skip to content

Commit 884b88f

Browse files
Finetuning and MD default params dict (#850)
2 parents c51aa4a + 3f0edfa commit 884b88f

File tree

4 files changed

+28
-48
lines changed

4 files changed

+28
-48
lines changed

ads/aqua/finetuning/finetuning.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
import os
88
from dataclasses import asdict, fields
9-
from typing import Dict, List
9+
from typing import Dict
1010

1111
from oci.data_science.models import (
1212
Metadata,
@@ -559,7 +559,7 @@ def get_finetuning_config(self, model_id: str) -> Dict:
559559
entry_point="plugin=finetuning&action=get_finetuning_default_params",
560560
name="aqua",
561561
)
562-
def get_finetuning_default_params(self, model_id: str) -> List[str]:
562+
def get_finetuning_default_params(self, model_id: str) -> Dict:
563563
"""Gets the default params set in the finetuning configs for the given model. Only the fields that are
564564
available in AquaFineTuningParams will be accessible for user overrides.
565565
@@ -570,21 +570,19 @@ def get_finetuning_default_params(self, model_id: str) -> List[str]:
570570
571571
Returns
572572
-------
573-
List[str]:
574-
List of parameters from the loaded from finetuning config json file. If config information is available,
575-
then an empty list is returned.
573+
Dict:
574+
Dict of parameters from the loaded from finetuning config json file. If config information is not available,
575+
then an empty dict is returned.
576576
"""
577-
default_params = []
577+
default_params = {"params": {}}
578578
finetuning_config = self.get_finetuning_config(model_id)
579579
config_parameters = finetuning_config.get("configuration", UNKNOWN_DICT)
580580
dataclass_fields = {field.name for field in fields(AquaFineTuningParams)}
581581
for name, value in config_parameters.items():
582582
if name == "micro_batch_size":
583583
name = "batch_size"
584-
if name == "lora_target_modules":
585-
value = ",".join(str(k) for k in value)
586584
if name in dataclass_fields:
587-
default_params.append(f"--{name} {str(value).lower()}")
585+
default_params["params"][name] = value
588586

589587
return default_params
590588

ads/aqua/modeldeployment/deployment.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def get_deployment_default_params(
527527
self,
528528
model_id: str,
529529
instance_shape: str,
530-
) -> Dict:
530+
) -> List[str]:
531531
"""Gets the default params set in the deployment configs for the given model and instance shape.
532532
533533
Parameters
@@ -546,8 +546,6 @@ def get_deployment_default_params(
546546
547547
"""
548548
default_params = []
549-
container_type = UNKNOWN
550-
551549
model = DataScienceModel.from_id(model_id)
552550
try:
553551
container_type_key = model.custom_metadata_list.get(
@@ -572,12 +570,10 @@ def get_deployment_default_params(
572570
params = config_parameters.get(
573571
InferenceContainerParamType.PARAM_TYPE_VLLM, UNKNOWN
574572
)
575-
container_type = InferenceContainerType.CONTAINER_TYPE_VLLM
576573
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_key:
577574
params = config_parameters.get(
578575
InferenceContainerParamType.PARAM_TYPE_TGI, UNKNOWN
579576
)
580-
container_type = InferenceContainerType.CONTAINER_TYPE_TGI
581577
else:
582578
params = UNKNOWN
583579
logger.debug(
@@ -588,10 +584,7 @@ def get_deployment_default_params(
588584
# account for param that can have --arg but no values, e.g. --trust-remote-code
589585
default_params.extend(get_params_list(params))
590586

591-
return dict(
592-
container_type=container_type,
593-
params=default_params,
594-
)
587+
return default_params
595588

596589
def validate_deployment_params(
597590
self, model_id: str, params: List[str] = None

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -423,38 +423,29 @@ def yaml_to_json(input_file):
423423
[
424424
(
425425
"VLLM_PARAMS",
426-
"vllm",
427426
"odsc-vllm-serving",
428427
["--max-model-len 4096", "--seed 42", "--trust-remote-code"],
429428
),
430429
(
431430
"VLLM_PARAMS",
432-
"vllm",
433431
"odsc-vllm-serving",
434432
[],
435433
),
436434
(
437435
"TGI_PARAMS",
438-
"tgi",
439436
"odsc-tgi-serving",
440437
["--sharded true", "--trust-remote-code"],
441438
),
442439
(
443440
"CUSTOM_PARAMS",
444-
"",
445441
"custom-container-key",
446442
["--max-model-len 4096", "--seed 42", "--trust-remote-code"],
447443
),
448444
]
449445
)
450446
@patch("ads.model.datascience_model.DataScienceModel.from_id")
451447
def test_get_deployment_default_params(
452-
self,
453-
container_params_field,
454-
container_type,
455-
container_type_key,
456-
params,
457-
mock_from_id,
448+
self, container_params_field, container_type_key, params, mock_from_id
458449
):
459450
"""Test for fetching config details for a given deployment."""
460451

@@ -480,12 +471,10 @@ def test_get_deployment_default_params(
480471
result = self.app.get_deployment_default_params(
481472
TestDataset.MODEL_ID, TestDataset.DEPLOYMENT_SHAPE_NAME
482473
)
483-
484-
assert result["container_type"] == container_type
485474
if container_params_field == "CUSTOM_PARAMS":
486-
assert result["params"] == []
475+
assert result == []
487476
else:
488-
assert result["params"] == params
477+
assert result == params
489478

490479
@parameterized.expand(
491480
[

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -237,32 +237,32 @@ def test_build_oci_launch_cmd(self):
237237
def test_get_finetuning_default_params(self):
238238
"""Test for fetching finetuning config params for a given model."""
239239

240+
params_dict = {
241+
"params": {
242+
"batch_size": 1,
243+
"sequence_len": 2048,
244+
"sample_packing": True,
245+
"pad_to_sequence_len": True,
246+
"learning_rate": 0.0002,
247+
"lora_r": 32,
248+
"lora_alpha": 16,
249+
"lora_dropout": 0.05,
250+
"lora_target_linear": True,
251+
"lora_target_modules": ["q_proj", "k_proj"],
252+
}
253+
}
240254
config_json = os.path.join(self.curr_dir, "test_data/finetuning/ft_config.json")
241255
with open(config_json, "r") as _file:
242256
config = json.load(_file)
243257

244258
self.app.get_finetuning_config = MagicMock(return_value=config)
245259
result = self.app.get_finetuning_default_params(model_id="test_model_id")
246-
self.assertCountEqual(
247-
result,
248-
[
249-
"--batch_size 1",
250-
"--sequence_len 2048",
251-
"--sample_packing true",
252-
"--pad_to_sequence_len true",
253-
"--learning_rate 0.0002",
254-
"--lora_r 32",
255-
"--lora_alpha 16",
256-
"--lora_dropout 0.05",
257-
"--lora_target_linear true",
258-
"--lora_target_modules q_proj,k_proj",
259-
],
260-
)
260+
assert result == params_dict
261261

262262
# check when config json is not available
263263
self.app.get_finetuning_config = MagicMock(return_value={})
264264
result = self.app.get_finetuning_default_params(model_id="test_model_id")
265-
assert result == []
265+
assert result == {"params": {}}
266266

267267
@parameterized.expand(
268268
[

0 commit comments

Comments
 (0)