Skip to content

Commit 9cc430b

Browse files
add ignore validation flag while registering model
1 parent 612bf71 commit 9cc430b

File tree

4 files changed

+113
-52
lines changed

4 files changed

+113
-52
lines changed

ads/aqua/extension/model_handler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
133133
ignore_patterns = input_data.get("ignore_patterns")
134134
freeform_tags = input_data.get("freeform_tags")
135135
defined_tags = input_data.get("defined_tags")
136+
ignore_model_artifact_check = (
137+
str(input_data.get("ignore_model_artifact_check", "false")).lower()
138+
== "true"
139+
)
136140

137141
return self.finish(
138142
AquaModelApp().register(
@@ -149,6 +153,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
149153
ignore_patterns=ignore_patterns,
150154
freeform_tags=freeform_tags,
151155
defined_tags=defined_tags,
156+
ignore_model_artifact_check=ignore_model_artifact_check,
152157
)
153158
)
154159

ads/aqua/model/entities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class ImportModelDetails(CLIBuilderMixin):
293293
ignore_patterns: Optional[List[str]] = None
294294
freeform_tags: Optional[dict] = None
295295
defined_tags: Optional[dict] = None
296+
ignore_model_artifact_check: Optional[bool] = None
296297

297298
def __post_init__(self):
298299
self._command = "model register"

ads/aqua/model/model.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,9 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]:
972972
# todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models
973973
# are grouped in one category and validation checks for config.json files only.
974974
if model_format == ModelFormat.SAFETENSORS:
975+
model_files.extend(
976+
list_os_files_with_extension(oss_path=os_path, extension=".safetensors")
977+
)
975978
try:
976979
load_config(
977980
file_path=os_path,
@@ -1022,10 +1025,12 @@ def get_hf_model_files(model_name: str, model_format: ModelFormat) -> List[str]:
10221025

10231026
for model_sibling in model_siblings:
10241027
extension = pathlib.Path(model_sibling.rfilename).suffix[1:].upper()
1025-
if model_format == ModelFormat.SAFETENSORS:
1026-
if model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG:
1027-
model_files.append(model_sibling.rfilename)
1028-
elif extension == model_format.value:
1028+
if (
1029+
model_format == ModelFormat.SAFETENSORS
1030+
and model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG
1031+
):
1032+
model_files.append(model_sibling.rfilename)
1033+
if extension == model_format.value:
10291034
model_files.append(model_sibling.rfilename)
10301035

10311036
return model_files
@@ -1061,7 +1066,10 @@ def _validate_model(
10611066
safetensors_model_files = self.get_hf_model_files(
10621067
model_name, ModelFormat.SAFETENSORS
10631068
)
1064-
if safetensors_model_files:
1069+
if (
1070+
safetensors_model_files
1071+
and AQUA_MODEL_ARTIFACT_CONFIG in safetensors_model_files
1072+
):
10651073
hf_download_config_present = True
10661074
gguf_model_files = self.get_hf_model_files(model_name, ModelFormat.GGUF)
10671075
else:
@@ -1173,14 +1181,20 @@ def _validate_safetensor_format(
11731181
model_name: str = None,
11741182
):
11751183
if import_model_details.download_from_hf:
1176-
# validates config.json exists for safetensors model from hugginface
1177-
if not hf_download_config_present:
1184+
# validates config.json exists for safetensors model from huggingface
1185+
if not (
1186+
hf_download_config_present
1187+
or import_model_details.ignore_model_artifact_check
1188+
):
11781189
raise AquaRuntimeError(
11791190
f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
11801191
f"by {ModelFormat.SAFETENSORS.value} format model."
11811192
f" Please check if the model name is correct in Hugging Face repository."
11821193
)
1194+
validation_result.telemetry_model_name = model_name
11831195
else:
1196+
# validate if config.json is available from object storage, and get model name for telemetry
1197+
model_config = None
11841198
try:
11851199
model_config = load_config(
11861200
file_path=import_model_details.os_path,
@@ -1191,22 +1205,25 @@ def _validate_safetensor_format(
11911205
f"Exception occurred while loading config file from {import_model_details.os_path}"
11921206
f"Exception message: {ex}"
11931207
)
1194-
raise AquaRuntimeError(
1195-
f"The model path {import_model_details.os_path} does not contain the file config.json. "
1196-
f"Please check if the path is correct or the model artifacts are available at this location."
1197-
) from ex
1198-
else:
1208+
if not import_model_details.ignore_model_artifact_check:
1209+
raise AquaRuntimeError(
1210+
f"The model path {import_model_details.os_path} does not contain the file config.json. "
1211+
f"Please check if the path is correct or the model artifacts are available at this location."
1212+
) from ex
1213+
1214+
if verified_model:
1215+
# model_type validation, log message if metadata field doesn't match.
11991216
try:
12001217
metadata_model_type = verified_model.custom_metadata_list.get(
12011218
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
12021219
).value
1203-
if metadata_model_type:
1220+
if metadata_model_type and model_config is not None:
12041221
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
12051222
if (
12061223
model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
12071224
!= metadata_model_type
12081225
):
1209-
raise AquaRuntimeError(
1226+
logger.debug(
12101227
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
12111228
f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
12121229
f"the model {model_name}. Please check if the path is correct or "
@@ -1219,21 +1236,22 @@ def _validate_safetensor_format(
12191236
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
12201237
)
12211238
except Exception:
1239+
# todo: raise exception if model_type doesn't match. Currently log message and pass since service
1240+
# models do not have this metadata.
12221241
pass
1223-
if verified_model:
1224-
validation_result.telemetry_model_name = verified_model.display_name
1225-
elif (
1226-
model_config is not None
1227-
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
1228-
):
1229-
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
1230-
elif (
1231-
model_config is not None
1232-
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
1233-
):
1234-
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
1235-
else:
1236-
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
1242+
validation_result.telemetry_model_name = verified_model.display_name
1243+
elif (
1244+
model_config is not None
1245+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
1246+
):
1247+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
1248+
elif (
1249+
model_config is not None
1250+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
1251+
):
1252+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
1253+
else:
1254+
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
12371255

12381256
@staticmethod
12391257
def _validate_gguf_format(

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -920,10 +920,18 @@ def test_import_model_with_project_compartment_override(
920920
assert model.project_id == project_override
921921

922922
@pytest.mark.parametrize(
923-
"download_from_hf",
924-
[True, False],
923+
("ignore_artifact_check", "download_from_hf"),
924+
[
925+
(True, True),
926+
(True, False),
927+
(False, True),
928+
(False, False),
929+
(None, False),
930+
(None, True),
931+
],
925932
)
926933
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
934+
@patch("ads.model.datascience_model.DataScienceModel.sync")
927935
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
928936
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
929937
@patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError)
@@ -936,45 +944,65 @@ def test_import_model_with_missing_config(
936944
mock_load_config,
937945
mock_list_objects,
938946
mock_upload_artifact,
947+
mock_sync,
939948
mock_ocidsc_create,
940-
mock_get_container_config,
949+
ignore_artifact_check,
941950
download_from_hf,
942951
mock_get_hf_model_info,
943952
mock_init_client,
944953
):
945-
"""Test for validating if error is returned when model artifacts are incomplete or not available."""
946-
947-
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
948-
model_name = "oracle/aqua-1t-mega-model"
954+
my_model = "oracle/aqua-1t-mega-model"
949955
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
950-
mock_list_objects.return_value = MagicMock(objects=[])
951-
reload(ads.aqua.model.model)
952-
app = AquaModelApp()
953-
app.list = MagicMock(return_value=[])
956+
# set object list from OSS without config.json
957+
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
954958

959+
# set object list from HF without config.json
955960
if download_from_hf:
956-
with pytest.raises(AquaValueError):
957-
mock_get_hf_model_info.return_value.siblings = []
958-
with tempfile.TemporaryDirectory() as tmpdir:
959-
model: AquaModel = app.register(
960-
model=model_name,
961-
os_path=os_path,
962-
local_dir=str(tmpdir),
963-
download_from_hf=True,
964-
)
961+
mock_get_hf_model_info.return_value.siblings = [
962+
MagicMock(rfilename="model.safetensors")
963+
]
965964
else:
966-
with pytest.raises(AquaRuntimeError):
965+
obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150)
966+
obj1.name = f"prefix/path/model.safetensors"
967+
objects = [obj1]
968+
mock_list_objects.return_value = MagicMock(objects=objects)
969+
970+
reload(ads.aqua.model.model)
971+
app = AquaModelApp()
972+
with patch.object(AquaModelApp, "list") as aqua_model_mock_list:
973+
aqua_model_mock_list.return_value = [
974+
AquaModelSummary(
975+
id="test_id1",
976+
name="organization1/name1",
977+
organization="organization1",
978+
)
979+
]
980+
981+
if ignore_artifact_check:
967982
model: AquaModel = app.register(
968-
model=model_name,
983+
model=my_model,
969984
os_path=os_path,
970-
download_from_hf=False,
985+
inference_container="odsc-vllm-or-tgi-container",
986+
finetuning_container="odsc-llm-fine-tuning",
987+
download_from_hf=download_from_hf,
988+
ignore_model_artifact_check=ignore_artifact_check,
971989
)
990+
assert model.ready_to_deploy is True
991+
else:
992+
with pytest.raises(AquaRuntimeError):
993+
model: AquaModel = app.register(
994+
model=my_model,
995+
os_path=os_path,
996+
inference_container="odsc-vllm-or-tgi-container",
997+
finetuning_container="odsc-llm-fine-tuning",
998+
download_from_hf=download_from_hf,
999+
ignore_model_artifact_check=ignore_artifact_check,
1000+
)
9721001

9731002
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
9741003
@patch("ads.model.datascience_model.DataScienceModel.sync")
9751004
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
9761005
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
977-
@patch.object(HfApi, "model_info")
9781006
@patch("ads.aqua.common.utils.load_config", return_value={})
9791007
def test_import_any_model_smc_container(
9801008
self,
@@ -1230,6 +1258,15 @@ def test_import_model_with_input_tags(
12301258
"--download_from_hf True --inference_container odsc-vllm-serving --freeform_tags "
12311259
'{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}',
12321260
),
1261+
(
1262+
{
1263+
"os_path": "oci://aqua-bkt@aqua-ns/path",
1264+
"model": "oracle/oracle-1it",
1265+
"inference_container": "odsc-vllm-serving",
1266+
"ignore_model_artifact_check": True,
1267+
},
1268+
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-vllm-serving --ignore_model_artifact_check True",
1269+
),
12331270
],
12341271
)
12351272
def test_import_cli(self, data, expected_output):

0 commit comments

Comments
 (0)