-
Notifications
You must be signed in to change notification settings - Fork 250
fix: QwQ and R1 finetune format #294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
a0334cc
db6cf83
b35706b
096c6db
e9b06a9
26ee24a
7c7c637
7b14cc9
2415c94
be33978
6a1f7d1
86c8512
2848879
b25ea6c
fca9792
defa1ec
1070910
f45eb69
ea59f62
fb123c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,15 @@ | |
from fastapi.testclient import TestClient | ||
from kiln_ai.adapters.fine_tune.base_finetune import FineTuneParameter | ||
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat | ||
from kiln_ai.adapters.ml_model_list import KilnModel, KilnModelProvider | ||
from kiln_ai.adapters.ml_model_list import ( | ||
KilnModel, | ||
KilnModelProvider, | ||
ModelFamily, | ||
ModelName, | ||
ModelParserID, | ||
ModelProviderName, | ||
StructuredOutputMode, | ||
) | ||
from kiln_ai.datamodel import ( | ||
DatasetSplit, | ||
Finetune, | ||
|
@@ -32,6 +40,7 @@ | |
FinetuneProviderModel, | ||
connect_fine_tune_api, | ||
fetch_fireworks_finetune_models, | ||
infer_data_strategies_for_model, | ||
thinking_instructions_from_request, | ||
) | ||
|
||
|
@@ -110,6 +119,18 @@ def client(): | |
return TestClient(app) | ||
|
||
|
||
def test_finetune_provider_model_defaults(): | ||
model = FinetuneProviderModel( | ||
name="Test Provider", | ||
id="test_provider", | ||
) | ||
|
||
assert model.data_strategies_supported == [ | ||
FinetuneDataStrategy.final_only, | ||
FinetuneDataStrategy.final_and_intermediate, | ||
] | ||
|
||
|
||
def test_get_dataset_splits(client, mock_task_from_id_disk_backed, test_task): | ||
response = client.get("/api/projects/project1/tasks/task1/dataset_splits") | ||
|
||
|
@@ -1260,3 +1281,141 @@ async def test_fetch_fireworks_finetune_models_http_error( | |
await fetch_fireworks_finetune_models() | ||
|
||
mock_httpx_client.get.assert_called_once() | ||
|
||
|
||
@pytest.fixture | ||
def mock_available_models(): | ||
return [ | ||
KilnModel( | ||
family=ModelFamily.gpt, | ||
name=ModelName.gpt_4_1, | ||
friendly_name="GPT 4.1", | ||
providers=[ | ||
KilnModelProvider( | ||
name=ModelProviderName.openai, | ||
model_id="gpt-4.1", | ||
provider_finetune_id="gpt-4.1-2025-04-14", | ||
), | ||
KilnModelProvider( | ||
name=ModelProviderName.openrouter, | ||
model_id="openai/gpt-4.1", | ||
), | ||
KilnModelProvider( | ||
name=ModelProviderName.azure_openai, | ||
model_id="gpt-4.1", | ||
), | ||
], | ||
), | ||
KilnModel( | ||
family=ModelFamily.gpt, | ||
name=ModelName.gpt_4_1_mini, | ||
friendly_name="GPT 4.1 Mini", | ||
providers=[ | ||
KilnModelProvider( | ||
name=ModelProviderName.openai, | ||
model_id="gpt-4.1-mini", | ||
provider_finetune_id="gpt-4.1-mini-2025-04-14", | ||
), | ||
KilnModelProvider( | ||
name=ModelProviderName.openrouter, | ||
model_id="openai/gpt-4.1-mini", | ||
), | ||
KilnModelProvider( | ||
name=ModelProviderName.azure_openai, | ||
model_id="gpt-4.1-mini", | ||
), | ||
], | ||
), | ||
KilnModel( | ||
family=ModelFamily.qwen, | ||
name=ModelName.qwq_32b, | ||
friendly_name="QwQ 32B", | ||
providers=[ | ||
KilnModelProvider( | ||
name=ModelProviderName.huggingface, | ||
model_id="qwen/qwq-32b", | ||
provider_finetune_id="qwq-32b-xxx", | ||
parser=ModelParserID.r1_thinking, | ||
) | ||
], | ||
), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model_id, provider, expected_data_strategies", | ||
[ | ||
# for models where we have built-in models, we can infer the data strategies from the object itself | ||
( | ||
# does not have a parser, so should be defaults | ||
"gpt-4.1-2025-04-14", | ||
"openai", | ||
[ | ||
FinetuneDataStrategy.final_only, | ||
FinetuneDataStrategy.final_and_intermediate, | ||
], | ||
), | ||
( | ||
# does not have a parser, so should be defaults | ||
"gpt-4.1-mini-2025-04-14", | ||
"openai", | ||
[ | ||
FinetuneDataStrategy.final_only, | ||
FinetuneDataStrategy.final_and_intermediate, | ||
], | ||
), | ||
( | ||
# this model is not in any list, so should be defaults | ||
"fake-model-id", | ||
"fake-provider", | ||
[ | ||
FinetuneDataStrategy.final_only, | ||
FinetuneDataStrategy.final_and_intermediate, | ||
], | ||
), | ||
# this model has an R1 parser, should be r1 compatible | ||
( | ||
"qwq-32b-xxx", | ||
"huggingface", | ||
[ | ||
FinetuneDataStrategy.final_and_intermediate_r1_compatible, | ||
], | ||
), | ||
# for fireworks_ai models, we infer the data strategies from the model name | ||
( | ||
# does not contain r1 or qwq in the id so it should be defaults | ||
"some-model-id", | ||
"fireworks_ai", | ||
[ | ||
FinetuneDataStrategy.final_only, | ||
FinetuneDataStrategy.final_and_intermediate, | ||
], | ||
), | ||
( | ||
# contains r1 in the id so it should be r1 compatible | ||
"some-model-with-r1-in-id", | ||
"fireworks_ai", | ||
[ | ||
FinetuneDataStrategy.final_and_intermediate_r1_compatible, | ||
], | ||
), | ||
( | ||
# contains qwq in the id so it should be r1 compatible | ||
"some-model-with-qwq-in-id", | ||
"fireworks_ai", | ||
[ | ||
FinetuneDataStrategy.final_and_intermediate_r1_compatible, | ||
], | ||
), | ||
], | ||
) | ||
def test_infer_data_strategies( | ||
mock_available_models, | ||
model_id: str, | ||
provider: str, | ||
expected_data_strategies: list[FinetuneDataStrategy], | ||
): | ||
assert ( | ||
infer_data_strategies_for_model(mock_available_models, model_id, provider) | ||
== expected_data_strategies | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd add a parameterize test with expected real values from the API. Just on data_strategies_from_finetune_id so the parameter is short and sweet. Expect R1:
Some ones we expect false:
And some new Qwen 3 (TBD, see convo above):
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’ve added a test (including the qwen3 case) in commit 86c8512 Note: the validation targets the model’s Since we only need this for Fireworks.ai currently - and their IDs always include the necessary keywords - validating on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And Qwen 3 already requires a change 😀
We shoud add "qwen3", and qwen3 should return final_only, final_and_intermediate_r1_compatible (since it can do both thinking or non thinking with /think /no_think directives)
Let's maybe split this into another PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a pattern match in the
data_strategies_from_finetune_id
to targetqwen3
and only allow['final_and_intermediate_r1_compatible', 'final_only']
- just to be able to test for the allowed strategies in a parameterized test, but have not implemented any other logic forqwen3
besides that