@@ -54,7 +54,7 @@ def test_get_field():
54
54
("jason9693/Qwen2.5-1.5B-apeach" , "pooling" , "classify" ),
55
55
("cross-encoder/ms-marco-MiniLM-L-6-v2" , "pooling" , "classify" ),
56
56
("Qwen/Qwen2.5-Math-RM-72B" , "pooling" , "reward" ),
57
- ("openai/whisper-small" , "transcription " , "transcription" ),
57
+ ("openai/whisper-small" , "generate " , "transcription" ),
58
58
],
59
59
)
60
60
def test_auto_task (model_id , expected_runner_type , expected_task ):
@@ -69,7 +69,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
69
69
)
70
70
71
71
assert config .runner_type == expected_runner_type
72
- assert config .task == expected_task
72
+
73
+ if config .runner_type == "pooling" :
74
+ assert config .task == expected_task
75
+ else :
76
+ assert expected_task in config .supported_tasks
73
77
74
78
75
79
@pytest .mark .parametrize (
@@ -98,11 +102,50 @@ def test_score_task(model_id, expected_runner_type, expected_task):
98
102
assert config .task == expected_task
99
103
100
104
105
+ @pytest .mark .parametrize (("model_id" , "expected_runner_type" , "expected_task" ),
106
+ [
107
+ ("Qwen/Qwen2.5-1.5B-Instruct" , "draft" , "auto" ),
108
+ ])
109
+ def test_draft_task (model_id , expected_runner_type , expected_task ):
110
+ config = ModelConfig (
111
+ model_id ,
112
+ runner = "draft" ,
113
+ tokenizer = model_id ,
114
+ seed = 0 ,
115
+ dtype = "float16" ,
116
+ )
117
+
118
+ assert config .runner_type == expected_runner_type
119
+ assert config .task == expected_task
120
+
121
+
122
+ @pytest .mark .parametrize (
123
+ ("model_id" , "expected_runner_type" , "expected_task" ),
124
+ [
125
+ ("openai/whisper-small" , "generate" , "transcription" ),
126
+ ],
127
+ )
128
+ def test_transcription_task (model_id , expected_runner_type , expected_task ):
129
+ config = ModelConfig (
130
+ model_id ,
131
+ task = "transcription" ,
132
+ tokenizer = model_id ,
133
+ tokenizer_mode = "auto" ,
134
+ trust_remote_code = False ,
135
+ seed = 0 ,
136
+ dtype = "float16" ,
137
+ )
138
+
139
+ assert config .runner_type == expected_runner_type
140
+ assert config .task == expected_task
141
+
142
+
101
143
@pytest .mark .parametrize (("model_id" , "bad_task" ), [
102
144
("Qwen/Qwen2.5-Math-RM-72B" , "generate" ),
145
+ ("Qwen/Qwen3-0.6B" , "transcription" ),
103
146
])
104
147
def test_incorrect_task (model_id , bad_task ):
105
- with pytest .raises (ValueError , match = r"does not support the .* task " ):
148
+ with pytest .raises (ValueError , match = r"does not support task=.* " ):
106
149
ModelConfig (
107
150
model_id ,
108
151
task = bad_task ,
0 commit comments