Skip to content

fix: QwQ and R1 finetune format #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
a0334cc
wip: annotate todos
leonardmq Apr 22, 2025
db6cf83
wip: fix qwq/r1 data strategy
leonardmq Apr 24, 2025
b35706b
test: add cases and update existing tests
leonardmq Apr 25, 2025
096c6db
fix: default values on finetune_model_provider pydantic model
leonardmq Apr 26, 2025
e9b06a9
fix: expose all data strategies in UI dropdown when
leonardmq Apr 26, 2025
26ee24a
fix: dataset formatter duplicate final output message
leonardmq Apr 26, 2025
7c7c637
fix: validation of data strategy, and small refactor and validation t…
leonardmq Apr 27, 2025
7b14cc9
chore: remove obsolete todo
leonardmq Apr 27, 2025
2415c94
refactor: extract valid thinking data strategies into own constant
leonardmq Apr 27, 2025
be33978
fix: raise error in r1 serialization if none or empty thinking
leonardmq Apr 27, 2025
6a1f7d1
refactor: data formatter generators and fixes on COT vs R1
leonardmq Apr 27, 2025
86c8512
test: add tests for data_strategies_from_finetune_id and qwen3 match
leonardmq Apr 30, 2025
2848879
chore: replace error message strings
leonardmq Apr 30, 2025
b25ea6c
fix: formatter fix (newlines, and throw if R1 while vertex)
leonardmq Apr 30, 2025
fca9792
fix: UI to select strategy, clean labels, clean switch default option
leonardmq Apr 30, 2025
defa1ec
ui: error and block submit when no thinking filter dataset for R1 model
leonardmq Apr 30, 2025
1070910
fix: disable R1 data strategy for vertex download, toolcall downloads
leonardmq Apr 30, 2025
f45eb69
chore: remove lingering console log
leonardmq May 1, 2025
ea59f62
fix: string change for warning, and not block submit
leonardmq May 1, 2025
fb123c8
fix: no thinking instructions on r1 model and update validation
leonardmq May 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 96 additions & 16 deletions app/desktop/studio_server/finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,26 @@
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
from kiln_ai.adapters.fine_tune.finetune_registry import finetune_registry
from kiln_ai.adapters.ml_model_list import (
KilnModel,
KilnModelProvider,
ModelParserID,
ModelProviderName,
built_in_models,
)
from kiln_ai.adapters.prompt_builders import (
chain_of_thought_prompt,
prompt_builder_from_id,
)
from kiln_ai.adapters.provider_tools import (
provider_enabled,
provider_name_from_id,
)
from kiln_ai.adapters.provider_tools import provider_enabled, provider_name_from_id
from kiln_ai.datamodel import (
DatasetSplit,
Finetune,
FinetuneDataStrategy,
FineTuneStatusType,
Task,
)
from kiln_ai.datamodel.dataset_filters import (
DatasetFilterId,
)
from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES
from kiln_ai.datamodel.dataset_filters import DatasetFilterId
from kiln_ai.datamodel.dataset_split import (
AllSplitDefinition,
Train60Test20Val20SplitDefinition,
Expand All @@ -38,7 +37,7 @@
from kiln_ai.utils.config import Config
from kiln_ai.utils.name_generator import generate_memorable_name
from kiln_server.task_api import task_from_id
from pydantic import BaseModel
from pydantic import BaseModel, Field, model_validator

logger = logging.getLogger(__name__)

Expand All @@ -48,6 +47,12 @@ class FinetuneProviderModel(BaseModel):

name: str
id: str
data_strategies_supported: list[FinetuneDataStrategy] = Field(
default_factory=lambda: [
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate,
]
)


class FinetuneProvider(BaseModel):
Expand Down Expand Up @@ -101,6 +106,16 @@ class CreateFinetuneRequest(BaseModel):
custom_thinking_instructions: str | None = None
data_strategy: FinetuneDataStrategy

@model_validator(mode="after")
def validate_data_strategy(self) -> "CreateFinetuneRequest":
if self.data_strategy not in infer_data_strategies_for_model(
built_in_models, self.base_model_id, self.provider
):
raise ValueError(
f"The data strategy {self.data_strategy} is not supported for the provider model {self.base_model_id}"
)
return self


class FinetuneWithStatus(BaseModel):
"""Finetune with status"""
Expand Down Expand Up @@ -198,7 +213,8 @@ async def finetune_providers() -> list[FinetuneProvider]:
provider_models[provider.name] = []
provider_models[provider.name].append(
FinetuneProviderModel(
name=model.friendly_name, id=provider.provider_finetune_id
name=model.friendly_name,
id=provider.provider_finetune_id,
)
)

Expand All @@ -212,14 +228,19 @@ async def finetune_providers() -> list[FinetuneProvider]:
# Create provider entries
providers: list[FinetuneProvider] = []
for provider_name, models in provider_models.items():
providers.append(
FinetuneProvider(
name=provider_name_from_id(provider_name),
id=provider_name,
enabled=await provider_enabled(provider_name),
models=models,
# attach the compatible data strategies to each model
for model in models:
model.data_strategies_supported = infer_data_strategies_for_model(
built_in_models, model.id, provider_name
)

provider = FinetuneProvider(
name=provider_name_from_id(provider_name),
id=provider_name,
enabled=await provider_enabled(provider_name),
models=models,
)
providers.append(provider)

return providers

Expand Down Expand Up @@ -326,6 +347,7 @@ async def download_dataset_jsonl(
status_code=400,
detail=f"Data strategy '{data_strategy}' not found",
)

data_strategy_typed = FinetuneDataStrategy(data_strategy)

task = task_from_id(project_id, task_id)
Expand Down Expand Up @@ -406,10 +428,13 @@ def thinking_instructions_from_request(
data_strategy: FinetuneDataStrategy,
custom_thinking_instructions: str | None,
) -> str | None:
if data_strategy != FinetuneDataStrategy.final_and_intermediate:
if data_strategy not in THINKING_DATA_STRATEGIES:
# Not using COT/Thinking style
return None

if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
return None

if custom_thinking_instructions:
# prefer custom instructions
return custom_thinking_instructions
Expand Down Expand Up @@ -477,3 +502,58 @@ async def fetch_fireworks_finetune_models() -> list[FinetuneProviderModel]:
)

return tuneable_models


DEFAULT_DATA_STRATEGIES = [
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate,
]


def data_strategies_from_model_provider(
provider: KilnModelProvider,
) -> list[FinetuneDataStrategy]:
if provider.parser == ModelParserID.r1_thinking:
return [
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
]
return DEFAULT_DATA_STRATEGIES


def data_strategies_from_finetune_id(
provider_finetune_id: str,
) -> list[FinetuneDataStrategy]:
if "qwen3" in provider_finetune_id.lower():
return [
FinetuneDataStrategy.final_only,
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
]

r1_must_include = ["r1", "qwq"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And Qwen 3 already requires a change 😀

We shoud add "qwen3", and qwen3 should return final_only, final_and_intermediate_r1_compatible (since it can do both thinking or non thinking with /think /no_think directives)

Let's maybe split this into another PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a pattern match in the data_strategies_from_finetune_id to target qwen3 and only allow ['final_and_intermediate_r1_compatible', 'final_only'] - just to be able to test for the allowed strategies in a parameterized test, but have not implemented any other logic for qwen3 besides that

if any(substring in provider_finetune_id.lower() for substring in r1_must_include):
return [
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
]
return DEFAULT_DATA_STRATEGIES


def infer_data_strategies_for_model(
available_models: list[KilnModel],
provider_finetune_id: str,
provider_name: str,
) -> list[FinetuneDataStrategy]:
# we don't have built-in models for fireworks models, so we infer the data strategy from the model name
if provider_name == ModelProviderName.fireworks_ai:
return data_strategies_from_finetune_id(provider_finetune_id)

# where we have built-in models, we can infer the data strategy from the object itself
for model in available_models:
for provider in model.providers:
if (
provider.name == provider_name
and provider.provider_finetune_id == provider_finetune_id
):
return data_strategies_from_model_provider(provider)

# for everything else, we don't know what the data strategy is, so we use the default
return DEFAULT_DATA_STRATEGIES
Loading
Loading