1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pytest
16
+ import traceback
17
+ from fastdeploy import LLM , SamplingParams
18
+ import os
19
+ import subprocess
20
+ import signal
21
+
22
+ FD_ENGINE_QUEUE_PORT = int (os .getenv ("FD_ENGINE_QUEUE_PORT" , 8313 ))
23
+
24
+ def format_chat_prompt (messages ):
25
+ """
26
+ Format multi-turn conversation into prompt string, suitable for chat models.
27
+ Uses Qwen2 style with <|im_start|> / <|im_end|> tokens.
28
+ """
29
+ prompt = ""
30
+ for msg in messages :
31
+ role , content = msg ["role" ], msg ["content" ]
32
+ if role == "user" :
33
+ prompt += "<|im_start|>user\n {content}<|im_end|>\n " .format (content = content )
34
+ elif role == "assistant" :
35
+ prompt += "<|im_start|>assistant\n {content}<|im_end|>\n " .format (content = content )
36
+ prompt += "<|im_start|>assistant\n "
37
+ return prompt
38
+
39
+
40
+ @pytest .fixture (scope = "module" )
41
+ def model_path ():
42
+ """
43
+ Get model path from environment variable MODEL_PATH,
44
+ default to "./Qwen2-7B-Instruct" if not set.
45
+ """
46
+ base_path = os .getenv ("MODEL_PATH" )
47
+ if base_path :
48
+ return os .path .join (base_path , "Qwen2-7B-Instruct" )
49
+ else :
50
+ return "./Qwen2-7B-Instruct"
51
+
52
+ @pytest .fixture (scope = "module" )
53
+ def llm (model_path ):
54
+ """
55
+ Fixture to initialize the LLM model with a given model path
56
+ """
57
+ try :
58
+ output = subprocess .check_output (f"lsof -i:{ FD_ENGINE_QUEUE_PORT } -t" , shell = True ).decode ().strip ()
59
+ for pid in output .splitlines ():
60
+ os .kill (int (pid ), signal .SIGKILL )
61
+ print (f"Killed process on port { FD_ENGINE_QUEUE_PORT } , pid={ pid } " )
62
+ except subprocess .CalledProcessError :
63
+ pass
64
+
65
+ try :
66
+ llm = LLM (
67
+ model = model_path ,
68
+ tensor_parallel_size = 1 ,
69
+ engine_worker_queue_port = FD_ENGINE_QUEUE_PORT ,
70
+ max_model_len = 4096
71
+ )
72
+ print ("Model loaded successfully from {}." .format (model_path ))
73
+ yield llm
74
+ except Exception :
75
+ print ("Failed to load model from {}." .format (model_path ))
76
+ traceback .print_exc ()
77
+ pytest .fail ("Failed to initialize LLM model from {}" .format (model_path ))
78
+
79
+
80
+ def test_generate_prompts (llm ):
81
+ """
82
+ Test basic prompt generation
83
+ """
84
+ # Only one prompt enabled for testing currently
85
+ prompts = [
86
+ "请介绍一下中国的四大发明。" ,
87
+ # "太阳和地球之间的距离是多少?",
88
+ # "写一首关于春天的古风诗。",
89
+ ]
90
+
91
+ sampling_params = SamplingParams (
92
+ temperature = 0.8 ,
93
+ top_p = 0.95 ,
94
+ )
95
+
96
+ try :
97
+ outputs = llm .generate (prompts , sampling_params )
98
+
99
+ # Verify basic properties of the outputs
100
+ assert len (outputs ) == len (prompts ), "Number of outputs should match number of prompts"
101
+
102
+ for i , output in enumerate (outputs ):
103
+ assert output .prompt == prompts [i ], "Prompt mismatch for case {}" .format (i + 1 )
104
+ assert isinstance (output .outputs .text , str ), "Output text should be string for case {}" .format (i + 1 )
105
+ assert len (output .outputs .text ) > 0 , "Generated text should not be empty for case {}" .format (i + 1 )
106
+ assert isinstance (output .finished , bool ), "'finished' should be boolean for case {}" .format (i + 1 )
107
+ assert output .metrics .model_execute_time > 0 , "Execution time should be positive for case {}" .format (i + 1 )
108
+
109
+ print ("=== Prompt generation Case {} Passed ===" .format (i + 1 ))
110
+
111
+ except Exception :
112
+ print ("Failed during prompt generation." )
113
+ traceback .print_exc ()
114
+ pytest .fail ("Prompt generation test failed" )
115
+
116
+
117
+ def test_chat_completion (llm ):
118
+ """
119
+ Test chat completion with multiple turns
120
+ """
121
+ chat_cases = [
122
+ [
123
+ {"role" : "user" , "content" : "你好,请介绍一下你自己。" },
124
+ ],
125
+ [
126
+ {"role" : "user" , "content" : "你知道地球到月球的距离是多少吗?" },
127
+ {"role" : "assistant" , "content" : "大约是38万公里左右。" },
128
+ {"role" : "user" , "content" : "那太阳到地球的距离是多少?" },
129
+ ],
130
+ [
131
+ {"role" : "user" , "content" : "请给我起一个中文名。" },
132
+ {"role" : "assistant" , "content" : "好的,你可以叫“星辰”。" },
133
+ {"role" : "user" , "content" : "再起一个。" },
134
+ {"role" : "assistant" , "content" : "那就叫”大海“吧。" },
135
+ {"role" : "user" , "content" : "再来三个。" },
136
+ ],
137
+ ]
138
+
139
+ sampling_params = SamplingParams (
140
+ temperature = 0.8 ,
141
+ top_p = 0.95 ,
142
+ )
143
+
144
+ for i , case in enumerate (chat_cases ):
145
+ prompt = format_chat_prompt (case )
146
+ try :
147
+ outputs = llm .generate (prompt , sampling_params )
148
+
149
+ # Verify chat completion properties
150
+ assert len (outputs ) == 1 , "Should return one output per prompt"
151
+ assert isinstance (outputs [0 ].outputs .text , str ), "Output text should be string"
152
+ assert len (outputs [0 ].outputs .text ) > 0 , "Generated text should not be empty"
153
+ assert outputs [0 ].metrics .model_execute_time > 0 , "Execution time should be positive"
154
+
155
+ print ("=== Chat Case {} Passed ===" .format (i + 1 ))
156
+
157
+ except Exception :
158
+ print ("[ERROR] Chat Case {} failed." .format (i + 1 ))
159
+ traceback .print_exc ()
160
+ pytest .fail ("Chat case {} failed" .format (i + 1 ))
161
+
162
+
163
+ if __name__ == "__main__" :
164
+ """
165
+ Main entry point for the test script.
166
+ """
167
+ pytest .main (["-sv" , __file__ ])
0 commit comments