Skip to content

Commit 2b4ef4f

Browse files
authored
Allow latest inference framework tag (#403)
1 parent c820e72 commit 2b4ef4f

File tree

11 files changed

+79
-18
lines changed

11 files changed

+79
-18
lines changed

clients/python/llmengine/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def create(
6767
Name of the base model
6868
6969
inference_framework_image_tag (`str`):
70-
Image tag for the inference framework
70+
Image tag for the inference framework. Use "latest" for the most recent image
7171
7272
source (`LLMSource`):
7373
Source of the LLM. Currently only HuggingFace is supported

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ async def create_model_endpoint(
164164
use_case = CreateLLMModelEndpointV1UseCase(
165165
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
166166
model_endpoint_service=external_interfaces.model_endpoint_service,
167+
docker_repository=external_interfaces.docker_repository,
167168
)
168169
return await use_case.execute(user=auth, request=request)
169170
except ObjectAlreadyExistsException as exc:
@@ -265,6 +266,7 @@ async def update_model_endpoint(
265266
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
266267
model_endpoint_service=external_interfaces.model_endpoint_service,
267268
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
269+
docker_repository=external_interfaces.docker_repository,
268270
)
269271
return await use_case.execute(
270272
user=auth, model_endpoint_name=model_endpoint_name, request=request

model-engine/model_engine_server/core/docker/ecr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,15 @@ def ecr_exists_for_repo(repo_name: str, image_tag: Optional[str] = None):
9797
return True
9898
except ecr.exceptions.ImageNotFoundException:
9999
return False
100+
101+
102+
def get_latest_image_tag(repository_name: str):
103+
ecr = boto3.client("ecr", region_name=infra_config().default_region)
104+
images = ecr.describe_images(
105+
registryId=infra_config().ml_account_id,
106+
repositoryName=repository_name,
107+
filter=DEFAULT_FILTER,
108+
maxResults=1000,
109+
)["imageDetails"]
110+
latest_image = max(images, key=lambda image: image["imagePushedAt"])
111+
return latest_image["imageTags"][0]

model-engine/model_engine_server/domain/repositories/docker_repository.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
4949
"""
5050
pass
5151

52+
@abstractmethod
53+
def get_latest_image_tag(self, repository_name: str) -> str:
54+
"""
55+
Returns the Docker image tag of the most recently pushed image in the given repository
56+
57+
Args:
58+
repository_name: the name of the repository containing the image.
59+
60+
Returns: the tag of the latest Docker image.
61+
"""
62+
5263
def is_repo_name(self, repo_name: str):
5364
# We assume repository names must start with a letter and can only contain lowercase letters, numbers, hyphens, underscores, and forward slashes.
5465
# Based-off ECR naming standards

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@
8989
logger = make_logger(logger_name())
9090

9191

92+
INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = {
93+
LLMInferenceFramework.DEEPSPEED: "instant-llm",
94+
LLMInferenceFramework.TEXT_GENERATION_INFERENCE: hmi_config.tgi_repository,
95+
LLMInferenceFramework.VLLM: hmi_config.vllm_repository,
96+
LLMInferenceFramework.LIGHTLLM: hmi_config.lightllm_repository,
97+
LLMInferenceFramework.TENSORRT_LLM: hmi_config.tensorrt_llm_repository,
98+
}
99+
92100
_SUPPORTED_MODELS_BY_FRAMEWORK = {
93101
LLMInferenceFramework.DEEPSPEED: set(
94102
[
@@ -332,8 +340,10 @@ async def execute(
332340
checkpoint_path: Optional[str],
333341
) -> ModelBundle:
334342
if source == LLMSource.HUGGING_FACE:
343+
self.check_docker_image_exists_for_image_tag(
344+
framework_image_tag, INFERENCE_FRAMEWORK_REPOSITORY[framework]
345+
)
335346
if framework == LLMInferenceFramework.DEEPSPEED:
336-
self.check_docker_image_exists_for_image_tag(framework_image_tag, "instant-llm")
337347
bundle_id = await self.create_deepspeed_bundle(
338348
user,
339349
model_name,
@@ -342,9 +352,6 @@ async def execute(
342352
endpoint_name,
343353
)
344354
elif framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE:
345-
self.check_docker_image_exists_for_image_tag(
346-
framework_image_tag, hmi_config.tgi_repository
347-
)
348355
bundle_id = await self.create_text_generation_inference_bundle(
349356
user,
350357
model_name,
@@ -355,9 +362,6 @@ async def execute(
355362
checkpoint_path,
356363
)
357364
elif framework == LLMInferenceFramework.VLLM:
358-
self.check_docker_image_exists_for_image_tag(
359-
framework_image_tag, hmi_config.vllm_repository
360-
)
361365
bundle_id = await self.create_vllm_bundle(
362366
user,
363367
model_name,
@@ -368,9 +372,6 @@ async def execute(
368372
checkpoint_path,
369373
)
370374
elif framework == LLMInferenceFramework.LIGHTLLM:
371-
self.check_docker_image_exists_for_image_tag(
372-
framework_image_tag, hmi_config.lightllm_repository
373-
)
374375
bundle_id = await self.create_lightllm_bundle(
375376
user,
376377
model_name,
@@ -862,10 +863,12 @@ def __init__(
862863
self,
863864
create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase,
864865
model_endpoint_service: ModelEndpointService,
866+
docker_repository: DockerRepository,
865867
):
866868
self.authz_module = LiveAuthorizationModule()
867869
self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case
868870
self.model_endpoint_service = model_endpoint_service
871+
self.docker_repository = docker_repository
869872

870873
async def execute(
871874
self, user: User, request: CreateLLMModelEndpointV1Request
@@ -895,6 +898,11 @@ async def execute(
895898
f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM, LightLLM, and TensorRT-LLM."
896899
)
897900

901+
if request.inference_framework_image_tag == "latest":
902+
request.inference_framework_image_tag = self.docker_repository.get_latest_image_tag(
903+
INFERENCE_FRAMEWORK_REPOSITORY[request.inference_framework]
904+
)
905+
898906
bundle = await self.create_llm_model_bundle_use_case.execute(
899907
user,
900908
endpoint_name=request.name,
@@ -1059,11 +1067,13 @@ def __init__(
10591067
create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase,
10601068
model_endpoint_service: ModelEndpointService,
10611069
llm_model_endpoint_service: LLMModelEndpointService,
1070+
docker_repository: DockerRepository,
10621071
):
10631072
self.authz_module = LiveAuthorizationModule()
10641073
self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case
10651074
self.model_endpoint_service = model_endpoint_service
10661075
self.llm_model_endpoint_service = llm_model_endpoint_service
1076+
self.docker_repository = docker_repository
10671077

10681078
async def execute(
10691079
self, user: User, model_endpoint_name: str, request: UpdateLLMModelEndpointV1Request
@@ -1106,12 +1116,18 @@ async def execute(
11061116
llm_metadata = (model_endpoint.record.metadata or {}).get("_llm", {})
11071117
inference_framework = llm_metadata["inference_framework"]
11081118

1119+
if request.inference_framework_image_tag == "latest":
1120+
inference_framework_image_tag = self.docker_repository.get_latest_image_tag(
1121+
INFERENCE_FRAMEWORK_REPOSITORY[inference_framework]
1122+
)
1123+
else:
1124+
inference_framework_image_tag = (
1125+
request.inference_framework_image_tag
1126+
or llm_metadata["inference_framework_image_tag"]
1127+
)
1128+
11091129
model_name = request.model_name or llm_metadata["model_name"]
11101130
source = request.source or llm_metadata["source"]
1111-
inference_framework_image_tag = (
1112-
request.inference_framework_image_tag
1113-
or llm_metadata["inference_framework_image_tag"]
1114-
)
11151131
num_shards = request.num_shards or llm_metadata["num_shards"]
11161132
quantize = request.quantize or llm_metadata.get("quantize")
11171133
checkpoint_path = request.checkpoint_path or llm_metadata.get("checkpoint_path")

model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from model_engine_server.common.config import hmi_config
44
from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse
55
from model_engine_server.core.config import infra_config
6+
from model_engine_server.core.docker.ecr import get_latest_image_tag
67
from model_engine_server.core.docker.ecr import image_exists as ecr_image_exists
78
from model_engine_server.core.docker.remote_build import build_remote_block
89
from model_engine_server.core.loggers import logger_name, make_logger
@@ -52,3 +53,6 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
5253
return BuildImageResponse(
5354
status=build_result.status, logs=build_result.logs, job_name=build_result.job_name
5455
)
56+
57+
def get_latest_image_tag(self, repository_name: str) -> str:
58+
return get_latest_image_tag(repository_name=repository_name)

model-engine/model_engine_server/infra/repositories/fake_docker_repository.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ def get_image_url(self, image_tag: str, repository_name: str) -> str:
1919

2020
def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
2121
raise NotImplementedError("FakeDockerRepository build_image() not implemented")
22+
23+
def get_latest_image_tag(self, repository_name: str) -> str:
24+
raise NotImplementedError("FakeDockerRepository get_latest_image_tag() not implemented")

model-engine/setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ test=pytest
55
omit =
66
model_engine_server/entrypoints/*
77
model_engine_server/api/app.py
8+
model_engine_server/core/docker/ecr.py
89

910
# TODO: Fix pylint errors
1011
# [pylint]

model-engine/tests/unit/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,9 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
672672
raise Exception("I hope you're handling this!")
673673
return BuildImageResponse(status=True, logs="", job_name="test-job-name")
674674

675+
def get_latest_image_tag(self, repository_name: str) -> str:
676+
return "fake_docker_repository_latest_image_tag"
677+
675678

676679
class FakeModelEndpointCacheRepository(ModelEndpointCacheRepository):
677680
def __init__(self):

model-engine/tests/unit/domain/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request
203203
model_name="mpt-7b",
204204
source="hugging_face",
205205
inference_framework="deepspeed",
206-
inference_framework_image_tag="test_tag",
206+
inference_framework_image_tag="latest",
207207
num_shards=2,
208208
endpoint_type=ModelEndpointType.ASYNC,
209209
metadata={},
@@ -252,6 +252,7 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req
252252
@pytest.fixture
253253
def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request:
254254
return UpdateLLMModelEndpointV1Request(
255+
inference_framework_image_tag="latest",
255256
checkpoint_path="s3://test_checkpoint_path",
256257
memory="4G",
257258
min_workers=0,

0 commit comments

Comments
 (0)