Skip to content

Commit dacc46f

Browse files
authored
[CI] Add validation for MTP and CUDAGraph (#2710)
* set git identity to avoid merge failure in CI * add ci cases * [CI] Add validation for MTP and CUDAGraph
1 parent 09ded77 commit dacc46f

File tree

4 files changed

+325
-22
lines changed

4 files changed

+325
-22
lines changed

test/ci_use/EB_Lite/test_EB_Lite_serving.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,14 @@ def setup_and_run_server():
9090
"--max-model-len", "32768",
9191
"--max-num-seqs", "128",
9292
"--quantization", "wint4",
93+
"--use-cudagraph",
94+
"--max-capture-batch-size", "1"
9395
]
9496

95-
# Set environment variables
96-
env = os.environ.copy()
97-
env["ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY"] = "0"
98-
env["FLAGS_use_append_attn"] = "1"
99-
env["ELLM_DYNAMIC_MODE"] = "1"
100-
env["NCCL_ALGO"] = "Ring"
101-
env["USE_WORKER_V1"] = "1"
102-
10397
# Start subprocess in new process group
10498
with open(log_path, "w") as logfile:
10599
process = subprocess.Popen(
106100
cmd,
107-
env=env,
108101
stdout=logfile,
109102
stderr=subprocess.STDOUT,
110103
start_new_session=True # Enables killing full group via os.killpg
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import requests
17+
import time
18+
import json
19+
import subprocess
20+
import socket
21+
import os
22+
import signal
23+
import sys
24+
import openai
25+
26+
# Read ports from environment variables; use default values if not set
27+
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
28+
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
29+
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
30+
31+
# List of ports to clean before and after tests
32+
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT]
33+
34+
def is_port_open(host: str, port: int, timeout=1.0):
35+
"""
36+
Check if a TCP port is open on the given host.
37+
Returns True if connection succeeds, False otherwise.
38+
"""
39+
try:
40+
with socket.create_connection((host, port), timeout):
41+
return True
42+
except Exception:
43+
return False
44+
45+
def kill_process_on_port(port: int):
46+
"""
47+
Kill processes that are listening on the given port.
48+
Uses `lsof` to find process ids and sends SIGKILL.
49+
"""
50+
try:
51+
output = subprocess.check_output("lsof -i:{} -t".format(port), shell=True).decode().strip()
52+
for pid in output.splitlines():
53+
os.kill(int(pid), signal.SIGKILL)
54+
print("Killed process on port {}, pid={}".format(port, pid))
55+
except subprocess.CalledProcessError:
56+
pass
57+
58+
def clean_ports():
59+
"""
60+
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
61+
"""
62+
for port in PORTS_TO_CLEAN:
63+
kill_process_on_port(port)
64+
65+
@pytest.fixture(scope="session", autouse=True)
66+
def setup_and_run_server():
67+
"""
68+
Pytest fixture that runs once per test session:
69+
- Cleans ports before tests
70+
- Starts the API server as a subprocess
71+
- Waits for server port to open (up to 30 seconds)
72+
- Tears down server after all tests finish
73+
"""
74+
print("Pre-test port cleanup...")
75+
clean_ports()
76+
77+
base_path = os.getenv("MODEL_PATH")
78+
if base_path:
79+
model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle")
80+
else:
81+
model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
82+
83+
mtp_model_path = os.path.join(model_path, "mtp")
84+
mtp_mode_str = json.dumps({
85+
"method": "mtp",
86+
"num_speculative_tokens": 1,
87+
"model": mtp_model_path
88+
})
89+
90+
log_path = "server.log"
91+
cmd = [
92+
sys.executable, "-m", "fastdeploy.entrypoints.openai.api_server",
93+
"--model", model_path,
94+
"--port", str(FD_API_PORT),
95+
"--tensor-parallel-size", "1",
96+
"--engine-worker-queue-port", str(FD_ENGINE_QUEUE_PORT),
97+
"--metrics-port", str(FD_METRICS_PORT),
98+
"--max-model-len", "32768",
99+
"--max-num-seqs", "128",
100+
"--quantization", "wint4",
101+
"--speculative-config", mtp_mode_str
102+
]
103+
104+
# Start subprocess in new process group
105+
with open(log_path, "w") as logfile:
106+
process = subprocess.Popen(
107+
cmd,
108+
stdout=logfile,
109+
stderr=subprocess.STDOUT,
110+
start_new_session=True # Enables killing full group via os.killpg
111+
)
112+
113+
# Wait up to 300 seconds for API server to be ready
114+
for _ in range(300):
115+
if is_port_open("127.0.0.1", FD_API_PORT):
116+
print("API server is up on port {}".format(FD_API_PORT))
117+
break
118+
time.sleep(1)
119+
else:
120+
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
121+
try:
122+
os.killpg(process.pid, signal.SIGTERM)
123+
except Exception as e:
124+
print("Failed to kill process group: {}".format(e))
125+
raise RuntimeError("API server did not start on port {}".format(FD_API_PORT))
126+
127+
yield # Run tests
128+
129+
print("\n===== Post-test server cleanup... =====")
130+
try:
131+
os.killpg(process.pid, signal.SIGTERM)
132+
print("API server (pid={}) terminated".format(process.pid))
133+
except Exception as e:
134+
print("Failed to terminate API server: {}".format(e))
135+
136+
137+
@pytest.fixture(scope="session")
138+
def api_url(request):
139+
"""
140+
Returns the API endpoint URL for chat completions.
141+
"""
142+
return "http://0.0.0.0:{}/v1/chat/completions".format(FD_API_PORT)
143+
144+
145+
@pytest.fixture(scope="session")
146+
def metrics_url(request):
147+
"""
148+
Returns the metrics endpoint URL.
149+
"""
150+
return "http://0.0.0.0:{}/metrics".format(FD_METRICS_PORT)
151+
152+
153+
@pytest.fixture
154+
def headers():
155+
"""
156+
Returns common HTTP request headers.
157+
"""
158+
return {"Content-Type": "application/json"}
159+
160+
161+
@pytest.fixture
162+
def consistent_payload():
163+
"""
164+
Returns a fixed payload for consistency testing,
165+
including a fixed random seed and temperature.
166+
"""
167+
return {
168+
"messages": [{"role": "user", "content": "用一句话介绍 PaddlePaddle"}],
169+
"temperature": 0.9,
170+
"top_p": 0, # fix top_p to reduce randomness
171+
"seed": 13 # fixed random seed
172+
}
173+
174+
# ==========================
175+
# Helper function to calculate difference rate between two texts
176+
# ==========================
177+
def calculate_diff_rate(text1, text2):
178+
"""
179+
Calculate the difference rate between two strings
180+
based on the normalized Levenshtein edit distance.
181+
Returns a float in [0,1], where 0 means identical.
182+
"""
183+
if text1 == text2:
184+
return 0.0
185+
186+
len1, len2 = len(text1), len(text2)
187+
dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]
188+
189+
for i in range(len1 + 1):
190+
for j in range(len2 + 1):
191+
if i == 0 or j == 0:
192+
dp[i][j] = i + j
193+
elif text1[i - 1] == text2[j - 1]:
194+
dp[i][j] = dp[i - 1][j - 1]
195+
else:
196+
dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
197+
198+
edit_distance = dp[len1][len2]
199+
max_len = max(len1, len2)
200+
return edit_distance / max_len if max_len > 0 else 0.0
201+
202+
# ==========================
203+
# Consistency test for repeated runs with fixed payload
204+
# ==========================
205+
def test_consistency_between_runs(api_url, headers, consistent_payload):
206+
"""
207+
Test that two runs with the same fixed input produce similar outputs.
208+
"""
209+
# First request
210+
resp1 = requests.post(api_url, headers=headers, json=consistent_payload)
211+
assert resp1.status_code == 200
212+
result1 = resp1.json()
213+
content1 = result1["choices"][0]["message"]["content"]
214+
215+
# Second request
216+
resp2 = requests.post(api_url, headers=headers, json=consistent_payload)
217+
assert resp2.status_code == 200
218+
result2 = resp2.json()
219+
content2 = result2["choices"][0]["message"]["content"]
220+
221+
# Calculate difference rate
222+
diff_rate = calculate_diff_rate(content1, content2)
223+
224+
# Verify that the difference rate is below the threshold
225+
assert diff_rate < 0.05, "Output difference too large ({:.4%})".format(diff_rate)
226+
227+
# ==========================
228+
# OpenAI Client chat.completions Test
229+
# ==========================
230+
231+
@pytest.fixture
232+
def openai_client():
233+
ip = "0.0.0.0"
234+
service_http_port = str(FD_API_PORT)
235+
client = openai.Client(
236+
base_url="http://{}:{}/v1".format(ip, service_http_port),
237+
api_key="EMPTY_API_KEY"
238+
)
239+
return client
240+
241+
# Non-streaming test
242+
def test_non_streaming_chat(openai_client):
243+
"""
244+
Test non-streaming chat functionality with the local service
245+
"""
246+
response = openai_client.chat.completions.create(
247+
model="default",
248+
messages=[
249+
{"role": "system", "content": "You are a helpful AI assistant."},
250+
{"role": "user", "content": "List 3 countries and their capitals."},
251+
],
252+
temperature=1,
253+
max_tokens=1024,
254+
stream=False,
255+
)
256+
257+
assert hasattr(response, 'choices')
258+
assert len(response.choices) > 0
259+
assert hasattr(response.choices[0], 'message')
260+
assert hasattr(response.choices[0].message, 'content')
261+
262+
# Streaming test
263+
def test_streaming_chat(openai_client, capsys):
264+
"""
265+
Test streaming chat functionality with the local service
266+
"""
267+
response = openai_client.chat.completions.create(
268+
model="default",
269+
messages=[
270+
{"role": "system", "content": "You are a helpful AI assistant."},
271+
{"role": "user", "content": "List 3 countries and their capitals."},
272+
{"role": "assistant", "content": "China(Beijing), France(Paris), Australia(Canberra)."},
273+
{"role": "user", "content": "OK, tell more."},
274+
],
275+
temperature=1,
276+
max_tokens=1024,
277+
stream=True,
278+
)
279+
280+
output = []
281+
for chunk in response:
282+
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
283+
output.append(chunk.choices[0].delta.content)
284+
assert len(output) > 2
285+
286+
# ==========================
287+
# OpenAI Client completions Test
288+
# ==========================
289+
290+
def test_non_streaming(openai_client):
291+
"""
292+
Test non-streaming chat functionality with the local service
293+
"""
294+
response = openai_client.completions.create(
295+
model="default",
296+
prompt="Hello, how are you?",
297+
temperature=1,
298+
max_tokens=1024,
299+
stream=False,
300+
)
301+
302+
# Assertions to check the response structure
303+
assert hasattr(response, 'choices')
304+
assert len(response.choices) > 0
305+
306+
307+
def test_streaming(openai_client, capsys):
308+
"""
309+
Test streaming functionality with the local service
310+
"""
311+
response = openai_client.completions.create(
312+
model="default",
313+
prompt="Hello, how are you?",
314+
temperature=1,
315+
max_tokens=1024,
316+
stream=True,
317+
)
318+
319+
# Collect streaming output
320+
output = []
321+
for chunk in response:
322+
output.append(chunk.choices[0].text)
323+
assert len(output) > 0

test/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,10 @@ def setup_and_run_server():
101101
"--quantization", "wint4"
102102
]
103103

104-
# Set environment variables
105-
env = os.environ.copy()
106-
env["ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY"] = "0"
107-
env["NCCL_ALGO"] = "Ring"
108-
109104
# Start subprocess in new process group
110105
with open(log_path, "w") as logfile:
111106
process = subprocess.Popen(
112107
cmd,
113-
env=env,
114108
stdout=logfile,
115109
stderr=subprocess.STDOUT,
116110
start_new_session=True # Enables killing full group via os.killpg

test/ci_use/Qwen3-MoE/test_Qwen3-MoE_serving.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,10 @@ def setup_and_run_server():
9292
"--quantization", "wint4"
9393
]
9494

95-
# Set environment variables
96-
env = os.environ.copy()
97-
env["ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY"] = "0"
98-
env["NCCL_ALGO"] = "Ring"
99-
env["FLAG_SAMPLING_CLASS"] = "rejection"
100-
10195
# Start subprocess in new process group
10296
with open(log_path, "w") as logfile:
10397
process = subprocess.Popen(
10498
cmd,
105-
env=env,
10699
stdout=logfile,
107100
stderr=subprocess.STDOUT,
108101
start_new_session=True # Enables killing full group via os.killpg

0 commit comments

Comments
 (0)