1212
1313from  tests .utils .cleanup_utils  import  safe_remove_directory 
1414
15+ 
1516def  formatting_prompts_func (examples ):
1617    convos  =  examples ["messages" ]
17-     texts  =  [tokenizer .apply_chat_template (convo , tokenize = False , add_generation_prompt = False ) for  convo  in  convos ]
18+     texts  =  [
19+         tokenizer .apply_chat_template (
20+             convo , tokenize = False , add_generation_prompt = False 
21+         )
22+         for  convo  in  convos 
23+     ]
1824    return  {"text" : texts }
1925
20- print (f"\n { '=' * 80 }  )
26+ 
27+ print (f"\n { '='  *  80 }  )
2128print ("🔍 PHASE 1: Loading Base Model and Initial Training" )
22- print (f"{ '=' * 80 }  )
29+ print (f"{ '='   *   80 }  )
2330
2431if  torch .cuda .is_bf16_supported ():
2532    compute_dtype  =  torch .bfloat16 
26-     attn_implementation  =  ' flash_attention_2' 
33+     attn_implementation  =  " flash_attention_2" 
2734else :
2835    compute_dtype  =  torch .float16 
29-     attn_implementation  =  ' sdpa' 
36+     attn_implementation  =  " sdpa" 
3037
3138model , tokenizer  =  FastLanguageModel .from_pretrained (
3239    model_name = "unsloth/Llama-3.1-8B-Instruct" ,
@@ -35,7 +42,7 @@ def formatting_prompts_func(examples):
3542    load_in_4bit = True ,
3643    load_in_8bit = False ,
3744    full_finetuning = False ,
38-     attn_implementation = attn_implementation 
45+     attn_implementation = attn_implementation , 
3946)
4047
4148tokenizer  =  get_chat_template (
@@ -44,19 +51,29 @@ def formatting_prompts_func(examples):
4451)
4552
4653# Load small dataset for quick training 
47- dataset_train  =  load_dataset ("allenai/openassistant-guanaco-reformatted" , split = "train[:100]" )
54+ dataset_train  =  load_dataset (
55+     "allenai/openassistant-guanaco-reformatted" , split = "train[:100]" 
56+ )
4857dataset_train  =  dataset_train .map (formatting_prompts_func , batched = True )
4958
5059print ("✅ Base model loaded successfully!" )
5160
52- print (f"\n { '=' * 80 }  )
61+ print (f"\n { '='   *   80 }  )
5362print ("🔍 PHASE 2: First Fine-tuning" )
54- print (f"{ '=' * 80 }  )
63+ print (f"{ '='   *   80 }  )
5564
5665model  =  FastLanguageModel .get_peft_model (
5766    model ,
5867    r = 16 ,
59-     target_modules = ['k_proj' , 'q_proj' , 'v_proj' , 'o_proj' , "gate_proj" , "down_proj" , "up_proj" ],
68+     target_modules = [
69+         "k_proj" ,
70+         "q_proj" ,
71+         "v_proj" ,
72+         "o_proj" ,
73+         "gate_proj" ,
74+         "down_proj" ,
75+         "up_proj" ,
76+     ],
6077    lora_alpha = 16 ,
6178    lora_dropout = 0 ,
6279    bias = "none" ,
@@ -97,21 +114,21 @@ def formatting_prompts_func(examples):
97114trainer_stats  =  trainer .train ()
98115print ("✅ First fine-tuning completed!" )
99116
100- print (f"\n { '=' * 80 }  )
117+ print (f"\n { '='   *   80 }  )
101118print ("🔍 PHASE 3: Save with Forced 4bit Merge" )
102- print (f"{ '=' * 80 }  )
119+ print (f"{ '='   *   80 }  )
103120
104121model .save_pretrained_merged (
105-     save_directory = ' ./test_4bit_model' 
122+     save_directory = " ./test_4bit_model" 
106123    tokenizer = tokenizer ,
107-     save_method = "forced_merged_4bit" 
124+     save_method = "forced_merged_4bit" , 
108125)
109126
110127print ("✅ Model saved with forced 4bit merge!" )
111128
112- print (f"\n { '=' * 80 }  )
129+ print (f"\n { '='   *   80 }  )
113130print ("🔍 PHASE 4: Loading 4bit Model and Second Fine-tuning" )
114- print (f"{ '=' * 80 }  )
131+ print (f"{ '='   *   80 }  )
115132
116133# Clean up first model 
117134del  model 
@@ -137,7 +154,15 @@ def formatting_prompts_func(examples):
137154model_4bit  =  FastLanguageModel .get_peft_model (
138155    model_4bit ,
139156    r = 16 ,
140-     target_modules = ['k_proj' , 'q_proj' , 'v_proj' , 'o_proj' , "gate_proj" , "down_proj" , "up_proj" ],
157+     target_modules = [
158+         "k_proj" ,
159+         "q_proj" ,
160+         "v_proj" ,
161+         "o_proj" ,
162+         "gate_proj" ,
163+         "down_proj" ,
164+         "up_proj" ,
165+     ],
141166    lora_alpha = 16 ,
142167    lora_dropout = 0 ,
143168    bias = "none" ,
@@ -177,14 +202,14 @@ def formatting_prompts_func(examples):
177202trainer_4bit .train ()
178203print ("✅ Second fine-tuning on 4bit model completed!" )
179204
180- print (f"\n { '=' * 80 }  )
205+ print (f"\n { '='   *   80 }  )
181206print ("🔍 PHASE 5: Testing TypeError on Regular Merge (Should Fail)" )
182- print (f"{ '=' * 80 }  )
207+ print (f"{ '='   *   80 }  )
183208
184209try :
185210    model_4bit .save_pretrained_merged (
186-         save_directory = ' ./test_should_fail' 
187-         tokenizer = tokenizer_4bit 
211+         save_directory = " ./test_should_fail" 
212+         tokenizer = tokenizer_4bit , 
188213        # No save_method specified, should default to regular merge 
189214    )
190215    assert  False , "Expected TypeError but merge succeeded!" 
@@ -194,23 +219,23 @@ def formatting_prompts_func(examples):
194219    print ("✅ Correct TypeError raised for 4bit base model regular merge attempt!" )
195220    print (f"Error message: { str (e )}  )
196221
197- print (f"\n { '=' * 80 }  )
222+ print (f"\n { '='   *   80 }  )
198223print ("🔍 PHASE 6: Successful Save with Forced 4bit Method" )
199- print (f"{ '=' * 80 }  )
224+ print (f"{ '='   *   80 }  )
200225
201226try :
202227    model_4bit .save_pretrained_merged (
203-         save_directory = ' ./test_4bit_second' 
228+         save_directory = " ./test_4bit_second" 
204229        tokenizer = tokenizer_4bit ,
205-         save_method = "forced_merged_4bit" 
230+         save_method = "forced_merged_4bit" , 
206231    )
207232    print ("✅ Successfully saved 4bit model with forced 4bit method!" )
208233except  Exception  as  e :
209234    assert  False , f"Phase 6 failed unexpectedly: { e }  
210235
211- print (f"\n { '=' * 80 }  )
236+ print (f"\n { '='   *   80 }  )
212237print ("🔍 CLEANUP" )
213- print (f"{ '=' * 80 }  )
238+ print (f"{ '='   *   80 }  )
214239
215240# Cleanup 
216241safe_remove_directory ("./outputs" )
0 commit comments