Skip to content

Commit b35706b

Browse files
committed
test: add cases and update existing tests
1 parent db6cf83 commit b35706b

File tree

9 files changed

+437
-32
lines changed

9 files changed

+437
-32
lines changed

app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,6 @@
686686
</div>
687687
{/if}
688688
<div>
689-
<!-- TODO: fix issue when changing the model where data_strategy current value refers to an option that is no longer there -->
690689
<FormElement
691690
label="Model Type / Training Strategy"
692691
description="Should the model be trained on only the final response, or also include intermediate thinking?"

libs/core/kiln_ai/adapters/fine_tune/dataset_formatter.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def supports_cot(self) -> bool:
5050
self.thinking_instructions is not None
5151
and self.thinking is not None
5252
and self.thinking_final_answer_prompt is not None
53+
) or (
54+
self.thinking_r1_style
55+
and self.thinking_instructions is not None
56+
and self.thinking is not None
5357
)
5458

5559

@@ -83,10 +87,12 @@ def build_training_data(
8387
thinking_final_answer_prompt = None
8488
parent_task = task_run.parent_task()
8589

86-
if (
90+
include_cot = (
8791
data_strategy == FinetuneDataStrategy.final_and_intermediate
88-
and task_run.has_thinking_training_data()
89-
):
92+
or data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible
93+
)
94+
95+
if include_cot and task_run.has_thinking_training_data():
9096
if not parent_task:
9197
raise ValueError(
9298
"TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
@@ -98,7 +104,12 @@ def build_training_data(
98104
"chain_of_thought"
99105
)
100106

101-
thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
107+
# For R1 style, we don't need the final answer prompt (only used for COT multi-message)
108+
thinking_final_answer_prompt = (
109+
COT_FINAL_ANSWER_PROMPT
110+
if data_strategy == FinetuneDataStrategy.final_and_intermediate
111+
else None
112+
)
102113

103114
# Always use the passed thinking instructions, but check they are present for COT
104115
if not thinking_instructions:
@@ -141,10 +152,15 @@ def generate_chat_message_response(
141152
[
142153
{
143154
"role": "assistant",
144-
"content": serialize_r1_style_message(training_data),
155+
"content": serialize_r1_style_message(
156+
thinking=training_data.thinking,
157+
final_output=training_data.final_output,
158+
),
145159
}
146160
]
147161
)
162+
163+
return {"messages": messages}
148164
else:
149165
messages.extend(
150166
[
@@ -186,7 +202,10 @@ def generate_json_schema_message(
186202
[
187203
{
188204
"role": "assistant",
189-
"content": serialize_r1_style_message(training_data),
205+
"content": serialize_r1_style_message(
206+
thinking=training_data.thinking,
207+
final_output=training_data.final_output,
208+
),
190209
}
191210
]
192211
)
@@ -227,7 +246,10 @@ def generate_chat_message_toolcall(
227246
[
228247
{
229248
"role": "assistant",
230-
"content": serialize_r1_style_message(training_data),
249+
"content": serialize_r1_style_message(
250+
thinking=training_data.thinking,
251+
final_output=training_data.final_output,
252+
),
231253
}
232254
]
233255
)
@@ -280,10 +302,15 @@ def generate_huggingface_chat_template(
280302
[
281303
{
282304
"role": "assistant",
283-
"content": serialize_r1_style_message(training_data),
305+
"content": serialize_r1_style_message(
306+
thinking=training_data.thinking,
307+
final_output=training_data.final_output,
308+
),
284309
}
285310
]
286311
)
312+
313+
return {"conversations": conversations}
287314
else:
288315
conversations.extend(
289316
[
@@ -322,10 +349,24 @@ def generate_huggingface_chat_template_toolcall(
322349
[
323350
{
324351
"role": "assistant",
325-
"content": serialize_r1_style_message(training_data),
352+
"content": serialize_r1_style_message(
353+
thinking=training_data.thinking,
354+
final_output=training_data.final_output,
355+
),
356+
"tool_calls": [
357+
{
358+
"type": "function",
359+
"function": {
360+
"name": "task_response",
361+
"id": str(uuid4()).replace("-", "")[:9],
362+
"arguments": arguments,
363+
},
364+
}
365+
],
326366
}
327367
]
328368
)
369+
return {"conversations": conversations}
329370
else:
330371
conversations.extend(
331372
[
@@ -379,11 +420,30 @@ def generate_vertex_gemini(
379420
contents.extend(
380421
[
381422
{
382-
"role": "assistant",
383-
"parts": [{"text": serialize_r1_style_message(training_data)}],
423+
"role": "model",
424+
"parts": [
425+
{
426+
"text": serialize_r1_style_message(
427+
thinking=training_data.thinking,
428+
final_output=training_data.final_output,
429+
)
430+
}
431+
],
384432
}
385433
]
386434
)
435+
436+
return {
437+
"systemInstruction": {
438+
"role": "system",
439+
"parts": [
440+
{
441+
"text": training_data.system_message,
442+
}
443+
],
444+
},
445+
"contents": contents,
446+
}
387447
else:
388448
contents.extend(
389449
[

0 commit comments

Comments
 (0)