Skip to content

Commit ed7ec19

Browse files
committed
apply
1 parent 5fbb1b8 commit ed7ec19

File tree

77 files changed

+12822
-7709
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+12822
-7709
lines changed

tests/saving/gpt-oss-merge/test_merged_model.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import shutil
88

9+
910
def safe_remove_directory(path):
1011
try:
1112
if os.path.exists(path) and os.path.isdir(path):
@@ -17,6 +18,8 @@ def safe_remove_directory(path):
1718
except Exception as e:
1819
print(f"Failed to remove directory {path}: {e}")
1920
return False
21+
22+
2023
pass
2124

2225
print("🔥 Loading the 16-bit merged model from disk...")
@@ -35,13 +38,15 @@ def safe_remove_directory(path):
3538
]
3639
inputs = merged_tokenizer.apply_chat_template(
3740
messages,
38-
add_generation_prompt = True,
39-
return_tensors = "pt",
40-
return_dict = True,
41-
reasoning_effort = "low", # **NEW!** Set reasoning effort to low, medium or high
41+
add_generation_prompt=True,
42+
return_tensors="pt",
43+
return_dict=True,
44+
reasoning_effort="low", # **NEW!** Set reasoning effort to low, medium or high
4245
).to(merged_model.device)
4346

44-
_ = merged_model.generate(**inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer))
47+
_ = merged_model.generate(
48+
**inputs, max_new_tokens=512, streamer=TextStreamer(merged_tokenizer)
49+
)
4550
print("\n✅ Inference complete.")
4651

4752
# --- Final Cleanup ---
@@ -51,5 +56,7 @@ def safe_remove_directory(path):
5156
gc.collect()
5257

5358
safe_remove_directory("./gpt-oss-finetuned-merged")
54-
safe_remove_directory("./unsloth_compiled_cache") # Clean up cache created by this process
59+
safe_remove_directory(
60+
"./unsloth_compiled_cache"
61+
) # Clean up cache created by this process
5562
print("✅ Final cleanup complete. Exiting inference script.")

tests/saving/gpt-oss-merge/train_and_merge.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import shutil
99

10+
1011
def safe_remove_directory(path):
1112
try:
1213
if os.path.exists(path) and os.path.isdir(path):
@@ -18,15 +19,25 @@ def safe_remove_directory(path):
1819
except Exception as e:
1920
print(f"Failed to remove directory {path}: {e}")
2021
return False
22+
23+
2124
pass
2225

2326
# This tokenizer will be used by the mapping function
2427
tokenizer = None
28+
29+
2530
def formatting_prompts_func(examples):
2631
convos = examples["messages"]
27-
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
32+
texts = [
33+
tokenizer.apply_chat_template(
34+
convo, tokenize=False, add_generation_prompt=False
35+
)
36+
for convo in convos
37+
]
2838
return {"text": texts}
2939

40+
3041
# --- Load 4-bit Model and Train ---
3142
print("Loading 4-bit Mxfp4 gpt-oss model for training...")
3243
max_seq_length = 1024
@@ -39,15 +50,33 @@ def formatting_prompts_func(examples):
3950
)
4051

4152
model = FastLanguageModel.get_peft_model(
42-
model, r=8, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
43-
lora_alpha=16, use_gradient_checkpointing="unsloth", random_state=3407,
53+
model,
54+
r=8,
55+
target_modules=[
56+
"q_proj",
57+
"k_proj",
58+
"v_proj",
59+
"o_proj",
60+
"gate_proj",
61+
"up_proj",
62+
"down_proj",
63+
],
64+
lora_alpha=16,
65+
use_gradient_checkpointing="unsloth",
66+
random_state=3407,
4467
)
4568

4669
trainer = SFTTrainer(
47-
model=model, tokenizer=tokenizer, train_dataset=dataset,
70+
model=model,
71+
tokenizer=tokenizer,
72+
train_dataset=dataset,
4873
args=SFTConfig(
49-
per_device_train_batch_size=1, gradient_accumulation_steps=4, max_steps=10,
50-
learning_rate=2e-4, output_dir="outputs", report_to="none",
74+
per_device_train_batch_size=1,
75+
gradient_accumulation_steps=4,
76+
max_steps=10,
77+
learning_rate=2e-4,
78+
output_dir="outputs",
79+
report_to="none",
5180
),
5281
)
5382

@@ -57,7 +86,9 @@ def formatting_prompts_func(examples):
5786

5887
# --- Merge and Save ---
5988
print("\n💾 Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...")
60-
model.save_pretrained_merged(save_directory="./gpt-oss-finetuned-merged", tokenizer=tokenizer)
89+
model.save_pretrained_merged(
90+
save_directory="./gpt-oss-finetuned-merged", tokenizer=tokenizer
91+
)
6192
print("✅ Model merged and saved.")
6293

6394
# --- Cleanup ---
@@ -67,5 +98,7 @@ def formatting_prompts_func(examples):
6798
gc.collect()
6899

69100
safe_remove_directory("./outputs")
70-
safe_remove_directory("./unsloth_compiled_cache") # Clean up the cache created by this process
101+
safe_remove_directory(
102+
"./unsloth_compiled_cache"
103+
) # Clean up the cache created by this process
71104
print("✅ Cleanup complete. Exiting training script.")

tests/saving/language_models/test_merge_4bit_validation.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,28 @@
1212

1313
from tests.utils.cleanup_utils import safe_remove_directory
1414

15+
1516
def formatting_prompts_func(examples):
1617
convos = examples["messages"]
17-
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
18+
texts = [
19+
tokenizer.apply_chat_template(
20+
convo, tokenize=False, add_generation_prompt=False
21+
)
22+
for convo in convos
23+
]
1824
return {"text": texts}
1925

20-
print(f"\n{'='*80}")
26+
27+
print(f"\n{'=' * 80}")
2128
print("🔍 PHASE 1: Loading Base Model and Initial Training")
22-
print(f"{'='*80}")
29+
print(f"{'=' * 80}")
2330

2431
if torch.cuda.is_bf16_supported():
2532
compute_dtype = torch.bfloat16
26-
attn_implementation = 'flash_attention_2'
33+
attn_implementation = "flash_attention_2"
2734
else:
2835
compute_dtype = torch.float16
29-
attn_implementation = 'sdpa'
36+
attn_implementation = "sdpa"
3037

3138
model, tokenizer = FastLanguageModel.from_pretrained(
3239
model_name="unsloth/Llama-3.1-8B-Instruct",
@@ -35,7 +42,7 @@ def formatting_prompts_func(examples):
3542
load_in_4bit=True,
3643
load_in_8bit=False,
3744
full_finetuning=False,
38-
attn_implementation=attn_implementation
45+
attn_implementation=attn_implementation,
3946
)
4047

4148
tokenizer = get_chat_template(
@@ -44,19 +51,29 @@ def formatting_prompts_func(examples):
4451
)
4552

4653
# Load small dataset for quick training
47-
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train[:100]")
54+
dataset_train = load_dataset(
55+
"allenai/openassistant-guanaco-reformatted", split="train[:100]"
56+
)
4857
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
4958

5059
print("✅ Base model loaded successfully!")
5160

52-
print(f"\n{'='*80}")
61+
print(f"\n{'=' * 80}")
5362
print("🔍 PHASE 2: First Fine-tuning")
54-
print(f"{'='*80}")
63+
print(f"{'=' * 80}")
5564

5665
model = FastLanguageModel.get_peft_model(
5766
model,
5867
r=16,
59-
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
68+
target_modules=[
69+
"k_proj",
70+
"q_proj",
71+
"v_proj",
72+
"o_proj",
73+
"gate_proj",
74+
"down_proj",
75+
"up_proj",
76+
],
6077
lora_alpha=16,
6178
lora_dropout=0,
6279
bias="none",
@@ -97,21 +114,21 @@ def formatting_prompts_func(examples):
97114
trainer_stats = trainer.train()
98115
print("✅ First fine-tuning completed!")
99116

100-
print(f"\n{'='*80}")
117+
print(f"\n{'=' * 80}")
101118
print("🔍 PHASE 3: Save with Forced 4bit Merge")
102-
print(f"{'='*80}")
119+
print(f"{'=' * 80}")
103120

104121
model.save_pretrained_merged(
105-
save_directory='./test_4bit_model',
122+
save_directory="./test_4bit_model",
106123
tokenizer=tokenizer,
107-
save_method="forced_merged_4bit"
124+
save_method="forced_merged_4bit",
108125
)
109126

110127
print("✅ Model saved with forced 4bit merge!")
111128

112-
print(f"\n{'='*80}")
129+
print(f"\n{'=' * 80}")
113130
print("🔍 PHASE 4: Loading 4bit Model and Second Fine-tuning")
114-
print(f"{'='*80}")
131+
print(f"{'=' * 80}")
115132

116133
# Clean up first model
117134
del model
@@ -137,7 +154,15 @@ def formatting_prompts_func(examples):
137154
model_4bit = FastLanguageModel.get_peft_model(
138155
model_4bit,
139156
r=16,
140-
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
157+
target_modules=[
158+
"k_proj",
159+
"q_proj",
160+
"v_proj",
161+
"o_proj",
162+
"gate_proj",
163+
"down_proj",
164+
"up_proj",
165+
],
141166
lora_alpha=16,
142167
lora_dropout=0,
143168
bias="none",
@@ -177,14 +202,14 @@ def formatting_prompts_func(examples):
177202
trainer_4bit.train()
178203
print("✅ Second fine-tuning on 4bit model completed!")
179204

180-
print(f"\n{'='*80}")
205+
print(f"\n{'=' * 80}")
181206
print("🔍 PHASE 5: Testing TypeError on Regular Merge (Should Fail)")
182-
print(f"{'='*80}")
207+
print(f"{'=' * 80}")
183208

184209
try:
185210
model_4bit.save_pretrained_merged(
186-
save_directory='./test_should_fail',
187-
tokenizer=tokenizer_4bit
211+
save_directory="./test_should_fail",
212+
tokenizer=tokenizer_4bit,
188213
# No save_method specified, should default to regular merge
189214
)
190215
assert False, "Expected TypeError but merge succeeded!"
@@ -194,23 +219,23 @@ def formatting_prompts_func(examples):
194219
print("✅ Correct TypeError raised for 4bit base model regular merge attempt!")
195220
print(f"Error message: {str(e)}")
196221

197-
print(f"\n{'='*80}")
222+
print(f"\n{'=' * 80}")
198223
print("🔍 PHASE 6: Successful Save with Forced 4bit Method")
199-
print(f"{'='*80}")
224+
print(f"{'=' * 80}")
200225

201226
try:
202227
model_4bit.save_pretrained_merged(
203-
save_directory='./test_4bit_second',
228+
save_directory="./test_4bit_second",
204229
tokenizer=tokenizer_4bit,
205-
save_method="forced_merged_4bit"
230+
save_method="forced_merged_4bit",
206231
)
207232
print("✅ Successfully saved 4bit model with forced 4bit method!")
208233
except Exception as e:
209234
assert False, f"Phase 6 failed unexpectedly: {e}"
210235

211-
print(f"\n{'='*80}")
236+
print(f"\n{'=' * 80}")
212237
print("🔍 CLEANUP")
213-
print(f"{'='*80}")
238+
print(f"{'=' * 80}")
214239

215240
# Cleanup
216241
safe_remove_directory("./outputs")

0 commit comments

Comments
 (0)