Skip to content

Commit cfc0597

Browse files
authored
[Fix] Fix ignore_eos support (#2414)
The ignore_eos support was broken during recent refactors. This PR fixes the support.
1 parent cd79b96 commit cfc0597

File tree

9 files changed

+29
-16
lines changed

9 files changed

+29
-16
lines changed

cpp/serve/engine.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ class EngineModule : public ModuleNode {
599599
void AddRequest(Request request) { return GetEngine()->AddRequest(std::move(request)); }
600600
/*! \brief Redirection to `Engine::AbortRequest`. */
601601
void Abort(const String& request_id) { return GetEngine()->AbortRequest(request_id); }
602-
602+
/*! \brief Create request with given arguments and the engine default generation config. */
603603
Request CreateRequest(String id, Array<Data> inputs, String generation_cfg_json_str) {
604604
auto gen_config =
605605
GenerationConfig::FromJSON(std::move(generation_cfg_json_str), default_generation_config_);

python/mlc_llm/protocol/openai_api_protocol.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,9 @@ def openai_api_get_generation_config(
391391
) -> Dict[str, Any]:
392392
"""Create the generation config from the given request."""
393393
from ..serve.config import ResponseFormat # pylint: disable=import-outside-toplevel
394+
from ..serve.config import ( # pylint: disable=import-outside-toplevel,redefined-outer-name
395+
DebugConfig,
396+
)
394397

395398
kwargs: Dict[str, Any] = {}
396399
arg_names = [
@@ -404,7 +407,6 @@ def openai_api_get_generation_config(
404407
"top_logprobs",
405408
"logit_bias",
406409
"seed",
407-
"debug_config",
408410
]
409411
for arg_name in arg_names:
410412
kwargs[arg_name] = getattr(request, arg_name)
@@ -418,4 +420,6 @@ def openai_api_get_generation_config(
418420
kwargs["response_format"] = ResponseFormat(
419421
**request.response_format.model_dump(by_alias=True)
420422
)
423+
if request.debug_config is not None:
424+
kwargs["debug_config"] = DebugConfig(**request.debug_config.model_dump())
421425
return kwargs

python/mlc_llm/serve/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,17 @@ def __post_init__(self):
3333
class DebugConfig:
3434
"""The debug configuration dataclass.Parameters
3535
----------
36+
ignore_eos : bool
37+
When it is true, ignore the eos token and generate tokens until `max_tokens`.
38+
Default is set to False.
3639
3740
pinned_system_prompt : bool
3841
Whether the input and generated data pinned in engine. Default is set to False.
3942
This can be used for system prompt or other purpose, if the data is aimed to be
4043
kept all the time.
4144
"""
4245

46+
ignore_eos: bool = False
4347
pinned_system_prompt: bool = False
4448

4549

python/mlc_llm/serve/server/popen_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__( # pylint: disable=too-many-arguments
2828
mode: Literal["local", "interactive", "server"] = "local",
2929
engine_config: Optional[EngineConfig] = None,
3030
enable_tracing: bool = False,
31+
enable_debug: bool = False,
3132
host: str = "127.0.0.1",
3233
port: int = 8000,
3334
) -> None:
@@ -43,6 +44,7 @@ def __init__( # pylint: disable=too-many-arguments
4344
self.mode = mode
4445
self.engine_config = engine_config
4546
self.enable_tracing = enable_tracing
47+
self.enable_debug = enable_debug
4648
self.host = host
4749
self.port = port
4850
self._proc: Optional[subprocess.Popen] = None
@@ -96,6 +98,8 @@ def start(self) -> None: # pylint: disable=too-many-branches,too-many-statement
9698

9799
if self.enable_tracing:
98100
cmd += ["--enable-tracing"]
101+
if self.enable_debug:
102+
cmd += ["--enable-debug"]
99103

100104
cmd += ["--host", self.host]
101105
cmd += ["--port", str(self.port)]

python/mlc_llm/serve/sync_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def create_request(
307307
"""
308308
if not isinstance(inputs, list):
309309
inputs = [inputs]
310-
self._ffi["create_request"](request_id, inputs, generation_config.asjson())
310+
return self._ffi["create_request"](request_id, inputs, generation_config.asjson())
311311

312312
def add_request(self, request: Request) -> None:
313313
"""Add a new request to the engine.

tests/python/serve/server/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from mlc_llm.serve import EngineConfig, PopenServer
7+
from mlc_llm.serve import PopenServer
88

99

1010
@pytest.fixture(scope="session")
@@ -27,6 +27,7 @@ def launch_server(served_model): # pylint: disable=redefined-outer-name
2727
model=served_model[0],
2828
model_lib=served_model[1],
2929
enable_tracing=True,
30+
enable_debug=True,
3031
)
3132

3233
with server:

tests/python/serve/server/test_server.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def test_openai_v1_completions(
256256
"prompt": prompt,
257257
"max_tokens": max_tokens,
258258
"stream": stream,
259-
"ignore_eos": True,
259+
"debug_config": {"ignore_eos": True},
260260
}
261261

262262
response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)
@@ -347,7 +347,7 @@ def test_openai_v1_completions_echo(
347347
"max_tokens": max_tokens,
348348
"echo": True,
349349
"stream": stream,
350-
"ignore_eos": True,
350+
"debug_config": {"ignore_eos": True},
351351
}
352352

353353
response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)
@@ -398,7 +398,7 @@ def test_openai_v1_completions_suffix(
398398
"max_tokens": max_tokens,
399399
"suffix": suffix,
400400
"stream": stream,
401-
"ignore_eos": True,
401+
"debug_config": {"ignore_eos": True},
402402
}
403403

404404
response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)
@@ -498,7 +498,7 @@ def test_openai_v1_completions_temperature(
498498
"max_tokens": max_tokens,
499499
"stream": stream,
500500
"temperature": 0.0,
501-
"ignore_eos": True,
501+
"debug_config": {"ignore_eos": True},
502502
}
503503

504504
response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)
@@ -652,7 +652,7 @@ def test_openai_v1_completions_logit_bias(
652652
"max_tokens": max_tokens,
653653
"stream": stream,
654654
"logit_bias": {338: -100}, # 338 is " is" in Llama tokenizer.
655-
"ignore_eos": True,
655+
"debug_config": {"ignore_eos": True},
656656
}
657657

658658
response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)
@@ -699,7 +699,7 @@ def test_openai_v1_completions_presence_frequency_penalty(
699699
"stream": stream,
700700
"frequency_penalty": 2.0,
701701
"presence_penalty": 2.0,
702-
"ignore_eos": True,
702+
"debug_config": {"ignore_eos": True},
703703
}
704704

705705
response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)
@@ -743,7 +743,7 @@ def test_openai_v1_completions_seed(
743743
"max_tokens": max_tokens,
744744
"stream": False,
745745
"seed": 233,
746-
"ignore_eos": True,
746+
"debug_config": {"ignore_eos": True},
747747
}
748748

749749
response1 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)
@@ -1207,7 +1207,7 @@ def test_openai_v1_chat_completions_ignore_eos(
12071207
"messages": messages,
12081208
"stream": stream,
12091209
"max_tokens": max_tokens,
1210-
"ignore_eos": True,
1210+
"debug_config": {"ignore_eos": True},
12111211
}
12121212

12131213
response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180)

tests/python/serve/test_serve_async_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ async def generate_task(prompt: str, request_id: str):
195195
model=model,
196196
max_tokens=max_tokens,
197197
n=n,
198-
ignore_eos=True,
199198
request_id=request_id,
200199
stream=True,
200+
debug_config={"ignore_eos": True},
201201
):
202202
for choice in response.choices:
203203
output_texts[rid][choice.index] += choice.text
@@ -245,8 +245,8 @@ async def generate_task(prompt: str, request_id: str):
245245
model=model,
246246
max_tokens=max_tokens,
247247
n=n,
248-
ignore_eos=True,
249248
request_id=request_id,
249+
debug_config={"ignore_eos": True},
250250
)
251251
for choice in response.choices:
252252
output_texts[rid][choice.index] += choice.text

tests/python/serve/test_serve_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ def test_completion(model: str, model_lib: str):
175175
model=model,
176176
max_tokens=max_tokens,
177177
n=n,
178-
ignore_eos=True,
179178
request_id=str(rid),
180179
stream=True,
180+
debug_config={"ignore_eos": True},
181181
):
182182
for choice in response.choices:
183183
output_texts[rid][choice.index] += choice.text
@@ -212,8 +212,8 @@ def test_completion_non_stream(model: str, model_lib: str):
212212
model=model,
213213
max_tokens=max_tokens,
214214
n=n,
215-
ignore_eos=True,
216215
request_id=str(rid),
216+
debug_config={"ignore_eos": True},
217217
)
218218
for choice in response.choices:
219219
output_texts[rid][choice.index] += choice.text

0 commit comments

Comments
 (0)