Skip to content

Commit 69f8bcb

Browse files
Batch inference client / doc (#424)
* batch inference client / doc * fix * fixes
1 parent d4be9b9 commit 69f8bcb

File tree

6 files changed

+242
-1
lines changed

6 files changed

+242
-1
lines changed

clients/python/llmengine/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
CompletionStreamOutput,
2626
CompletionStreamResponse,
2727
CompletionSyncResponse,
28+
CreateBatchCompletionsModelConfig,
29+
CreateBatchCompletionsRequest,
30+
CreateBatchCompletionsRequestContent,
31+
CreateBatchCompletionsResponse,
2832
CreateFineTuneRequest,
2933
CreateFineTuneResponse,
3034
DeleteFileResponse,
@@ -51,6 +55,10 @@
5155
"CompletionStreamOutput",
5256
"CompletionStreamResponse",
5357
"CompletionSyncResponse",
58+
"CreateBatchCompletionsModelConfig",
59+
"CreateBatchCompletionsRequest",
60+
"CreateBatchCompletionsRequestContent",
61+
"CreateBatchCompletionsResponse",
5462
"CreateFineTuneRequest",
5563
"CreateFineTuneResponse",
5664
"DeleteFileResponse",

clients/python/llmengine/completion.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
CompletionStreamV1Request,
77
CompletionSyncResponse,
88
CompletionSyncV1Request,
9+
CreateBatchCompletionsModelConfig,
10+
CreateBatchCompletionsRequest,
11+
CreateBatchCompletionsRequestContent,
12+
CreateBatchCompletionsResponse,
913
)
1014

1115
COMPLETION_TIMEOUT = 300
16+
HTTP_TIMEOUT = 60
1217

1318

1419
class Completion(APIEngine):
@@ -397,3 +402,96 @@ def _create_stream(**kwargs):
397402
timeout=timeout,
398403
)
399404
return CompletionSyncResponse.parse_obj(response)
405+
406+
@classmethod
407+
def batch_create(
408+
cls,
409+
output_data_path: str,
410+
model_config: CreateBatchCompletionsModelConfig,
411+
content: Optional[CreateBatchCompletionsRequestContent] = None,
412+
input_data_path: Optional[str] = None,
413+
data_parallelism: int = 1,
414+
max_runtime_sec: int = 24 * 3600,
415+
) -> CreateBatchCompletionsResponse:
416+
"""
417+
Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint.
418+
419+
Prompts can be passed in from an input file, or as a part of the request.
420+
421+
Args:
422+
output_data_path (str):
423+
The path to the output file. The output file will be a JSON file containing the completions.
424+
425+
model_config (CreateBatchCompletionsModelConfig):
426+
The model configuration to use for the batch completion.
427+
428+
content (Optional[CreateBatchCompletionsRequestContent]):
429+
The content to use for the batch completion. Either one of `content` or `input_data_path` must be provided.
430+
431+
input_data_path (Optional[str]):
432+
The path to the input file. The input file should be a JSON file with data of type `BatchCompletionsRequestContent`. Either one of `content` or `input_data_path` must be provided.
433+
434+
data_parallelism (int):
435+
The number of parallel jobs to run. Data will be evenly distributed to the jobs. Defaults to 1.
436+
437+
max_runtime_sec (int):
438+
The maximum runtime of the batch completion in seconds. Defaults to 24 hours.
439+
440+
Returns:
441+
response (CreateBatchCompletionsResponse): The response containing the job id.
442+
443+
=== "Batch completions with prompts in the request"
444+
```python
445+
from llmengine import Completion
446+
from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent
447+
448+
response = Completion.batch_create(
449+
output_data_path="s3://my-path",
450+
model_config=CreateBatchCompletionsModelConfig(
451+
model="llama-2-7b",
452+
checkpoint_path="s3://checkpoint-path",
453+
labels={"team":"my-team", "product":"my-product"}
454+
),
455+
content=CreateBatchCompletionsRequestContent(
456+
prompts=["What is deep learning", "What is a neural network"],
457+
max_new_tokens=10,
458+
temperature=0.0
459+
)
460+
)
461+
print(response.json())
462+
```
463+
464+
=== "Batch completions with prompts in a file and with 2 parallel jobs"
465+
```python
466+
from llmengine import Completion
467+
from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent
468+
469+
# Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path"
470+
471+
response = Completion.batch_create(
472+
input_data_path="s3://my-input-path",
473+
output_data_path="s3://my-output-path",
474+
model_config=CreateBatchCompletionsModelConfig(
475+
model="llama-2-7b",
476+
checkpoint_path="s3://checkpoint-path",
477+
labels={"team":"my-team", "product":"my-product"}
478+
),
479+
data_parallelism=2
480+
)
481+
print(response.json())
482+
```
483+
"""
484+
data = CreateBatchCompletionsRequest(
485+
model_config=model_config,
486+
content=content,
487+
input_data_path=input_data_path,
488+
output_data_path=output_data_path,
489+
data_parallelism=data_parallelism,
490+
max_runtime_sec=max_runtime_sec,
491+
).dict()
492+
response = cls.post_sync(
493+
resource_name="v1/llm/batch-completions",
494+
data=data,
495+
timeout=HTTP_TIMEOUT,
496+
)
497+
return CreateBatchCompletionsResponse.parse_obj(response)

clients/python/llmengine/data_types.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,99 @@ class GetFileContentResponse(BaseModel):
591591

592592
content: str = Field(..., description="File content.")
593593
"""File content."""
594+
595+
596+
class CreateBatchCompletionsRequestContent(BaseModel):
597+
prompts: List[str]
598+
max_new_tokens: int
599+
temperature: float = Field(ge=0.0, le=1.0)
600+
"""
601+
Temperature of the sampling. Setting to 0 equals to greedy sampling.
602+
"""
603+
stop_sequences: Optional[List[str]] = None
604+
"""
605+
List of sequences to stop the completion at.
606+
"""
607+
return_token_log_probs: Optional[bool] = False
608+
"""
609+
Whether to return the log probabilities of the tokens.
610+
"""
611+
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
612+
"""
613+
Only supported in vllm, lightllm
614+
Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty
615+
"""
616+
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
617+
"""
618+
Only supported in vllm, lightllm
619+
Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty
620+
"""
621+
top_k: Optional[int] = Field(default=None, ge=-1)
622+
"""
623+
Controls the number of top tokens to consider. -1 means consider all tokens.
624+
"""
625+
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
626+
"""
627+
Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens.
628+
"""
629+
630+
631+
class CreateBatchCompletionsModelConfig(BaseModel):
632+
model: str
633+
checkpoint_path: Optional[str] = None
634+
"""
635+
Path to the checkpoint to load the model from.
636+
"""
637+
labels: Dict[str, str]
638+
"""
639+
Labels to attach to the batch inference job.
640+
"""
641+
num_shards: Optional[int] = 1
642+
"""
643+
Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config.
644+
System may decide to use a different number than the given value.
645+
"""
646+
quantize: Optional[Quantization] = None
647+
"""
648+
Whether to quantize the model.
649+
"""
650+
seed: Optional[int] = None
651+
"""
652+
Random seed for the model.
653+
"""
654+
655+
656+
class CreateBatchCompletionsRequest(BaseModel):
657+
"""
658+
Request object for batch completions.
659+
"""
660+
661+
input_data_path: Optional[str]
662+
output_data_path: str
663+
"""
664+
Path to the output file. The output file will be a JSON file of type List[CompletionOutput].
665+
"""
666+
content: Optional[CreateBatchCompletionsRequestContent] = None
667+
"""
668+
Either `input_data_path` or `content` needs to be provided.
669+
When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent.
670+
"""
671+
model_config: CreateBatchCompletionsModelConfig
672+
"""
673+
Model configuration for the batch inference. Hardware configurations are inferred.
674+
"""
675+
data_parallelism: Optional[int] = Field(default=1, ge=1, le=64)
676+
"""
677+
Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference.
678+
"""
679+
max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600)
680+
"""
681+
Maximum runtime of the batch inference in seconds. Default to one day.
682+
"""
683+
684+
685+
class CreateBatchCompletionsResponse(BaseModel):
686+
job_id: str
687+
"""
688+
The ID of the batch completions job.
689+
"""

docs/api/data_types.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,41 @@
110110
options:
111111
members:
112112
- deleted
113+
114+
::: llmengine.CreateBatchCompletionsRequestContent
115+
options:
116+
members:
117+
- prompts
118+
- max_new_tokens
119+
- temperature
120+
- stop_sequences
121+
- return_token_log_probs
122+
- presence_penalty
123+
- frequency_penalty
124+
- top_k
125+
- top_p
126+
127+
::: llmengine.CreateBatchCompletionsModelConfig
128+
options:
129+
members:
130+
- model
131+
- checkpoint_path
132+
- labels
133+
- num_shards
134+
- quantize
135+
- seed
136+
137+
::: llmengine.CreateBatchCompletionsRequest
138+
options:
139+
members:
140+
- input_data_path
141+
- output_data_path
142+
- content
143+
- model_config
144+
- data_parallelism
145+
- max_runtime_sec
146+
147+
::: llmengine.CreateBatchCompletionsResponse
148+
options:
149+
members:
150+
- job_id

docs/api/python_client.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
members:
66
- create
77
- acreate
8+
- batch_create
89

910
::: llmengine.FineTune
1011
options:

docs/contributing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pip install -r requirements-docs.txt
2121
Our Python client API reference is autogenerated from our client. You can install the client in editable mode with
2222

2323
```
24-
pip install -r clients/python
24+
pip install -e clients/python
2525
```
2626

2727
### Step 4: Run Locally

0 commit comments

Comments
 (0)