Skip to content

Commit 4d25808

Browse files
authored
Merge pull request #294 from leonardmq/leonard/fix-qwq-fine-tune-format
fix: QwQ and R1 finetune format
2 parents 30dce61 + fb123c8 commit 4d25808

File tree

13 files changed

+910
-78
lines changed

13 files changed

+910
-78
lines changed

app/desktop/studio_server/finetune_api.py

Lines changed: 96 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,26 @@
88
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
99
from kiln_ai.adapters.fine_tune.finetune_registry import finetune_registry
1010
from kiln_ai.adapters.ml_model_list import (
11+
KilnModel,
12+
KilnModelProvider,
13+
ModelParserID,
1114
ModelProviderName,
1215
built_in_models,
1316
)
1417
from kiln_ai.adapters.prompt_builders import (
1518
chain_of_thought_prompt,
1619
prompt_builder_from_id,
1720
)
18-
from kiln_ai.adapters.provider_tools import (
19-
provider_enabled,
20-
provider_name_from_id,
21-
)
21+
from kiln_ai.adapters.provider_tools import provider_enabled, provider_name_from_id
2222
from kiln_ai.datamodel import (
2323
DatasetSplit,
2424
Finetune,
2525
FinetuneDataStrategy,
2626
FineTuneStatusType,
2727
Task,
2828
)
29-
from kiln_ai.datamodel.dataset_filters import (
30-
DatasetFilterId,
31-
)
29+
from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES
30+
from kiln_ai.datamodel.dataset_filters import DatasetFilterId
3231
from kiln_ai.datamodel.dataset_split import (
3332
AllSplitDefinition,
3433
Train60Test20Val20SplitDefinition,
@@ -38,7 +37,7 @@
3837
from kiln_ai.utils.config import Config
3938
from kiln_ai.utils.name_generator import generate_memorable_name
4039
from kiln_server.task_api import task_from_id
41-
from pydantic import BaseModel
40+
from pydantic import BaseModel, Field, model_validator
4241

4342
logger = logging.getLogger(__name__)
4443

@@ -48,6 +47,12 @@ class FinetuneProviderModel(BaseModel):
4847

4948
name: str
5049
id: str
50+
data_strategies_supported: list[FinetuneDataStrategy] = Field(
51+
default_factory=lambda: [
52+
FinetuneDataStrategy.final_only,
53+
FinetuneDataStrategy.final_and_intermediate,
54+
]
55+
)
5156

5257

5358
class FinetuneProvider(BaseModel):
@@ -101,6 +106,16 @@ class CreateFinetuneRequest(BaseModel):
101106
custom_thinking_instructions: str | None = None
102107
data_strategy: FinetuneDataStrategy
103108

109+
@model_validator(mode="after")
110+
def validate_data_strategy(self) -> "CreateFinetuneRequest":
111+
if self.data_strategy not in infer_data_strategies_for_model(
112+
built_in_models, self.base_model_id, self.provider
113+
):
114+
raise ValueError(
115+
f"The data strategy {self.data_strategy} is not supported for the provider model {self.base_model_id}"
116+
)
117+
return self
118+
104119

105120
class FinetuneWithStatus(BaseModel):
106121
"""Finetune with status"""
@@ -198,7 +213,8 @@ async def finetune_providers() -> list[FinetuneProvider]:
198213
provider_models[provider.name] = []
199214
provider_models[provider.name].append(
200215
FinetuneProviderModel(
201-
name=model.friendly_name, id=provider.provider_finetune_id
216+
name=model.friendly_name,
217+
id=provider.provider_finetune_id,
202218
)
203219
)
204220

@@ -212,14 +228,19 @@ async def finetune_providers() -> list[FinetuneProvider]:
212228
# Create provider entries
213229
providers: list[FinetuneProvider] = []
214230
for provider_name, models in provider_models.items():
215-
providers.append(
216-
FinetuneProvider(
217-
name=provider_name_from_id(provider_name),
218-
id=provider_name,
219-
enabled=await provider_enabled(provider_name),
220-
models=models,
231+
# attach the compatible data strategies to each model
232+
for model in models:
233+
model.data_strategies_supported = infer_data_strategies_for_model(
234+
built_in_models, model.id, provider_name
221235
)
236+
237+
provider = FinetuneProvider(
238+
name=provider_name_from_id(provider_name),
239+
id=provider_name,
240+
enabled=await provider_enabled(provider_name),
241+
models=models,
222242
)
243+
providers.append(provider)
223244

224245
return providers
225246

@@ -326,6 +347,7 @@ async def download_dataset_jsonl(
326347
status_code=400,
327348
detail=f"Data strategy '{data_strategy}' not found",
328349
)
350+
329351
data_strategy_typed = FinetuneDataStrategy(data_strategy)
330352

331353
task = task_from_id(project_id, task_id)
@@ -406,10 +428,13 @@ def thinking_instructions_from_request(
406428
data_strategy: FinetuneDataStrategy,
407429
custom_thinking_instructions: str | None,
408430
) -> str | None:
409-
if data_strategy != FinetuneDataStrategy.final_and_intermediate:
431+
if data_strategy not in THINKING_DATA_STRATEGIES:
410432
# Not using COT/Thinking style
411433
return None
412434

435+
if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
436+
return None
437+
413438
if custom_thinking_instructions:
414439
# prefer custom instructions
415440
return custom_thinking_instructions
@@ -477,3 +502,58 @@ async def fetch_fireworks_finetune_models() -> list[FinetuneProviderModel]:
477502
)
478503

479504
return tuneable_models
505+
506+
507+
DEFAULT_DATA_STRATEGIES = [
508+
FinetuneDataStrategy.final_only,
509+
FinetuneDataStrategy.final_and_intermediate,
510+
]
511+
512+
513+
def data_strategies_from_model_provider(
514+
provider: KilnModelProvider,
515+
) -> list[FinetuneDataStrategy]:
516+
if provider.parser == ModelParserID.r1_thinking:
517+
return [
518+
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
519+
]
520+
return DEFAULT_DATA_STRATEGIES
521+
522+
523+
def data_strategies_from_finetune_id(
524+
provider_finetune_id: str,
525+
) -> list[FinetuneDataStrategy]:
526+
if "qwen3" in provider_finetune_id.lower():
527+
return [
528+
FinetuneDataStrategy.final_only,
529+
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
530+
]
531+
532+
r1_must_include = ["r1", "qwq"]
533+
if any(substring in provider_finetune_id.lower() for substring in r1_must_include):
534+
return [
535+
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
536+
]
537+
return DEFAULT_DATA_STRATEGIES
538+
539+
540+
def infer_data_strategies_for_model(
541+
available_models: list[KilnModel],
542+
provider_finetune_id: str,
543+
provider_name: str,
544+
) -> list[FinetuneDataStrategy]:
545+
# we don't have built-in models for fireworks models, so we infer the data strategy from the model name
546+
if provider_name == ModelProviderName.fireworks_ai:
547+
return data_strategies_from_finetune_id(provider_finetune_id)
548+
549+
# where we have built-in models, we can infer the data strategy from the object itself
550+
for model in available_models:
551+
for provider in model.providers:
552+
if (
553+
provider.name == provider_name
554+
and provider.provider_finetune_id == provider_finetune_id
555+
):
556+
return data_strategies_from_model_provider(provider)
557+
558+
# for everything else, we don't know what the data strategy is, so we use the default
559+
return DEFAULT_DATA_STRATEGIES

0 commit comments

Comments
 (0)