Skip to content

Commit 0797702

Browse files
Some batch inference improvements (#460)
* Some batch inference improvements * fix unit test * coverage * integration test * fix
1 parent 06563a1 commit 0797702

File tree

6 files changed

+128
-12
lines changed

6 files changed

+128
-12
lines changed

docs/guides/completions.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,55 @@ async def main():
120120
asyncio.run(main())
121121
```
122122

123+
## Batch completions
124+
125+
The Python client also supports batch completions. Batch completions supports distributing data to multiple workers to accelerate inference. It also tries to maximize throughput so the completions should finish quite a bit faster than hitting models through HTTP. Use [Completion.batch_complete](../../api/python_client/#llmengine.completion.Completion.batch_complete) to utilize batch completions.
126+
127+
Some examples of batch completions:
128+
129+
=== "Batch completions with prompts in the request"
130+
```python
131+
from llmengine import Completion
132+
from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent
133+
134+
content = CreateBatchCompletionsRequestContent(
135+
prompts=["What is deep learning", "What is a neural network"],
136+
max_new_tokens=10,
137+
temperature=0.0
138+
)
139+
140+
response = Completion.batch_create(
141+
output_data_path="s3://my-path",
142+
model_config=CreateBatchCompletionsModelConfig(
143+
model="llama-2-7b",
144+
checkpoint_path="s3://checkpoint-path",
145+
labels={"team":"my-team", "product":"my-product"}
146+
),
147+
content=content
148+
)
149+
print(response.job_id)
150+
```
151+
152+
=== "Batch completions with prompts in a file and with 2 parallel jobs"
153+
```python
154+
from llmengine import Completion
155+
from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent
156+
157+
# Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path"
158+
159+
response = Completion.batch_create(
160+
input_data_path="s3://my-input-path",
161+
output_data_path="s3://my-output-path",
162+
model_config=CreateBatchCompletionsModelConfig(
163+
model="llama-2-7b",
164+
checkpoint_path="s3://checkpoint-path",
165+
labels={"team":"my-team", "product":"my-product"}
166+
),
167+
data_parallelism=2
168+
)
169+
print(response.job_id)
170+
```
171+
123172
## Which model should I use?
124173

125174
See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions.

model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
FROM nvcr.io/nvidia/pytorch:23.09-py3
22

33
RUN apt-get update && \
4-
apt-get install -y dumb-init && \
4+
apt-get install -y dumb-init psmisc && \
55
apt-get autoremove -y && \
66
rm -rf /var/lib/apt/lists/* && \
77
apt-get clean

model-engine/model_engine_server/inference/batch_inference/vllm_batch.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def download_model(checkpoint_path, final_weights_folder):
3333
# Need to override these env vars so s5cmd uses AWS_PROFILE
3434
env["AWS_ROLE_ARN"] = ""
3535
env["AWS_WEB_IDENTITY_TOKEN_FILE"] = ""
36+
# nosemgrep
3637
process = subprocess.Popen(
3738
s5cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env
3839
)
@@ -193,6 +194,7 @@ async def generate_with_vllm(request, content, model, job_index):
193194
tensor_parallel_size=request.model_config.num_shards,
194195
seed=request.model_config.seed or 0,
195196
disable_log_requests=True,
197+
gpu_memory_utilization=0.8, # To avoid OOM errors when there's host machine GPU usage
196198
)
197199

198200
llm = AsyncLLMEngine.from_engine_args(engine_args)
@@ -220,5 +222,33 @@ async def generate_with_vllm(request, content, model, job_index):
220222
return results_generators
221223

222224

225+
def get_gpu_free_memory(): # pragma: no cover
226+
"""Get GPU free memory using nvidia-smi."""
227+
try:
228+
output = subprocess.check_output(
229+
["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"]
230+
).decode("utf-8")
231+
gpu_memory = [int(x) for x in output.strip().split("\n")]
232+
return gpu_memory
233+
except subprocess.CalledProcessError:
234+
return None
235+
236+
237+
def check_unknown_startup_memory_usage(): # pragma: no cover
238+
"""Check for unknown memory usage at startup."""
239+
gpu_free_memory = get_gpu_free_memory()
240+
if gpu_free_memory is not None:
241+
min_mem = min(gpu_free_memory)
242+
max_mem = max(gpu_free_memory)
243+
if max_mem - min_mem > 10:
244+
print(
245+
f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}."
246+
)
247+
# nosemgrep
248+
output = subprocess.check_output(["fuser -v /dev/nvidia*"], shell=True).decode("utf-8")
249+
print(f"Processes using GPU: {output}")
250+
251+
223252
if __name__ == "__main__":
253+
check_unknown_startup_memory_usage()
224254
asyncio.run(batch_inference())

model-engine/model_engine_server/inference/vllm/vllm_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def check_unknown_startup_memory_usage():
114114
print(
115115
f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}."
116116
)
117+
# nosemgrep
117118
output = subprocess.check_output(["fuser -v /dev/nvidia*"], shell=True).decode("utf-8")
118119
print(f"Processes using GPU: {output}")
119120

model-engine/model_engine_server/infra/services/image_cache_service.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytz
55
from model_engine_server.common.config import hmi_config
6-
from model_engine_server.common.env_vars import GIT_TAG
6+
from model_engine_server.common.env_vars import CIRCLECI, GIT_TAG
77
from model_engine_server.core.config import infra_config
88
from model_engine_server.core.loggers import logger_name, make_logger
99
from model_engine_server.domain.entities import GpuType, ModelEndpointInfraState
@@ -69,17 +69,38 @@ def _cache_finetune_llm_images(
6969
)
7070

7171
istio_image = DockerImage("gcr.io/istio-release/proxyv2", "1.15.0")
72-
tgi_image = DockerImage(
73-
f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.3-launch_s3"
72+
tgi_image_110 = DockerImage(
73+
f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "1.1.0"
7474
)
75-
tgi_image_2 = DockerImage(
76-
f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.4"
75+
vllm_image_027 = DockerImage(
76+
f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.2.7"
77+
)
78+
vllm_image_032 = DockerImage(
79+
f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2"
80+
)
81+
latest_tag = (
82+
self.docker_repository.get_latest_image_tag(
83+
f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}"
84+
)
85+
if not CIRCLECI
86+
else "fake_docker_repository_latest_image_tag"
87+
)
88+
vllm_batch_image_latest = DockerImage(
89+
f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}",
90+
latest_tag,
7791
)
7892
forwarder_image = DockerImage(
7993
f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG
8094
)
8195

82-
for llm_image in [istio_image, tgi_image, tgi_image_2, forwarder_image]:
96+
for llm_image in [
97+
istio_image,
98+
tgi_image_110,
99+
vllm_image_027,
100+
vllm_image_032,
101+
vllm_batch_image_latest,
102+
forwarder_image,
103+
]:
83104
if self.docker_repository.is_repo_name(
84105
llm_image.repo
85106
) and not self.docker_repository.image_exists(llm_image.tag, llm_image.repo):

model-engine/tests/unit/infra/services/test_image_cache_service.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,29 @@ async def test_caching_finetune_llm_images(
5252
gateway: Any = fake_image_cache_service.image_cache_gateway
5353

5454
istio_image = DockerImage("gcr.io/istio-release/proxyv2", "1.15.0")
55-
tgi_image = DockerImage(
56-
f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.3-launch_s3"
55+
tgi_image_110 = DockerImage(
56+
f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "1.1.0"
5757
)
58-
tgi_image_2 = DockerImage(
59-
f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.4"
58+
vllm_image_027 = DockerImage(
59+
f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.2.7"
60+
)
61+
vllm_image_032 = DockerImage(
62+
f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2"
63+
)
64+
latest_tag = "fake_docker_repository_latest_image_tag"
65+
vllm_batch_image_latest = DockerImage(
66+
f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}",
67+
latest_tag,
6068
)
6169
forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG)
6270

6371
for key in ["a10", "a100"]:
64-
for llm_image in [istio_image, tgi_image, tgi_image_2, forwarder_image]:
72+
for llm_image in [
73+
istio_image,
74+
tgi_image_110,
75+
vllm_image_027,
76+
vllm_image_032,
77+
vllm_batch_image_latest,
78+
forwarder_image,
79+
]:
6580
assert f"{llm_image.repo}:{llm_image.tag}" in gateway.cached_images[key]

0 commit comments

Comments
 (0)