Skip to content

Commit 6a04cca

Browse files
Add tool completion to batch inference (#461)
* test * Impl * remove logging * lints * fix tests * fix * fix * try fix unit test * fix * no cover * stop sequences
1 parent a99baea commit 6a04cca

File tree

14 files changed

+985
-84
lines changed

14 files changed

+985
-84
lines changed

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,30 @@ class CreateBatchCompletionsModelConfig(BaseModel):
429429
"""
430430

431431

432+
class ToolConfig(BaseModel):
433+
"""
434+
Configuration for tool use.
435+
NOTE: this config is highly experimental and signature will change significantly in future iterations.
436+
"""
437+
438+
name: str
439+
"""
440+
Name of the tool to use for the batch inference.
441+
"""
442+
max_iterations: Optional[int] = 10
443+
"""
444+
Maximum number of iterations to run the tool.
445+
"""
446+
execution_timeout_seconds: Optional[int] = 60
447+
"""
448+
Maximum runtime of the tool in seconds.
449+
"""
450+
should_retry_on_error: Optional[bool] = True
451+
"""
452+
Whether to retry the tool on error.
453+
"""
454+
455+
432456
class CreateBatchCompletionsRequest(BaseModel):
433457
"""
434458
Request object for batch completions.
@@ -456,6 +480,11 @@ class CreateBatchCompletionsRequest(BaseModel):
456480
"""
457481
Maximum runtime of the batch inference in seconds. Default to one day.
458482
"""
483+
tool_config: Optional[ToolConfig] = None
484+
"""
485+
Configuration for tool use.
486+
NOTE: this config is highly experimental and signature will change significantly in future iterations.
487+
"""
459488

460489

461490
class CreateBatchCompletionsResponse(BaseModel):

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,6 +2275,11 @@ async def execute(
22752275
hardware.gpus = max(hardware.gpus, request.model_config.num_shards)
22762276
request.model_config.num_shards = hardware.gpus
22772277

2278+
if request.tool_config and request.tool_config.name != "code_evaluator":
2279+
raise ObjectHasInvalidValueException(
2280+
"Only code_evaluator tool is supported for batch completions."
2281+
)
2282+
22782283
batch_bundle = await self.create_batch_job_bundle(user, request, hardware)
22792284

22802285
validate_resource_requests(
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import json
2+
3+
COMPLETION_PROMPT1 = """\
4+
FYI: you can write code like this:
5+
```python
6+
import math
7+
print(math.sqrt(2))
8+
```
9+
1.41...
10+
>>>
11+
12+
For reference, the third digit of 4.32 is 2. Also, use "Final Answer: X" to indicate your final answer.
13+
14+
### Problem:
15+
16+
What is the 4th digit of pi?
17+
18+
### Answer:
19+
```python
20+
import math
21+
print(math.pi)
22+
```
23+
3.141592653589793
24+
>>>
25+
26+
Final Answer: 1
27+
28+
### Problem:
29+
30+
What is the 4th digit of the square root of 2?
31+
32+
### Answer:
33+
"""
34+
35+
COMPLETION_PROMPT2 = """\
36+
FYI: you can write code like this:
37+
```python
38+
import math
39+
print(math.sqrt(2))
40+
```
41+
1.41...
42+
>>>
43+
44+
For reference, the third digit of 4.32 is 2. Also, use "Final Answer: X" to indicate your final answer.
45+
46+
### Problem:
47+
48+
What is the 4th digit of pi?
49+
50+
### Answer:
51+
```python
52+
import math
53+
print(math.pi)
54+
```
55+
3.141592653589793
56+
>>>
57+
58+
Final Answer: 1
59+
60+
### Problem:
61+
62+
What is the 5th digit of the square root of 2?
63+
64+
### Answer:
65+
"""
66+
67+
data = {
68+
"prompts": [
69+
COMPLETION_PROMPT1,
70+
COMPLETION_PROMPT2,
71+
"what is deep learning",
72+
],
73+
"max_new_tokens": 100,
74+
"temperature": 0.0,
75+
"return_token_log_probs": True,
76+
"stop_sequences": ["</s>", "\n### Problem:\n", ">>>\n"],
77+
}
78+
79+
json.dump(data, open("sample_data_tool.json", "w"))

model-engine/model_engine_server/inference/batch_inference/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@ vllm==0.2.5
22
pydantic==1.10.13
33
boto3==1.34.15
44
smart-open==6.4.0
5-
ddtrace==2.4.0
5+
ddtrace==2.4.0
6+
docker==7.0.0
7+
func-timeout==4.3.5
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"input_data_path":"./sample_data_tool.json",
3+
"output_data_path":"./sample_output_tool.json",
4+
"model_config":{
5+
"model":"mistral-7b",
6+
"checkpoint_path":"s3://scale-ml/models/mistral-7b",
7+
"num_shards": 1,
8+
"labels": {"team": "my_team"}
9+
},
10+
"data_parallelism":2,
11+
"tool_config": {
12+
"name": "code_evaluator"
13+
}
14+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"prompts": [
3+
"FYI: you can write code like this: \n```python\nimport math\nprint(math.sqrt(2))\n```\n1.41...\n>>>\n\nFor reference, the third digit of 4.32 is 2. Also, use \"Final Answer: X\" to indicate your final answer.\n\n### Problem:\n\nWhat is the 4th digit of pi?\n\n### Answer:\n```python\nimport math\nprint(math.pi)\n```\n3.141592653589793\n>>>\n\nFinal Answer: 1\n\n### Problem:\n\nWhat is the 4th digit of the square root of 2?\n\n### Answer: \n",
4+
"FYI: you can write code like this: \n```python\nimport math\nprint(math.sqrt(2))\n```\n1.41...\n>>>\n\nFor reference, the third digit of 4.32 is 2. Also, use \"Final Answer: X\" to indicate your final answer.\n\n### Problem:\n\nWhat is the 4th digit of pi?\n\n### Answer:\n```python\nimport math\nprint(math.pi)\n```\n3.141592653589793\n>>>\n\nFinal Answer: 1\n\n### Problem:\n\nWhat is the 5th digit of the square root of 2?\n\n### Answer: \n",
5+
"what is deep learning"
6+
],
7+
"max_new_tokens": 100,
8+
"temperature": 0.0,
9+
"return_token_log_probs": true,
10+
"stop_sequences": [
11+
"</s>",
12+
"\n### Problem:\n",
13+
">>>\n"
14+
]
15+
}

0 commit comments

Comments
 (0)