Skip to content

Commit 020f58a

Browse files
[Core] Support multiple tasks per model (#20771)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent c1acd6d commit 020f58a

File tree

8 files changed

+278
-147
lines changed

8 files changed

+278
-147
lines changed

tests/test_config.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_get_field():
5454
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
5555
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
5656
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
57-
("openai/whisper-small", "transcription", "transcription"),
57+
("openai/whisper-small", "generate", "transcription"),
5858
],
5959
)
6060
def test_auto_task(model_id, expected_runner_type, expected_task):
@@ -69,7 +69,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
6969
)
7070

7171
assert config.runner_type == expected_runner_type
72-
assert config.task == expected_task
72+
73+
if config.runner_type == "pooling":
74+
assert config.task == expected_task
75+
else:
76+
assert expected_task in config.supported_tasks
7377

7478

7579
@pytest.mark.parametrize(
@@ -98,11 +102,50 @@ def test_score_task(model_id, expected_runner_type, expected_task):
98102
assert config.task == expected_task
99103

100104

105+
@pytest.mark.parametrize(("model_id", "expected_runner_type", "expected_task"),
106+
[
107+
("Qwen/Qwen2.5-1.5B-Instruct", "draft", "auto"),
108+
])
109+
def test_draft_task(model_id, expected_runner_type, expected_task):
110+
config = ModelConfig(
111+
model_id,
112+
runner="draft",
113+
tokenizer=model_id,
114+
seed=0,
115+
dtype="float16",
116+
)
117+
118+
assert config.runner_type == expected_runner_type
119+
assert config.task == expected_task
120+
121+
122+
@pytest.mark.parametrize(
123+
("model_id", "expected_runner_type", "expected_task"),
124+
[
125+
("openai/whisper-small", "generate", "transcription"),
126+
],
127+
)
128+
def test_transcription_task(model_id, expected_runner_type, expected_task):
129+
config = ModelConfig(
130+
model_id,
131+
task="transcription",
132+
tokenizer=model_id,
133+
tokenizer_mode="auto",
134+
trust_remote_code=False,
135+
seed=0,
136+
dtype="float16",
137+
)
138+
139+
assert config.runner_type == expected_runner_type
140+
assert config.task == expected_task
141+
142+
101143
@pytest.mark.parametrize(("model_id", "bad_task"), [
102144
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
145+
("Qwen/Qwen3-0.6B", "transcription"),
103146
])
104147
def test_incorrect_task(model_id, bad_task):
105-
with pytest.raises(ValueError, match=r"does not support the .* task"):
148+
with pytest.raises(ValueError, match=r"does not support task=.*"):
106149
ModelConfig(
107150
model_id,
108151
task=bad_task,

0 commit comments

Comments
 (0)