8
8
from kiln_ai .adapters .fine_tune .dataset_formatter import DatasetFormat , DatasetFormatter
9
9
from kiln_ai .adapters .fine_tune .finetune_registry import finetune_registry
10
10
from kiln_ai .adapters .ml_model_list import (
11
+ KilnModel ,
12
+ KilnModelProvider ,
13
+ ModelParserID ,
11
14
ModelProviderName ,
12
15
built_in_models ,
13
16
)
14
17
from kiln_ai .adapters .prompt_builders import (
15
18
chain_of_thought_prompt ,
16
19
prompt_builder_from_id ,
17
20
)
18
- from kiln_ai .adapters .provider_tools import (
19
- provider_enabled ,
20
- provider_name_from_id ,
21
- )
21
+ from kiln_ai .adapters .provider_tools import provider_enabled , provider_name_from_id
22
22
from kiln_ai .datamodel import (
23
23
DatasetSplit ,
24
24
Finetune ,
25
25
FinetuneDataStrategy ,
26
26
FineTuneStatusType ,
27
27
Task ,
28
28
)
29
- from kiln_ai .datamodel .dataset_filters import (
30
- DatasetFilterId ,
31
- )
29
+ from kiln_ai .datamodel .datamodel_enums import THINKING_DATA_STRATEGIES
30
+ from kiln_ai .datamodel .dataset_filters import DatasetFilterId
32
31
from kiln_ai .datamodel .dataset_split import (
33
32
AllSplitDefinition ,
34
33
Train60Test20Val20SplitDefinition ,
38
37
from kiln_ai .utils .config import Config
39
38
from kiln_ai .utils .name_generator import generate_memorable_name
40
39
from kiln_server .task_api import task_from_id
41
- from pydantic import BaseModel
40
+ from pydantic import BaseModel , Field , model_validator
42
41
43
42
logger = logging .getLogger (__name__ )
44
43
@@ -48,6 +47,12 @@ class FinetuneProviderModel(BaseModel):
48
47
49
48
name : str
50
49
id : str
50
+ data_strategies_supported : list [FinetuneDataStrategy ] = Field (
51
+ default_factory = lambda : [
52
+ FinetuneDataStrategy .final_only ,
53
+ FinetuneDataStrategy .final_and_intermediate ,
54
+ ]
55
+ )
51
56
52
57
53
58
class FinetuneProvider (BaseModel ):
@@ -101,6 +106,16 @@ class CreateFinetuneRequest(BaseModel):
101
106
custom_thinking_instructions : str | None = None
102
107
data_strategy : FinetuneDataStrategy
103
108
109
+ @model_validator (mode = "after" )
110
+ def validate_data_strategy (self ) -> "CreateFinetuneRequest" :
111
+ if self .data_strategy not in infer_data_strategies_for_model (
112
+ built_in_models , self .base_model_id , self .provider
113
+ ):
114
+ raise ValueError (
115
+ f"The data strategy { self .data_strategy } is not supported for the provider model { self .base_model_id } "
116
+ )
117
+ return self
118
+
104
119
105
120
class FinetuneWithStatus (BaseModel ):
106
121
"""Finetune with status"""
@@ -198,7 +213,8 @@ async def finetune_providers() -> list[FinetuneProvider]:
198
213
provider_models [provider .name ] = []
199
214
provider_models [provider .name ].append (
200
215
FinetuneProviderModel (
201
- name = model .friendly_name , id = provider .provider_finetune_id
216
+ name = model .friendly_name ,
217
+ id = provider .provider_finetune_id ,
202
218
)
203
219
)
204
220
@@ -212,14 +228,19 @@ async def finetune_providers() -> list[FinetuneProvider]:
212
228
# Create provider entries
213
229
providers : list [FinetuneProvider ] = []
214
230
for provider_name , models in provider_models .items ():
215
- providers .append (
216
- FinetuneProvider (
217
- name = provider_name_from_id (provider_name ),
218
- id = provider_name ,
219
- enabled = await provider_enabled (provider_name ),
220
- models = models ,
231
+ # attach the compatible data strategies to each model
232
+ for model in models :
233
+ model .data_strategies_supported = infer_data_strategies_for_model (
234
+ built_in_models , model .id , provider_name
221
235
)
236
+
237
+ provider = FinetuneProvider (
238
+ name = provider_name_from_id (provider_name ),
239
+ id = provider_name ,
240
+ enabled = await provider_enabled (provider_name ),
241
+ models = models ,
222
242
)
243
+ providers .append (provider )
223
244
224
245
return providers
225
246
@@ -326,6 +347,7 @@ async def download_dataset_jsonl(
326
347
status_code = 400 ,
327
348
detail = f"Data strategy '{ data_strategy } ' not found" ,
328
349
)
350
+
329
351
data_strategy_typed = FinetuneDataStrategy (data_strategy )
330
352
331
353
task = task_from_id (project_id , task_id )
@@ -406,10 +428,13 @@ def thinking_instructions_from_request(
406
428
data_strategy : FinetuneDataStrategy ,
407
429
custom_thinking_instructions : str | None ,
408
430
) -> str | None :
409
- if data_strategy != FinetuneDataStrategy . final_and_intermediate :
431
+ if data_strategy not in THINKING_DATA_STRATEGIES :
410
432
# Not using COT/Thinking style
411
433
return None
412
434
435
+ if data_strategy == FinetuneDataStrategy .final_and_intermediate_r1_compatible :
436
+ return None
437
+
413
438
if custom_thinking_instructions :
414
439
# prefer custom instructions
415
440
return custom_thinking_instructions
@@ -477,3 +502,58 @@ async def fetch_fireworks_finetune_models() -> list[FinetuneProviderModel]:
477
502
)
478
503
479
504
return tuneable_models
505
+
506
+
507
+ DEFAULT_DATA_STRATEGIES = [
508
+ FinetuneDataStrategy .final_only ,
509
+ FinetuneDataStrategy .final_and_intermediate ,
510
+ ]
511
+
512
+
513
+ def data_strategies_from_model_provider (
514
+ provider : KilnModelProvider ,
515
+ ) -> list [FinetuneDataStrategy ]:
516
+ if provider .parser == ModelParserID .r1_thinking :
517
+ return [
518
+ FinetuneDataStrategy .final_and_intermediate_r1_compatible ,
519
+ ]
520
+ return DEFAULT_DATA_STRATEGIES
521
+
522
+
523
+ def data_strategies_from_finetune_id (
524
+ provider_finetune_id : str ,
525
+ ) -> list [FinetuneDataStrategy ]:
526
+ if "qwen3" in provider_finetune_id .lower ():
527
+ return [
528
+ FinetuneDataStrategy .final_only ,
529
+ FinetuneDataStrategy .final_and_intermediate_r1_compatible ,
530
+ ]
531
+
532
+ r1_must_include = ["r1" , "qwq" ]
533
+ if any (substring in provider_finetune_id .lower () for substring in r1_must_include ):
534
+ return [
535
+ FinetuneDataStrategy .final_and_intermediate_r1_compatible ,
536
+ ]
537
+ return DEFAULT_DATA_STRATEGIES
538
+
539
+
540
+ def infer_data_strategies_for_model (
541
+ available_models : list [KilnModel ],
542
+ provider_finetune_id : str ,
543
+ provider_name : str ,
544
+ ) -> list [FinetuneDataStrategy ]:
545
+ # we don't have built-in models for fireworks models, so we infer the data strategy from the model name
546
+ if provider_name == ModelProviderName .fireworks_ai :
547
+ return data_strategies_from_finetune_id (provider_finetune_id )
548
+
549
+ # where we have built-in models, we can infer the data strategy from the object itself
550
+ for model in available_models :
551
+ for provider in model .providers :
552
+ if (
553
+ provider .name == provider_name
554
+ and provider .provider_finetune_id == provider_finetune_id
555
+ ):
556
+ return data_strategies_from_model_provider (provider )
557
+
558
+ # for everything else, we don't know what the data strategy is, so we use the default
559
+ return DEFAULT_DATA_STRATEGIES
0 commit comments