@@ -50,6 +50,10 @@ def supports_cot(self) -> bool:
50
50
self .thinking_instructions is not None
51
51
and self .thinking is not None
52
52
and self .thinking_final_answer_prompt is not None
53
+ ) or (
54
+ self .thinking_r1_style
55
+ and self .thinking_instructions is not None
56
+ and self .thinking is not None
53
57
)
54
58
55
59
@@ -83,10 +87,12 @@ def build_training_data(
83
87
thinking_final_answer_prompt = None
84
88
parent_task = task_run .parent_task ()
85
89
86
- if (
90
+ include_cot = (
87
91
data_strategy == FinetuneDataStrategy .final_and_intermediate
88
- and task_run .has_thinking_training_data ()
89
- ):
92
+ or data_strategy == FinetuneDataStrategy .final_and_intermediate_r1_compatible
93
+ )
94
+
95
+ if include_cot and task_run .has_thinking_training_data ():
90
96
if not parent_task :
91
97
raise ValueError (
92
98
"TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
@@ -98,7 +104,12 @@ def build_training_data(
98
104
"chain_of_thought"
99
105
)
100
106
101
- thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
107
+ # For R1 style, we don't need the final answer prompt (only used for COT multi-message)
108
+ thinking_final_answer_prompt = (
109
+ COT_FINAL_ANSWER_PROMPT
110
+ if data_strategy == FinetuneDataStrategy .final_and_intermediate
111
+ else None
112
+ )
102
113
103
114
# Always use the passed thinking instructions, but check they are present for COT
104
115
if not thinking_instructions :
@@ -141,10 +152,15 @@ def generate_chat_message_response(
141
152
[
142
153
{
143
154
"role" : "assistant" ,
144
- "content" : serialize_r1_style_message (training_data ),
155
+ "content" : serialize_r1_style_message (
156
+ thinking = training_data .thinking ,
157
+ final_output = training_data .final_output ,
158
+ ),
145
159
}
146
160
]
147
161
)
162
+
163
+ return {"messages" : messages }
148
164
else :
149
165
messages .extend (
150
166
[
@@ -186,7 +202,10 @@ def generate_json_schema_message(
186
202
[
187
203
{
188
204
"role" : "assistant" ,
189
- "content" : serialize_r1_style_message (training_data ),
205
+ "content" : serialize_r1_style_message (
206
+ thinking = training_data .thinking ,
207
+ final_output = training_data .final_output ,
208
+ ),
190
209
}
191
210
]
192
211
)
@@ -227,7 +246,10 @@ def generate_chat_message_toolcall(
227
246
[
228
247
{
229
248
"role" : "assistant" ,
230
- "content" : serialize_r1_style_message (training_data ),
249
+ "content" : serialize_r1_style_message (
250
+ thinking = training_data .thinking ,
251
+ final_output = training_data .final_output ,
252
+ ),
231
253
}
232
254
]
233
255
)
@@ -280,10 +302,15 @@ def generate_huggingface_chat_template(
280
302
[
281
303
{
282
304
"role" : "assistant" ,
283
- "content" : serialize_r1_style_message (training_data ),
305
+ "content" : serialize_r1_style_message (
306
+ thinking = training_data .thinking ,
307
+ final_output = training_data .final_output ,
308
+ ),
284
309
}
285
310
]
286
311
)
312
+
313
+ return {"conversations" : conversations }
287
314
else :
288
315
conversations .extend (
289
316
[
@@ -322,10 +349,24 @@ def generate_huggingface_chat_template_toolcall(
322
349
[
323
350
{
324
351
"role" : "assistant" ,
325
- "content" : serialize_r1_style_message (training_data ),
352
+ "content" : serialize_r1_style_message (
353
+ thinking = training_data .thinking ,
354
+ final_output = training_data .final_output ,
355
+ ),
356
+ "tool_calls" : [
357
+ {
358
+ "type" : "function" ,
359
+ "function" : {
360
+ "name" : "task_response" ,
361
+ "id" : str (uuid4 ()).replace ("-" , "" )[:9 ],
362
+ "arguments" : arguments ,
363
+ },
364
+ }
365
+ ],
326
366
}
327
367
]
328
368
)
369
+ return {"conversations" : conversations }
329
370
else :
330
371
conversations .extend (
331
372
[
@@ -379,11 +420,30 @@ def generate_vertex_gemini(
379
420
contents .extend (
380
421
[
381
422
{
382
- "role" : "assistant" ,
383
- "parts" : [{"text" : serialize_r1_style_message (training_data )}],
423
+ "role" : "model" ,
424
+ "parts" : [
425
+ {
426
+ "text" : serialize_r1_style_message (
427
+ thinking = training_data .thinking ,
428
+ final_output = training_data .final_output ,
429
+ )
430
+ }
431
+ ],
384
432
}
385
433
]
386
434
)
435
+
436
+ return {
437
+ "systemInstruction" : {
438
+ "role" : "system" ,
439
+ "parts" : [
440
+ {
441
+ "text" : training_data .system_message ,
442
+ }
443
+ ],
444
+ },
445
+ "contents" : contents ,
446
+ }
387
447
else :
388
448
contents .extend (
389
449
[
0 commit comments