Skip to content

Commit 34552b8

Browse files
authored
Add product label validation (#442)
Adds support for a new shared plugin that validates the product and team labels
1 parent eee364f commit 34552b8

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import re
8+
from dataclasses import dataclass
89
from typing import Any, Dict, List, Optional
910

1011
from model_engine_server.common.constants import SUPPORTED_POST_INFERENCE_HOOKS
@@ -46,6 +47,8 @@
4647
CONVERTED_FROM_ARTIFACT_LIKE_KEY = "_CONVERTED_FROM_ARTIFACT_LIKE"
4748
MODEL_BUNDLE_CHANGED_KEY = "_MODEL_BUNDLE_CHANGED"
4849

50+
DEFAULT_DISALLOWED_TEAMS = ["_INVALID_TEAM"]
51+
4952
logger = make_logger(logger_name())
5053

5154

@@ -118,6 +121,20 @@ def validate_deployment_resources(
118121
)
119122

120123

124+
@dataclass
125+
class ValidationResult:
126+
passed: bool
127+
message: str
128+
129+
130+
# Placeholder team and product label validator that only checks for a single invalid team
131+
def simple_team_product_validator(team: str, product: str) -> ValidationResult:
132+
if team in DEFAULT_DISALLOWED_TEAMS:
133+
return ValidationResult(False, "Invalid team")
134+
else:
135+
return ValidationResult(True, "Valid team")
136+
137+
121138
def validate_labels(labels: Dict[str, str]) -> None:
122139
for required_label in REQUIRED_ENDPOINT_LABELS:
123140
if required_label not in labels:
@@ -129,6 +146,7 @@ def validate_labels(labels: Dict[str, str]) -> None:
129146
if restricted_label in labels:
130147
raise EndpointLabelsException(f"Cannot specify '{restricted_label}' in labels")
131148

149+
# TODO: remove after we fully migrate to the new team + product validator
132150
try:
133151
from plugins.known_users import ALLOWED_TEAMS
134152

@@ -138,6 +156,15 @@ def validate_labels(labels: Dict[str, str]) -> None:
138156
except ModuleNotFoundError:
139157
pass
140158

159+
try:
160+
from shared_plugins.team_product_label_validation import validate_team_product_label
161+
except ModuleNotFoundError:
162+
validate_team_product_label = simple_team_product_validator
163+
164+
validation_result = validate_team_product_label(labels["team"], labels["product"])
165+
if not validation_result.passed:
166+
raise EndpointLabelsException(validation_result.message)
167+
141168
# Check k8s will accept the label values
142169
regex_pattern = "(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?" # k8s label regex
143170
for label_value in labels.values():

model-engine/tests/unit/api/test_model_endpoints.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from fastapi.testclient import TestClient
66
from model_engine_server.common.dtos.model_endpoints import GetModelEndpointV1Response
77
from model_engine_server.domain.entities import ModelBundle, ModelEndpoint, ModelEndpointStatus
8+
from model_engine_server.domain.use_cases.model_endpoint_use_cases import DEFAULT_DISALLOWED_TEAMS
89

910

1011
def test_create_model_endpoint_success(
@@ -40,7 +41,6 @@ def test_create_model_endpoint_success(
4041
assert response_2.status_code == 200
4142

4243

43-
@pytest.mark.skip(reason="TODO: team validation is currently disabled")
4444
def test_create_model_endpoint_invalid_team_returns_400(
4545
model_bundle_1_v1: Tuple[ModelBundle, Any],
4646
create_model_endpoint_request_sync: Dict[str, Any],
@@ -59,15 +59,16 @@ def test_create_model_endpoint_invalid_team_returns_400(
5959
fake_batch_job_progress_gateway_contents={},
6060
fake_docker_image_batch_job_bundle_repository_contents={},
6161
)
62-
create_model_endpoint_request_sync["labels"]["team"] = "some_invalid_team"
62+
invalid_team_name = DEFAULT_DISALLOWED_TEAMS[0]
63+
create_model_endpoint_request_sync["labels"]["team"] = invalid_team_name
6364
response_1 = client.post(
6465
"/v1/model-endpoints",
6566
auth=(test_api_key, ""),
6667
json=create_model_endpoint_request_sync,
6768
)
6869
assert response_1.status_code == 400
6970

70-
create_model_endpoint_request_async["labels"]["team"] = "some_invalid_team"
71+
create_model_endpoint_request_async["labels"]["team"] = invalid_team_name
7172
response_2 = client.post(
7273
"/v1/model-endpoints",
7374
auth=(test_api_key, ""),
@@ -394,7 +395,6 @@ def test_update_model_endpoint_by_id_success(
394395
assert response.json()["endpoint_creation_task_id"]
395396

396397

397-
@pytest.mark.skip(reason="TODO: team validation is currently disabled")
398398
def test_update_model_endpoint_by_id_invalid_team_returns_400(
399399
model_bundle_1_v1: Tuple[ModelBundle, Any],
400400
model_endpoint_1: Tuple[ModelEndpoint, Any],
@@ -418,8 +418,9 @@ def test_update_model_endpoint_by_id_invalid_team_returns_400(
418418
fake_batch_job_progress_gateway_contents={},
419419
fake_docker_image_batch_job_bundle_repository_contents={},
420420
)
421+
invalid_team_name = DEFAULT_DISALLOWED_TEAMS[0]
421422
update_model_endpoint_request["labels"] = {
422-
"team": "some_invalid_team",
423+
"team": invalid_team_name,
423424
"product": "my_product",
424425
}
425426
response = client.put(

0 commit comments

Comments
 (0)