Skip to content

Commit 242d62b

Browse files
LLM batch completions API (#418)
* wip * wip * wip * batch run files * wip * fix * nonindex * fixes * log * fix vllm version * fix config * aws profile * add dumb-init and ddtrace * use batch to start * fix path * delete temp s3 files * fix unit test * Add unit tests * mypy * fix tests * comments
1 parent 4a3424f commit 242d62b

28 files changed

+1114
-14
lines changed

charts/model-engine/templates/service_template_config_map.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,9 @@ data:
615615
backoffLimit: 0
616616
activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME}
617617
ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED}
618+
completions: ${BATCH_JOB_NUM_WORKERS}
619+
parallelism: ${BATCH_JOB_NUM_WORKERS}
620+
completionMode: "Indexed"
618621
template:
619622
metadata:
620623
labels:

charts/model-engine/values_circleci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ config:
151151
vllm_repository: "vllm"
152152
lightllm_repository: "lightllm"
153153
tensorrt_llm_repository: "tensorrt-llm"
154+
batch_inference_vllm_repository: "llm-engine/batch-infer-vllm"
154155
user_inference_base_repository: "launch/inference"
155156
user_inference_pytorch_repository: "hosted-model-inference/async-pytorch"
156157
user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu"

charts/model-engine/values_sample.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ config:
207207
vllm_repository: "vllm"
208208
lightllm_repository: "lightllm"
209209
tensorrt_llm_repository: "tensorrt-llm"
210+
batch_inference_vllm_repository: "llm-engine/batch-infer-vllm"
210211
user_inference_base_repository: "launch/inference"
211212
user_inference_pytorch_repository: "launch/inference/pytorch"
212213
user_inference_tensorflow_repository: "launch/inference/tf"

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
CompletionStreamV1Response,
2020
CompletionSyncV1Request,
2121
CompletionSyncV1Response,
22+
CreateBatchCompletionsRequest,
23+
CreateBatchCompletionsResponse,
2224
CreateFineTuneRequest,
2325
CreateFineTuneResponse,
2426
CreateLLMModelEndpointV1Request,
@@ -73,6 +75,7 @@
7375
from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import (
7476
CompletionStreamV1UseCase,
7577
CompletionSyncV1UseCase,
78+
CreateBatchCompletionsUseCase,
7679
CreateLLMModelBundleV1UseCase,
7780
CreateLLMModelEndpointV1UseCase,
7881
DeleteLLMEndpointByNameUseCase,
@@ -568,3 +571,26 @@ async def delete_llm_model_endpoint(
568571
status_code=500,
569572
detail="deletion of endpoint failed.",
570573
) from exc
574+
575+
576+
@llm_router_v1.post("/batch-completions", response_model=CreateBatchCompletionsResponse)
577+
async def create_batch_completions(
578+
request: CreateBatchCompletionsRequest,
579+
auth: User = Depends(verify_authentication),
580+
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces),
581+
) -> CreateBatchCompletionsResponse:
582+
logger.info(f"POST /batch-completions with {request} for {auth}")
583+
try:
584+
use_case = CreateBatchCompletionsUseCase(
585+
docker_image_batch_job_gateway=external_interfaces.docker_image_batch_job_gateway,
586+
docker_repository=external_interfaces.docker_repository,
587+
docker_image_batch_job_bundle_repo=external_interfaces.docker_image_batch_job_bundle_repository,
588+
)
589+
return await use_case.execute(user=auth, request=request)
590+
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
591+
raise HTTPException(
592+
status_code=404,
593+
detail="The specified endpoint could not be found.",
594+
) from exc
595+
except (InvalidRequestException, ObjectHasInvalidValueException) as exc:
596+
raise HTTPException(status_code=400, detail=str(exc))

model-engine/model_engine_server/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class HostedModelInferenceServiceConfig:
5858
vllm_repository: str
5959
lightllm_repository: str
6060
tensorrt_llm_repository: str
61+
batch_inference_vllm_repository: str
6162
user_inference_base_repository: str
6263
user_inference_pytorch_repository: str
6364
user_inference_tensorflow_repository: str

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,21 @@ class CreateLLMModelEndpointV1Request(BaseModel):
3030
# LLM specific fields
3131
model_name: str
3232
source: LLMSource = LLMSource.HUGGING_FACE
33-
inference_framework: LLMInferenceFramework = LLMInferenceFramework.DEEPSPEED
34-
inference_framework_image_tag: str
33+
inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM
34+
inference_framework_image_tag: str = "latest"
3535
num_shards: int = 1
3636
"""
37-
Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models
37+
Number of shards to distribute the model onto GPUs.
3838
"""
3939

4040
quantize: Optional[Quantization] = None
4141
"""
42-
Whether to quantize the model. Only affect behavior for text-generation-inference models
42+
Whether to quantize the model.
4343
"""
4444

4545
checkpoint_path: Optional[str] = None
4646
"""
47-
Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models
47+
Path to the checkpoint to load the model from.
4848
"""
4949

5050
# General endpoint fields
@@ -102,17 +102,17 @@ class UpdateLLMModelEndpointV1Request(BaseModel):
102102
inference_framework_image_tag: Optional[str]
103103
num_shards: Optional[int]
104104
"""
105-
Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models
105+
Number of shards to distribute the model onto GPUs.
106106
"""
107107

108108
quantize: Optional[Quantization]
109109
"""
110-
Whether to quantize the model. Only affect behavior for text-generation-inference models
110+
Whether to quantize the model.
111111
"""
112112

113113
checkpoint_path: Optional[str]
114114
"""
115-
Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models
115+
Path to the checkpoint to load the model from.
116116
"""
117117

118118
# General endpoint fields
@@ -220,7 +220,7 @@ class CompletionStreamV1Request(BaseModel):
220220
"""
221221
return_token_log_probs: Optional[bool] = False
222222
"""
223-
Whether to return the log probabilities of the tokens. Only affects behavior for text-generation-inference models
223+
Whether to return the log probabilities of the tokens.
224224
"""
225225
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
226226
"""
@@ -359,3 +359,104 @@ class ModelDownloadResponse(BaseModel):
359359

360360
class DeleteLLMEndpointResponse(BaseModel):
361361
deleted: bool
362+
363+
364+
class CreateBatchCompletionsRequestContent(BaseModel):
365+
prompts: List[str]
366+
max_new_tokens: int
367+
temperature: float = Field(ge=0.0, le=1.0)
368+
"""
369+
Temperature of the sampling. Setting to 0 equals to greedy sampling.
370+
"""
371+
stop_sequences: Optional[List[str]] = None
372+
"""
373+
List of sequences to stop the completion at.
374+
"""
375+
return_token_log_probs: Optional[bool] = False
376+
"""
377+
Whether to return the log probabilities of the tokens.
378+
"""
379+
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
380+
"""
381+
Only supported in vllm, lightllm
382+
Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty
383+
"""
384+
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
385+
"""
386+
Only supported in vllm, lightllm
387+
Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty
388+
"""
389+
top_k: Optional[int] = Field(default=None, ge=-1)
390+
"""
391+
Controls the number of top tokens to consider. -1 means consider all tokens.
392+
"""
393+
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
394+
"""
395+
Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens.
396+
"""
397+
398+
399+
class CreateBatchCompletionsModelConfig(BaseModel):
400+
model: str
401+
checkpoint_path: Optional[str] = None
402+
"""
403+
Path to the checkpoint to load the model from.
404+
"""
405+
labels: Dict[str, str]
406+
"""
407+
Labels to attach to the batch inference job.
408+
"""
409+
num_shards: Optional[int] = 1
410+
"""
411+
Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config.
412+
System may decide to use a different number than the given value.
413+
"""
414+
quantize: Optional[Quantization] = None
415+
"""
416+
Whether to quantize the model.
417+
"""
418+
seed: Optional[int] = None
419+
"""
420+
Random seed for the model.
421+
"""
422+
423+
424+
class CreateBatchCompletionsRequest(BaseModel):
425+
"""
426+
Request object for batch completions.
427+
"""
428+
429+
input_data_path: Optional[str]
430+
output_data_path: str
431+
"""
432+
Path to the output file. The output file will be a JSON file of type List[CompletionOutput].
433+
"""
434+
content: Optional[CreateBatchCompletionsRequestContent] = None
435+
"""
436+
Either `input_data_path` or `content` needs to be provided.
437+
When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent.
438+
"""
439+
model_config: CreateBatchCompletionsModelConfig
440+
"""
441+
Model configuration for the batch inference. Hardware configurations are inferred.
442+
"""
443+
data_parallelism: Optional[int] = Field(default=1, ge=1, le=64)
444+
"""
445+
Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference.
446+
"""
447+
max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600)
448+
"""
449+
Maximum runtime of the batch inference in seconds. Default to one day.
450+
"""
451+
452+
453+
class CreateBatchCompletionsResponse(BaseModel):
454+
job_id: str
455+
456+
457+
class GetBatchCompletionsResponse(BaseModel):
458+
progress: float
459+
"""
460+
Progress of the batch inference in percentage from 0 to 100.
461+
"""
462+
finished: bool

model-engine/model_engine_server/domain/entities/batch_job_entity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,4 @@ class DockerImageBatchJob(BaseModel):
6161
status: BatchJobStatus # the status map relatively nicely onto BatchJobStatus
6262
annotations: Optional[Dict[str, str]] = None
6363
override_job_max_runtime_s: Optional[int] = None
64+
num_workers: Optional[int] = 1

model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ async def create_docker_image_batch_job(
2626
mount_location: Optional[str],
2727
annotations: Optional[Dict[str, str]] = None,
2828
override_job_max_runtime_s: Optional[int] = None,
29+
num_workers: Optional[int] = 1,
2930
) -> str:
3031
"""
3132
Create a docker image batch job
@@ -42,6 +43,8 @@ async def create_docker_image_batch_job(
4243
annotations: K8s annotations
4344
resource_requests: The resource requests for the batch job.
4445
mount_location: Location on filesystem where runtime-provided file contents get mounted
46+
override_job_max_runtime_s: Optional override for the maximum runtime of the job
47+
num_workers: num of pods to run in this job. Coordination needs to happen between the workers.
4548
4649
4750
Returns:

0 commit comments

Comments
 (0)