@@ -284,6 +284,56 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success(
284
284
)
285
285
286
286
287
+ def test_load_model_weights_sub_commands (
288
+ fake_model_bundle_repository ,
289
+ fake_model_endpoint_service ,
290
+ fake_docker_repository_image_always_exists ,
291
+ fake_model_primitive_gateway ,
292
+ fake_llm_artifact_gateway ,
293
+ ):
294
+ fake_model_endpoint_service .model_bundle_repository = fake_model_bundle_repository
295
+ bundle_use_case = CreateModelBundleV2UseCase (
296
+ model_bundle_repository = fake_model_bundle_repository ,
297
+ docker_repository = fake_docker_repository_image_always_exists ,
298
+ model_primitive_gateway = fake_model_primitive_gateway ,
299
+ )
300
+ llm_bundle_use_case = CreateLLMModelBundleV1UseCase (
301
+ create_model_bundle_use_case = bundle_use_case ,
302
+ model_bundle_repository = fake_model_bundle_repository ,
303
+ llm_artifact_gateway = fake_llm_artifact_gateway ,
304
+ docker_repository = fake_docker_repository_image_always_exists ,
305
+ )
306
+
307
+ framework = LLMInferenceFramework .VLLM
308
+ framework_image_tag = "0.2.7"
309
+ checkpoint_path = "fake-checkpoint"
310
+ final_weights_folder = "test_folder"
311
+
312
+ subcommands = llm_bundle_use_case .load_model_weights_sub_commands (
313
+ framework , framework_image_tag , checkpoint_path , final_weights_folder
314
+ )
315
+
316
+ expected_result = [
317
+ "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder" ,
318
+ ]
319
+ assert expected_result == subcommands
320
+
321
+ framework = LLMInferenceFramework .TEXT_GENERATION_INFERENCE
322
+ framework_image_tag = "1.0.0"
323
+ checkpoint_path = "fake-checkpoint"
324
+ final_weights_folder = "test_folder"
325
+
326
+ subcommands = llm_bundle_use_case .load_model_weights_sub_commands (
327
+ framework , framework_image_tag , checkpoint_path , final_weights_folder
328
+ )
329
+
330
+ expected_result = [
331
+ "s5cmd > /dev/null || conda install -c conda-forge -y s5cmd" ,
332
+ "s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder" ,
333
+ ]
334
+ assert expected_result == subcommands
335
+
336
+
287
337
@pytest .mark .asyncio
288
338
async def test_create_model_endpoint_trt_llm_use_case_success (
289
339
test_api_key : str ,
0 commit comments