Skip to content

Commit 14093cc

Browse files
Merge branch 'vllm-project:main' into large-block-size-fix
2 parents 98bc88b + a931b4c commit 14093cc

File tree

21 files changed

+799
-419
lines changed

21 files changed

+799
-419
lines changed

csrc/attention/mla/sm100_cutlass_mla_kernel.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
1919
* by Alcanderian JieXin Liang
2020
*/
21+
#include "core/registration.h"
2122

2223
#include <ATen/cuda/CUDAContext.h>
2324
#include <c10/cuda/CUDAGuard.h>
@@ -270,4 +271,13 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
270271
}
271272

272273
#endif
274+
275+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
276+
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
277+
}
278+
279+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
280+
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
281+
}
282+
273283
// clang-format on

csrc/ops.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
167167
torch::Tensor const& seq_lens,
168168
torch::Tensor const& page_table, double scale);
169169

170-
void sm100_cutlass_mla_decode(
171-
torch::Tensor const& out, torch::Tensor const& q_nope,
172-
torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache,
173-
torch::Tensor const& seq_lens, torch::Tensor const& page_table,
174-
torch::Tensor const& workspace, double sm_scale,
175-
int64_t num_kv_splits =
176-
1 /* Set to 1 to avoid cuda_graph issue by default. */);
177-
178-
int64_t sm100_cutlass_mla_get_workspace_size(
179-
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0,
180-
int64_t num_kv_splits =
181-
1 /* Set to 1 to avoid cuda_graph issue by default. */);
182-
183170
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
184171

185172
#ifndef USE_ROCM

csrc/torch_bindings.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,15 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
521521
" Tensor page_table, Tensor workspace, float "
522522
"scale,"
523523
" int num_kv_splits) -> ()");
524-
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
524+
// conditionally compiled so impl in source file
525525

526526
// SM100 CUTLASS MLA workspace
527527
ops.def(
528528
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
529529
" int sm_count, int num_kv_splits) "
530530
"-> int");
531-
ops.impl("sm100_cutlass_mla_get_workspace_size",
532-
&sm100_cutlass_mla_get_workspace_size);
531+
// conditionally compiled so impl in source file
533532

534533
// Compute NVFP4 block quantized tensor.
535534
ops.def(

docs/assets/deployment/open_webui.png

-10.4 KB
Loading
Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,42 @@
11
# Open WebUI
22

3-
1. Install the [Docker](https://docs.docker.com/engine/install/)
3+
[Open WebUI](https://github.com/open-webui/open-webui) is an extensible, feature-rich,
4+
and user-friendly self-hosted AI platform designed to operate entirely offline.
5+
It supports various LLM runners like Ollama and OpenAI-compatible APIs,
6+
with built-in RAG capabilities, making it a powerful AI deployment solution.
47

5-
2. Start the vLLM server with the supported chat completion model, e.g.
8+
To get started with Open WebUI using vLLM, follow these steps:
69

7-
```bash
8-
vllm serve qwen/Qwen1.5-0.5B-Chat
9-
```
10+
1. Install the [Docker](https://docs.docker.com/engine/install/).
1011

11-
1. Start the [Open WebUI](https://github.com/open-webui/open-webui) docker container (replace the vllm serve host and vllm serve port):
12+
2. Start the vLLM server with a supported chat completion model:
1213

13-
```bash
14-
docker run -d -p 3000:8080 \
15-
--name open-webui \
16-
-v open-webui:/app/backend/data \
17-
-e OPENAI_API_BASE_URL=http://<vllm serve host>:<vllm serve port>/v1 \
18-
--restart always \
19-
ghcr.io/open-webui/open-webui:main
20-
```
14+
```console
15+
vllm serve Qwen/Qwen3-0.6B-Chat
16+
```
2117

22-
1. Open it in the browser: <http://open-webui-host:3000/>
18+
!!! note
19+
When starting the vLLM server, be sure to specify the host and port using the `--host` and `--port` flags.
20+
For example:
2321

24-
On the top of the web page, you can see the model `qwen/Qwen1.5-0.5B-Chat`.
22+
```console
23+
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000
24+
```
2525

26-
![](../../assets/deployment/open_webui.png)
26+
3. Start the Open WebUI Docker container:
27+
28+
```console
29+
docker run -d \
30+
--name open-webui \
31+
-p 3000:8080 \
32+
-v open-webui:/app/backend/data \
33+
-e OPENAI_API_BASE_URL=http://0.0.0.0:8000/v1 \
34+
--restart always \
35+
ghcr.io/open-webui/open-webui:main
36+
```
37+
38+
4. Open it in the browser: <http://open-webui-host:3000/>
39+
40+
At the top of the page, you should see the model `Qwen/Qwen3-0.6B-Chat`.
41+
42+
![Web portal of model Qwen/Qwen3-0.6B-Chat](../../assets/deployment/open_webui.png)

tests/distributed/test_pipeline_parallel.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
import pytest
1616

17-
from vllm.config import TaskOption
17+
from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, TaskOption
1818
from vllm.logger import init_logger
19+
from vllm.transformers_utils.config import get_config
1920

2021
from ..models.registry import HF_EXAMPLE_MODELS
2122
from ..utils import compare_two_settings, create_new_process_for_each_test
@@ -158,7 +159,7 @@ def iter_params(self, model_id: str):
158159
"databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
159160
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
160161
"deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
161-
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(),
162+
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(tp_base=2),
162163
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
163164
"tiiuae/falcon-7b": PPTestSettings.fast(),
164165
"google/gemma-1.1-2b-it": PPTestSettings.fast(),
@@ -210,9 +211,11 @@ def iter_params(self, model_id: str):
210211

211212
EMBEDDING_MODELS = { # type: ignore[var-annotated]
212213
# [Text-only]
213-
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
214-
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
215-
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(load_format="dummy"),
214+
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(task="embed"),
215+
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(task="embed"),
216+
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
217+
load_format="dummy", task="embed"
218+
),
216219
}
217220

218221
MULTIMODAL_MODELS = {
@@ -248,6 +251,7 @@ def iter_params(self, model_id: str):
248251
"meta-llama/Llama-3.2-1B-Instruct",
249252
"ArthurZ/Ilama-3.2-1B",
250253
"ibm/PowerLM-3b",
254+
"deepseek-ai/DeepSeek-V2-Lite-Chat",
251255
# [LANGUAGE EMBEDDING]
252256
"intfloat/e5-mistral-7b-instruct",
253257
"BAAI/bge-multilingual-gemma2",
@@ -287,6 +291,11 @@ def _compare_tp(
287291
trust_remote_code = model_info.trust_remote_code
288292
tokenizer_mode = model_info.tokenizer_mode
289293
hf_overrides = model_info.hf_overrides
294+
hf_config = get_config(model_id, trust_remote_code)
295+
296+
dtype = "float16"
297+
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
298+
dtype = "bfloat16"
290299

291300
if load_format == "dummy":
292301
# Avoid OOM
@@ -316,7 +325,7 @@ def _compare_tp(
316325
common_args = [
317326
# use half precision for speed and memory savings in CI environment
318327
"--dtype",
319-
"float16",
328+
dtype,
320329
"--max-model-len",
321330
"2048",
322331
"--max-num-seqs",
@@ -338,6 +347,7 @@ def _compare_tp(
338347
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
339348

340349
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
350+
testing_ray_compiled_graph = False
341351
if distributed_backend == "ray" and (vllm_major_version == "1"
342352
or specific_case):
343353
# For V1, test Ray Compiled Graph for all the tests
@@ -351,6 +361,7 @@ def _compare_tp(
351361
# Temporary. Currently when zeromq + SPMD is used, it does not properly
352362
# terminate because of a Ray Compiled Graph issue.
353363
common_args.append("--disable-frontend-multiprocessing")
364+
testing_ray_compiled_graph = True
354365
elif distributed_backend == "mp":
355366
# Both V0/V1 of multiprocessing executor support PP
356367
pp_env = {
@@ -394,7 +405,6 @@ def _compare_tp(
394405
tp_env,
395406
method=method)
396407
except Exception:
397-
testing_ray_compiled_graph = pp_env is not None
398408
if testing_ray_compiled_graph and vllm_major_version == "0":
399409
# Ray Compiled Graph tests are flaky for V0,
400410
# so we don't want to fail the test

tests/entrypoints/openai/test_tokenization.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
3232
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
3333
"--max-lora-rank",
3434
"64",
35+
"--enable-tokenizer-info-endpoint",
3536
]
3637

3738
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -283,3 +284,106 @@ async def test_detokenize(
283284
response.raise_for_status()
284285

285286
assert response.json() == {"prompt": prompt}
287+
288+
289+
@pytest.mark.asyncio
290+
@pytest.mark.parametrize(
291+
"model_name,tokenizer_name",
292+
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
293+
indirect=["tokenizer_name"],
294+
)
295+
async def test_tokenizer_info_basic(
296+
server: RemoteOpenAIServer,
297+
model_name: str,
298+
tokenizer_name: str,
299+
):
300+
"""Test basic tokenizer info endpoint functionality."""
301+
response = requests.get(server.url_for("tokenizer_info"))
302+
response.raise_for_status()
303+
result = response.json()
304+
assert "tokenizer_class" in result
305+
assert isinstance(result["tokenizer_class"], str)
306+
assert result["tokenizer_class"]
307+
308+
309+
@pytest.mark.asyncio
310+
async def test_tokenizer_info_schema(server: RemoteOpenAIServer):
311+
"""Test that the response matches expected schema types."""
312+
response = requests.get(server.url_for("tokenizer_info"))
313+
response.raise_for_status()
314+
result = response.json()
315+
field_types = {
316+
"add_bos_token": bool,
317+
"add_prefix_space": bool,
318+
"clean_up_tokenization_spaces": bool,
319+
"split_special_tokens": bool,
320+
"bos_token": str,
321+
"eos_token": str,
322+
"pad_token": str,
323+
"unk_token": str,
324+
"chat_template": str,
325+
"errors": str,
326+
"model_max_length": int,
327+
"additional_special_tokens": list,
328+
"added_tokens_decoder": dict,
329+
}
330+
for field, expected_type in field_types.items():
331+
if field in result and result[field] is not None:
332+
assert isinstance(
333+
result[field],
334+
expected_type), (f"{field} should be {expected_type.__name__}")
335+
336+
337+
@pytest.mark.asyncio
338+
async def test_tokenizer_info_added_tokens_structure(
339+
server: RemoteOpenAIServer, ):
340+
"""Test added_tokens_decoder structure if present."""
341+
response = requests.get(server.url_for("tokenizer_info"))
342+
response.raise_for_status()
343+
result = response.json()
344+
added_tokens = result.get("added_tokens_decoder")
345+
if added_tokens:
346+
for token_id, token_info in added_tokens.items():
347+
assert isinstance(token_id, str), "Token IDs should be strings"
348+
assert isinstance(token_info, dict), "Token info should be a dict"
349+
assert "content" in token_info, "Token info should have content"
350+
assert "special" in token_info, (
351+
"Token info should have special flag")
352+
assert isinstance(token_info["special"],
353+
bool), ("Special flag should be boolean")
354+
355+
356+
@pytest.mark.asyncio
357+
async def test_tokenizer_info_consistency_with_tokenize(
358+
server: RemoteOpenAIServer, ):
359+
"""Test that tokenizer info is consistent with tokenization endpoint."""
360+
info_response = requests.get(server.url_for("tokenizer_info"))
361+
info_response.raise_for_status()
362+
info = info_response.json()
363+
tokenize_response = requests.post(
364+
server.url_for("tokenize"),
365+
json={
366+
"model": MODEL_NAME,
367+
"prompt": "Hello world!"
368+
},
369+
)
370+
tokenize_response.raise_for_status()
371+
tokenize_result = tokenize_response.json()
372+
info_max_len = info.get("model_max_length")
373+
tokenize_max_len = tokenize_result.get("max_model_len")
374+
if info_max_len and tokenize_max_len:
375+
assert info_max_len >= tokenize_max_len, (
376+
"Info max length should be >= tokenize max length")
377+
378+
379+
@pytest.mark.asyncio
380+
async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
381+
"""Test chat template is properly included."""
382+
response = requests.get(server.url_for("tokenizer_info"))
383+
response.raise_for_status()
384+
result = response.json()
385+
chat_template = result.get("chat_template")
386+
if chat_template:
387+
assert isinstance(chat_template,
388+
str), ("Chat template should be a string")
389+
assert chat_template.strip(), "Chat template should not be empty"

vllm/entrypoints/openai/api_server.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,19 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
522522
assert_never(generator)
523523

524524

525+
def maybe_register_tokenizer_info_endpoint(args):
526+
"""Conditionally register the tokenizer info endpoint if enabled."""
527+
if getattr(args, 'enable_tokenizer_info_endpoint', False):
528+
529+
@router.get("/tokenizer_info")
530+
async def get_tokenizer_info(raw_request: Request):
531+
"""Get comprehensive tokenizer information."""
532+
result = await tokenization(raw_request).get_tokenizer_info()
533+
return JSONResponse(content=result.model_dump(),
534+
status_code=result.code if isinstance(
535+
result, ErrorResponse) else 200)
536+
537+
525538
@router.get("/v1/models")
526539
async def show_available_models(raw_request: Request):
527540
handler = models(raw_request)
@@ -1692,6 +1705,7 @@ async def run_server_worker(listen_address,
16921705
uvicorn_kwargs['log_config'] = log_config
16931706

16941707
async with build_async_engine_client(args, client_config) as engine_client:
1708+
maybe_register_tokenizer_info_endpoint(args)
16951709
app = build_app(args)
16961710

16971711
vllm_config = await engine_client.get_vllm_config()

vllm/entrypoints/openai/cli_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ class FrontendArgs:
182182
"""If set to True, enable tracking server_load_metrics in the app state."""
183183
enable_force_include_usage: bool = False
184184
"""If set to True, including usage on every request."""
185+
enable_tokenizer_info_endpoint: bool = False
186+
"""Enable the /get_tokenizer_info endpoint. May expose chat
187+
templates and other tokenizer configuration."""
185188

186189
@staticmethod
187190
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

vllm/entrypoints/openai/protocol.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1953,6 +1953,16 @@ class DetokenizeResponse(OpenAIBaseModel):
19531953
prompt: str
19541954

19551955

1956+
class TokenizerInfoResponse(OpenAIBaseModel):
1957+
"""
1958+
Response containing tokenizer configuration
1959+
equivalent to tokenizer_config.json
1960+
"""
1961+
1962+
model_config = ConfigDict(extra="allow")
1963+
tokenizer_class: str
1964+
1965+
19561966
class LoadLoRAAdapterRequest(BaseModel):
19571967
lora_name: str
19581968
lora_path: str

0 commit comments

Comments
 (0)