13
13
14
14
from ...utils import RemoteOpenAIServer
15
15
16
- MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach "
16
+ MODEL_NAME = "internlm/internlm2-1_8b-reward "
17
17
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\ n'}}{% endfor %}""" # noqa: E501
18
18
19
19
20
20
@pytest .fixture (scope = "module" )
21
21
def server ():
22
22
args = [
23
23
"--task" ,
24
- "classify " ,
24
+ "reward " ,
25
25
# use half precision for speed and memory savings in CI environment
26
26
"--dtype" ,
27
27
"bfloat16" ,
28
28
"--enforce-eager" ,
29
29
"--max-model-len" ,
30
- "8192 " ,
30
+ "512 " ,
31
31
"--chat-template" ,
32
32
DUMMY_CHAT_TEMPLATE ,
33
+ "--trust-remote-code" ,
33
34
]
34
35
35
36
with RemoteOpenAIServer (MODEL_NAME , args ) as remote_server :
@@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
57
58
58
59
assert poolings .id is not None
59
60
assert len (poolings .data ) == 1
60
- assert len (poolings .data [0 ].data ) == 2
61
+ assert len (poolings .data [0 ].data ) == 8
61
62
assert poolings .usage .completion_tokens == 0
62
- assert poolings .usage .prompt_tokens == 7
63
- assert poolings .usage .total_tokens == 7
63
+ assert poolings .usage .prompt_tokens == 8
64
+ assert poolings .usage .total_tokens == 8
64
65
65
66
# test using token IDs
66
67
input_tokens = [1 , 1 , 1 , 1 , 1 ]
@@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
77
78
78
79
assert poolings .id is not None
79
80
assert len (poolings .data ) == 1
80
- assert len (poolings .data [0 ].data ) == 2
81
+ assert len (poolings .data [0 ].data ) == 5
81
82
assert poolings .usage .completion_tokens == 0
82
83
assert poolings .usage .prompt_tokens == 5
83
84
assert poolings .usage .total_tokens == 5
@@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
104
105
105
106
assert poolings .id is not None
106
107
assert len (poolings .data ) == 3
107
- assert len (poolings .data [0 ].data ) == 2
108
+ assert len (poolings .data [0 ].data ) == 8
108
109
assert poolings .usage .completion_tokens == 0
109
- assert poolings .usage .prompt_tokens == 25
110
- assert poolings .usage .total_tokens == 25
110
+ assert poolings .usage .prompt_tokens == 29
111
+ assert poolings .usage .total_tokens == 29
111
112
112
113
# test list[list[int]]
113
114
input_tokens = [[4 , 5 , 7 , 9 , 20 ], [15 , 29 , 499 ], [24 , 24 , 24 , 24 , 24 ],
@@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
125
126
126
127
assert poolings .id is not None
127
128
assert len (poolings .data ) == 4
128
- assert len (poolings .data [0 ].data ) == 2
129
+ assert len (poolings .data [0 ].data ) == 5
129
130
assert poolings .usage .completion_tokens == 0
130
131
assert poolings .usage .prompt_tokens == 17
131
132
assert poolings .usage .total_tokens == 17
@@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer,
157
158
chat_response .raise_for_status ()
158
159
chat_poolings = PoolingResponse .model_validate (chat_response .json ())
159
160
160
- tokenizer = get_tokenizer (tokenizer_name = model_name , tokenizer_mode = "fast" )
161
+ tokenizer = get_tokenizer (
162
+ tokenizer_name = model_name ,
163
+ tokenizer_mode = "fast" ,
164
+ trust_remote_code = True ,
165
+ )
161
166
prompt = tokenizer .apply_chat_template (
162
167
messages ,
163
168
chat_template = DUMMY_CHAT_TEMPLATE ,
@@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
206
211
)
207
212
float_response .raise_for_status ()
208
213
responses_float = PoolingResponse .model_validate (float_response .json ())
214
+ float_data = [
215
+ np .array (d .data ).squeeze (- 1 ).tolist () for d in responses_float .data
216
+ ]
209
217
210
218
base64_response = requests .post (
211
219
server .url_for ("pooling" ),
@@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
224
232
np .frombuffer (base64 .b64decode (data .data ),
225
233
dtype = "float32" ).tolist ())
226
234
227
- check_embeddings_close (
228
- embeddings_0_lst = [d .data for d in responses_float .data ],
229
- embeddings_1_lst = decoded_responses_base64_data ,
230
- name_0 = "float32" ,
231
- name_1 = "base64" )
235
+ check_embeddings_close (embeddings_0_lst = float_data ,
236
+ embeddings_1_lst = decoded_responses_base64_data ,
237
+ name_0 = "float32" ,
238
+ name_1 = "base64" )
232
239
233
240
# Default response is float32 decoded from base64 by OpenAI Client
234
241
default_response = requests .post (
@@ -240,9 +247,71 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
240
247
)
241
248
default_response .raise_for_status ()
242
249
responses_default = PoolingResponse .model_validate (default_response .json ())
250
+ default_data = [
251
+ np .array (d .data ).squeeze (- 1 ).tolist () for d in responses_default .data
252
+ ]
253
+
254
+ check_embeddings_close (embeddings_0_lst = float_data ,
255
+ embeddings_1_lst = default_data ,
256
+ name_0 = "float32" ,
257
+ name_1 = "default" )
258
+
259
+
260
+ @pytest .mark .asyncio
261
+ async def test_invocations (server : RemoteOpenAIServer ):
262
+ input_texts = [
263
+ "The chef prepared a delicious meal." ,
264
+ ]
265
+
266
+ request_args = {
267
+ "model" : MODEL_NAME ,
268
+ "input" : input_texts ,
269
+ "encoding_format" : "float" ,
270
+ }
271
+
272
+ completion_response = requests .post (server .url_for ("pooling" ),
273
+ json = request_args )
274
+ completion_response .raise_for_status ()
275
+
276
+ invocation_response = requests .post (server .url_for ("invocations" ),
277
+ json = request_args )
278
+ invocation_response .raise_for_status ()
279
+
280
+ completion_output = completion_response .json ()
281
+ invocation_output = invocation_response .json ()
282
+
283
+ assert completion_output .keys () == invocation_output .keys ()
284
+ assert completion_output ["data" ] == invocation_output ["data" ]
285
+
286
+
287
+ @pytest .mark .asyncio
288
+ async def test_invocations_conversation (server : RemoteOpenAIServer ):
289
+ messages = [{
290
+ "role" : "user" ,
291
+ "content" : "The cat sat on the mat." ,
292
+ }, {
293
+ "role" : "assistant" ,
294
+ "content" : "A feline was resting on a rug." ,
295
+ }, {
296
+ "role" : "user" ,
297
+ "content" : "Stars twinkle brightly in the night sky." ,
298
+ }]
299
+
300
+ request_args = {
301
+ "model" : MODEL_NAME ,
302
+ "messages" : messages ,
303
+ "encoding_format" : "float" ,
304
+ }
305
+
306
+ chat_response = requests .post (server .url_for ("pooling" ), json = request_args )
307
+ chat_response .raise_for_status ()
308
+
309
+ invocation_response = requests .post (server .url_for ("invocations" ),
310
+ json = request_args )
311
+ invocation_response .raise_for_status ()
312
+
313
+ chat_output = chat_response .json ()
314
+ invocation_output = invocation_response .json ()
243
315
244
- check_embeddings_close (
245
- embeddings_0_lst = [d .data for d in responses_default .data ],
246
- embeddings_1_lst = [d .data for d in responses_default .data ],
247
- name_0 = "float32" ,
248
- name_1 = "base64" )
316
+ assert chat_output .keys () == invocation_output .keys ()
317
+ assert chat_output ["data" ] == invocation_output ["data" ]
0 commit comments