Skip to content

fix: QwQ and R1 finetune format #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
a0334cc
wip: annotate todos
leonardmq Apr 22, 2025
db6cf83
wip: fix qwq/r1 data strategy
leonardmq Apr 24, 2025
b35706b
test: add cases and update existing tests
leonardmq Apr 25, 2025
096c6db
fix: default values on finetune_model_provider pydantic model
leonardmq Apr 26, 2025
e9b06a9
fix: expose all data strategies in UI dropdown when
leonardmq Apr 26, 2025
26ee24a
fix: dataset formatter duplicate final output message
leonardmq Apr 26, 2025
7c7c637
fix: validation of data strategy, and small refactor and validation t…
leonardmq Apr 27, 2025
7b14cc9
chore: remove obsolete todo
leonardmq Apr 27, 2025
2415c94
refactor: extract valid thinking data strategies into own constant
leonardmq Apr 27, 2025
be33978
fix: raise error in r1 serialization if none or empty thinking
leonardmq Apr 27, 2025
6a1f7d1
refactor: data formatter generators and fixes on COT vs R1
leonardmq Apr 27, 2025
86c8512
test: add tests for data_strategies_from_finetune_id and qwen3 match
leonardmq Apr 30, 2025
2848879
chore: replace error message strings
leonardmq Apr 30, 2025
b25ea6c
fix: formatter fix (newlines, and throw if R1 while vertex)
leonardmq Apr 30, 2025
fca9792
fix: UI to select strategy, clean labels, clean switch default option
leonardmq Apr 30, 2025
defa1ec
ui: error and block submit when no thinking filter dataset for R1 model
leonardmq Apr 30, 2025
1070910
fix: disable R1 data strategy for vertex download, toolcall downloads
leonardmq Apr 30, 2025
f45eb69
chore: remove lingering console log
leonardmq May 1, 2025
ea59f62
fix: string change for warning, and not block submit
leonardmq May 1, 2025
fb123c8
fix: no thinking instructions on r1 model and update validation
leonardmq May 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 87 additions & 16 deletions app/desktop/studio_server/finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,26 @@
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
from kiln_ai.adapters.fine_tune.finetune_registry import finetune_registry
from kiln_ai.adapters.ml_model_list import (
KilnModel,
KilnModelProvider,
ModelParserID,
ModelProviderName,
built_in_models,
)
from kiln_ai.adapters.prompt_builders import (
chain_of_thought_prompt,
prompt_builder_from_id,
)
from kiln_ai.adapters.provider_tools import (
provider_enabled,
provider_name_from_id,
)
from kiln_ai.adapters.provider_tools import provider_enabled, provider_name_from_id
from kiln_ai.datamodel import (
DatasetSplit,
Finetune,
FinetuneDataStrategy,
FineTuneStatusType,
Task,
)
from kiln_ai.datamodel.dataset_filters import (
DatasetFilterId,
)
from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES
from kiln_ai.datamodel.dataset_filters import DatasetFilterId
from kiln_ai.datamodel.dataset_split import (
AllSplitDefinition,
Train60Test20Val20SplitDefinition,
Expand All @@ -38,7 +37,7 @@
from kiln_ai.utils.config import Config
from kiln_ai.utils.name_generator import generate_memorable_name
from kiln_server.task_api import task_from_id
from pydantic import BaseModel
from pydantic import BaseModel, Field, model_validator

logger = logging.getLogger(__name__)

Expand All @@ -48,6 +47,12 @@ class FinetuneProviderModel(BaseModel):

name: str
id: str
data_strategies_supported: list[FinetuneDataStrategy] = Field(
default_factory=lambda: [
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate,
]
)


class FinetuneProvider(BaseModel):
Expand Down Expand Up @@ -101,6 +106,16 @@ class CreateFinetuneRequest(BaseModel):
custom_thinking_instructions: str | None = None
data_strategy: FinetuneDataStrategy

@model_validator(mode="after")
def validate_data_strategy(self) -> "CreateFinetuneRequest":
if self.data_strategy not in infer_data_strategies_for_model(
built_in_models, self.base_model_id, self.provider
):
raise ValueError(
f"The data strategy {self.data_strategy} is not supported for the provider model {self.base_model_id}"
)
return self


class FinetuneWithStatus(BaseModel):
"""Finetune with status"""
Expand Down Expand Up @@ -198,7 +213,8 @@ async def finetune_providers() -> list[FinetuneProvider]:
provider_models[provider.name] = []
provider_models[provider.name].append(
FinetuneProviderModel(
name=model.friendly_name, id=provider.provider_finetune_id
name=model.friendly_name,
id=provider.provider_finetune_id,
)
)

Expand All @@ -212,14 +228,19 @@ async def finetune_providers() -> list[FinetuneProvider]:
# Create provider entries
providers: list[FinetuneProvider] = []
for provider_name, models in provider_models.items():
providers.append(
FinetuneProvider(
name=provider_name_from_id(provider_name),
id=provider_name,
enabled=await provider_enabled(provider_name),
models=models,
# attach the compatible data strategies to each model
for model in models:
model.data_strategies_supported = infer_data_strategies_for_model(
built_in_models, model.id, provider_name
)

provider = FinetuneProvider(
name=provider_name_from_id(provider_name),
id=provider_name,
enabled=await provider_enabled(provider_name),
models=models,
)
providers.append(provider)

return providers

Expand Down Expand Up @@ -326,6 +347,7 @@ async def download_dataset_jsonl(
status_code=400,
detail=f"Data strategy '{data_strategy}' not found",
)

data_strategy_typed = FinetuneDataStrategy(data_strategy)

task = task_from_id(project_id, task_id)
Expand Down Expand Up @@ -406,7 +428,7 @@ def thinking_instructions_from_request(
data_strategy: FinetuneDataStrategy,
custom_thinking_instructions: str | None,
) -> str | None:
if data_strategy != FinetuneDataStrategy.final_and_intermediate:
if data_strategy not in THINKING_DATA_STRATEGIES:
# Not using COT/Thinking style
return None

Expand Down Expand Up @@ -477,3 +499,52 @@ async def fetch_fireworks_finetune_models() -> list[FinetuneProviderModel]:
)

return tuneable_models


DEFAULT_DATA_STRATEGIES = [
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate,
]


def data_strategies_from_model_provider(
provider: KilnModelProvider,
) -> list[FinetuneDataStrategy]:
if provider.parser == ModelParserID.r1_thinking:
return [
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
]
return DEFAULT_DATA_STRATEGIES


def data_strategies_from_finetune_id(
provider_finetune_id: str,
) -> list[FinetuneDataStrategy]:
r1_must_include = ["r1", "qwq"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And Qwen 3 already requires a change 😀

We shoud add "qwen3", and qwen3 should return final_only, final_and_intermediate_r1_compatible (since it can do both thinking or non thinking with /think /no_think directives)

Let's maybe split this into another PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a pattern match in the data_strategies_from_finetune_id to target qwen3 and only allow ['final_and_intermediate_r1_compatible', 'final_only'] - just to be able to test for the allowed strategies in a parameterized test, but have not implemented any other logic for qwen3 besides that

if any(substring in provider_finetune_id.lower() for substring in r1_must_include):
return [
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
]
return DEFAULT_DATA_STRATEGIES


def infer_data_strategies_for_model(
available_models: list[KilnModel],
provider_finetune_id: str,
provider_name: str,
) -> list[FinetuneDataStrategy]:
# we don't have built-in models for fireworks models, so we infer the data strategy from the model name
if provider_name == ModelProviderName.fireworks_ai:
return data_strategies_from_finetune_id(provider_finetune_id)

# where we have built-in models, we can infer the data strategy from the object itself
for model in available_models:
for provider in model.providers:
if (
provider.name == provider_name
and provider.provider_finetune_id == provider_finetune_id
):
return data_strategies_from_model_provider(provider)

# for everything else, we don't know what the data strategy is, so we use the default
return DEFAULT_DATA_STRATEGIES
161 changes: 160 additions & 1 deletion app/desktop/studio_server/test_finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
from fastapi.testclient import TestClient
from kiln_ai.adapters.fine_tune.base_finetune import FineTuneParameter
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat
from kiln_ai.adapters.ml_model_list import KilnModel, KilnModelProvider
from kiln_ai.adapters.ml_model_list import (
KilnModel,
KilnModelProvider,
ModelFamily,
ModelName,
ModelParserID,
ModelProviderName,
StructuredOutputMode,
)
from kiln_ai.datamodel import (
DatasetSplit,
Finetune,
Expand All @@ -32,6 +40,7 @@
FinetuneProviderModel,
connect_fine_tune_api,
fetch_fireworks_finetune_models,
infer_data_strategies_for_model,
thinking_instructions_from_request,
)

Expand Down Expand Up @@ -110,6 +119,18 @@ def client():
return TestClient(app)


def test_finetune_provider_model_defaults():
model = FinetuneProviderModel(
name="Test Provider",
id="test_provider",
)

assert model.data_strategies_supported == [
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate,
]


def test_get_dataset_splits(client, mock_task_from_id_disk_backed, test_task):
response = client.get("/api/projects/project1/tasks/task1/dataset_splits")

Expand Down Expand Up @@ -1260,3 +1281,141 @@ async def test_fetch_fireworks_finetune_models_http_error(
await fetch_fireworks_finetune_models()

mock_httpx_client.get.assert_called_once()


@pytest.fixture
def mock_available_models():
return [
KilnModel(
family=ModelFamily.gpt,
name=ModelName.gpt_4_1,
friendly_name="GPT 4.1",
providers=[
KilnModelProvider(
name=ModelProviderName.openai,
model_id="gpt-4.1",
provider_finetune_id="gpt-4.1-2025-04-14",
),
KilnModelProvider(
name=ModelProviderName.openrouter,
model_id="openai/gpt-4.1",
),
KilnModelProvider(
name=ModelProviderName.azure_openai,
model_id="gpt-4.1",
),
],
),
KilnModel(
family=ModelFamily.gpt,
name=ModelName.gpt_4_1_mini,
friendly_name="GPT 4.1 Mini",
providers=[
KilnModelProvider(
name=ModelProviderName.openai,
model_id="gpt-4.1-mini",
provider_finetune_id="gpt-4.1-mini-2025-04-14",
),
KilnModelProvider(
name=ModelProviderName.openrouter,
model_id="openai/gpt-4.1-mini",
),
KilnModelProvider(
name=ModelProviderName.azure_openai,
model_id="gpt-4.1-mini",
),
],
),
KilnModel(
family=ModelFamily.qwen,
name=ModelName.qwq_32b,
friendly_name="QwQ 32B",
providers=[
KilnModelProvider(
name=ModelProviderName.huggingface,
model_id="qwen/qwq-32b",
provider_finetune_id="qwq-32b-xxx",
parser=ModelParserID.r1_thinking,
)
],
),
]


@pytest.mark.parametrize(
"model_id, provider, expected_data_strategies",
[
# for models where we have built-in models, we can infer the data strategies from the object itself
(
# does not have a parser, so should be defaults
"gpt-4.1-2025-04-14",
"openai",
[
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate,
],
),
(
# does not have a parser, so should be defaults
"gpt-4.1-mini-2025-04-14",
"openai",
[
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate,
],
),
(
# this model is not in any list, so should be defaults
"fake-model-id",
"fake-provider",
[
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate,
],
),
# this model has an R1 parser, should be r1 compatible
(
"qwq-32b-xxx",
"huggingface",
[
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
],
),
# for fireworks_ai models, we infer the data strategies from the model name
(
# does not contain r1 or qwq in the id so it should be defaults
"some-model-id",
"fireworks_ai",
[
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate,
],
),
(
# contains r1 in the id so it should be r1 compatible
"some-model-with-r1-in-id",
"fireworks_ai",
[
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
],
),
(
# contains qwq in the id so it should be r1 compatible
"some-model-with-qwq-in-id",
"fireworks_ai",
[
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
],
),
],
)
def test_infer_data_strategies(
mock_available_models,
model_id: str,
provider: str,
expected_data_strategies: list[FinetuneDataStrategy],
):
assert (
infer_data_strategies_for_model(mock_available_models, model_id, provider)
== expected_data_strategies
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a parameterize test with expected real values from the API. Just on data_strategies_from_finetune_id so the parameter is short and sweet.

Expect R1:

  • qwq-32b
  • DeepSeek R1 (Fast) (deepseek-r1)
  • DeepSeek R1 (Basic) (deepseek-r1-basic)
  • deepseek-r1-distill-llama-70b
  • Deepseek R1 Distill Llama 8B (deepseek-r1-distill-llama-8b)
  • Deepseek R1 Distill Qwen 14B (deepseek-r1-distill-qwen-14b)
  • Deepseek R1 Distill Qwen 1.5B (deepseek-r1-distill-qwen-1p5b)
  • deepseek-r1-distill-qwen-32b
  • Deepseek R1 Distill Qwen 7B (deepseek-r1-distill-qwen-7b)

Some ones we expect false:

  • DeepSeek V3 (deepseek-v3)
  • Deepseek V3 03-24 (deepseek-v3-0324)

And some new Qwen 3 (TBD, see convo above):

  • Qwen 3, 30B-A3B version (qwen3-30b-a3b)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve added a test (including the qwen3 case) in commit 86c8512

Note: the validation targets the model’s id, not its name. Checking the name would require fetching the name from Fireworks.ai, which would then require an async Pydantic validator (which does not seem standard / supported) and extra wiring to attach data strategies per model; or alternatively some caching of the Fireworks.ai response which may come with other problems.

Since we only need this for Fireworks.ai currently - and their IDs always include the necessary keywords - validating on id should hold ok enough until we move to the remote JSON model registry

Loading
Loading