Skip to content

Commit a99baea

Browse files
Fix vllm batch docker image (#463)
* Fix vllm batch docker image * try again with 0.2.5
1 parent 4a37de4 commit a99baea

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@ RUN apt-get update && \
66
rm -rf /var/lib/apt/lists/* && \
77
apt-get clean
88

9-
RUN pip uninstall torch -y
109
COPY model-engine/model_engine_server/inference/batch_inference/requirements.txt /workspace/requirements.txt
1110
RUN pip install -r requirements.txt
1211

12+
RUN pip uninstall torch -y
13+
RUN pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cu121
14+
15+
RUN pip uninstall xformers -y
16+
RUN pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu121
17+
1318
RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz
1419
RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz
1520

model-engine/model_engine_server/inference/batch_inference/requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
ray==2.6.3
2-
#git+https://github.com/vllm-project/vllm.git@4b61c6b669e368c6850531815940d9a542b9f223#egg=vllm
31
vllm==0.2.5
42
pydantic==1.10.13
53
boto3==1.34.15

model-engine/model_engine_server/inference/batch_inference/sample_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
"labels": {"team": "my_team"}
99
},
1010
"data_parallelism":2
11-
}
11+
}

model-engine/model_engine_server/inference/batch_inference/vllm_batch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
CONFIG_FILE = os.getenv("CONFIG_FILE")
1919
AWS_REGION = os.getenv("AWS_REGION", "us-west-2")
20+
MODEL_WEIGHTS_FOLDER = os.getenv("MODEL_WEIGHTS_FOLDER", "./model_weights")
2021

2122
os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default")
2223

@@ -118,15 +119,15 @@ async def batch_inference():
118119
request = CreateBatchCompletionsRequest.parse_file(CONFIG_FILE)
119120

120121
if request.model_config.checkpoint_path is not None:
121-
download_model(request.model_config.checkpoint_path, "./model_weights")
122+
download_model(request.model_config.checkpoint_path, MODEL_WEIGHTS_FOLDER)
122123

123124
content = request.content
124125
if content is None:
125126
with smart_open.open(request.input_data_path, "r") as f:
126127
content = CreateBatchCompletionsRequestContent.parse_raw(f.read())
127128

128129
model = (
129-
"./model_weights" if request.model_config.checkpoint_path else request.model_config.model
130+
MODEL_WEIGHTS_FOLDER if request.model_config.checkpoint_path else request.model_config.model
130131
)
131132

132133
results_generators = await generate_with_vllm(request, content, model, job_index)

0 commit comments

Comments
 (0)