Skip to content

Commit e6e9111

Browse files
Revert a broken refactoring (#423)
* Revert a broken refactoring * fix
1 parent 2fa8fbb commit e6e9111

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -507,14 +507,6 @@ def load_model_weights_sub_commands(
507507
else:
508508
s5cmd = "./s5cmd"
509509

510-
subcommands.extend(
511-
self.get_s5cmd_copy_command(checkpoint_path, final_weights_folder, subcommands, s5cmd)
512-
)
513-
514-
return subcommands
515-
516-
def get_s5cmd_copy_command(self, checkpoint_path, final_weights_folder, s5cmd):
517-
subcommands = []
518510
base_path = checkpoint_path.split("/")[-1]
519511
if base_path.endswith(".tar"):
520512
# If the checkpoint file is a tar file, extract it into final_weights_folder
@@ -535,6 +527,7 @@ def get_s5cmd_copy_command(self, checkpoint_path, final_weights_folder, s5cmd):
535527
subcommands.append(
536528
f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
537529
)
530+
538531
return subcommands
539532

540533
def load_model_files_sub_commands_trt_llm(

model-engine/tests/unit/domain/test_llm_use_cases.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,56 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success(
284284
)
285285

286286

287+
def test_load_model_weights_sub_commands(
288+
fake_model_bundle_repository,
289+
fake_model_endpoint_service,
290+
fake_docker_repository_image_always_exists,
291+
fake_model_primitive_gateway,
292+
fake_llm_artifact_gateway,
293+
):
294+
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
295+
bundle_use_case = CreateModelBundleV2UseCase(
296+
model_bundle_repository=fake_model_bundle_repository,
297+
docker_repository=fake_docker_repository_image_always_exists,
298+
model_primitive_gateway=fake_model_primitive_gateway,
299+
)
300+
llm_bundle_use_case = CreateLLMModelBundleV1UseCase(
301+
create_model_bundle_use_case=bundle_use_case,
302+
model_bundle_repository=fake_model_bundle_repository,
303+
llm_artifact_gateway=fake_llm_artifact_gateway,
304+
docker_repository=fake_docker_repository_image_always_exists,
305+
)
306+
307+
framework = LLMInferenceFramework.VLLM
308+
framework_image_tag = "0.2.7"
309+
checkpoint_path = "fake-checkpoint"
310+
final_weights_folder = "test_folder"
311+
312+
subcommands = llm_bundle_use_case.load_model_weights_sub_commands(
313+
framework, framework_image_tag, checkpoint_path, final_weights_folder
314+
)
315+
316+
expected_result = [
317+
"./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder",
318+
]
319+
assert expected_result == subcommands
320+
321+
framework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE
322+
framework_image_tag = "1.0.0"
323+
checkpoint_path = "fake-checkpoint"
324+
final_weights_folder = "test_folder"
325+
326+
subcommands = llm_bundle_use_case.load_model_weights_sub_commands(
327+
framework, framework_image_tag, checkpoint_path, final_weights_folder
328+
)
329+
330+
expected_result = [
331+
"s5cmd > /dev/null || conda install -c conda-forge -y s5cmd",
332+
"s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder",
333+
]
334+
assert expected_result == subcommands
335+
336+
287337
@pytest.mark.asyncio
288338
async def test_create_model_endpoint_trt_llm_use_case_success(
289339
test_api_key: str,

0 commit comments

Comments
 (0)